├── .gitignore ├── .idea ├── Wave-U-Net-for-Speech-Enhancement-master.iml ├── encodings.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml └── vcs.xml ├── LICENSE ├── README.md ├── config ├── enhancement │ └── unet_basic.json └── train │ └── train.json ├── dataset ├── __init__.py ├── waveform_dataset.py └── waveform_dataset_enhancement.py ├── doc ├── audio.png └── tensorboard.png ├── enhancement.py ├── model ├── __init__.py ├── conv_tas_net.py ├── loss.py └── unet_basic.py ├── trainer ├── __init__.py ├── base_trainer.py └── trainer.py └── util ├── __init__.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/dictionaries 10 | .idea/**/shelf 11 | .vscode 12 | 13 | # Sensitive or high-churn files 14 | .idea/**/dataSources/ 15 | .idea/**/dataSources.ids 16 | .idea/**/dataSources.local.xml 17 | .idea/**/sqlDataSources.xml 18 | .idea/**/dynamic.xml 19 | .idea/**/uiDesigner.xml 20 | .idea/**/dbnavigator.xml 21 | 22 | # Gradle 23 | .idea/**/gradle.xml 24 | .idea/**/libraries 25 | 26 | # CMake 27 | cmake-build-debug/ 28 | cmake-build-release/ 29 | 30 | # Mongo Explorer plugin 31 | .idea/**/mongoSettings.xml 32 | 33 | # File-based project format 34 | *.iws 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | 54 | # Editor-based Rest Client 55 | .idea/httpRequests 56 | ### Python template 57 | # Byte-compiled / optimized / DLL files 58 | __pycache__/ 59 | *.py[cod] 60 | *$py.class 61 | 62 | # C extensions 63 | *.so 64 | 65 | # Distribution / packaging 66 | .Python 67 | build/ 68 | develop-eggs/ 69 | dist/ 70 | downloads/ 71 | eggs/ 72 | .eggs/ 73 | lib/ 74 | lib64/ 75 | parts/ 76 | sdist/ 77 | var/ 78 | wheels/ 79 | *.egg-info/ 80 | .installed.cfg 81 | *.egg 82 | MANIFEST 83 | 84 | # PyInstaller 85 | # Usually these files are written by a python script from a template 86 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 87 | *.manifest 88 | *.spec 89 | 90 | # Installer logs 91 | pip-log.txt 92 | pip-delete-this-directory.txt 93 | 94 | # Unit test / coverage reports 95 | htmlcov/ 96 | .tox/ 97 | .coverage 98 | .coverage.* 99 | .cache 100 | nosetests.xml 101 | coverage.xml 102 | *.cover 103 | .hypothesis/ 104 | .pytest_cache/ 105 | 106 | # Translations 107 | *.mo 108 | *.pot 109 | 110 | # Django stuff: 111 | *.log 112 | local_settings.py 113 | db.sqlite3 114 | 115 | # Flask stuff: 116 | instance/ 117 | .webassets-cache 118 | 119 | # Scrapy stuff: 120 | .scrapy 121 | 122 | # Sphinx documentation 123 | docs/_build/ 124 | 125 | # PyBuilder 126 | target/ 127 | 128 | # Jupyter Notebook 129 | .ipynb_checkpoints 130 | 131 | # pyenv 132 | .python-version 133 | 134 | # celery beat schedule file 135 | celerybeat-schedule 136 | 137 | # SageMath parsed files 138 | *.sage.py 139 | 140 | # Environments 141 | .env 142 | .venv 143 | env/ 144 | venv/ 145 | ENV/ 146 | env.bak/ 147 | venv.bak/ 148 | 149 | # Spyder project settings 150 | .spyderproject 151 | .spyproject 152 | 153 | # Rope project settings 154 | .ropeproject 155 | 156 | # mkdocs documentation 157 | /site 158 | 159 | # mypy 160 | .mypy_cache/ 161 | 162 | /config/train_config* 163 | /config/test_config* 164 | 165 | enhanced 166 | -------------------------------------------------------------------------------- /.idea/Wave-U-Net-for-Speech-Enhancement-master.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 王治愚 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Wave-U-Net-for-Speech-Enhancement 2 | 3 | Implement [Wave-U-Net](https://arxiv.org/abs/1806.03185) by PyTorch, and migrate it to the speech enhancement area. 4 | 5 | ![](./doc/tensorboard.png) 6 | ![](./doc/audio.png) 7 | 8 | ## 环境与依赖 9 | 10 | 11 | ```shell 12 | # 确保 CUDA 的 bin 目录添加到 PATH 环境变量中 13 | # 通过附加 LD_LIBRARY_PATH 环境变量来安装 CUDA 附带的 CUPTI 14 | export PATH="/usr/local/cuda-10.0/bin:$PATH" 15 | export LD_LIBRARY_PATH="/usr/local/cuda-10.0/lib64:$LD_LIBRARY_PATH" 16 | 17 | # 安装 Anaconda,以清华镜像源,python 3.6.5为例 18 | wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-5.2.0-Linux-x86_64.sh 19 | chmod a+x Anaconda3-5.2.0-Linux-x86_64.sh 20 | ./Anaconda3-5.2.0-Linux-x86_64.sh # 按 f 翻页,默认安装在 ~/anaconda 目录下,安装过程会提示修改 PATH 变量 21 | 22 | # Create env 23 | conda create -n wave-u-net python=3 24 | conda activate wave-u-net 25 | 26 | # Install deps 27 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 28 | conda install tensorflow-gpu 29 | conda install matplotlib 30 | pip install tqdm librosa 31 | pip install pystoi # for STOI metric 32 | pip install pesq # for PESQ metric 33 | 34 | # 配置好环境与依赖之后,可以拉取代码 35 | git clone https://github.com/haoxiangsnr/Wave-U-Net-for-Speech-Enhancement.git 36 | ``` 37 | 38 | ## 使用方法 39 | 40 | 当前项目有三个入口文件: 41 | 42 | - 用于训练模型的入口文件:`train.py` 43 | - 用于增强带噪语音的入口文件:`enhancement.py` 44 | - 用于测试模型降噪能力的入口文件(TODO):`test.py` 45 | 46 | ### 训练 47 | 48 | 使用 `train.py` 训练模型,它接收三个命令行参数: 49 | 50 | - `-h`,显示帮助信息 51 | - `-C, --config`,指定训练所需的配置文件 52 | - `-R, --resume`,从最近一次保存的模型断点处继续训练 53 | 54 | 语法:`python train.py [-h] -C CONFIG [-R]` 55 | 56 | 例如: 57 | 58 | ```shell script 59 | python train.py -C config/train.json 60 | # 训练模型所用的配置文件为 config/train.json 61 | # 使用所有的 GPU 进行训练 62 | 63 | python train.py -C config/train.json -R 64 | # 训练模型所用的配置文件为 config/train.json 65 | # 使用所有的 GPU 从最近一次保存的模型断点继续训练 66 | 67 | CUDA_VISIBLE_DEVICES=1,2 python train.py -C config/train.json 68 | # 训练模型所用的配置文件为 config/train.json 69 | # 使用 1,2 号索引的GPU进行训练 70 | 71 | CUDA_VISIBLE_DEVICES=-1 python train.py -C config/train.json 72 | # 训练模型所用的配置文件为 config/train.json 73 | # 使用 CPU 进行训练 74 | ``` 75 | 76 | 补充: 77 | - 一般将训练所需要的配置文件放置于 `config/train/` 目录下 78 | - 训练配置文件中的参数见“参数说明”部分 79 | - 配置文件的文件名即是实验名 80 | 81 | ### 增强 82 | 83 | 使用 `enhancement.py` 来增强带噪语音,它接收以下参数: 84 | 85 | - `-h, --help`,显示帮助信息 86 | - `-C, --config`,指定增强语音所用的模型,以及被增强的数据集。 87 | - `-O, --output_dir`,指定在哪里存储增强后的语音,需要确保这个目录提前存在 88 | - `-M, --model_checkpoint_path`,模型断点的路径,拓展名为 .tar 或 .pth 89 | 90 | 语法:`python enhancement.py [-h] -C CONFIG -O OUTPUT_DIR -M MODEL_CHECKPOINT_PATH` 91 | 92 | 例如: 93 | 94 | ```shell script 95 | CUDA_VISIBLE_DEVICES=1 python enhancement.py -C config/enhancement/unet_basic.json -O enhanced -M /media/imucs/DataDisk/haoxiang/Experiment/Wave-U-Net-for-Speech-Enhancement/smooth_l1_loss/checkpoints/model_0020.pth 96 | # 增强语音所用的配置文件为 config/enhancement/unet_basic.json,使用这个文件可以指定增强所需的模型以及数据集信息 97 | # 使用索引为 1 的 GPU 98 | # 输出的目录为 enhanced/,该目录需要提前新建好 99 | # 指定模型断点的路径 100 | 101 | python enhancement.py -C config/enhancement/unet_basic.json -O enhanced -M /media/imucs/DataDisk/haoxiang/Experiment/Wave-U-Net-for-Speech-Enhancement/smooth_l1_loss/checkpoints/model_0020.pth 102 | # 增强语音所用的配置文件为 config/enhancement/unet_basic.json,使用这个文件可以指定增强所需的模型以及数据集信息 103 | # 默认使用所有的 GPU 来进行增强 104 | 105 | CUDA_VISIBLE_DEVICES=-1 python enhancement.py -C config/enhancement/unet_basic.json -O enhanced -M /media/imucs/DataDisk/haoxiang/Experiment/Wave-U-Net-for-Speech-Enhancement/smooth_l1_loss/checkpoints/model_0020.tar 106 | # 使用 CPU 来增强语音 107 | ``` 108 | 109 | 补充: 110 | - 一般将增强所需要的配置文件放置于 `config/enhancement/` 目录下 111 | - 增强配置文件中的参数见“参数说明”部分 112 | 113 | ### 测试 114 | 115 | TODO 116 | 117 | 118 | ## 可视化 119 | 120 | 训练中产生的所有日志信息都会存储至`config["save_location"]//`目录下。假设用于训练的配置文件为`config/train/sample_16384.json`,`sample_16384.json`中`save_location`参数的值为`/home/UNet/`,那么当前实验训练过程中产生的日志会存储在 `/home/UNet/sample_16384/` 目录下。 121 | 该目录会包含以下内容: 122 | 123 | - `logs/`目录: 存储 Tensorboard 相关的数据,包含损失曲线,波形文件,语音文件等 124 | - `checkpoints/`目录: 存储模型的所有断点,后续可从这些断点处重启训练或进行语音增强 125 | - `config.json`文件: 训练配置文件的备份 126 | 127 | 在训练过程中可以使用 `tensorboard` 来启动一个静态的前端服务器,可视化相关目录中的日志数据: 128 | 129 | ```shell script 130 | tensorboard --logdir config["save_location"]// 131 | 132 | # 可使用 --port 指定 tensorboard 静态服务器的启动端口 133 | tensorboard --logdir config["save_location"]// --port 134 | 135 | # 例如,配置文件中的 "save_location" 参数为 "/home/happy/Experiments",配置文件名为 "train_config.json",修改默认端口为 6000 136 | # 可使用如下命令: 137 | tensorboard --logdir /home/happy/Experiments/train_config --port 6000 138 | ``` 139 | 140 | ## 目录说明 141 | 142 | 在项目运行过程,会产生多个目录,均有不同的用途: 143 | 144 | - 主目录:当前 README.md 所在的目录,存储着所有源代码 145 | - 训练目录:训练配置文件中的`config["save_location"]`目录,存储当前项目的所有实验日志和模型断点 146 | - 实验目录:`config["save_location"]/<实验名>/`目录,存储着某一次实验的日志信息 147 | 148 | 149 | ## 参数说明 150 | ### 训练 151 | 152 | `config/train/<实验名>.json`,训练过程中产生的日志信息会存放在`config["save_location"]/<实验名>/`目录下 153 | 154 | ```json5 155 | { 156 | "seed": 0, // 保证实验可重复性的随机种子 157 | "description": "...", // 实验描述,后续会显示在 Tensorboard 中 158 | "root_dir": "~/Experiments/Wave-U-Net", //存放实验结果的目录 159 | "cudnn_deterministic": false, 160 | "trainer": { // 训练过程 161 | "module": "trainer.trainer", // 训练器模型的文件 162 | "main": "Trainer", // 训练器模型的具体类 163 | "epochs": 1200, // 训练的上限 164 | "save_checkpoint_interval": 10, // 保存模型断点的间隔 165 | "validation":{ 166 | "interval": 10, // 验证的间隔 167 | "find_max": true, // 当 find_max 为 true 时,如果计算出的评价指标为已知的最大值,就会将当前轮次的模型断点另外缓存一份 168 | "custon": { 169 | "visualize_audio_limit": 20, // 验证时可视化音频的间隔,之所以设置这个参数,是因为可视化音频比较慢 170 | "visualize_waveform_limit": 20, // 验证时可视化波形的间隔,之所以设置这个参数,是因为可视化波形比较慢 171 | "visualize_spectrogram_limit": 20, //验证可视化频谱的间隔,之所以设置这个参数,是因为可视化频谱比较慢 172 | "sample_length": 16384 //采样点数 173 | } 174 | } 175 | }, 176 | "model": { 177 | "module": "model.unet_basic", // 训练使用的模型文件 178 | "main": "Model", // 训练模型的具体类 179 | "args": {} // 传给模型类的参数 180 | }, 181 | "loss_function": { 182 | "module": "model.loss", // 损失函数的模型文件 183 | "main": "mse_loss", // 损失函数模型的具体类 184 | "args": {} // 传给模型类的参数 185 | }, 186 | "optimizer": { 187 | "lr": 0.001, 188 | "beta1": 0.9, 189 | "beat2": 0.009 190 | }, 191 | "train_dataset": { 192 | "module": "dataset.waveform_dataset", // 存放训练集类模型的文件 193 | "main": "Dataset", // 训练集模型的具体类 194 | "args": { // 传递给训练集类的参数,详见具体的训练集类 195 | "dataset": "~/Datasets/SEGAN_Dataset/train_dataset.txt", 196 | "limit": null, 197 | "offset": 0, 198 | "sample_length": 16384, 199 | "mode":"train" 200 | } 201 | }, 202 | "validation_dataset": { 203 | "module": "dataset.waveform_dataset", 204 | "main": "Dataset", 205 | "args": { 206 | "dataset": "~/Datasets/SEGAN_Dataset/test_dataset.txt", 207 | "limit": 400, 208 | "offset": 0, 209 | "mode":"validation" 210 | } 211 | }, 212 | "train_dataloader": { 213 | "batch_size": 120, 214 | "num_workers": 40, // 开启多少个线程对数据进行预处理 215 | "shuffle": true, 216 | "pin_memory":true 217 | } 218 | } 219 | ``` 220 | 221 | ### 增强 222 | 223 | `config/enhancement/*.json` 224 | 225 | ```json5 226 | { 227 | "model": { 228 | "module": "model.unet_basic", // 放置模型的文件 229 | "main": "UNet",// 文件内的具体模型类 230 | "args": {} // 传给模型类的参数 231 | }, 232 | "dataset": { 233 | "module": "dataset.waveform_dataset", // 增强使用的数据集类 234 | "main": "WaveformDataset", // 传递给数据集类的参数,详见具体的训练集类 235 | "args": { 236 | "dataset": "/home/imucs/Datasets/2019-09-03-timit_train-900_test-50/enhancement.txt", 237 | "limit": 400, 238 | "offset": 0, 239 | "sample_length": 16384 240 | } 241 | } 242 | } 243 | ``` 244 | 245 | 在增强时,存储数据集路径的 txt 文件仅仅指定带噪语音的路径即可,类似这样: 246 | 247 | ```text 248 | # enhancement.txt 249 | 250 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Clean.wav 251 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_200.wav 252 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_270.wav 253 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_UNet.wav 254 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Mixture.wav 255 | ``` 256 | 257 | ## TODO 258 | 259 | - [x] 使用全长语音进行验证 260 | - [x] 增强脚本 261 | - [ ] 测试脚本 262 | -------------------------------------------------------------------------------- /config/enhancement/unet_basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "module": "model.unet_basic", 4 | "main": "UNet", 5 | "args": {} 6 | }, 7 | "dataset": { 8 | "module": "dataset.waveform_dataset_enhancement", 9 | "main": "WaveformDataset", 10 | "args": { 11 | "dataset": "/home/imucs/tmp/UNet_and_Inpainting/data.txt", 12 | "limit": 400, 13 | "offset": 0, 14 | "sample_length": 16384 15 | } 16 | } 17 | } -------------------------------------------------------------------------------- /config/train/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "description": "test", 4 | "root_dir": "E:/Experiments/Wave-U-Net", 5 | "cudnn_deterministic": false, 6 | "trainer": { 7 | "module": "trainer.trainer", 8 | "main": "Trainer", 9 | "epochs": 600, 10 | "save_checkpoint_interval": 10, 11 | "validation": { 12 | "interval": 10, 13 | "find_max": true, 14 | "custom": { 15 | "visualize_audio_limit": 20, 16 | "visualize_waveform_limit": 20, 17 | "visualize_spectrogram_limit": 20, 18 | "sample_length": 16000 19 | } 20 | } 21 | }, 22 | "model": { 23 | "module": "model.conv_tas_net", 24 | "main": "Model", 25 | "args": {} 26 | }, 27 | "loss_function": { 28 | "module": "model.loss", 29 | "main": "mse_loss", 30 | "args": {} 31 | }, 32 | "optimizer": { 33 | "lr": 0.001, 34 | "beta1": 0.9, 35 | "beta2": 0.999 36 | }, 37 | "train_dataset": { 38 | "module": "dataset.waveform_dataset", 39 | "main": "Dataset", 40 | "args": { 41 | "dataset": "E:/train_dataset.txt", 42 | "limit": null, 43 | "offset": 0, 44 | "sample_length": 16000, 45 | "mode": "train" 46 | } 47 | }, 48 | "validation_dataset": { 49 | "module": "dataset.waveform_dataset", 50 | "main": "Dataset", 51 | "args": { 52 | "dataset": "E:/test_dataset.txt", 53 | "limit": 400, 54 | "offset": 0, 55 | "mode": "validation" 56 | } 57 | }, 58 | "train_dataloader": { 59 | "batch_size": 4, 60 | "num_workers": 4, 61 | "shuffle": true, 62 | "pin_memory": true 63 | } 64 | } -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/waveform_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | import librosa 4 | from util.utils import sample_fixed_length_data_aligned 5 | 6 | 7 | class Dataset(data.Dataset): 8 | def __init__(self, dataset, limit=None, offset=0, sample_length=16384, mode="train"): 9 | """ 10 | 构建训练数据集 11 | Args: 12 | dataset (str): 语音数据集的路径,拓展名为 txt,见 Notes 部分 13 | limit (int): 数据集的数量上限 14 | offset (int): 数据集的起始位置的偏移值 15 | sample_length(int): 模型仅支持定长输入,这个参数指定了每次输入模型的大小 16 | mode(str): 当为 train 时,表示需要对语音进行定长切分,当为 validation 时,表示不需要,直接返回全长的语音。 17 | 18 | Notes: 19 | 语音数据集格式如下: 20 | <带噪语音1的路径><空格><纯净语音1的路径> 21 | <带噪语音2的路径><空格><纯净语音2的路径> 22 | ... 23 | <带噪语音n的路径><空格><纯净语音n的路径> 24 | 25 | eg: 26 | /train/noisy/a.wav /train/clean/a.wav 27 | /train/noisy/b.wav /train/clean/b.wav 28 | ... 29 | 30 | Return: 31 | (mixture signals, clean signals, file name) 32 | """ 33 | super(Dataset, self).__init__() 34 | dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset)), "r")] 35 | 36 | dataset_list = dataset_list[offset:] 37 | if limit: 38 | dataset_list = dataset_list[:limit] 39 | 40 | assert mode in ("train", "validation"), "Mode must be one of train or validation." 41 | 42 | self.length = len(dataset_list) 43 | self.dataset_list = dataset_list 44 | self.sample_length = sample_length 45 | self.mode = mode 46 | 47 | def __len__(self): 48 | return self.length 49 | 50 | def __getitem__(self, item): 51 | mixture_path, clean_path = self.dataset_list[item].split(" ") 52 | name = os.path.splitext(os.path.basename(mixture_path))[0] 53 | mixture, _ = librosa.load(os.path.abspath(os.path.expanduser(mixture_path)), sr=None) 54 | clean, _ = librosa.load(os.path.abspath(os.path.expanduser(clean_path)), sr=None) 55 | 56 | if self.mode == "train": 57 | # The input of model should be fixed length. 58 | mixture, clean = sample_fixed_length_data_aligned(mixture, clean, self.sample_length) 59 | return mixture.reshape(1, -1), clean.reshape(1, -1), name 60 | else: 61 | return mixture.reshape(1, -1), clean.reshape(1, -1), name 62 | -------------------------------------------------------------------------------- /dataset/waveform_dataset_enhancement.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | import librosa 4 | 5 | 6 | class WaveformDataset(Dataset): 7 | def __init__(self, dataset, limit=None, offset=0, sample_length=16384): 8 | """ 9 | 构建增强数据集 10 | Args: 11 | dataset (str): 语音数据集的路径,拓展名为 txt,见 Notes 部分 12 | limit (int): 数据集的数量上限 13 | offset (int): 数据集的起始位置的偏移值 14 | sample_length(int): 模型仅支持定长输入,这个参数指定了每次输入模型的大小 15 | 16 | Notes: 17 | 语音数据集格式如下: 18 | <带噪语音1的路径> 19 | <带噪语音2的路径> 20 | ... 21 | <带噪语音n的路径> 22 | 23 | eg: 24 | /enhancement/noisy/a.wav 25 | /enhancement/noisy/b.wav 26 | ... 27 | 28 | Return: 29 | (mixture signals, clean signals, file name) 30 | """ 31 | super(WaveformDataset, self).__init__() 32 | dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset)), "r")] 33 | 34 | dataset_list = dataset_list[offset:] 35 | if limit: 36 | dataset_list = dataset_list[:limit] 37 | 38 | self.length = len(dataset_list) 39 | self.dataset_list = dataset_list 40 | self.sample_length = sample_length 41 | 42 | def __len__(self): 43 | return self.length 44 | 45 | def __getitem__(self, item): 46 | mixture_path = self.dataset_list[item] 47 | name = os.path.splitext(os.path.basename(mixture_path))[0] 48 | 49 | mixture, _ = librosa.load(os.path.abspath(os.path.expanduser(mixture_path)), sr=None) 50 | 51 | return mixture.reshape(1, -1), name 52 | -------------------------------------------------------------------------------- /doc/audio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/doc/audio.png -------------------------------------------------------------------------------- /doc/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/doc/tensorboard.png -------------------------------------------------------------------------------- /enhancement.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import librosa 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | 10 | from util.utils import initialize_config, load_checkpoint 11 | 12 | """ 13 | Parameters 14 | """ 15 | parser = argparse.ArgumentParser("Wave-U-Net: Speech Enhancement") 16 | parser.add_argument("-C", "--config", type=str, required=True, help="Model and dataset for enhancement (*.json).") 17 | parser.add_argument("-D", "--device", default="-1", type=str, help="GPU for speech enhancement. default: CPU") 18 | parser.add_argument("-O", "--output_dir", type=str, required=True, help="Where are audio save.") 19 | parser.add_argument("-M", "--model_checkpoint_path", type=str, required=True, help="Checkpoint.") 20 | args = parser.parse_args() 21 | 22 | """ 23 | Preparation 24 | """ 25 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 26 | config = json.load(open(args.config)) 27 | model_checkpoint_path = args.model_checkpoint_path 28 | output_dir = args.output_dir 29 | assert os.path.exists(output_dir), "Enhanced directory should be exist." 30 | 31 | """ 32 | DataLoader 33 | """ 34 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 35 | dataloader = DataLoader(dataset=initialize_config(config["dataset"]), batch_size=1, num_workers=0) 36 | 37 | """ 38 | Model 39 | """ 40 | model = initialize_config(config["model"]) 41 | model.load_state_dict(load_checkpoint(model_checkpoint_path, device)) 42 | model.to(device) 43 | model.eval() 44 | 45 | """ 46 | Enhancement 47 | """ 48 | sample_length = dataloader.dataset.sample_length 49 | for mixture, name in tqdm(dataloader): 50 | assert len(name) == 1, "Only support batch size is 1 in enhancement stage." 51 | name = name[0] 52 | 53 | mixture = mixture.to(device) 54 | mixture_chunks = torch.split(mixture, sample_length, dim=2) 55 | if mixture_chunks[-1].shape[-1] != sample_length: 56 | mixture_chunks = mixture_chunks[:-1] 57 | 58 | enhance_chunks = [] 59 | for chunk in mixture_chunks: 60 | enhance_chunks.append((model(chunk).detach().cpu())) 61 | 62 | enhanced = torch.cat(enhance_chunks, dim=2) 63 | enhanced = enhanced.numpy().reshape(-1) 64 | 65 | output_path = os.path.join(output_dir, f"{name}.wav") 66 | librosa.output.write_wav(output_path, enhanced, sr=16000) 67 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/model/__init__.py -------------------------------------------------------------------------------- /model/conv_tas_net.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | 6 | import torch.nn.functional as F 7 | 8 | 9 | def param(nnet, Mb=True): 10 | """ 11 | Return number parameters(not bytes) in nnet 12 | """ 13 | neles = sum([param.nelement() for param in nnet.parameters()]) 14 | return neles / 10**6 if Mb else neles 15 | 16 | 17 | class ChannelWiseLayerNorm(nn.LayerNorm): 18 | """ 19 | Channel wise layer normalization 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 24 | 25 | def forward(self, x): 26 | """ 27 | x: N x C x T 28 | """ 29 | if x.dim() != 3: 30 | raise RuntimeError("{} accept 3D tensor as input".format( 31 | self.__name__)) 32 | # N x C x T => N x T x C 33 | x = th.transpose(x, 1, 2) 34 | # LN 35 | x = super().forward(x) 36 | # N x C x T => N x T x C 37 | x = th.transpose(x, 1, 2) 38 | return x 39 | 40 | 41 | class GlobalChannelLayerNorm(nn.Module): 42 | """ 43 | Global channel layer normalization 44 | """ 45 | 46 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 47 | super(GlobalChannelLayerNorm, self).__init__() 48 | self.eps = eps 49 | self.normalized_dim = dim 50 | self.elementwise_affine = elementwise_affine 51 | if elementwise_affine: 52 | self.beta = nn.Parameter(th.zeros(dim, 1)) 53 | self.gamma = nn.Parameter(th.ones(dim, 1)) 54 | else: 55 | self.register_parameter("weight", None) 56 | self.register_parameter("bias", None) 57 | 58 | def forward(self, x): 59 | """ 60 | x: N x C x T 61 | """ 62 | if x.dim() != 3: 63 | raise RuntimeError("{} accept 3D tensor as input".format( 64 | self.__name__)) 65 | # N x 1 x 1 66 | mean = th.mean(x, (1, 2), keepdim=True) 67 | var = th.mean((x - mean)**2, (1, 2), keepdim=True) 68 | # N x T x C 69 | if self.elementwise_affine: 70 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta 71 | else: 72 | x = (x - mean) / th.sqrt(var + self.eps) 73 | return x 74 | 75 | def extra_repr(self): 76 | return "{normalized_dim}, eps={eps}, " \ 77 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 78 | 79 | 80 | def build_norm(norm, dim): 81 | """ 82 | Build normalize layer 83 | LN cost more memory than BN 84 | """ 85 | if norm not in ["cLN", "gLN", "BN"]: 86 | raise RuntimeError("Unsupported normalize layer: {}".format(norm)) 87 | if norm == "cLN": 88 | return ChannelWiseLayerNorm(dim, elementwise_affine=True) 89 | elif norm == "BN": 90 | return nn.BatchNorm1d(dim) 91 | else: 92 | return GlobalChannelLayerNorm(dim, elementwise_affine=True) 93 | 94 | 95 | class Conv1D(nn.Conv1d): 96 | """ 97 | 1D conv in ConvTasNet 98 | """ 99 | 100 | def __init__(self, *args, **kwargs): 101 | super(Conv1D, self).__init__(*args, **kwargs) 102 | 103 | def forward(self, x, squeeze=False): 104 | """ 105 | x: N x L or N x C x L 106 | """ 107 | if x.dim() not in [2, 3]: 108 | raise RuntimeError("{} accept 2/3D tensor as input".format( 109 | self.__name__)) 110 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 111 | if squeeze: 112 | x = th.squeeze(x) 113 | return x 114 | 115 | 116 | class ConvTrans1D(nn.ConvTranspose1d): 117 | """ 118 | 1D conv transpose in ConvTasNet 119 | """ 120 | 121 | def __init__(self, *args, **kwargs): 122 | super(ConvTrans1D, self).__init__(*args, **kwargs) 123 | 124 | def forward(self, x, squeeze=False): 125 | """ 126 | x: N x L or N x C x L 127 | """ 128 | if x.dim() not in [2, 3]: 129 | raise RuntimeError("{} accept 2/3D tensor as input".format( 130 | self.__name__)) 131 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 132 | if squeeze: 133 | x = th.squeeze(x) 134 | return x 135 | 136 | 137 | class Conv1DBlock(nn.Module): 138 | """ 139 | 1D convolutional block: 140 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 141 | """ 142 | 143 | def __init__(self, 144 | in_channels=256, 145 | conv_channels=512, 146 | kernel_size=3, 147 | dilation=1, 148 | norm="cLN", 149 | causal=False): 150 | super(Conv1DBlock, self).__init__() 151 | # 1x1 conv 152 | self.conv1x1 = Conv1D(in_channels, conv_channels, 1) 153 | self.prelu1 = nn.PReLU() 154 | self.lnorm1 = build_norm(norm, conv_channels) 155 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 156 | dilation * (kernel_size - 1)) 157 | # depthwise conv 158 | self.dconv = nn.Conv1d( 159 | conv_channels, 160 | conv_channels, 161 | kernel_size, 162 | groups=conv_channels, 163 | padding=dconv_pad, 164 | dilation=dilation, 165 | bias=True) 166 | self.prelu2 = nn.PReLU() 167 | self.lnorm2 = build_norm(norm, conv_channels) 168 | # 1x1 conv cross channel 169 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 170 | # different padding way 171 | self.causal = causal 172 | self.dconv_pad = dconv_pad 173 | 174 | def forward(self, x): 175 | y = self.conv1x1(x) 176 | y = self.lnorm1(self.prelu1(y)) 177 | y = self.dconv(y) 178 | if self.causal: 179 | y = y[:, :, :-self.dconv_pad] 180 | y = self.lnorm2(self.prelu2(y)) 181 | y = self.sconv(y) 182 | x = x + y 183 | return x 184 | 185 | 186 | class Model(nn.Module): 187 | def __init__(self, 188 | L=20, 189 | N=256, 190 | X=8, 191 | R=4, 192 | B=256, 193 | H=512, 194 | P=3, 195 | norm="cLN", 196 | num_spks=1, 197 | non_linear="relu", 198 | causal=False): 199 | super(Model, self).__init__() 200 | supported_nonlinear = { 201 | "relu": F.relu, 202 | "sigmoid": th.sigmoid, 203 | "softmax": F.softmax 204 | } 205 | if non_linear not in supported_nonlinear: 206 | raise RuntimeError("Unsupported non-linear function: {}", 207 | format(non_linear)) 208 | self.non_linear_type = non_linear 209 | self.non_linear = supported_nonlinear[non_linear] 210 | # n x S => n x N x T, S = 4s*8000 = 32000 211 | self.encoder_1d = Conv1D(1, N, L, stride=L // 2, padding=0) 212 | # keep T not change 213 | # T = int((xlen - L) / (L // 2)) + 1 214 | # before repeat blocks, always cLN 215 | self.ln = ChannelWiseLayerNorm(N) 216 | # n x N x T => n x B x T 217 | self.proj = Conv1D(N, B, 1) 218 | # repeat blocks 219 | # n x B x T => n x B x T 220 | self.repeats = self._build_repeats( 221 | R, 222 | X, 223 | in_channels=B, 224 | conv_channels=H, 225 | kernel_size=P, 226 | norm=norm, 227 | causal=causal) 228 | # output 1x1 conv 229 | # n x B x T => n x N x T 230 | # NOTE: using ModuleList not python list 231 | # self.conv1x1_2 = th.nn.ModuleList( 232 | # [Conv1D(B, N, 1) for _ in range(num_spks)]) 233 | # n x B x T => n x 2N x T 234 | self.mask = Conv1D(B, num_spks * N, 1) 235 | # using ConvTrans1D: n x N x T => n x 1 x To 236 | # To = (T - 1) * L // 2 + L 237 | self.decoder_1d = ConvTrans1D( 238 | N, 1, kernel_size=L, stride=L // 2, bias=True) 239 | self.num_spks = num_spks 240 | 241 | def _build_blocks(self, num_blocks, **block_kwargs): 242 | """ 243 | Build Conv1D block 244 | """ 245 | blocks = [ 246 | Conv1DBlock(**block_kwargs, dilation=(2**b)) 247 | for b in range(num_blocks) 248 | ] 249 | return nn.Sequential(*blocks) 250 | 251 | def _build_repeats(self, num_repeats, num_blocks, **block_kwargs): 252 | """ 253 | Build Conv1D block repeats 254 | """ 255 | repeats = [ 256 | self._build_blocks(num_blocks, **block_kwargs) 257 | for r in range(num_repeats) 258 | ] 259 | return nn.Sequential(*repeats) 260 | 261 | def forward(self, x): 262 | x = th.reshape(x, [x.shape[0], -1]) 263 | if x.dim() >= 3: 264 | raise RuntimeError( 265 | "{} accept 1/2D tensor as input, but got {:d}".format( 266 | self.__name__, x.dim())) 267 | # when inference, only one utt 268 | if x.dim() == 1: 269 | x = th.unsqueeze(x, 0) 270 | # n x 1 x S => n x N x T 271 | w = F.relu(self.encoder_1d(x)) 272 | # n x B x T 273 | y = self.proj(self.ln(w)) 274 | # n x B x T 275 | y = self.repeats(y) 276 | # n x 2N x T 277 | e = th.chunk(self.mask(y), self.num_spks, 1) 278 | # n x N x T 279 | if self.non_linear_type == "softmax": 280 | m = self.non_linear(th.stack(e, dim=0), dim=0) 281 | else: 282 | m = self.non_linear(th.stack(e, dim=0)) 283 | # spks x [n x N x T] 284 | s = [w * m[n] for n in range(self.num_spks)] 285 | # spks x n x S 286 | out = th.stack([(self.decoder_1d(x, squeeze=True)) for x in s], axis=1) 287 | # print(out.shape) 288 | return out 289 | # [batch, num] 290 | # return [self.decoder_1d(x, squeeze=True) for x in s] 291 | 292 | 293 | 294 | def foo_conv1d_block(): 295 | nnet = Conv1DBlock(256, 512, 3, 20) 296 | print(param(nnet)) 297 | 298 | 299 | def foo_layernorm(): 300 | C, T = 256, 20 301 | nnet1 = nn.LayerNorm([C, T], elementwise_affine=True) 302 | print(param(nnet1, Mb=False)) 303 | nnet2 = nn.LayerNorm([C, T], elementwise_affine=False) 304 | print(param(nnet2, Mb=False)) 305 | 306 | 307 | def foo_conv_tas_net(): 308 | x = th.rand(4, 16000) 309 | nnet = Model(norm="cLN", causal=False) 310 | # print(nnet) 311 | print("ConvTasNet #param: {:.2f}".format(param(nnet))) 312 | x = nnet(x) 313 | print(x.shape) 314 | 315 | if __name__ == "__main__": 316 | foo_conv_tas_net() 317 | # foo_conv1d_block() 318 | # foo_layernorm() 319 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse_loss(): 4 | return torch.nn.MSELoss() 5 | 6 | def l1_loss(): 7 | return torch.nn.L1Loss() 8 | 9 | def bce_loss(): 10 | return torch.nn.BCEWithLogitsLoss() # output 0~1 -------------------------------------------------------------------------------- /model/unet_basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DownSamplingLayer(nn.Module): 7 | def __init__(self, channel_in, channel_out, dilation=1, kernel_size=15, stride=1, padding=7): 8 | super(DownSamplingLayer, self).__init__() 9 | self.main = nn.Sequential( 10 | nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size, 11 | stride=stride, padding=padding, dilation=dilation), 12 | nn.BatchNorm1d(channel_out), 13 | nn.LeakyReLU(negative_slope=0.1) 14 | ) 15 | 16 | def forward(self, ipt): 17 | return self.main(ipt) 18 | 19 | class UpSamplingLayer(nn.Module): 20 | def __init__(self, channel_in, channel_out, kernel_size=5, stride=1, padding=2): 21 | super(UpSamplingLayer, self).__init__() 22 | self.main = nn.Sequential( 23 | nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size, 24 | stride=stride, padding=padding), 25 | nn.BatchNorm1d(channel_out), 26 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 27 | ) 28 | 29 | def forward(self, ipt): 30 | return self.main(ipt) 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, n_layers=12, channels_interval=24): 35 | super(Model, self).__init__() 36 | 37 | self.n_layers = n_layers 38 | self.channels_interval = channels_interval 39 | encoder_in_channels_list = [1] + [i * self.channels_interval for i in range(1, self.n_layers)] 40 | encoder_out_channels_list = [i * self.channels_interval for i in range(1, self.n_layers + 1)] 41 | 42 | # 1 => 2 => 3 => 4 => 5 => 6 => 7 => 8 => 9 => 10 => 11 =>12 43 | # 16384 => 8192 => 4096 => 2048 => 1024 => 512 => 256 => 128 => 64 => 32 => 16 => 8 => 4 44 | self.encoder = nn.ModuleList() 45 | for i in range(self.n_layers): 46 | self.encoder.append( 47 | DownSamplingLayer( 48 | channel_in=encoder_in_channels_list[i], 49 | channel_out=encoder_out_channels_list[i] 50 | ) 51 | ) 52 | 53 | self.middle = nn.Sequential( 54 | nn.Conv1d(self.n_layers * self.channels_interval, self.n_layers * self.channels_interval, 15, stride=1, 55 | padding=7), 56 | nn.BatchNorm1d(self.n_layers * self.channels_interval), 57 | nn.LeakyReLU(negative_slope=0.1, inplace=True) 58 | ) 59 | 60 | decoder_in_channels_list = [(2 * i + 1) * self.channels_interval for i in range(1, self.n_layers)] + [ 61 | 2 * self.n_layers * self.channels_interval] 62 | decoder_in_channels_list = decoder_in_channels_list[::-1] 63 | decoder_out_channels_list = encoder_out_channels_list[::-1] 64 | self.decoder = nn.ModuleList() 65 | for i in range(self.n_layers): 66 | self.decoder.append( 67 | UpSamplingLayer( 68 | channel_in=decoder_in_channels_list[i], 69 | channel_out=decoder_out_channels_list[i] 70 | ) 71 | ) 72 | 73 | self.out = nn.Sequential( 74 | nn.Conv1d(1 + self.channels_interval, 1, kernel_size=1, stride=1), 75 | nn.Tanh() 76 | ) 77 | 78 | def forward(self, input): 79 | tmp = [] 80 | o = input 81 | 82 | # Up Sampling 83 | for i in range(self.n_layers): 84 | o = self.encoder[i](o) 85 | tmp.append(o) 86 | # [batch_size, T // 2, channels] 87 | o = o[:, :, ::2] 88 | 89 | o = self.middle(o) 90 | 91 | # Down Sampling 92 | for i in range(self.n_layers): 93 | # [batch_size, T * 2, channels] 94 | o = F.interpolate(o, scale_factor=2, mode="linear", align_corners=True) 95 | # Skip Connection 96 | o = torch.cat([o, tmp[self.n_layers - i - 1]], dim=1) 97 | o = self.decoder[i](o) 98 | 99 | o = torch.cat([o, input], dim=1) 100 | o = self.out(o) 101 | return o 102 | 103 | 104 | # n_layers = 12, channels_interval = 24 105 | # UpSamplingLayer(288 + 288, 288), 106 | # UpSamplingLayer(264 + 288, 264), # 同水平层的降采样后维度为 264 107 | # UpSamplingLayer(240 + 264, 240), 108 | # 109 | # UpSamplingLayer(216 + 240, 216), 110 | # UpSamplingLayer(192 + 216, 192), 111 | # UpSamplingLayer(168 + 192, 168), 112 | # 113 | # UpSamplingLayer(144 + 168, 144), 114 | # UpSamplingLayer(120 + 144, 120), 115 | # UpSamplingLayer(96 + 120, 96), 116 | # 117 | # UpSamplingLayer(72 + 96, 72), 118 | # UpSamplingLayer(48 + 72, 48), 119 | # UpSamplingLayer(24 + 48, 24), 120 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import json5 5 | import numpy as np 6 | import torch 7 | from torch.optim.lr_scheduler import StepLR 8 | from util import visualization 9 | from util.utils import prepare_empty_dir, ExecutionTime 10 | 11 | class BaseTrainer: 12 | def __init__(self, config, resume: bool, model, loss_function, optimizer): 13 | self.n_gpu = torch.cuda.device_count() 14 | self.device = self._prepare_device(self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"]) 15 | 16 | self.optimizer = optimizer 17 | self.loss_function = loss_function 18 | 19 | self.model = model.to(self.device) 20 | 21 | if self.n_gpu > 1: 22 | self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.n_gpu))) 23 | 24 | # Trainer 25 | self.epochs = config["trainer"]["epochs"] 26 | self.save_checkpoint_interval = config["trainer"]["save_checkpoint_interval"] 27 | self.validation_config = config["trainer"]["validation"] 28 | self.validation_interval = self.validation_config["interval"] 29 | self.find_max = self.validation_config["find_max"] 30 | self.validation_custom_config = self.validation_config["custom"] 31 | 32 | # The following args is not in the config file, We will update it if resume is True in later. 33 | self.start_epoch = 1 34 | self.best_score = -np.inf if self.find_max else np.inf 35 | self.root_dir = Path(config["root_dir"]).expanduser().absolute() / config["experiment_name"] 36 | self.checkpoints_dir = self.root_dir / "checkpoints" 37 | self.logs_dir = self.root_dir / "logs" 38 | prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume) 39 | 40 | self.writer = visualization.writer(self.logs_dir.as_posix()) 41 | self.writer.add_text( 42 | tag="Configuration", 43 | text_string=f"
  \n{json5.dumps(config, indent=4, sort_keys=False)}  \n
", 44 | global_step=1 45 | ) 46 | 47 | if resume: self._resume_checkpoint() 48 | 49 | print("Configurations are as follows: ") 50 | print(json5.dumps(config, indent=2, sort_keys=False)) 51 | 52 | with open((self.root_dir / f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(), "w") as handle: 53 | json5.dump(config, handle, indent=2, sort_keys=False) 54 | 55 | self._print_networks([self.model]) 56 | 57 | def _resume_checkpoint(self): 58 | """Resume experiment from latest checkpoint. 59 | Notes: 60 | To be careful at Loading model. if model is an instance of DataParallel, we need to set model.module.* 61 | """ 62 | latest_model_path = self.checkpoints_dir.expanduser().absolute() / "latest_model.tar" 63 | assert latest_model_path.exists(), f"{latest_model_path} does not exist, can not load latest checkpoint." 64 | 65 | checkpoint = torch.load(latest_model_path.as_posix(), map_location=self.device) 66 | 67 | self.start_epoch = checkpoint["epoch"] + 1 68 | self.best_score = checkpoint["best_score"] 69 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 70 | 71 | if isinstance(self.model, torch.nn.DataParallel): 72 | self.model.module.load_state_dict(checkpoint["model"]) 73 | else: 74 | self.model.load_state_dict(checkpoint["model"]) 75 | 76 | print(f"Model checkpoint loaded. Training will begin in {self.start_epoch} epoch.") 77 | 78 | def _save_checkpoint(self, epoch, is_best=False): 79 | """Save checkpoint to /checkpoints directory, which contains: 80 | - current epoch 81 | - best score in history 82 | - optimizer parameters 83 | - model parameters 84 | Args: 85 | is_best(bool): if current checkpoint got the best score, it also will be saved in /checkpoints/best_model.tar. 86 | """ 87 | print(f"\t Saving {epoch} epoch model checkpoint...") 88 | 89 | # Construct checkpoint tar package 90 | state_dict = { 91 | "epoch": epoch, 92 | "best_score": self.best_score, 93 | "optimizer": self.optimizer.state_dict() 94 | } 95 | 96 | if isinstance(self.model, torch.nn.DataParallel): # Parallel 97 | state_dict["model"] = self.model.module.cpu().state_dict() 98 | else: 99 | state_dict["model"] = self.model.cpu().state_dict() 100 | 101 | """ 102 | Notes: 103 | - latest_model.tar: 104 | Contains all checkpoint information, including optimizer parameters, model parameters, etc. New checkpoint will overwrite old one. 105 | - model_.pth: 106 | The parameters of the model. Follow-up we can specify epoch to inference. 107 | - best_model.tar: 108 | Like latest_model, but only saved when is True. 109 | """ 110 | torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix()) 111 | torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix()) 112 | if is_best: 113 | print(f"\t Found best score in {epoch} epoch, saving...") 114 | torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix()) 115 | 116 | # Use model.cpu() or model.to("cpu") will migrate the model to CPU, at which point we need re-migrate model back. 117 | # No matter tensor.cuda() or tensor.to("cuda"), if tensor in CPU, the tensor will not be migrated to GPU, but the model will. 118 | self.model.to(self.device) 119 | 120 | @staticmethod 121 | def _prepare_device(n_gpu: int, cudnn_deterministic=False): 122 | """Choose to use CPU or GPU depend on "n_gpu". 123 | Args: 124 | n_gpu(int): the number of GPUs used in the experiment. 125 | if n_gpu is 0, use CPU; 126 | if n_gpu > 1, use GPU. 127 | cudnn_deterministic (bool): repeatability 128 | cudnn.benchmark will find algorithms to optimize training. if we need to consider the repeatability of experiment, set use_cudnn_deterministic to True 129 | """ 130 | if n_gpu == 0: 131 | print("Using CPU in the experiment.") 132 | device = torch.device("cpu") 133 | else: 134 | if cudnn_deterministic: 135 | print("Using CuDNN deterministic mode in the experiment.") 136 | torch.backends.cudnn.deterministic = True 137 | torch.backends.cudnn.benchmark = False 138 | 139 | device = torch.device("cuda:0") 140 | 141 | return device 142 | 143 | def _is_best(self, score, find_max=True): 144 | """Check if the current model is the best model 145 | """ 146 | if find_max and score >= self.best_score: 147 | self.best_score = score 148 | return True 149 | elif not find_max and score <= self.best_score: 150 | self.best_score = score 151 | return True 152 | else: 153 | return False 154 | 155 | @staticmethod 156 | def _transform_pesq_range(pesq_score): 157 | """transform [-0.5 ~ 4.5] to [0 ~ 1] 158 | """ 159 | return (pesq_score + 0.5) / 5 160 | 161 | @staticmethod 162 | def _print_networks(nets: list): 163 | print(f"This project contains {len(nets)} networks, the number of the parameters: ") 164 | params_of_all_networks = 0 165 | for i, net in enumerate(nets, start=1): 166 | params_of_network = 0 167 | for param in net.parameters(): 168 | params_of_network += param.numel() 169 | 170 | print(f"\tNetwork {i}: {params_of_network / 1e6} million.") 171 | params_of_all_networks += params_of_network 172 | 173 | print(f"The amount of parameters in the project is {params_of_all_networks / 1e6} million.") 174 | 175 | def _set_models_to_train_mode(self): 176 | self.model.train() 177 | 178 | def _set_models_to_eval_mode(self): 179 | self.model.eval() 180 | 181 | def train(self): 182 | for epoch in range(self.start_epoch, self.epochs + 1): 183 | print(f"============== {epoch} epoch ==============") 184 | print("[0 seconds] Begin training...") 185 | timer = ExecutionTime() 186 | 187 | self._set_models_to_train_mode() 188 | self._train_epoch(epoch) 189 | 190 | if self.save_checkpoint_interval != 0 and (epoch % self.save_checkpoint_interval == 0): 191 | self._save_checkpoint(epoch) 192 | 193 | if self.validation_interval != 0 and epoch % self.validation_interval == 0: 194 | print(f"[{timer.duration()} seconds] Training is over, Validation is in progress...") 195 | 196 | self._set_models_to_eval_mode() 197 | score = self._validation_epoch(epoch) 198 | 199 | if self._is_best(score, find_max=self.find_max): 200 | self._save_checkpoint(epoch, is_best=True) 201 | 202 | print(f"[{timer.duration()} seconds] End this epoch.") 203 | 204 | def _train_epoch(self, epoch): 205 | raise NotImplementedError 206 | 207 | def _validation_epoch(self, epoch): 208 | raise NotImplementedError -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.display 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | 7 | from trainer.base_trainer import BaseTrainer 8 | from util.utils import compute_STOI, compute_PESQ 9 | plt.switch_backend('agg') 10 | 11 | 12 | class Trainer(BaseTrainer): 13 | def __init__( 14 | self, 15 | config, 16 | resume: bool, 17 | model, 18 | loss_function, 19 | optimizer, 20 | train_dataloader, 21 | validation_dataloader, 22 | ): 23 | super(Trainer, self).__init__(config, resume, model, loss_function, optimizer) 24 | self.train_data_loader = train_dataloader 25 | self.validation_data_loader = validation_dataloader 26 | 27 | def _train_epoch(self, epoch): 28 | loss_total = 0.0 29 | 30 | for i, (mixture, clean, name) in enumerate(self.train_data_loader): 31 | mixture = mixture.to(self.device) 32 | clean = clean.to(self.device) 33 | 34 | self.optimizer.zero_grad() 35 | enhanced = self.model(mixture) 36 | loss = self.loss_function(clean, enhanced) 37 | loss.backward() 38 | self.optimizer.step() 39 | 40 | loss_total += loss.item() 41 | 42 | dl_len = len(self.train_data_loader) 43 | self.writer.add_scalar(f"Train/Loss", loss_total / dl_len, epoch) 44 | 45 | @torch.no_grad() 46 | def _validation_epoch(self, epoch): 47 | visualize_audio_limit = self.validation_custom_config["visualize_audio_limit"] 48 | visualize_waveform_limit = self.validation_custom_config["visualize_waveform_limit"] 49 | visualize_spectrogram_limit = self.validation_custom_config["visualize_spectrogram_limit"] 50 | 51 | sample_length = self.validation_custom_config["sample_length"] 52 | 53 | stoi_c_n = [] # clean and noisy 54 | stoi_c_d = [] # clean and denoisy 55 | pesq_c_n = [] 56 | pesq_c_d = [] 57 | 58 | for i, (mixture, clean, name) in enumerate(self.validation_data_loader): 59 | assert len(name) == 1, "Only support batch size is 1 in enhancement stage." 60 | name = name[0] 61 | 62 | # [1, 1, T] 63 | mixture = mixture.to(self.device) 64 | clean = clean.to(self.device) 65 | 66 | # Input of model should fixed length 67 | mixture_chunks = torch.split(mixture, sample_length, dim=2) 68 | if mixture_chunks[-1].shape[-1] != sample_length: 69 | mixture_chunks = mixture_chunks[:-1] 70 | 71 | enhanced_chunks = [] 72 | for chunk in mixture_chunks: 73 | enhanced_chunks.append(self.model(chunk).detach().cpu()) 74 | 75 | enhanced = torch.cat(enhanced_chunks, dim=2) 76 | 77 | # Back to numpy array 78 | mixture = mixture.cpu().numpy().reshape(-1) 79 | enhanced = enhanced.numpy().reshape(-1) 80 | clean = clean.cpu().numpy().reshape(-1) 81 | 82 | min_len = min(len(mixture), len(clean), len(enhanced)) 83 | 84 | mixture = mixture[:min_len] 85 | clean = clean[:min_len] 86 | enhanced = enhanced[:min_len] 87 | 88 | # Visualize audio 89 | if i <= visualize_audio_limit: 90 | self.writer.add_audio(f"Speech/{name}_Noisy", mixture, epoch, sample_rate=16000) 91 | self.writer.add_audio(f"Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000) 92 | self.writer.add_audio(f"Speech/{name}_Clean", clean, epoch, sample_rate=16000) 93 | 94 | # Visualize waveform 95 | if i <= visualize_waveform_limit: 96 | fig, ax = plt.subplots(3, 1) 97 | for j, y in enumerate([mixture, enhanced, clean]): 98 | ax[j].set_title("mean: {:.3f}, std: {:.3f}, max: {:.3f}, min: {:.3f}".format( 99 | np.mean(y), 100 | np.std(y), 101 | np.max(y), 102 | np.min(y) 103 | )) 104 | librosa.display.waveplot(y, sr=16000, ax=ax[j]) 105 | plt.tight_layout() 106 | self.writer.add_figure(f"Waveform/{name}", fig, epoch) 107 | 108 | # Visualize spectrogram 109 | noisy_mag, _ = librosa.magphase(librosa.stft(mixture, n_fft=320, hop_length=160, win_length=320)) 110 | enhanced_mag, _ = librosa.magphase(librosa.stft(enhanced, n_fft=320, hop_length=160, win_length=320)) 111 | clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=320, hop_length=160, win_length=320)) 112 | 113 | if i <= visualize_spectrogram_limit: 114 | fig, axes = plt.subplots(3, 1, figsize=(6, 6)) 115 | for k, mag in enumerate([ 116 | noisy_mag, 117 | enhanced_mag, 118 | clean_mag, 119 | ]): 120 | axes[k].set_title(f"mean: {np.mean(mag):.3f}, " 121 | f"std: {np.std(mag):.3f}, " 122 | f"max: {np.max(mag):.3f}, " 123 | f"min: {np.min(mag):.3f}") 124 | librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000) 125 | plt.tight_layout() 126 | self.writer.add_figure(f"Spectrogram/{name}", fig, epoch) 127 | 128 | # Metric 129 | stoi_c_n.append(compute_STOI(clean, mixture, sr=16000)) 130 | stoi_c_d.append(compute_STOI(clean, enhanced, sr=16000)) 131 | pesq_c_n.append(compute_PESQ(clean, mixture, sr=16000)) 132 | pesq_c_d.append(compute_PESQ(clean, enhanced, sr=16000)) 133 | 134 | get_metrics_ave = lambda metrics: np.sum(metrics) / len(metrics) 135 | self.writer.add_scalars(f"评价指标均值/STOI", { 136 | "clean 与 noisy": get_metrics_ave(stoi_c_n), 137 | "clean 与 denoisy": get_metrics_ave(stoi_c_d) 138 | }, epoch) 139 | self.writer.add_scalars(f"评价指标均值/PESQ", { 140 | "clean 与 noisy": get_metrics_ave(pesq_c_n), 141 | "clean 与 denoisy": get_metrics_ave(pesq_c_d) 142 | }, epoch) 143 | 144 | score = (get_metrics_ave(stoi_c_d) + self._transform_pesq_range(get_metrics_ave(pesq_c_d))) / 2 145 | return score 146 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/util/__init__.py -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import time 3 | import os 4 | 5 | import torch 6 | from pesq import pesq 7 | import numpy as np 8 | from pystoi.stoi import stoi 9 | 10 | 11 | def load_checkpoint(checkpoint_path, device): 12 | _, ext = os.path.splitext(os.path.basename(checkpoint_path)) 13 | assert ext in (".pth", ".tar"), "Only support ext and tar extensions of model checkpoint." 14 | model_checkpoint = torch.load(checkpoint_path, map_location=device) 15 | 16 | if ext == ".pth": 17 | print(f"Loading {checkpoint_path}.") 18 | return model_checkpoint 19 | else: # tar 20 | print(f"Loading {checkpoint_path}, epoch = {model_checkpoint['epoch']}.") 21 | return model_checkpoint["model"] 22 | 23 | 24 | def prepare_empty_dir(dirs, resume=False): 25 | """ 26 | if resume experiment, assert the dirs exist, 27 | if not resume experiment, make dirs. 28 | 29 | Args: 30 | dirs (list): directors list 31 | resume (bool): whether to resume experiment, default is False 32 | """ 33 | for dir_path in dirs: 34 | if resume: 35 | assert dir_path.exists() 36 | else: 37 | dir_path.mkdir(parents=True, exist_ok=True) 38 | 39 | 40 | class ExecutionTime: 41 | """ 42 | Usage: 43 | timer = ExecutionTime() 44 | 45 | print(f'Finished in {timer.duration()} seconds.') 46 | """ 47 | 48 | def __init__(self): 49 | self.start_time = time.time() 50 | 51 | def duration(self): 52 | return int(time.time() - self.start_time) 53 | 54 | 55 | def initialize_config(module_cfg, pass_args=True): 56 | """ 57 | According to config items, load specific module dynamically with params. 58 | eg,config items as follow: 59 | module_cfg = { 60 | "module": "model.model", 61 | "main": "Model", 62 | "args": {...} 63 | } 64 | 1. Load the module corresponding to the "module" param. 65 | 2. Call function (or instantiate class) corresponding to the "main" param. 66 | 3. Send the param (in "args") into the function (or class) when calling ( or instantiating) 67 | """ 68 | module = importlib.import_module(module_cfg["module"]) 69 | 70 | if pass_args: 71 | return getattr(module, module_cfg["main"])(**module_cfg["args"]) 72 | else: 73 | return getattr(module, module_cfg["main"]) 74 | 75 | 76 | 77 | def compute_PESQ(clean_signal, noisy_signal, sr=16000): 78 | return pesq(sr, clean_signal, noisy_signal, "wb") 79 | 80 | 81 | def z_score(m): 82 | mean = np.mean(m) 83 | std_var = np.std(m) 84 | return (m - mean) / std_var, mean, std_var 85 | 86 | 87 | def reverse_z_score(m, mean, std_var): 88 | return m * std_var + mean 89 | 90 | 91 | def min_max(m): 92 | m_max = np.max(m) 93 | m_min = np.min(m) 94 | 95 | return (m - m_min) / (m_max - m_min), m_max, m_min 96 | 97 | 98 | def reverse_min_max(m, m_max, m_min): 99 | return m * (m_max - m_min) + m_min 100 | 101 | 102 | def sample_fixed_length_data_aligned(data_a, data_b, sample_length): 103 | """ 104 | sample with fixed length from two dataset 105 | """ 106 | assert len(data_a) == len(data_b), "Inconsistent dataset length, unable to sampling" 107 | assert len(data_a) >= sample_length, f"len(data_a) is {len(data_a)}, sample_length is {sample_length}." 108 | 109 | frames_total = len(data_a) 110 | 111 | start = np.random.randint(frames_total - sample_length + 1) 112 | # print(f"Random crop from: {start}") 113 | end = start + sample_length 114 | 115 | return data_a[start:end], data_b[start:end] 116 | 117 | 118 | def compute_STOI(clean_signal, noisy_signal, sr=16000): 119 | return stoi(clean_signal, noisy_signal, sr, extended=False) 120 | 121 | 122 | def print_tensor_info(tensor, flag="Tensor"): 123 | floor_tensor = lambda float_tensor: int(float(float_tensor) * 1000) / 1000 124 | print(flag) 125 | print( 126 | f"\tmax: {floor_tensor(torch.max(tensor))}, min: {float(torch.min(tensor))}, mean: {floor_tensor(torch.mean(tensor))}, std: {floor_tensor(torch.std(tensor))}") 127 | -------------------------------------------------------------------------------- /util/visualization.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | 4 | def writer(logs_dir): 5 | return SummaryWriter(log_dir=logs_dir, max_queue=5, flush_secs=30) --------------------------------------------------------------------------------