134 |
135 | # 例如,配置文件中的 "save_location" 参数为 "/home/happy/Experiments",配置文件名为 "train_config.json",修改默认端口为 6000
136 | # 可使用如下命令:
137 | tensorboard --logdir /home/happy/Experiments/train_config --port 6000
138 | ```
139 |
140 | ## 目录说明
141 |
142 | 在项目运行过程,会产生多个目录,均有不同的用途:
143 |
144 | - 主目录:当前 README.md 所在的目录,存储着所有源代码
145 | - 训练目录:训练配置文件中的`config["save_location"]`目录,存储当前项目的所有实验日志和模型断点
146 | - 实验目录:`config["save_location"]/<实验名>/`目录,存储着某一次实验的日志信息
147 |
148 |
149 | ## 参数说明
150 | ### 训练
151 |
152 | `config/train/<实验名>.json`,训练过程中产生的日志信息会存放在`config["save_location"]/<实验名>/`目录下
153 |
154 | ```json5
155 | {
156 | "seed": 0, // 保证实验可重复性的随机种子
157 | "description": "...", // 实验描述,后续会显示在 Tensorboard 中
158 | "root_dir": "~/Experiments/Wave-U-Net", //存放实验结果的目录
159 | "cudnn_deterministic": false,
160 | "trainer": { // 训练过程
161 | "module": "trainer.trainer", // 训练器模型的文件
162 | "main": "Trainer", // 训练器模型的具体类
163 | "epochs": 1200, // 训练的上限
164 | "save_checkpoint_interval": 10, // 保存模型断点的间隔
165 | "validation":{
166 | "interval": 10, // 验证的间隔
167 | "find_max": true, // 当 find_max 为 true 时,如果计算出的评价指标为已知的最大值,就会将当前轮次的模型断点另外缓存一份
168 | "custon": {
169 | "visualize_audio_limit": 20, // 验证时可视化音频的间隔,之所以设置这个参数,是因为可视化音频比较慢
170 | "visualize_waveform_limit": 20, // 验证时可视化波形的间隔,之所以设置这个参数,是因为可视化波形比较慢
171 | "visualize_spectrogram_limit": 20, //验证可视化频谱的间隔,之所以设置这个参数,是因为可视化频谱比较慢
172 | "sample_length": 16384 //采样点数
173 | }
174 | }
175 | },
176 | "model": {
177 | "module": "model.unet_basic", // 训练使用的模型文件
178 | "main": "Model", // 训练模型的具体类
179 | "args": {} // 传给模型类的参数
180 | },
181 | "loss_function": {
182 | "module": "model.loss", // 损失函数的模型文件
183 | "main": "mse_loss", // 损失函数模型的具体类
184 | "args": {} // 传给模型类的参数
185 | },
186 | "optimizer": {
187 | "lr": 0.001,
188 | "beta1": 0.9,
189 | "beat2": 0.009
190 | },
191 | "train_dataset": {
192 | "module": "dataset.waveform_dataset", // 存放训练集类模型的文件
193 | "main": "Dataset", // 训练集模型的具体类
194 | "args": { // 传递给训练集类的参数,详见具体的训练集类
195 | "dataset": "~/Datasets/SEGAN_Dataset/train_dataset.txt",
196 | "limit": null,
197 | "offset": 0,
198 | "sample_length": 16384,
199 | "mode":"train"
200 | }
201 | },
202 | "validation_dataset": {
203 | "module": "dataset.waveform_dataset",
204 | "main": "Dataset",
205 | "args": {
206 | "dataset": "~/Datasets/SEGAN_Dataset/test_dataset.txt",
207 | "limit": 400,
208 | "offset": 0,
209 | "mode":"validation"
210 | }
211 | },
212 | "train_dataloader": {
213 | "batch_size": 120,
214 | "num_workers": 40, // 开启多少个线程对数据进行预处理
215 | "shuffle": true,
216 | "pin_memory":true
217 | }
218 | }
219 | ```
220 |
221 | ### 增强
222 |
223 | `config/enhancement/*.json`
224 |
225 | ```json5
226 | {
227 | "model": {
228 | "module": "model.unet_basic", // 放置模型的文件
229 | "main": "UNet",// 文件内的具体模型类
230 | "args": {} // 传给模型类的参数
231 | },
232 | "dataset": {
233 | "module": "dataset.waveform_dataset", // 增强使用的数据集类
234 | "main": "WaveformDataset", // 传递给数据集类的参数,详见具体的训练集类
235 | "args": {
236 | "dataset": "/home/imucs/Datasets/2019-09-03-timit_train-900_test-50/enhancement.txt",
237 | "limit": 400,
238 | "offset": 0,
239 | "sample_length": 16384
240 | }
241 | }
242 | }
243 | ```
244 |
245 | 在增强时,存储数据集路径的 txt 文件仅仅指定带噪语音的路径即可,类似这样:
246 |
247 | ```text
248 | # enhancement.txt
249 |
250 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Clean.wav
251 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_200.wav
252 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_Inpainting_270.wav
253 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Enhanced_UNet.wav
254 | /home/imucs/tmp/UNet_and_Inpainting/0001_babble_-7dB_Mixture.wav
255 | ```
256 |
257 | ## TODO
258 |
259 | - [x] 使用全长语音进行验证
260 | - [x] 增强脚本
261 | - [ ] 测试脚本
262 |
--------------------------------------------------------------------------------
/config/enhancement/unet_basic.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": {
3 | "module": "model.unet_basic",
4 | "main": "UNet",
5 | "args": {}
6 | },
7 | "dataset": {
8 | "module": "dataset.waveform_dataset_enhancement",
9 | "main": "WaveformDataset",
10 | "args": {
11 | "dataset": "/home/imucs/tmp/UNet_and_Inpainting/data.txt",
12 | "limit": 400,
13 | "offset": 0,
14 | "sample_length": 16384
15 | }
16 | }
17 | }
--------------------------------------------------------------------------------
/config/train/train.json:
--------------------------------------------------------------------------------
1 | {
2 | "seed": 0,
3 | "description": "test",
4 | "root_dir": "E:/Experiments/Wave-U-Net",
5 | "cudnn_deterministic": false,
6 | "trainer": {
7 | "module": "trainer.trainer",
8 | "main": "Trainer",
9 | "epochs": 600,
10 | "save_checkpoint_interval": 10,
11 | "validation": {
12 | "interval": 10,
13 | "find_max": true,
14 | "custom": {
15 | "visualize_audio_limit": 20,
16 | "visualize_waveform_limit": 20,
17 | "visualize_spectrogram_limit": 20,
18 | "sample_length": 16000
19 | }
20 | }
21 | },
22 | "model": {
23 | "module": "model.conv_tas_net",
24 | "main": "Model",
25 | "args": {}
26 | },
27 | "loss_function": {
28 | "module": "model.loss",
29 | "main": "mse_loss",
30 | "args": {}
31 | },
32 | "optimizer": {
33 | "lr": 0.001,
34 | "beta1": 0.9,
35 | "beta2": 0.999
36 | },
37 | "train_dataset": {
38 | "module": "dataset.waveform_dataset",
39 | "main": "Dataset",
40 | "args": {
41 | "dataset": "E:/train_dataset.txt",
42 | "limit": null,
43 | "offset": 0,
44 | "sample_length": 16000,
45 | "mode": "train"
46 | }
47 | },
48 | "validation_dataset": {
49 | "module": "dataset.waveform_dataset",
50 | "main": "Dataset",
51 | "args": {
52 | "dataset": "E:/test_dataset.txt",
53 | "limit": 400,
54 | "offset": 0,
55 | "mode": "validation"
56 | }
57 | },
58 | "train_dataloader": {
59 | "batch_size": 4,
60 | "num_workers": 4,
61 | "shuffle": true,
62 | "pin_memory": true
63 | }
64 | }
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/waveform_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | import librosa
4 | from util.utils import sample_fixed_length_data_aligned
5 |
6 |
7 | class Dataset(data.Dataset):
8 | def __init__(self, dataset, limit=None, offset=0, sample_length=16384, mode="train"):
9 | """
10 | 构建训练数据集
11 | Args:
12 | dataset (str): 语音数据集的路径,拓展名为 txt,见 Notes 部分
13 | limit (int): 数据集的数量上限
14 | offset (int): 数据集的起始位置的偏移值
15 | sample_length(int): 模型仅支持定长输入,这个参数指定了每次输入模型的大小
16 | mode(str): 当为 train 时,表示需要对语音进行定长切分,当为 validation 时,表示不需要,直接返回全长的语音。
17 |
18 | Notes:
19 | 语音数据集格式如下:
20 | <带噪语音1的路径><空格><纯净语音1的路径>
21 | <带噪语音2的路径><空格><纯净语音2的路径>
22 | ...
23 | <带噪语音n的路径><空格><纯净语音n的路径>
24 |
25 | eg:
26 | /train/noisy/a.wav /train/clean/a.wav
27 | /train/noisy/b.wav /train/clean/b.wav
28 | ...
29 |
30 | Return:
31 | (mixture signals, clean signals, file name)
32 | """
33 | super(Dataset, self).__init__()
34 | dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset)), "r")]
35 |
36 | dataset_list = dataset_list[offset:]
37 | if limit:
38 | dataset_list = dataset_list[:limit]
39 |
40 | assert mode in ("train", "validation"), "Mode must be one of train or validation."
41 |
42 | self.length = len(dataset_list)
43 | self.dataset_list = dataset_list
44 | self.sample_length = sample_length
45 | self.mode = mode
46 |
47 | def __len__(self):
48 | return self.length
49 |
50 | def __getitem__(self, item):
51 | mixture_path, clean_path = self.dataset_list[item].split(" ")
52 | name = os.path.splitext(os.path.basename(mixture_path))[0]
53 | mixture, _ = librosa.load(os.path.abspath(os.path.expanduser(mixture_path)), sr=None)
54 | clean, _ = librosa.load(os.path.abspath(os.path.expanduser(clean_path)), sr=None)
55 |
56 | if self.mode == "train":
57 | # The input of model should be fixed length.
58 | mixture, clean = sample_fixed_length_data_aligned(mixture, clean, self.sample_length)
59 | return mixture.reshape(1, -1), clean.reshape(1, -1), name
60 | else:
61 | return mixture.reshape(1, -1), clean.reshape(1, -1), name
62 |
--------------------------------------------------------------------------------
/dataset/waveform_dataset_enhancement.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.data import Dataset
3 | import librosa
4 |
5 |
6 | class WaveformDataset(Dataset):
7 | def __init__(self, dataset, limit=None, offset=0, sample_length=16384):
8 | """
9 | 构建增强数据集
10 | Args:
11 | dataset (str): 语音数据集的路径,拓展名为 txt,见 Notes 部分
12 | limit (int): 数据集的数量上限
13 | offset (int): 数据集的起始位置的偏移值
14 | sample_length(int): 模型仅支持定长输入,这个参数指定了每次输入模型的大小
15 |
16 | Notes:
17 | 语音数据集格式如下:
18 | <带噪语音1的路径>
19 | <带噪语音2的路径>
20 | ...
21 | <带噪语音n的路径>
22 |
23 | eg:
24 | /enhancement/noisy/a.wav
25 | /enhancement/noisy/b.wav
26 | ...
27 |
28 | Return:
29 | (mixture signals, clean signals, file name)
30 | """
31 | super(WaveformDataset, self).__init__()
32 | dataset_list = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(dataset)), "r")]
33 |
34 | dataset_list = dataset_list[offset:]
35 | if limit:
36 | dataset_list = dataset_list[:limit]
37 |
38 | self.length = len(dataset_list)
39 | self.dataset_list = dataset_list
40 | self.sample_length = sample_length
41 |
42 | def __len__(self):
43 | return self.length
44 |
45 | def __getitem__(self, item):
46 | mixture_path = self.dataset_list[item]
47 | name = os.path.splitext(os.path.basename(mixture_path))[0]
48 |
49 | mixture, _ = librosa.load(os.path.abspath(os.path.expanduser(mixture_path)), sr=None)
50 |
51 | return mixture.reshape(1, -1), name
52 |
--------------------------------------------------------------------------------
/doc/audio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/doc/audio.png
--------------------------------------------------------------------------------
/doc/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/doc/tensorboard.png
--------------------------------------------------------------------------------
/enhancement.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import librosa
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 |
10 | from util.utils import initialize_config, load_checkpoint
11 |
12 | """
13 | Parameters
14 | """
15 | parser = argparse.ArgumentParser("Wave-U-Net: Speech Enhancement")
16 | parser.add_argument("-C", "--config", type=str, required=True, help="Model and dataset for enhancement (*.json).")
17 | parser.add_argument("-D", "--device", default="-1", type=str, help="GPU for speech enhancement. default: CPU")
18 | parser.add_argument("-O", "--output_dir", type=str, required=True, help="Where are audio save.")
19 | parser.add_argument("-M", "--model_checkpoint_path", type=str, required=True, help="Checkpoint.")
20 | args = parser.parse_args()
21 |
22 | """
23 | Preparation
24 | """
25 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
26 | config = json.load(open(args.config))
27 | model_checkpoint_path = args.model_checkpoint_path
28 | output_dir = args.output_dir
29 | assert os.path.exists(output_dir), "Enhanced directory should be exist."
30 |
31 | """
32 | DataLoader
33 | """
34 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
35 | dataloader = DataLoader(dataset=initialize_config(config["dataset"]), batch_size=1, num_workers=0)
36 |
37 | """
38 | Model
39 | """
40 | model = initialize_config(config["model"])
41 | model.load_state_dict(load_checkpoint(model_checkpoint_path, device))
42 | model.to(device)
43 | model.eval()
44 |
45 | """
46 | Enhancement
47 | """
48 | sample_length = dataloader.dataset.sample_length
49 | for mixture, name in tqdm(dataloader):
50 | assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
51 | name = name[0]
52 |
53 | mixture = mixture.to(device)
54 | mixture_chunks = torch.split(mixture, sample_length, dim=2)
55 | if mixture_chunks[-1].shape[-1] != sample_length:
56 | mixture_chunks = mixture_chunks[:-1]
57 |
58 | enhance_chunks = []
59 | for chunk in mixture_chunks:
60 | enhance_chunks.append((model(chunk).detach().cpu()))
61 |
62 | enhanced = torch.cat(enhance_chunks, dim=2)
63 | enhanced = enhanced.numpy().reshape(-1)
64 |
65 | output_path = os.path.join(output_dir, f"{name}.wav")
66 | librosa.output.write_wav(output_path, enhanced, sr=16000)
67 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/model/__init__.py
--------------------------------------------------------------------------------
/model/conv_tas_net.py:
--------------------------------------------------------------------------------
1 | # wujian@2018
2 |
3 | import torch as th
4 | import torch.nn as nn
5 |
6 | import torch.nn.functional as F
7 |
8 |
9 | def param(nnet, Mb=True):
10 | """
11 | Return number parameters(not bytes) in nnet
12 | """
13 | neles = sum([param.nelement() for param in nnet.parameters()])
14 | return neles / 10**6 if Mb else neles
15 |
16 |
17 | class ChannelWiseLayerNorm(nn.LayerNorm):
18 | """
19 | Channel wise layer normalization
20 | """
21 |
22 | def __init__(self, *args, **kwargs):
23 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)
24 |
25 | def forward(self, x):
26 | """
27 | x: N x C x T
28 | """
29 | if x.dim() != 3:
30 | raise RuntimeError("{} accept 3D tensor as input".format(
31 | self.__name__))
32 | # N x C x T => N x T x C
33 | x = th.transpose(x, 1, 2)
34 | # LN
35 | x = super().forward(x)
36 | # N x C x T => N x T x C
37 | x = th.transpose(x, 1, 2)
38 | return x
39 |
40 |
41 | class GlobalChannelLayerNorm(nn.Module):
42 | """
43 | Global channel layer normalization
44 | """
45 |
46 | def __init__(self, dim, eps=1e-05, elementwise_affine=True):
47 | super(GlobalChannelLayerNorm, self).__init__()
48 | self.eps = eps
49 | self.normalized_dim = dim
50 | self.elementwise_affine = elementwise_affine
51 | if elementwise_affine:
52 | self.beta = nn.Parameter(th.zeros(dim, 1))
53 | self.gamma = nn.Parameter(th.ones(dim, 1))
54 | else:
55 | self.register_parameter("weight", None)
56 | self.register_parameter("bias", None)
57 |
58 | def forward(self, x):
59 | """
60 | x: N x C x T
61 | """
62 | if x.dim() != 3:
63 | raise RuntimeError("{} accept 3D tensor as input".format(
64 | self.__name__))
65 | # N x 1 x 1
66 | mean = th.mean(x, (1, 2), keepdim=True)
67 | var = th.mean((x - mean)**2, (1, 2), keepdim=True)
68 | # N x T x C
69 | if self.elementwise_affine:
70 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
71 | else:
72 | x = (x - mean) / th.sqrt(var + self.eps)
73 | return x
74 |
75 | def extra_repr(self):
76 | return "{normalized_dim}, eps={eps}, " \
77 | "elementwise_affine={elementwise_affine}".format(**self.__dict__)
78 |
79 |
80 | def build_norm(norm, dim):
81 | """
82 | Build normalize layer
83 | LN cost more memory than BN
84 | """
85 | if norm not in ["cLN", "gLN", "BN"]:
86 | raise RuntimeError("Unsupported normalize layer: {}".format(norm))
87 | if norm == "cLN":
88 | return ChannelWiseLayerNorm(dim, elementwise_affine=True)
89 | elif norm == "BN":
90 | return nn.BatchNorm1d(dim)
91 | else:
92 | return GlobalChannelLayerNorm(dim, elementwise_affine=True)
93 |
94 |
95 | class Conv1D(nn.Conv1d):
96 | """
97 | 1D conv in ConvTasNet
98 | """
99 |
100 | def __init__(self, *args, **kwargs):
101 | super(Conv1D, self).__init__(*args, **kwargs)
102 |
103 | def forward(self, x, squeeze=False):
104 | """
105 | x: N x L or N x C x L
106 | """
107 | if x.dim() not in [2, 3]:
108 | raise RuntimeError("{} accept 2/3D tensor as input".format(
109 | self.__name__))
110 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
111 | if squeeze:
112 | x = th.squeeze(x)
113 | return x
114 |
115 |
116 | class ConvTrans1D(nn.ConvTranspose1d):
117 | """
118 | 1D conv transpose in ConvTasNet
119 | """
120 |
121 | def __init__(self, *args, **kwargs):
122 | super(ConvTrans1D, self).__init__(*args, **kwargs)
123 |
124 | def forward(self, x, squeeze=False):
125 | """
126 | x: N x L or N x C x L
127 | """
128 | if x.dim() not in [2, 3]:
129 | raise RuntimeError("{} accept 2/3D tensor as input".format(
130 | self.__name__))
131 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
132 | if squeeze:
133 | x = th.squeeze(x)
134 | return x
135 |
136 |
137 | class Conv1DBlock(nn.Module):
138 | """
139 | 1D convolutional block:
140 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv
141 | """
142 |
143 | def __init__(self,
144 | in_channels=256,
145 | conv_channels=512,
146 | kernel_size=3,
147 | dilation=1,
148 | norm="cLN",
149 | causal=False):
150 | super(Conv1DBlock, self).__init__()
151 | # 1x1 conv
152 | self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
153 | self.prelu1 = nn.PReLU()
154 | self.lnorm1 = build_norm(norm, conv_channels)
155 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
156 | dilation * (kernel_size - 1))
157 | # depthwise conv
158 | self.dconv = nn.Conv1d(
159 | conv_channels,
160 | conv_channels,
161 | kernel_size,
162 | groups=conv_channels,
163 | padding=dconv_pad,
164 | dilation=dilation,
165 | bias=True)
166 | self.prelu2 = nn.PReLU()
167 | self.lnorm2 = build_norm(norm, conv_channels)
168 | # 1x1 conv cross channel
169 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
170 | # different padding way
171 | self.causal = causal
172 | self.dconv_pad = dconv_pad
173 |
174 | def forward(self, x):
175 | y = self.conv1x1(x)
176 | y = self.lnorm1(self.prelu1(y))
177 | y = self.dconv(y)
178 | if self.causal:
179 | y = y[:, :, :-self.dconv_pad]
180 | y = self.lnorm2(self.prelu2(y))
181 | y = self.sconv(y)
182 | x = x + y
183 | return x
184 |
185 |
186 | class Model(nn.Module):
187 | def __init__(self,
188 | L=20,
189 | N=256,
190 | X=8,
191 | R=4,
192 | B=256,
193 | H=512,
194 | P=3,
195 | norm="cLN",
196 | num_spks=1,
197 | non_linear="relu",
198 | causal=False):
199 | super(Model, self).__init__()
200 | supported_nonlinear = {
201 | "relu": F.relu,
202 | "sigmoid": th.sigmoid,
203 | "softmax": F.softmax
204 | }
205 | if non_linear not in supported_nonlinear:
206 | raise RuntimeError("Unsupported non-linear function: {}",
207 | format(non_linear))
208 | self.non_linear_type = non_linear
209 | self.non_linear = supported_nonlinear[non_linear]
210 | # n x S => n x N x T, S = 4s*8000 = 32000
211 | self.encoder_1d = Conv1D(1, N, L, stride=L // 2, padding=0)
212 | # keep T not change
213 | # T = int((xlen - L) / (L // 2)) + 1
214 | # before repeat blocks, always cLN
215 | self.ln = ChannelWiseLayerNorm(N)
216 | # n x N x T => n x B x T
217 | self.proj = Conv1D(N, B, 1)
218 | # repeat blocks
219 | # n x B x T => n x B x T
220 | self.repeats = self._build_repeats(
221 | R,
222 | X,
223 | in_channels=B,
224 | conv_channels=H,
225 | kernel_size=P,
226 | norm=norm,
227 | causal=causal)
228 | # output 1x1 conv
229 | # n x B x T => n x N x T
230 | # NOTE: using ModuleList not python list
231 | # self.conv1x1_2 = th.nn.ModuleList(
232 | # [Conv1D(B, N, 1) for _ in range(num_spks)])
233 | # n x B x T => n x 2N x T
234 | self.mask = Conv1D(B, num_spks * N, 1)
235 | # using ConvTrans1D: n x N x T => n x 1 x To
236 | # To = (T - 1) * L // 2 + L
237 | self.decoder_1d = ConvTrans1D(
238 | N, 1, kernel_size=L, stride=L // 2, bias=True)
239 | self.num_spks = num_spks
240 |
241 | def _build_blocks(self, num_blocks, **block_kwargs):
242 | """
243 | Build Conv1D block
244 | """
245 | blocks = [
246 | Conv1DBlock(**block_kwargs, dilation=(2**b))
247 | for b in range(num_blocks)
248 | ]
249 | return nn.Sequential(*blocks)
250 |
251 | def _build_repeats(self, num_repeats, num_blocks, **block_kwargs):
252 | """
253 | Build Conv1D block repeats
254 | """
255 | repeats = [
256 | self._build_blocks(num_blocks, **block_kwargs)
257 | for r in range(num_repeats)
258 | ]
259 | return nn.Sequential(*repeats)
260 |
261 | def forward(self, x):
262 | x = th.reshape(x, [x.shape[0], -1])
263 | if x.dim() >= 3:
264 | raise RuntimeError(
265 | "{} accept 1/2D tensor as input, but got {:d}".format(
266 | self.__name__, x.dim()))
267 | # when inference, only one utt
268 | if x.dim() == 1:
269 | x = th.unsqueeze(x, 0)
270 | # n x 1 x S => n x N x T
271 | w = F.relu(self.encoder_1d(x))
272 | # n x B x T
273 | y = self.proj(self.ln(w))
274 | # n x B x T
275 | y = self.repeats(y)
276 | # n x 2N x T
277 | e = th.chunk(self.mask(y), self.num_spks, 1)
278 | # n x N x T
279 | if self.non_linear_type == "softmax":
280 | m = self.non_linear(th.stack(e, dim=0), dim=0)
281 | else:
282 | m = self.non_linear(th.stack(e, dim=0))
283 | # spks x [n x N x T]
284 | s = [w * m[n] for n in range(self.num_spks)]
285 | # spks x n x S
286 | out = th.stack([(self.decoder_1d(x, squeeze=True)) for x in s], axis=1)
287 | # print(out.shape)
288 | return out
289 | # [batch, num]
290 | # return [self.decoder_1d(x, squeeze=True) for x in s]
291 |
292 |
293 |
294 | def foo_conv1d_block():
295 | nnet = Conv1DBlock(256, 512, 3, 20)
296 | print(param(nnet))
297 |
298 |
299 | def foo_layernorm():
300 | C, T = 256, 20
301 | nnet1 = nn.LayerNorm([C, T], elementwise_affine=True)
302 | print(param(nnet1, Mb=False))
303 | nnet2 = nn.LayerNorm([C, T], elementwise_affine=False)
304 | print(param(nnet2, Mb=False))
305 |
306 |
307 | def foo_conv_tas_net():
308 | x = th.rand(4, 16000)
309 | nnet = Model(norm="cLN", causal=False)
310 | # print(nnet)
311 | print("ConvTasNet #param: {:.2f}".format(param(nnet)))
312 | x = nnet(x)
313 | print(x.shape)
314 |
315 | if __name__ == "__main__":
316 | foo_conv_tas_net()
317 | # foo_conv1d_block()
318 | # foo_layernorm()
319 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def mse_loss():
4 | return torch.nn.MSELoss()
5 |
6 | def l1_loss():
7 | return torch.nn.L1Loss()
8 |
9 | def bce_loss():
10 | return torch.nn.BCEWithLogitsLoss() # output 0~1
--------------------------------------------------------------------------------
/model/unet_basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DownSamplingLayer(nn.Module):
7 | def __init__(self, channel_in, channel_out, dilation=1, kernel_size=15, stride=1, padding=7):
8 | super(DownSamplingLayer, self).__init__()
9 | self.main = nn.Sequential(
10 | nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
11 | stride=stride, padding=padding, dilation=dilation),
12 | nn.BatchNorm1d(channel_out),
13 | nn.LeakyReLU(negative_slope=0.1)
14 | )
15 |
16 | def forward(self, ipt):
17 | return self.main(ipt)
18 |
19 | class UpSamplingLayer(nn.Module):
20 | def __init__(self, channel_in, channel_out, kernel_size=5, stride=1, padding=2):
21 | super(UpSamplingLayer, self).__init__()
22 | self.main = nn.Sequential(
23 | nn.Conv1d(channel_in, channel_out, kernel_size=kernel_size,
24 | stride=stride, padding=padding),
25 | nn.BatchNorm1d(channel_out),
26 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
27 | )
28 |
29 | def forward(self, ipt):
30 | return self.main(ipt)
31 |
32 | class Model(nn.Module):
33 |
34 | def __init__(self, n_layers=12, channels_interval=24):
35 | super(Model, self).__init__()
36 |
37 | self.n_layers = n_layers
38 | self.channels_interval = channels_interval
39 | encoder_in_channels_list = [1] + [i * self.channels_interval for i in range(1, self.n_layers)]
40 | encoder_out_channels_list = [i * self.channels_interval for i in range(1, self.n_layers + 1)]
41 |
42 | # 1 => 2 => 3 => 4 => 5 => 6 => 7 => 8 => 9 => 10 => 11 =>12
43 | # 16384 => 8192 => 4096 => 2048 => 1024 => 512 => 256 => 128 => 64 => 32 => 16 => 8 => 4
44 | self.encoder = nn.ModuleList()
45 | for i in range(self.n_layers):
46 | self.encoder.append(
47 | DownSamplingLayer(
48 | channel_in=encoder_in_channels_list[i],
49 | channel_out=encoder_out_channels_list[i]
50 | )
51 | )
52 |
53 | self.middle = nn.Sequential(
54 | nn.Conv1d(self.n_layers * self.channels_interval, self.n_layers * self.channels_interval, 15, stride=1,
55 | padding=7),
56 | nn.BatchNorm1d(self.n_layers * self.channels_interval),
57 | nn.LeakyReLU(negative_slope=0.1, inplace=True)
58 | )
59 |
60 | decoder_in_channels_list = [(2 * i + 1) * self.channels_interval for i in range(1, self.n_layers)] + [
61 | 2 * self.n_layers * self.channels_interval]
62 | decoder_in_channels_list = decoder_in_channels_list[::-1]
63 | decoder_out_channels_list = encoder_out_channels_list[::-1]
64 | self.decoder = nn.ModuleList()
65 | for i in range(self.n_layers):
66 | self.decoder.append(
67 | UpSamplingLayer(
68 | channel_in=decoder_in_channels_list[i],
69 | channel_out=decoder_out_channels_list[i]
70 | )
71 | )
72 |
73 | self.out = nn.Sequential(
74 | nn.Conv1d(1 + self.channels_interval, 1, kernel_size=1, stride=1),
75 | nn.Tanh()
76 | )
77 |
78 | def forward(self, input):
79 | tmp = []
80 | o = input
81 |
82 | # Up Sampling
83 | for i in range(self.n_layers):
84 | o = self.encoder[i](o)
85 | tmp.append(o)
86 | # [batch_size, T // 2, channels]
87 | o = o[:, :, ::2]
88 |
89 | o = self.middle(o)
90 |
91 | # Down Sampling
92 | for i in range(self.n_layers):
93 | # [batch_size, T * 2, channels]
94 | o = F.interpolate(o, scale_factor=2, mode="linear", align_corners=True)
95 | # Skip Connection
96 | o = torch.cat([o, tmp[self.n_layers - i - 1]], dim=1)
97 | o = self.decoder[i](o)
98 |
99 | o = torch.cat([o, input], dim=1)
100 | o = self.out(o)
101 | return o
102 |
103 |
104 | # n_layers = 12, channels_interval = 24
105 | # UpSamplingLayer(288 + 288, 288),
106 | # UpSamplingLayer(264 + 288, 264), # 同水平层的降采样后维度为 264
107 | # UpSamplingLayer(240 + 264, 240),
108 | #
109 | # UpSamplingLayer(216 + 240, 216),
110 | # UpSamplingLayer(192 + 216, 192),
111 | # UpSamplingLayer(168 + 192, 168),
112 | #
113 | # UpSamplingLayer(144 + 168, 144),
114 | # UpSamplingLayer(120 + 144, 120),
115 | # UpSamplingLayer(96 + 120, 96),
116 | #
117 | # UpSamplingLayer(72 + 96, 72),
118 | # UpSamplingLayer(48 + 72, 48),
119 | # UpSamplingLayer(24 + 48, 24),
120 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/trainer/__init__.py
--------------------------------------------------------------------------------
/trainer/base_trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from pathlib import Path
3 |
4 | import json5
5 | import numpy as np
6 | import torch
7 | from torch.optim.lr_scheduler import StepLR
8 | from util import visualization
9 | from util.utils import prepare_empty_dir, ExecutionTime
10 |
11 | class BaseTrainer:
12 | def __init__(self, config, resume: bool, model, loss_function, optimizer):
13 | self.n_gpu = torch.cuda.device_count()
14 | self.device = self._prepare_device(self.n_gpu, cudnn_deterministic=config["cudnn_deterministic"])
15 |
16 | self.optimizer = optimizer
17 | self.loss_function = loss_function
18 |
19 | self.model = model.to(self.device)
20 |
21 | if self.n_gpu > 1:
22 | self.model = torch.nn.DataParallel(self.model, device_ids=list(range(self.n_gpu)))
23 |
24 | # Trainer
25 | self.epochs = config["trainer"]["epochs"]
26 | self.save_checkpoint_interval = config["trainer"]["save_checkpoint_interval"]
27 | self.validation_config = config["trainer"]["validation"]
28 | self.validation_interval = self.validation_config["interval"]
29 | self.find_max = self.validation_config["find_max"]
30 | self.validation_custom_config = self.validation_config["custom"]
31 |
32 | # The following args is not in the config file, We will update it if resume is True in later.
33 | self.start_epoch = 1
34 | self.best_score = -np.inf if self.find_max else np.inf
35 | self.root_dir = Path(config["root_dir"]).expanduser().absolute() / config["experiment_name"]
36 | self.checkpoints_dir = self.root_dir / "checkpoints"
37 | self.logs_dir = self.root_dir / "logs"
38 | prepare_empty_dir([self.checkpoints_dir, self.logs_dir], resume=resume)
39 |
40 | self.writer = visualization.writer(self.logs_dir.as_posix())
41 | self.writer.add_text(
42 | tag="Configuration",
43 | text_string=f" \n{json5.dumps(config, indent=4, sort_keys=False)} \n
",
44 | global_step=1
45 | )
46 |
47 | if resume: self._resume_checkpoint()
48 |
49 | print("Configurations are as follows: ")
50 | print(json5.dumps(config, indent=2, sort_keys=False))
51 |
52 | with open((self.root_dir / f"{time.strftime('%Y-%m-%d-%H-%M-%S')}.json").as_posix(), "w") as handle:
53 | json5.dump(config, handle, indent=2, sort_keys=False)
54 |
55 | self._print_networks([self.model])
56 |
57 | def _resume_checkpoint(self):
58 | """Resume experiment from latest checkpoint.
59 | Notes:
60 | To be careful at Loading model. if model is an instance of DataParallel, we need to set model.module.*
61 | """
62 | latest_model_path = self.checkpoints_dir.expanduser().absolute() / "latest_model.tar"
63 | assert latest_model_path.exists(), f"{latest_model_path} does not exist, can not load latest checkpoint."
64 |
65 | checkpoint = torch.load(latest_model_path.as_posix(), map_location=self.device)
66 |
67 | self.start_epoch = checkpoint["epoch"] + 1
68 | self.best_score = checkpoint["best_score"]
69 | self.optimizer.load_state_dict(checkpoint["optimizer"])
70 |
71 | if isinstance(self.model, torch.nn.DataParallel):
72 | self.model.module.load_state_dict(checkpoint["model"])
73 | else:
74 | self.model.load_state_dict(checkpoint["model"])
75 |
76 | print(f"Model checkpoint loaded. Training will begin in {self.start_epoch} epoch.")
77 |
78 | def _save_checkpoint(self, epoch, is_best=False):
79 | """Save checkpoint to /checkpoints directory, which contains:
80 | - current epoch
81 | - best score in history
82 | - optimizer parameters
83 | - model parameters
84 | Args:
85 | is_best(bool): if current checkpoint got the best score, it also will be saved in /checkpoints/best_model.tar.
86 | """
87 | print(f"\t Saving {epoch} epoch model checkpoint...")
88 |
89 | # Construct checkpoint tar package
90 | state_dict = {
91 | "epoch": epoch,
92 | "best_score": self.best_score,
93 | "optimizer": self.optimizer.state_dict()
94 | }
95 |
96 | if isinstance(self.model, torch.nn.DataParallel): # Parallel
97 | state_dict["model"] = self.model.module.cpu().state_dict()
98 | else:
99 | state_dict["model"] = self.model.cpu().state_dict()
100 |
101 | """
102 | Notes:
103 | - latest_model.tar:
104 | Contains all checkpoint information, including optimizer parameters, model parameters, etc. New checkpoint will overwrite old one.
105 | - model_.pth:
106 | The parameters of the model. Follow-up we can specify epoch to inference.
107 | - best_model.tar:
108 | Like latest_model, but only saved when is True.
109 | """
110 | torch.save(state_dict, (self.checkpoints_dir / "latest_model.tar").as_posix())
111 | torch.save(state_dict["model"], (self.checkpoints_dir / f"model_{str(epoch).zfill(4)}.pth").as_posix())
112 | if is_best:
113 | print(f"\t Found best score in {epoch} epoch, saving...")
114 | torch.save(state_dict, (self.checkpoints_dir / "best_model.tar").as_posix())
115 |
116 | # Use model.cpu() or model.to("cpu") will migrate the model to CPU, at which point we need re-migrate model back.
117 | # No matter tensor.cuda() or tensor.to("cuda"), if tensor in CPU, the tensor will not be migrated to GPU, but the model will.
118 | self.model.to(self.device)
119 |
120 | @staticmethod
121 | def _prepare_device(n_gpu: int, cudnn_deterministic=False):
122 | """Choose to use CPU or GPU depend on "n_gpu".
123 | Args:
124 | n_gpu(int): the number of GPUs used in the experiment.
125 | if n_gpu is 0, use CPU;
126 | if n_gpu > 1, use GPU.
127 | cudnn_deterministic (bool): repeatability
128 | cudnn.benchmark will find algorithms to optimize training. if we need to consider the repeatability of experiment, set use_cudnn_deterministic to True
129 | """
130 | if n_gpu == 0:
131 | print("Using CPU in the experiment.")
132 | device = torch.device("cpu")
133 | else:
134 | if cudnn_deterministic:
135 | print("Using CuDNN deterministic mode in the experiment.")
136 | torch.backends.cudnn.deterministic = True
137 | torch.backends.cudnn.benchmark = False
138 |
139 | device = torch.device("cuda:0")
140 |
141 | return device
142 |
143 | def _is_best(self, score, find_max=True):
144 | """Check if the current model is the best model
145 | """
146 | if find_max and score >= self.best_score:
147 | self.best_score = score
148 | return True
149 | elif not find_max and score <= self.best_score:
150 | self.best_score = score
151 | return True
152 | else:
153 | return False
154 |
155 | @staticmethod
156 | def _transform_pesq_range(pesq_score):
157 | """transform [-0.5 ~ 4.5] to [0 ~ 1]
158 | """
159 | return (pesq_score + 0.5) / 5
160 |
161 | @staticmethod
162 | def _print_networks(nets: list):
163 | print(f"This project contains {len(nets)} networks, the number of the parameters: ")
164 | params_of_all_networks = 0
165 | for i, net in enumerate(nets, start=1):
166 | params_of_network = 0
167 | for param in net.parameters():
168 | params_of_network += param.numel()
169 |
170 | print(f"\tNetwork {i}: {params_of_network / 1e6} million.")
171 | params_of_all_networks += params_of_network
172 |
173 | print(f"The amount of parameters in the project is {params_of_all_networks / 1e6} million.")
174 |
175 | def _set_models_to_train_mode(self):
176 | self.model.train()
177 |
178 | def _set_models_to_eval_mode(self):
179 | self.model.eval()
180 |
181 | def train(self):
182 | for epoch in range(self.start_epoch, self.epochs + 1):
183 | print(f"============== {epoch} epoch ==============")
184 | print("[0 seconds] Begin training...")
185 | timer = ExecutionTime()
186 |
187 | self._set_models_to_train_mode()
188 | self._train_epoch(epoch)
189 |
190 | if self.save_checkpoint_interval != 0 and (epoch % self.save_checkpoint_interval == 0):
191 | self._save_checkpoint(epoch)
192 |
193 | if self.validation_interval != 0 and epoch % self.validation_interval == 0:
194 | print(f"[{timer.duration()} seconds] Training is over, Validation is in progress...")
195 |
196 | self._set_models_to_eval_mode()
197 | score = self._validation_epoch(epoch)
198 |
199 | if self._is_best(score, find_max=self.find_max):
200 | self._save_checkpoint(epoch, is_best=True)
201 |
202 | print(f"[{timer.duration()} seconds] End this epoch.")
203 |
204 | def _train_epoch(self, epoch):
205 | raise NotImplementedError
206 |
207 | def _validation_epoch(self, epoch):
208 | raise NotImplementedError
--------------------------------------------------------------------------------
/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.display
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 |
7 | from trainer.base_trainer import BaseTrainer
8 | from util.utils import compute_STOI, compute_PESQ
9 | plt.switch_backend('agg')
10 |
11 |
12 | class Trainer(BaseTrainer):
13 | def __init__(
14 | self,
15 | config,
16 | resume: bool,
17 | model,
18 | loss_function,
19 | optimizer,
20 | train_dataloader,
21 | validation_dataloader,
22 | ):
23 | super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
24 | self.train_data_loader = train_dataloader
25 | self.validation_data_loader = validation_dataloader
26 |
27 | def _train_epoch(self, epoch):
28 | loss_total = 0.0
29 |
30 | for i, (mixture, clean, name) in enumerate(self.train_data_loader):
31 | mixture = mixture.to(self.device)
32 | clean = clean.to(self.device)
33 |
34 | self.optimizer.zero_grad()
35 | enhanced = self.model(mixture)
36 | loss = self.loss_function(clean, enhanced)
37 | loss.backward()
38 | self.optimizer.step()
39 |
40 | loss_total += loss.item()
41 |
42 | dl_len = len(self.train_data_loader)
43 | self.writer.add_scalar(f"Train/Loss", loss_total / dl_len, epoch)
44 |
45 | @torch.no_grad()
46 | def _validation_epoch(self, epoch):
47 | visualize_audio_limit = self.validation_custom_config["visualize_audio_limit"]
48 | visualize_waveform_limit = self.validation_custom_config["visualize_waveform_limit"]
49 | visualize_spectrogram_limit = self.validation_custom_config["visualize_spectrogram_limit"]
50 |
51 | sample_length = self.validation_custom_config["sample_length"]
52 |
53 | stoi_c_n = [] # clean and noisy
54 | stoi_c_d = [] # clean and denoisy
55 | pesq_c_n = []
56 | pesq_c_d = []
57 |
58 | for i, (mixture, clean, name) in enumerate(self.validation_data_loader):
59 | assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
60 | name = name[0]
61 |
62 | # [1, 1, T]
63 | mixture = mixture.to(self.device)
64 | clean = clean.to(self.device)
65 |
66 | # Input of model should fixed length
67 | mixture_chunks = torch.split(mixture, sample_length, dim=2)
68 | if mixture_chunks[-1].shape[-1] != sample_length:
69 | mixture_chunks = mixture_chunks[:-1]
70 |
71 | enhanced_chunks = []
72 | for chunk in mixture_chunks:
73 | enhanced_chunks.append(self.model(chunk).detach().cpu())
74 |
75 | enhanced = torch.cat(enhanced_chunks, dim=2)
76 |
77 | # Back to numpy array
78 | mixture = mixture.cpu().numpy().reshape(-1)
79 | enhanced = enhanced.numpy().reshape(-1)
80 | clean = clean.cpu().numpy().reshape(-1)
81 |
82 | min_len = min(len(mixture), len(clean), len(enhanced))
83 |
84 | mixture = mixture[:min_len]
85 | clean = clean[:min_len]
86 | enhanced = enhanced[:min_len]
87 |
88 | # Visualize audio
89 | if i <= visualize_audio_limit:
90 | self.writer.add_audio(f"Speech/{name}_Noisy", mixture, epoch, sample_rate=16000)
91 | self.writer.add_audio(f"Speech/{name}_Enhanced", enhanced, epoch, sample_rate=16000)
92 | self.writer.add_audio(f"Speech/{name}_Clean", clean, epoch, sample_rate=16000)
93 |
94 | # Visualize waveform
95 | if i <= visualize_waveform_limit:
96 | fig, ax = plt.subplots(3, 1)
97 | for j, y in enumerate([mixture, enhanced, clean]):
98 | ax[j].set_title("mean: {:.3f}, std: {:.3f}, max: {:.3f}, min: {:.3f}".format(
99 | np.mean(y),
100 | np.std(y),
101 | np.max(y),
102 | np.min(y)
103 | ))
104 | librosa.display.waveplot(y, sr=16000, ax=ax[j])
105 | plt.tight_layout()
106 | self.writer.add_figure(f"Waveform/{name}", fig, epoch)
107 |
108 | # Visualize spectrogram
109 | noisy_mag, _ = librosa.magphase(librosa.stft(mixture, n_fft=320, hop_length=160, win_length=320))
110 | enhanced_mag, _ = librosa.magphase(librosa.stft(enhanced, n_fft=320, hop_length=160, win_length=320))
111 | clean_mag, _ = librosa.magphase(librosa.stft(clean, n_fft=320, hop_length=160, win_length=320))
112 |
113 | if i <= visualize_spectrogram_limit:
114 | fig, axes = plt.subplots(3, 1, figsize=(6, 6))
115 | for k, mag in enumerate([
116 | noisy_mag,
117 | enhanced_mag,
118 | clean_mag,
119 | ]):
120 | axes[k].set_title(f"mean: {np.mean(mag):.3f}, "
121 | f"std: {np.std(mag):.3f}, "
122 | f"max: {np.max(mag):.3f}, "
123 | f"min: {np.min(mag):.3f}")
124 | librosa.display.specshow(librosa.amplitude_to_db(mag), cmap="magma", y_axis="linear", ax=axes[k], sr=16000)
125 | plt.tight_layout()
126 | self.writer.add_figure(f"Spectrogram/{name}", fig, epoch)
127 |
128 | # Metric
129 | stoi_c_n.append(compute_STOI(clean, mixture, sr=16000))
130 | stoi_c_d.append(compute_STOI(clean, enhanced, sr=16000))
131 | pesq_c_n.append(compute_PESQ(clean, mixture, sr=16000))
132 | pesq_c_d.append(compute_PESQ(clean, enhanced, sr=16000))
133 |
134 | get_metrics_ave = lambda metrics: np.sum(metrics) / len(metrics)
135 | self.writer.add_scalars(f"评价指标均值/STOI", {
136 | "clean 与 noisy": get_metrics_ave(stoi_c_n),
137 | "clean 与 denoisy": get_metrics_ave(stoi_c_d)
138 | }, epoch)
139 | self.writer.add_scalars(f"评价指标均值/PESQ", {
140 | "clean 与 noisy": get_metrics_ave(pesq_c_n),
141 | "clean 与 denoisy": get_metrics_ave(pesq_c_d)
142 | }, epoch)
143 |
144 | score = (get_metrics_ave(stoi_c_d) + self._transform_pesq_range(get_metrics_ave(pesq_c_d))) / 2
145 | return score
146 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wzhiyuyu/Wave-U-Net-for-SpeechEnhancement/308a2be9d91d931f09e873b19c3d67315b9b5962/util/__init__.py
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import time
3 | import os
4 |
5 | import torch
6 | from pesq import pesq
7 | import numpy as np
8 | from pystoi.stoi import stoi
9 |
10 |
11 | def load_checkpoint(checkpoint_path, device):
12 | _, ext = os.path.splitext(os.path.basename(checkpoint_path))
13 | assert ext in (".pth", ".tar"), "Only support ext and tar extensions of model checkpoint."
14 | model_checkpoint = torch.load(checkpoint_path, map_location=device)
15 |
16 | if ext == ".pth":
17 | print(f"Loading {checkpoint_path}.")
18 | return model_checkpoint
19 | else: # tar
20 | print(f"Loading {checkpoint_path}, epoch = {model_checkpoint['epoch']}.")
21 | return model_checkpoint["model"]
22 |
23 |
24 | def prepare_empty_dir(dirs, resume=False):
25 | """
26 | if resume experiment, assert the dirs exist,
27 | if not resume experiment, make dirs.
28 |
29 | Args:
30 | dirs (list): directors list
31 | resume (bool): whether to resume experiment, default is False
32 | """
33 | for dir_path in dirs:
34 | if resume:
35 | assert dir_path.exists()
36 | else:
37 | dir_path.mkdir(parents=True, exist_ok=True)
38 |
39 |
40 | class ExecutionTime:
41 | """
42 | Usage:
43 | timer = ExecutionTime()
44 |
45 | print(f'Finished in {timer.duration()} seconds.')
46 | """
47 |
48 | def __init__(self):
49 | self.start_time = time.time()
50 |
51 | def duration(self):
52 | return int(time.time() - self.start_time)
53 |
54 |
55 | def initialize_config(module_cfg, pass_args=True):
56 | """
57 | According to config items, load specific module dynamically with params.
58 | eg,config items as follow:
59 | module_cfg = {
60 | "module": "model.model",
61 | "main": "Model",
62 | "args": {...}
63 | }
64 | 1. Load the module corresponding to the "module" param.
65 | 2. Call function (or instantiate class) corresponding to the "main" param.
66 | 3. Send the param (in "args") into the function (or class) when calling ( or instantiating)
67 | """
68 | module = importlib.import_module(module_cfg["module"])
69 |
70 | if pass_args:
71 | return getattr(module, module_cfg["main"])(**module_cfg["args"])
72 | else:
73 | return getattr(module, module_cfg["main"])
74 |
75 |
76 |
77 | def compute_PESQ(clean_signal, noisy_signal, sr=16000):
78 | return pesq(sr, clean_signal, noisy_signal, "wb")
79 |
80 |
81 | def z_score(m):
82 | mean = np.mean(m)
83 | std_var = np.std(m)
84 | return (m - mean) / std_var, mean, std_var
85 |
86 |
87 | def reverse_z_score(m, mean, std_var):
88 | return m * std_var + mean
89 |
90 |
91 | def min_max(m):
92 | m_max = np.max(m)
93 | m_min = np.min(m)
94 |
95 | return (m - m_min) / (m_max - m_min), m_max, m_min
96 |
97 |
98 | def reverse_min_max(m, m_max, m_min):
99 | return m * (m_max - m_min) + m_min
100 |
101 |
102 | def sample_fixed_length_data_aligned(data_a, data_b, sample_length):
103 | """
104 | sample with fixed length from two dataset
105 | """
106 | assert len(data_a) == len(data_b), "Inconsistent dataset length, unable to sampling"
107 | assert len(data_a) >= sample_length, f"len(data_a) is {len(data_a)}, sample_length is {sample_length}."
108 |
109 | frames_total = len(data_a)
110 |
111 | start = np.random.randint(frames_total - sample_length + 1)
112 | # print(f"Random crop from: {start}")
113 | end = start + sample_length
114 |
115 | return data_a[start:end], data_b[start:end]
116 |
117 |
118 | def compute_STOI(clean_signal, noisy_signal, sr=16000):
119 | return stoi(clean_signal, noisy_signal, sr, extended=False)
120 |
121 |
122 | def print_tensor_info(tensor, flag="Tensor"):
123 | floor_tensor = lambda float_tensor: int(float(float_tensor) * 1000) / 1000
124 | print(flag)
125 | print(
126 | f"\tmax: {floor_tensor(torch.max(tensor))}, min: {float(torch.min(tensor))}, mean: {floor_tensor(torch.mean(tensor))}, std: {floor_tensor(torch.std(tensor))}")
127 |
--------------------------------------------------------------------------------
/util/visualization.py:
--------------------------------------------------------------------------------
1 | from torch.utils.tensorboard import SummaryWriter
2 |
3 |
4 | def writer(logs_dir):
5 | return SummaryWriter(log_dir=logs_dir, max_queue=5, flush_secs=30)
--------------------------------------------------------------------------------