├── README.md
├── data
├── dev
│ ├── auxs1.scp
│ ├── mix_clean.scp
│ └── ref.scp
├── test
│ ├── auxs1.scp
│ ├── mix_clean.scp
│ └── ref.scp
└── train
│ ├── auxs1.scp
│ ├── mix_clean.scp
│ └── ref.scp
├── eval.sh
├── nnet
├── SEF_PNet_pse.py
├── __pycache__
│ ├── SEF_PNet_pse.cpython-39.pyc
│ └── conf_unet_tse_32ms.cpython-39.pyc
├── conf_unet_tse_32ms.py
├── evaluate.py
├── libs
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── audio.cpython-39.pyc
│ │ ├── conv_stft.cpython-39.pyc
│ │ ├── dataset_tse.cpython-39.pyc
│ │ ├── trainer_unet_tse_steplr_clip.cpython-39.pyc
│ │ └── utils.cpython-39.pyc
│ ├── audio.py
│ ├── conv_stft.py
│ ├── dataset_tse.py
│ ├── metric.py
│ ├── trainer_unet_tse_steplr_clip.py
│ └── utils.py
├── memonger
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── checkpoint.cpython-39.pyc
│ │ └── memonger.cpython-39.pyc
│ ├── checkpoint.py
│ ├── memonger.py
│ └── resnet.py
├── separate.py
└── train_unet_tse_steplr_clip.py
├── requirements.txt
├── separate.sh
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # SEF-PNet
2 |
3 | Official PyTorch implementation of the paper "[SEF-PNet: Speaker Encoder-Free Personalized Speech Enhancement with Local and Global Contexts Aggregation](https://arxiv.org/abs/2501.11274)" in ICASSP 2025.
4 |
5 | ## Dataset
6 | [Libri2Mix](https://github.com/JorisCos/LibriMix) min wav8k dataset. The `Data` folder contains three subfolders: `train`, `dev`, and `test`. Each subfolder includes three files:
7 | - `mix_clean.scp`: Clean mixtures of 2 speakers.
8 | - `ref.scp`: Target speaker’s speech.
9 | - `auxs1.scp`: Enrollment speech from the target speaker, which is different from the target speaker’s speech in the mixture.
10 |
11 | The `mix_clean.scp` corresponds to the **2-speaker** scenario in the results section.
12 | Note that in this dataset, only the first speaker in the mixed speech is considered the target speaker.
13 | Make sure to update the file paths in the `scp` files to match your local data locations. Also, remember to update the data paths in `conf_unet_tse_32ms.py` accordingly.
14 |
15 | ## Training
16 | - **`train.sh`**: Shell script that initiates training by setting parameters (e.g., epochs, batch size, GPU settings) and calling the Python script (`train_unet_tse_steplr_clip.py`). To train the model, run:
17 | ```bash
18 | ./train.sh
19 |
20 | - **`train_unet_tse_steplr_clip.py`**: Main Python script for training. It initializes the model, sets up data loaders, and manages the training loop.
21 |
22 | - **`conf_unet_tse_32ms.py`**: Configuration file containing model architecture, data paths, and training hyperparameters.
23 |
24 | - **`SEF_PNet_pse.py`**: Defines the `SEF_PNet` model, which is used in the training script.
25 |
26 | ## Evaluation
27 |
28 | To evaluate the model, use the provided `eval.sh` script. It sets the necessary parameters (e.g., model checkpoint, GPU ID, data paths) and calls `evaluate.py` for performance evaluation.
29 |
30 | - **`eval.sh`**: Runs the evaluation by setting paths and calling `evaluate.py`.
31 | - Usage:
32 | ```bash
33 | ./eval.sh
34 | ```
35 |
36 | - **`evaluate.py`**: Evaluates the model on the test set, computing metrics like SDR, SI-SNR, PESQ, and STOI.
37 |
38 | ## Results
39 |
40 | Condition-wise results on three Libri2Mix PSE tasks:
41 |
42 |
43 |
44 |
45 | Condition |
46 | Method |
47 | Metrics |
48 |
49 |
50 | SI-SDR |
51 | PESQ |
52 | STOI |
53 |
54 |
55 |
56 |
57 | 1-speaker+noise |
58 | Mixture |
59 | 3.27 |
60 | 1.75 |
61 | 79.51 |
62 |
63 |
64 | sDPCCN |
65 | 14.49 |
66 | 3.04 |
67 | 92.47 |
68 |
69 |
70 | SEF-PNet |
71 | 14.50 |
72 | 3.05 |
73 | 92.47 |
74 |
75 |
76 | 2-speaker |
77 | Mixture |
78 | -0.03 |
79 | 1.60 |
80 | 71.38 |
81 |
82 |
83 | sDPCCN |
84 | 11.62 |
85 | 2.76 |
86 | 87.19 |
87 |
88 |
89 | SEF-PNet |
90 | 13.00 |
91 | 3.05 |
92 | 89.71 |
93 |
94 |
95 | 2-speaker+noise |
96 | Mixture |
97 | -2.03 |
98 | 1.43 |
99 | 64.65 |
100 |
101 |
102 | sDPCCN |
103 | 6.93 |
104 | 2.12 |
105 | 79.32 |
106 |
107 |
108 | SEF-PNet |
109 | 7.54 |
110 | 2.14 |
111 | 80.58 |
112 |
113 |
114 |
115 |
116 | ### GPU Setup
117 | This code is designed to run on a single GPU. By default, in the `train.sh` script, the `gpuid` is set to `0`.
118 |
119 | To use multiple GPUs, modify `gpuid=0,1,2,...` in `train.sh`.
120 |
121 | Additionally, for multi-GPU setups, comment out the line:
122 | ```python
123 | from memonger import SublinearSequential
124 | ```
125 | and replace SublinearSequential with nn.Sequential in SEF_PNet_pse.py to avoid memory issues.
126 |
127 | ### Create SCP
128 | The SCP file I provided is from [DPCCN](https://github.com/jyhan03/icassp22-dataset/tree/main/lst/libri2mix). It only uses the first speaker as the target. To match MC-Spex results for the 2-speaker condition in Libri2Mix, you'll need to use double the data, with two speakers taking turns as the target. This means you’ll need to recreate the SCP files for training, validation, and testing. You can use the script in the link for reference.
129 |
130 | Any problems, contact me at hzlkycg111@163.com, and a reply will be given promptly.
131 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eu
3 |
4 | checkpoint=/node/hzl/expriment/libri2mix_min_wav8k/SEF_PNet
5 | gpuid=0
6 |
7 | data_root=/node/hzl/data/data_libri2mix_s1_min_wav8k/test
8 |
9 | mix_scp=$data_root/mix_clean.scp
10 | spk1_scp=$data_root/s1.scp
11 | aux_scp=$data_root/auxs1.scp
12 |
13 | cal_sdr=1
14 |
15 | ./nnet/evaluate.py \
16 | --checkpoint $checkpoint \
17 | --gpuid $gpuid \
18 | --mix_scp $mix_scp \
19 | --ref_scp $spk1_scp \
20 | --aux_scp $aux_scp \
21 | --cal_sdr $cal_sdr \
22 | > eval.log 2>&1
23 |
24 | echo "eval done!"
25 |
--------------------------------------------------------------------------------
/nnet/SEF_PNet_pse.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Sun June 2 2024
3 | @author: Ziling Huang
4 | """
5 | import torch as th
6 | import torch.nn as nn
7 | import torch.nn.functional as nn_f
8 | from typing import Tuple, List
9 | from memonger import SublinearSequential
10 | from libs.conv_stft import ConvSTFT, ConviSTFT
11 |
12 | def param(nnet, Mb=True):
13 | """
14 | Return number parameters(not bytes) in nnet
15 | """
16 | neles = sum([param.nelement() for param in nnet.parameters()])
17 | return neles / 10**6 if Mb else neles
18 |
19 | class Conv1D(nn.Conv1d):
20 | """
21 | 1D conv in ConvTasNet
22 | """
23 |
24 | def __init__(self, *args, **kwargs):
25 | super(Conv1D, self).__init__(*args, **kwargs)
26 |
27 | def forward(self, x, squeeze=False):
28 | """
29 | x: N x L or N x C x L
30 | """
31 | if x.dim() not in [2, 3]:
32 | raise RuntimeError("{} accept 2/3D tensor as input".format(
33 | self.__name__))
34 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
35 | if squeeze:
36 | x = th.squeeze(x)
37 | return x
38 |
39 | class ChannelWiseLayerNorm(nn.LayerNorm):
40 | """
41 | Channel wise layer normalization
42 | """
43 |
44 | def __init__(self, *args, **kwargs):
45 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)
46 |
47 | def forward(self, x):
48 | """
49 | x: N x C x T
50 | """
51 | if x.dim() != 3:
52 | raise RuntimeError("{} accept 3D tensor as input".format(
53 | self.__name__))
54 | # N x C x T => N x T x C
55 | x = th.transpose(x, 1, 2)
56 | # LN
57 | x = super().forward(x)
58 | # N x C x T => N x T x C
59 | x = th.transpose(x, 1, 2)
60 | return x
61 |
62 | class GlobalChannelLayerNorm(nn.Module):
63 | """
64 | Global channel layer normalization
65 | """
66 |
67 | def __init__(self, dim, eps=1e-05, elementwise_affine=True):
68 | super(GlobalChannelLayerNorm, self).__init__()
69 | self.eps = eps
70 | self.normalized_dim = dim
71 | self.elementwise_affine = elementwise_affine
72 | if elementwise_affine:
73 | self.beta = nn.Parameter(th.zeros(dim, 1))
74 | self.gamma = nn.Parameter(th.ones(dim, 1))
75 | else:
76 | self.register_parameter("weight", None)
77 | self.register_parameter("bias", None)
78 |
79 | def forward(self, x):
80 | """
81 | x: N x C x T
82 | """
83 | if x.dim() != 3:
84 | raise RuntimeError("{} accept 3D tensor as input".format(
85 | self.__name__))
86 | # N x 1 x 1
87 | mean = th.mean(x, (1, 2), keepdim=True)
88 | var = th.mean((x - mean)**2, (1, 2), keepdim=True)
89 | # N x T x C
90 | if self.elementwise_affine:
91 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
92 | else:
93 | x = (x - mean) / th.sqrt(var + self.eps)
94 | return x
95 |
96 | def extra_repr(self):
97 | return "{normalized_dim}, eps={eps}, " \
98 | "elementwise_affine={elementwise_affine}".format(**self.__dict__)
99 |
100 | def build_norm(norm, dim):
101 | """
102 | Build normalize layer
103 | LN cost more memory than BN
104 | """
105 | if norm not in ["cLN", "gLN", "BN"]:
106 | raise RuntimeError("Unsupported normalize layer: {}".format(norm))
107 | if norm == "cLN":
108 | return ChannelWiseLayerNorm(dim, elementwise_affine=True)
109 | elif norm == "BN":
110 | return nn.BatchNorm1d(dim)
111 | else:
112 | return GlobalChannelLayerNorm(dim, elementwise_affine=True)
113 |
114 | class Conv2dBlock(nn.Module):
115 | def __init__(self,
116 | in_dims: int = 16,
117 | out_dims: int = 32,
118 | kernel_size: Tuple[int] = (3, 3),
119 | stride: Tuple[int] = (1, 1),
120 | padding: Tuple[int] = (1, 1)) -> None:
121 | super(Conv2dBlock, self).__init__()
122 | self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride, padding)
123 | self.elu = nn.ELU()
124 | self.norm = nn.InstanceNorm2d(out_dims)
125 |
126 | def forward(self, x: th.Tensor) -> th.Tensor:
127 | x = self.conv2d(x)
128 | x = self.elu(x)
129 | return self.norm(x)
130 |
131 | class ConvTrans2dBlock(nn.Module):
132 | def __init__(self,
133 | in_dims: int = 32,
134 | out_dims: int = 16,
135 | kernel_size: Tuple[int] = (3, 3),
136 | stride: Tuple[int] = (1, 2),
137 | padding: Tuple[int] = (1, 0),
138 | output_padding: Tuple[int] = (0, 0)) -> None:
139 | super(ConvTrans2dBlock, self).__init__()
140 | self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size, stride, padding, output_padding)
141 | self.elu = nn.ELU()
142 | self.norm = nn.InstanceNorm2d(out_dims)
143 |
144 | def forward(self, x: th.Tensor) -> th.Tensor:
145 | x = self.convtrans2d(x)
146 | x = self.elu(x)
147 | return self.norm(x)
148 |
149 | class DenseBlock(nn.Module):
150 | def __init__(self, in_dims, out_dims, mode = "enc", **kargs):
151 | super(DenseBlock, self).__init__()
152 | if mode not in ["enc", "dec"]:
153 | raise RuntimeError("The mode option must be 'enc' or 'dec'!")
154 |
155 | n = 1 if mode == "enc" else 2
156 | self.conv1 = Conv2dBlock(in_dims=in_dims*n, out_dims=in_dims, **kargs)
157 | self.conv2 = Conv2dBlock(in_dims=in_dims*(n+1), out_dims=in_dims, **kargs)
158 | self.conv3 = Conv2dBlock(in_dims=in_dims*(n+2), out_dims=in_dims, **kargs)
159 | self.conv4 = Conv2dBlock(in_dims=in_dims*(n+3), out_dims=in_dims, **kargs)
160 | self.conv5 = Conv2dBlock(in_dims=in_dims*(n+4), out_dims=out_dims, **kargs)
161 |
162 | def forward(self, x: th.Tensor) -> th.Tensor:
163 | y1 = self.conv1(x)
164 | y2 = self.conv2(th.cat([x, y1], 1))
165 | y3 = self.conv3(th.cat([x, y1, y2], 1))
166 | y4 = self.conv4(th.cat([x, y1, y2, y3], 1))
167 | y5 = self.conv5(th.cat([x, y1, y2, y3, y4], 1))
168 | return y5
169 |
170 | class TCNBlock(nn.Module):
171 | """
172 | TCN block:
173 | IN - ELU - Conv1D - IN - ELU - Conv1D
174 | """
175 |
176 | def __init__(self,
177 | in_dims: int = 384,
178 | out_dims: int = 384,
179 | kernel_size: int = 3,
180 | stride: int = 1,
181 | paddings: int = 1,
182 | dilation: int = 1,
183 | causal: bool = False) -> None:
184 | super(TCNBlock, self).__init__()
185 | self.norm1 = nn.InstanceNorm1d(in_dims)
186 | self.elu1 = nn.ELU()
187 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
188 | dilation * (kernel_size - 1))
189 | # dilated conv
190 | self.dconv1 = nn.Conv1d(
191 | in_dims,
192 | out_dims,
193 | kernel_size,
194 | padding=dconv_pad,
195 | dilation=dilation,
196 | groups=in_dims,
197 | bias=True)
198 |
199 | self.norm2 = nn.InstanceNorm1d(in_dims)
200 | self.elu2 = nn.ELU()
201 | self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True)
202 |
203 | # different padding way
204 | self.causal = causal
205 | self.dconv_pad = dconv_pad
206 |
207 | def forward(self, x: th.Tensor) -> th.Tensor:
208 | y = self.elu1(self.norm1(x))
209 | y = self.dconv1(y)
210 | if self.causal:
211 | y = y[:, :, :-self.dconv_pad]
212 | y = self.elu2(self.norm2(y))
213 | y = self.dconv2(y)
214 | x = x + y
215 |
216 | return x
217 |
218 |
219 | class LCA(nn.Module):
220 | def __init__(self, channels=64, r=4):
221 | super(LCA, self).__init__()
222 | inter_channels = int(channels // r)
223 |
224 | self.local_att = nn.Sequential(
225 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
226 | nn.BatchNorm2d(inter_channels),
227 | nn.ReLU(inplace=True),
228 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
229 | nn.BatchNorm2d(channels),
230 | )
231 |
232 | self.global_att = nn.Sequential(
233 | nn.AdaptiveAvgPool2d(1),
234 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
235 | nn.BatchNorm2d(inter_channels),
236 | nn.ReLU(inplace=True),
237 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
238 | nn.BatchNorm2d(channels),
239 | )
240 |
241 | self.sigmoid = nn.Sigmoid()
242 |
243 | def forward(self, x):
244 | xl = self.local_att(x)
245 | xg = self.global_att(x)
246 | xlg = xl + xg
247 | wei = self.sigmoid(xlg)
248 | return x * wei
249 |
250 | class IFI(nn.Module):
251 |
252 | def __init__(self, channels=64, r=4):
253 | super(IFI, self).__init__()
254 | inter_channels = int(channels // r)
255 |
256 | self.local_att = nn.Sequential(
257 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
258 | nn.BatchNorm2d(inter_channels),
259 | nn.ReLU(inplace=True),
260 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
261 | nn.BatchNorm2d(channels),
262 | )
263 |
264 | self.global_att = nn.Sequential(
265 | nn.AdaptiveAvgPool2d(1),
266 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
267 | nn.BatchNorm2d(inter_channels),
268 | nn.ReLU(inplace=True),
269 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
270 | nn.BatchNorm2d(channels),
271 | )
272 |
273 | self.local_att2 = nn.Sequential(
274 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
275 | nn.BatchNorm2d(inter_channels),
276 | nn.ReLU(inplace=True),
277 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
278 | nn.BatchNorm2d(channels),
279 | )
280 |
281 | self.global_att2 = nn.Sequential(
282 | nn.AdaptiveAvgPool2d(1),
283 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
284 | nn.BatchNorm2d(inter_channels),
285 | nn.ReLU(inplace=True),
286 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
287 | nn.BatchNorm2d(channels),
288 | )
289 |
290 | self.sigmoid = nn.Sigmoid()
291 |
292 | def forward(self, x, residual):
293 | xa = x + residual
294 | xl = self.local_att(xa)
295 | xg = self.global_att(xa)
296 | xlg = xl + xg
297 | wei = self.sigmoid(xlg)
298 | xi = x * wei + residual * (1 - wei)
299 |
300 | xl2 = self.local_att2(xi)
301 | xg2 = self.global_att(xi)
302 | xlg2 = xl2 + xg2
303 | wei2 = self.sigmoid(xlg2)
304 | xo = x * wei2 + residual * (1 - wei2)
305 | return xo
306 |
307 | class SEF_PNet(nn.Module):
308 | def __init__(self,
309 | win_len: int = 256, # 32 ms
310 | win_inc: int = 64, # 8 ms
311 | fft_len: int = 256,
312 | win_type: str = "sqrthann",
313 | kernel_size: Tuple[int] = (3, 3),
314 | stride1: Tuple[int] = (1, 1),
315 | stride2: Tuple[int] = (1, 2),
316 | paddings: Tuple[int] = (1, 0),
317 | output_padding: Tuple[int] = (0, 0),
318 | tcn_dims: int = 384,
319 | tcn_blocks: int = 10,
320 | tcn_layers: int = 2,
321 | causal: bool = False,
322 | pool_size: Tuple[int] = (4, 8, 16, 32),
323 | num_spks: int = 1,
324 | L: int = 20) -> None:
325 | super(SEF_PNet, self).__init__()
326 |
327 | self.L = L
328 | self.fft_len = fft_len
329 | self.num_spks = num_spks
330 | self.stft = ConvSTFT(win_len, win_inc, fft_len, win_type, 'complex')
331 | self.softmax = nn.Softmax(dim=-2)
332 | self.ifi = IFI(channels=2, r=1/32)
333 | self.upconv1 = nn.Conv2d(4, 64, 1, 1, 0)
334 | self.lca = LCA(64)
335 | self.conv2d = nn.Conv2d(64, 16, (1, 3), 1, 0)
336 | self.relu = nn.ReLU()
337 | self.encoder = self._build_encoder(
338 | kernel_size=kernel_size,
339 | stride=stride2,
340 | padding=paddings
341 | )
342 | self.tcn_layers = self._build_tcn_layers(
343 | tcn_layers,
344 | tcn_blocks,
345 | in_dims=tcn_dims,
346 | out_dims=tcn_dims,
347 | causal=causal
348 | )
349 | self.decoder = self._build_decoder(
350 | kernel_size=kernel_size,
351 | stride=stride2,
352 | padding=paddings,
353 | output_padding=output_padding
354 | )
355 | self.avg_pool = self._build_avg_pool(pool_size)
356 | self.avg_proj = nn.Conv2d(64, 32, 1, 1)
357 | self.deconv2d = nn.ConvTranspose2d(32, 2*num_spks, kernel_size, stride1, paddings)
358 | self.istft = ConviSTFT(win_len, win_inc, fft_len, win_type, 'complex')
359 |
360 | def _build_encoder(self, **enc_kargs):
361 | """
362 | Build encoder layers
363 | """
364 | encoder = nn.ModuleList()
365 | encoder.append(SublinearSequential(DenseBlock(16, 16, "enc"),LCA(16)))
366 |
367 | for i in range(3):
368 | encoder.append(
369 | SublinearSequential(
370 | Conv2dBlock(in_dims=16 if i==0 else 32,
371 | out_dims=32, **enc_kargs),
372 | DenseBlock(32, 32, "enc"),
373 | LCA(32)
374 | )
375 | )
376 | encoder.append(
377 | SublinearSequential(
378 | Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs),
379 | LCA(64)
380 | )
381 | )
382 | encoder.append(
383 | SublinearSequential(
384 | Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs),
385 | LCA(128)
386 | )
387 | )
388 | encoder.append(
389 | SublinearSequential(
390 | Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs),
391 | LCA(384)
392 | )
393 | )
394 |
395 | return encoder
396 |
397 | def _build_decoder(self, **dec_kargs):
398 | """
399 | Build decoder layers
400 | """
401 | decoder = nn.ModuleList()
402 | decoder.append(ConvTrans2dBlock(in_dims=384*2, out_dims=128, **dec_kargs))
403 | decoder.append(ConvTrans2dBlock(in_dims=128*2, out_dims=64, **dec_kargs))
404 | decoder.append(ConvTrans2dBlock(in_dims=64*2, out_dims=32, **dec_kargs))
405 | for i in range(3):
406 | decoder.append(
407 | SublinearSequential(
408 | DenseBlock(32, 64, "dec"),
409 | ConvTrans2dBlock(in_dims=64,
410 | out_dims=32 if i!=2 else 16,
411 | **dec_kargs)
412 | )
413 | )
414 | decoder.append(DenseBlock(16, 32, "dec"))
415 |
416 | return decoder
417 |
418 | def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs):
419 | """
420 | Build TCN blocks in each repeat (layer)
421 | """
422 | blocks = [
423 | TCNBlock(**tcn_kargs, dilation=(2**b))
424 | for b in range(tcn_blocks)
425 | ]
426 |
427 | return SublinearSequential(*blocks)
428 |
429 | def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs):
430 | """
431 | Build TCN layers
432 | """
433 | layers = [
434 | self._build_tcn_blocks(tcn_blocks, **tcn_kargs)
435 | for _ in range(tcn_layers)
436 | ]
437 |
438 | return SublinearSequential(*layers)
439 |
440 | def _build_avg_pool(self, pool_size):
441 | """
442 | Build avg pooling layers
443 | """
444 | avg_pool = nn.ModuleList()
445 | for sz in pool_size:
446 | avg_pool.append(
447 | SublinearSequential(
448 | nn.AvgPool2d(sz),
449 | nn.Conv2d(32, 8, 1, 1)
450 | )
451 | )
452 |
453 | return avg_pool
454 |
455 | def wav2spec(self, x: th.Tensor, mags: bool = False) -> th.Tensor:
456 | """
457 | convert waveform to spectrogram
458 | """
459 | # print(x.shape)
460 | assert x.dim() == 2
461 | # x = x / th.std(x, -1, keepdims=True) # variance normalization
462 | specs = self.stft(x)
463 | real = specs[:,:self.fft_len//2+1]
464 | imag = specs[:,self.fft_len//2+1:]
465 | spec = th.stack([real,imag], 1) #[B,2,F,T]
466 | # spec = th.einsum("hijk->hikj", spec) # batchsize, 2, T, F
467 | if mags:
468 | return th.sqrt(real**2+imag**2+1e-8)
469 | else:
470 | return spec
471 |
472 | def FeaCompression(self, input, factor=0.5):
473 | input_change = input.float()
474 | complex_spectrum = th.complex(input_change[:, 0, :, :], input_change[:, 1, :, :])
475 | magnitude = th.abs(complex_spectrum).unsqueeze(1) ** factor
476 | phase = th.angle(complex_spectrum).unsqueeze(1)
477 |
478 | real = magnitude * th.cos(phase)
479 | imag = magnitude * th.sin(phase)
480 | output = th.cat((real, imag), dim=1)
481 |
482 | return output
483 |
484 | def FeaDecompression(self, input, factor=0.5):
485 | input_change = input.float()
486 | complex_spectrum = th.complex(input_change[:, 0, :, :], input_change[:, 1, :, :])
487 | magnitude = th.abs(complex_spectrum).unsqueeze(1) ** (1 / factor)
488 | phase = th.angle(complex_spectrum).unsqueeze(1)
489 |
490 | real = magnitude * th.cos(phase)
491 | imag = magnitude * th.sin(phase)
492 | output = th.cat((real, imag), dim=1)
493 |
494 | return output
495 |
496 | def ComputeSimilarity(self, input, enrollment):
497 | att = enrollment.transpose(-2, -1) @ input
498 | att = self.softmax(att)
499 | output = enrollment @ att
500 |
501 | return output.unsqueeze(0).unsqueeze(0)
502 |
503 | def sep(self, spec: th.Tensor) -> List[th.Tensor]:
504 | """
505 | spec: (batchsize, 2, T, F)
506 | return [real, imag] or waveform
507 | """
508 | # spec = th.einsum("hijk->hikj", spec) # (batchsize, 2, F, T)
509 | B, N, F, T = spec.shape
510 | est = th.chunk(spec, 2, 1) # [(B, 1, F, T), (B, 1, F, T)]
511 | est = th.cat(est, 2).reshape(B, -1, T) # B, 2F, T
512 | return th.squeeze(self.istft(est))
513 |
514 | def forward(self,
515 | mix: th.Tensor,
516 | enrollment: th.Tensor) -> th.Tensor:
517 | """
518 | if waveform = True, return both waveform and real & imag parts;
519 | else, only return real & imag parts
520 | """
521 | batch_size = mix.shape[0]
522 | if mix.dim() == 1:
523 | mix = th.unsqueeze(mix, 0)
524 | aux = th.unsqueeze(aux, 0)
525 | mix_spec = self.wav2spec(mix, False)
526 | mix_spec_change = self.FeaCompression(mix_spec) #[B,2,F,T]
527 | similarity = []
528 | aux_drc = []
529 | for i in range(batch_size):
530 | aux = self.wav2spec(enrollment[i].unsqueeze(0), False)
531 | aux_spec_change = self.FeaCompression(aux)
532 | aux_drc.append(aux_spec_change)
533 | similarity.append(th.cat([self.ComputeSimilarity(mix_spec_change[i, 0, ...], aux_spec_change[0, 0, ...]), self.ComputeSimilarity(mix_spec_change[i, 1, ...], aux_spec_change[0, 1, ...])], dim=1))
534 | similarity = th.cat(similarity, dim=0)
535 | aux_drc = th.cat(aux_drc, dim=0)
536 | aux_drc = th.mean(aux_drc, dim=-1).unsqueeze(-1).repeat(1, 1,1, similarity.shape[-1])
537 | similarity = self.ifi(similarity, aux_drc)
538 | fus = th.cat((mix_spec_change, similarity), dim=1) #[1,4,129,251]
539 | fus = self.upconv1(fus)
540 | fus = self.lca(fus)
541 | # speech separation
542 | fus = fus.permute(0, 1, 3, 2)
543 | out = self.relu(self.conv2d(fus))
544 | out_list = []
545 | out = self.encoder[0](out)
546 | out_list.append(out)
547 | for idx, enc in enumerate(self.encoder[1:]):
548 | out = enc(out)
549 | out_list.append(out)
550 |
551 | B, N, T, F = out.shape
552 | out = out.reshape(B, N, T*F)
553 | out = self.tcn_layers(out)
554 | out = th.unsqueeze(out, -1)
555 |
556 | out_list = out_list[::-1]
557 | for idx, dec in enumerate(self.decoder):
558 | decinput = th.cat([out_list[idx], out], 1)
559 | out = dec(decinput)
560 |
561 | # Pyramidal pooling
562 | B, N, T, F = out.shape
563 | upsample = nn.Upsample(size=(T, F), mode='bilinear')
564 | pool_list = []
565 | for avg in self.avg_pool:
566 | pool_list.append(upsample(avg(out)))
567 | out = th.cat([out, *pool_list], 1)
568 | out = self.avg_proj(out)
569 | out = self.deconv2d(out)
570 | out = out.permute(0, 1, 3, 2)
571 | out = self.FeaDecompression(out)
572 | out = self.sep(out)
573 | return out
574 |
575 |
576 | def test_covn2d_block():
577 | x = th.randn(2, 16, 257, 200)
578 | conv = Conv2dBlock()
579 | y = conv(x)
580 | convtrans = ConvTrans2dBlock()
581 | z = convtrans(y)
582 |
583 | def test_dense_block():
584 | x = th.randn(2, 16, 257, 200)
585 | dense = DenseBlock(16, 32, "enc")
586 | y = dense(x)
587 |
588 | def test_tcn_block():
589 | x = th.randn(2, 384, 1000)
590 | tcn1 = TCNBlock(dilation=128)
591 |
592 | if __name__ == "__main__":
593 | from thop import profile, clever_format
594 | nnet = SEF_PNet()
595 | mix = th.randn(2, 8000)
596 | aux = th.randn(2, 8000)
597 | est = nnet(mix, aux)
598 | macs, params = profile(nnet, inputs=(mix,aux))
599 | macs, params = clever_format([macs, params], "%.3f")
600 | print(macs, params)
601 |
--------------------------------------------------------------------------------
/nnet/__pycache__/SEF_PNet_pse.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/__pycache__/SEF_PNet_pse.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/__pycache__/conf_unet_tse_32ms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/__pycache__/conf_unet_tse_32ms.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/conf_unet_tse_32ms.py:
--------------------------------------------------------------------------------
1 | fs = 8000
2 | chunk_len = 4 # (s)
3 | chunk_size = chunk_len * fs
4 |
5 | nnet_conf = {
6 | "win_len": 256,
7 | "win_inc": 64,
8 | "fft_len": 256,
9 | "win_type": "sqrthann",
10 | "kernel_size": (3, 3),
11 | "stride1": (1, 1),
12 | "stride2": (1, 2),
13 | "paddings": (1, 0),
14 | "output_padding": (0, 0),
15 | "tcn_dims": 384,
16 | "tcn_blocks": 10,
17 | "tcn_layers": 2,
18 | "causal": False,
19 | "num_spks": 1
20 | }
21 |
22 |
23 | # data configure:
24 | train_dir = "/node/hzl/expriment/SEF_PNet_icassp2025_github/data/train/"
25 | dev_dir = "/node/hzl/expriment/SEF_PNet_icassp2025_github/data/dev/"
26 |
27 | train_data = {
28 | "mix_scp": train_dir + "mix_clean.scp",
29 | "ref_scp": train_dir + "ref.scp",
30 | "aux_scp": train_dir + "auxs1.scp",
31 | "sample_rate": fs,
32 | }
33 |
34 | dev_data = {
35 | "mix_scp": dev_dir + "mix_clean.scp",
36 | "ref_scp": dev_dir + "ref.scp",
37 | "aux_scp": dev_dir + "auxs1.scp",
38 | "sample_rate": fs,
39 | }
40 |
41 | # trainer config
42 | adam_kwargs = {
43 | "lr": 0.5e-3,
44 | "weight_decay": 1e-5,
45 | }
46 |
47 | trainer_conf = {
48 | "optimizer": "adam",
49 | "optimizer_kwargs": adam_kwargs,
50 | "min_lr": 1e-8,
51 | "patience": 2,
52 | "factor": 0.5,
53 | "logging_period": 200
54 | }
55 |
--------------------------------------------------------------------------------
/nnet/evaluate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import time
5 | import argparse
6 | import torch as th
7 | import numpy as np
8 | from mir_eval.separation import bss_eval_sources
9 | from pesq import pesq as pesq2
10 | from pypesq import pesq as pesq1
11 | from pystoi.stoi import stoi
12 | from SEF_PNet_pse import SEF_PNet
13 | from libs.utils import load_json, get_logger
14 | from libs.dataset_tse import Dataset
15 |
16 | def evaluate(args, model_file, logger):
17 | start = time.time()
18 | total_SISNR = 0
19 | total_SISNRi = 0
20 | total_PESQ = 0
21 | total_PESQi = 0
22 | total_PESQ2 = 0
23 | total_PESQi2 = 0
24 | total_STOI = 0
25 | total_STOIi = 0
26 | total_SDR = 0
27 | total_cnt = 0
28 |
29 | # Load model
30 | nnet_conf = load_json(args.checkpoint, "mdl.json")
31 | nnet = SEF_PNet(**nnet_conf)
32 | cpt_fname = os.path.join(args.checkpoint, model_file)
33 | cpt = th.load(cpt_fname, map_location="cpu")
34 | nnet.load_state_dict(cpt["model_state_dict"])
35 | logger.info("Loaded checkpoint from {}, epoch {:d}".format(
36 | cpt_fname, cpt["epoch"]))
37 |
38 | device = th.device(
39 | "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu")
40 | nnet = nnet.to(device) if args.gpuid >= 0 else nnet
41 | nnet.eval()
42 |
43 | # Load data
44 | dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp, sample_rate=8000)
45 |
46 | with th.no_grad():
47 | for i, data in enumerate(dataset):
48 | mix = th.tensor(data['mix'], dtype=th.float32, device=device)
49 | aux = th.tensor(data['aux'], dtype=th.float32, device=device)
50 |
51 | if args.gpuid >= 0:
52 | mix = mix.unsqueeze(0).to(device)
53 | aux = aux.unsqueeze(0).to(device)
54 |
55 | # Forward
56 | ref = data['ref']
57 | key = data['key']
58 | ests = nnet(mix, aux)
59 | ests = ests.cpu().numpy()
60 | mix = mix.squeeze(0).cpu().numpy()
61 | if ests.size != ref.size:
62 | end = min(ests.size, ref.size)
63 | ests = ests[:end]
64 | ref = ref[:end]
65 | mix = mix[:end]
66 |
67 | # Compute metrics
68 | if args.cal_sdr == 1:
69 | SDR, sir, sar, popt = bss_eval_sources(ref, ests)
70 | total_SDR += SDR[0]
71 | SISNR, delta = cal_SISNRi(ests, ref, mix)
72 | PESQ, PESQi, PESQ2, PESQi2 = cal_PESQi(ests, ref, mix)
73 | STOI, STOIi = cal_STOIi(ests, ref, mix)
74 | if args.cal_sdr == 1:
75 | logger.info("Utt={:d} | SDR={:.2f} | SI-SNR={:.2f} | SI-SNRi={:.2f} | PESQ={:.2f} | PESQi={:.2f}| PESQ2={:.2f} | PESQi2={:.2f} | | STOI={:.2f} | STOIi={:.2f}".format(
76 | total_cnt+1, SDR[0], SISNR, delta, PESQ, PESQi, PESQ2, PESQi2, STOI, STOIi))
77 | else:
78 | logger.info("Utt={:d} | SI-SNR={:.2f} | SI-SNRi={:.2f} | PESQ={:.2f} | PESQi={:.2f} | PESQ2={:.2f} | PESQi2={:.2f} | STOI={:.2f} | STOIi={:.2f}".format(
79 | total_cnt+1, SISNR, delta, PESQ, PESQi, PESQ2, PESQi2, STOI, STOIi))
80 | total_SISNR += SISNR
81 | total_SISNRi += delta
82 | total_PESQ += PESQ
83 | total_PESQi += PESQi
84 | total_PESQ2 += PESQ2
85 | total_PESQi2 += PESQi2
86 | total_STOI += STOI
87 | total_STOIi += STOIi
88 | total_cnt += 1
89 | end = time.time()
90 |
91 | logger.info('Time Elapsed: {:.1f}s'.format(end-start))
92 | if args.cal_sdr == 1:
93 | logger.info("Average SDR: {0:.2f}".format(total_SDR / total_cnt))
94 | logger.info("Average SI-SNR: {:.2f}".format(total_SISNR / total_cnt))
95 | logger.info("Average SI-SNRi: {:.2f}".format(total_SISNRi / total_cnt))
96 | logger.info("Average PESQ: {:.2f}".format(total_PESQ / total_cnt))
97 | logger.info("Average PESQi: {:.2f}".format(total_PESQi / total_cnt))
98 | logger.info("Average PESQ2: {:.2f}".format(total_PESQ2 / total_cnt))
99 | logger.info("Average PESQi2: {:.2f}".format(total_PESQi2 / total_cnt))
100 | logger.info("Average STOI: {:.2f}".format(total_STOI / total_cnt))
101 | logger.info("Average STOIi: {:.2f}".format(total_STOIi / total_cnt))
102 |
103 | def cal_SISNR(est, ref, eps=1e-8):
104 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
105 | Args:
106 | est: separated signal, numpy.ndarray, [T]
107 | ref: reference signal, numpy.ndarray, [T]
108 | Returns:
109 | SISNR
110 | """
111 | assert len(est) == len(ref)
112 | est_zm = est - np.mean(est)
113 | ref_zm = ref - np.mean(ref)
114 |
115 | t = np.sum(est_zm * ref_zm) * ref_zm / (np.linalg.norm(ref_zm)**2 + eps)
116 |
117 | return 20 * np.log10(eps + np.linalg.norm(t) / (np.linalg.norm(est_zm - t) + eps))
118 |
119 | def cal_SISNRi(est, ref, mix, eps=1e-8):
120 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
121 | Args:
122 | est: separated signal, numpy.ndarray, [T]
123 | ref: reference signal, numpy.ndarray, [T]
124 | Returns:
125 | SISNR
126 | """
127 | assert len(est) == len(ref) == len(mix)
128 | sisnr1 = cal_SISNR(est, ref)
129 | sisnr2 = cal_SISNR(mix, ref)
130 |
131 | return sisnr1, sisnr1 - sisnr2
132 |
133 | def cal_PESQ(est, ref):
134 | assert len(est) == len(ref)
135 | mode ='nb'
136 | p = pesq1(ref, est,8000)
137 | p2 = pesq2(8000, ref, est, mode)
138 | return p,p2
139 |
140 | def cal_PESQi(est, ref, mix):
141 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
142 | Args:
143 | est: separated signal, numpy.ndarray, [T]
144 | ref: reference signal, numpy.ndarray, [T]
145 | Returns:
146 | SISNR
147 | """
148 | assert len(est) == len(ref) == len(mix)
149 | pesq1,pesq12 = cal_PESQ(est, ref)
150 | pesq2,pesq22= cal_PESQ(mix, ref)
151 |
152 | return pesq1, pesq1 - pesq2,pesq12,pesq12-pesq22
153 |
154 | def cal_STOI(est, ref):
155 | assert len(est) == len(ref)
156 | p = stoi(ref, est, 8000)
157 | return p
158 |
159 | def cal_STOIi(est, ref, mix):
160 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
161 | Args:
162 | est: separated signal, numpy.ndarray, [T]
163 | ref: reference signal, numpy.ndarray, [T]
164 | Returns:
165 | SISNR
166 | """
167 | assert len(est) == len(ref) == len(mix)
168 | stoi1 = cal_STOI(est, ref)*100
169 | stoi2 = cal_STOI(mix, ref)*100
170 |
171 | return stoi1, stoi1 - stoi2
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser('Evaluate separation performance using Conv-TasNet')
175 | parser.add_argument('--checkpoint', type=str,
176 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/demo',
177 | help='Path to model directory containing checkpoints')
178 | parser.add_argument('--gpuid', type=int, default=0,
179 | help="GPU device to offload model to, -1 means running on CPU")
180 | parser.add_argument('--mix_scp', type=str,
181 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/mix_clean.scp',
182 | help='mix scp')
183 | parser.add_argument('--ref_scp', type=str,
184 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/ref.scp',
185 | help='ref scp')
186 | parser.add_argument('--aux_scp', type=str,
187 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/auxs1.scp',
188 | help='aux scp')
189 | parser.add_argument('--cal_sdr', type=int, default=None,
190 | help='Whether calculate SDR, add this option because calculation of SDR is very slow')
191 |
192 | args = parser.parse_args()
193 |
194 |
195 | # eval best.pt.tar
196 | best_model_file = "best.pt.tar"
197 | best_log_file = os.path.join(args.checkpoint, "eval_best.log")
198 | best_logger = get_logger(best_log_file, file=True)
199 | best_logger.info(f"Evaluating model: {best_model_file}")
200 | evaluate(args, best_model_file, best_logger)
201 |
202 | # eval 110-122 epoch.pt.tar
203 | for epoch in range(110, 122):
204 | model_file = f"{epoch}.pt.tar"
205 | log_file = os.path.join(args.checkpoint, f"eval_{epoch}.log")
206 | logger = get_logger(log_file, file=True)
207 | logger.info(f"Evaluating model: {model_file}")
208 | evaluate(args, model_file, logger)
209 |
--------------------------------------------------------------------------------
/nnet/libs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__init__.py
--------------------------------------------------------------------------------
/nnet/libs/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__init__.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/audio.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/audio.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/conv_stft.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/conv_stft.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/dataset_tse.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/dataset_tse.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/trainer_unet_tse_steplr_clip.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/trainer_unet_tse_steplr_clip.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/libs/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import soundfile as sf
5 | import librosa
6 | import kaldiio
7 | MAX_INT16 = np.iinfo(np.int16).max
8 |
9 |
10 | def write_wav(fname, samps, fs=8000, normalize=True):
11 | """
12 | Write wav files in int16, support single/multi-channel
13 | """
14 | #if normalize:
15 | # samps = samps * MAX_INT16
16 | ## scipy.io.wavfile.write could write single/multi-channel files
17 | ## for multi-channel, accept ndarray [Nsamples, Nchannels]
18 | #if samps.ndim != 1 and samps.shape[0] < samps.shape[1]:
19 | # samps = np.transpose(samps)
20 | # samps = np.squeeze(samps)
21 | ## same as MATLAB and kaldi
22 | #samps_int16 = samps.astype(np.int16)
23 | #fdir = os.path.dirname(fname)
24 | #if fdir and not os.path.exists(fdir):
25 | # os.makedirs(fdir)
26 | ## NOTE: librosa 0.6.0 seems could not write non-float narray
27 | ## so use scipy.io.wavfile instead
28 | #wf.write(fname, fs, samps_int16)
29 |
30 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
31 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
32 | fdir = os.path.dirname(fname)
33 | if fdir and not os.path.exists(fdir):
34 | os.makedirs(fdir)
35 | sf.write(fname, samps, fs, subtype='FLOAT',format='WAV')
36 |
37 |
38 | def read_wav(fname, normalize=True, return_rate=False):
39 | """
40 | Read wave files using scipy.io.wavfile(support multi-channel)
41 | """
42 | # samps_int16: N x C or N
43 | # N: number of samples
44 | # C: number of channels
45 | #samp_rate, samps_int16 = wf.read(fname)
46 | ## N x C => C x N
47 | #samps = samps_int16.astype(np.float)
48 | ## tranpose because I used to put channel axis first
49 | #if samps.ndim != 1:
50 | # samps = np.transpose(samps)
51 | ## normalize like MATLAB and librosa
52 | #if normalize:
53 | # samps = samps / MAX_INT16
54 | #if return_rate:
55 | # return samp_rate, samps
56 | #return samps
57 |
58 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
59 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
60 | samps, samp_rate = sf.read(fname)
61 | if return_rate:
62 | return samp_rate, samps
63 | return samps
64 |
65 |
66 | def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2):
67 | """
68 | Parse kaldi's script(.scp) file
69 | If num_tokens >= 2, function will check token number
70 | """
71 | scp_dict = dict()
72 | line = 0
73 | with open(scp_path, "r") as f:
74 | for raw_line in f:
75 | scp_tokens = raw_line.strip().split()
76 | line += 1
77 | if num_tokens >= 2 and len(scp_tokens) != num_tokens or len(
78 | scp_tokens) < 2:
79 | raise RuntimeError(
80 | "For {}, format error in line[{:d}]: {}".format(
81 | scp_path, line, raw_line))
82 | if num_tokens == 2:
83 | key, value = scp_tokens
84 | else:
85 | key, value = scp_tokens[0], scp_tokens[1:]
86 | if key in scp_dict:
87 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
88 | key, scp_path))
89 | scp_dict[key] = value_processor(value)
90 | return scp_dict
91 |
92 |
93 | class Reader(object):
94 | """
95 | Basic Reader Class
96 | """
97 | def __init__(self, scp_path, value_processor=lambda x: x):
98 | self.index_dict = parse_scripts(
99 | scp_path, value_processor=value_processor, num_tokens=2)
100 | self.index_keys = list(self.index_dict.keys())
101 |
102 | def _load(self, key):
103 | # return path
104 | return self.index_dict[key]
105 |
106 | # number of utterance
107 | def __len__(self):
108 | return len(self.index_dict)
109 |
110 | def __contains__(self, key):
111 | return key in self.index_dict
112 |
113 | # sequential index
114 | def __iter__(self):
115 | for key in self.index_keys:
116 | yield key, self._load(key)
117 |
118 | # random index, support str/int as index
119 | def __getitem__(self, index):
120 | if type(index) not in [int, str]:
121 | raise IndexError("Unsupported index type: {}".format(type(index)))
122 | if type(index) == int:
123 | # from int index to key
124 | num_utts = len(self.index_keys)
125 | if index >= num_utts or index < 0:
126 | raise KeyError(
127 | "Interger index out of range, {:d} vs {:d}".format(
128 | index, num_utts))
129 | index = self.index_keys[index]
130 | if index not in self.index_dict:
131 | raise KeyError("Missing utterance {}!".format(index))
132 | return self._load(index)
133 |
134 |
135 | class WaveReader(Reader):
136 | """
137 | Sequential/Random Reader for single channel wave
138 | Format of wav.scp follows Kaldi's definition:
139 | key1 /path/to/wav
140 | ...
141 | """
142 | def __init__(self, wav_scp, sample_rate=None, normalize=True):
143 | super(WaveReader, self).__init__(wav_scp)
144 | self.samp_rate = sample_rate
145 | self.normalize = normalize
146 |
147 | def _load(self, key):
148 | # return C x N or N
149 | samp_rate, samps = read_wav(
150 | self.index_dict[key], normalize=self.normalize, return_rate=True)
151 | # if given samp_rate, check it
152 | if self.samp_rate is not None and samp_rate != self.samp_rate:
153 | samps = librosa.resample(samps, orig_sr=samp_rate, target_sr=self.samp_rate)
154 | # raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format(
155 | # samp_rate, self.samp_rate))
156 | return samps
157 |
--------------------------------------------------------------------------------
/nnet/libs/conv_stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from scipy.signal import get_window
6 |
7 |
8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
9 | """
10 | Return window coefficient
11 | """
12 | def sqrthann(win_len):
13 | return get_window("hann", win_len, fftbins=True)**0.5
14 |
15 | if win_type == 'None' or win_type is None:
16 | window = np.ones(win_len)
17 | elif win_type == "sqrthann":
18 | window = sqrthann(win_len)
19 | else:
20 | window = get_window(win_type, win_len, fftbins=True)#**0.5
21 |
22 | N = fft_len
23 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
24 | real_kernel = np.real(fourier_basis)
25 | imag_kernel = np.imag(fourier_basis)
26 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T
27 |
28 | if invers :
29 | kernel = np.linalg.pinv(kernel).T
30 |
31 | kernel = kernel*window
32 | kernel = kernel[:, None, :]
33 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32))
34 |
35 |
36 | class ConvSTFT(nn.Module):
37 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
38 | super(ConvSTFT, self).__init__()
39 |
40 | if fft_len == None:
41 | self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
42 | else:
43 | self.fft_len = fft_len
44 |
45 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
46 | self.register_buffer('weight', kernel)
47 | self.feature_type = feature_type
48 | self.stride = win_inc
49 | self.win_len = win_len
50 | self.dim = self.fft_len
51 |
52 | def forward(self, inputs):
53 | if inputs.dim() == 2:
54 | inputs = torch.unsqueeze(inputs, 1)
55 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride])
56 | outputs = F.conv1d(inputs, self.weight, stride=self.stride)
57 |
58 | if self.feature_type == 'complex':
59 | return outputs
60 | else:
61 | dim = self.dim//2+1
62 | real = outputs[:, :dim, :]
63 | imag = outputs[:, dim:, :]
64 | mags = torch.sqrt(real**2+imag**2)
65 | phase = torch.atan2(imag, real)
66 | return mags, phase
67 |
68 | class ConviSTFT(nn.Module):
69 |
70 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'):
71 | super(ConviSTFT, self).__init__()
72 | if fft_len == None:
73 | self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
74 | else:
75 | self.fft_len = fft_len
76 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
77 | self.register_buffer('weight', kernel)
78 | self.feature_type = feature_type
79 | self.win_type = win_type
80 | self.win_len = win_len
81 | self.stride = win_inc
82 | self.stride = win_inc
83 | self.dim = self.fft_len
84 | self.register_buffer('window', window)
85 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:])
86 |
87 | def forward(self, inputs, phase=None):
88 | """
89 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
90 | phase: [B, N//2+1, T] (if not none)
91 | """
92 |
93 | if phase is not None:
94 | real = inputs*torch.cos(phase)
95 | imag = inputs*torch.sin(phase)
96 | inputs = torch.cat([real, imag], 1)
97 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
98 |
99 |
100 |
101 | # this is from torch-stft: https://github.com/pseeth/torch-stft
102 | t = self.window.repeat(1,1,inputs.size(-1))**2
103 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
104 | outputs = outputs/(coff+1e-8)
105 | #outputs = torch.where(coff == 0, outputs, outputs/coff)
106 | outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)]
107 |
108 | return outputs
--------------------------------------------------------------------------------
/nnet/libs/dataset_tse.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch as th
3 | import numpy as np
4 |
5 | from torch.utils.data.dataloader import default_collate
6 | import torch.utils.data as dat
7 | from torch.nn.utils.rnn import pad_sequence
8 | from libs.audio import WaveReader
9 | from conf_unet_tse_32ms import train_data, dev_data
10 |
11 |
12 |
13 | def make_dataloader(train=True,
14 | data_kwargs=None,
15 | num_workers=4,
16 | chunk_size=80000,
17 | batch_size=16):
18 | dataset = Dataset(**data_kwargs)
19 | return DataLoader(dataset,
20 | train=train,
21 | chunk_size=chunk_size,
22 | batch_size=batch_size,
23 | num_workers=num_workers)
24 |
25 | def get_spk_ivec(key):
26 | '''
27 | 409o030h_1.7445_029o0304_-1.7445_409c0211
28 | '''
29 | spk = key.split('_')[-1][0:3]
30 | print(spk)
31 |
32 | class Dataset(object):
33 | """
34 | Per Utterance Loader
35 | """
36 | def __init__(self, mix_scp="", ref_scp=None, aux_scp=None, sample_rate=8000):
37 | self.mix = WaveReader(mix_scp, sample_rate=sample_rate)
38 | self.ref = WaveReader(ref_scp, sample_rate=sample_rate)
39 | self.aux = WaveReader(aux_scp, sample_rate=sample_rate)
40 | self.sample_rate = sample_rate
41 |
42 | def __len__(self):
43 | return len(self.mix)
44 |
45 | def __getitem__(self, index):
46 | key = self.mix.index_keys[index]
47 | mix = self.mix[key]
48 | ref = self.ref[key]
49 | aux = self.aux[key]
50 |
51 | return {
52 | "mix": mix.astype(np.float32),
53 | "ref": ref.astype(np.float32),
54 | "aux": aux.astype(np.float32),
55 | "aux_len": len(aux),
56 | "key": key
57 | }
58 |
59 |
60 | class ChunkSplitter(object):
61 | """
62 | Split utterance into small chunks
63 | """
64 | def __init__(self, chunk_size, train=True, least=2000):
65 | self.chunk_size = chunk_size
66 | self.least = least
67 | self.train = train
68 |
69 | def _make_chunk(self, eg, s):
70 | """
71 | Make a chunk instance, which contains:
72 | "mix": ndarray,
73 | "ref": [ndarray...]
74 | """
75 | chunk = dict()
76 | chunk["mix"] = eg["mix"][s:s + self.chunk_size]
77 | chunk["ref"] = eg["ref"][s:s + self.chunk_size]
78 | chunk["aux"] = eg["aux"]
79 | chunk["aux_len"] = chunk["aux"].shape[0]
80 | chunk["valid_len"] = int(self.chunk_size)
81 | return chunk
82 |
83 | def split(self, eg):
84 | N = eg["mix"].size
85 | # too short, throw away
86 | if N < self.least:
87 | return []
88 | chunks = []
89 | # padding zeros
90 | if N < self.chunk_size:
91 | P = self.chunk_size - N
92 | chunk = dict()
93 | chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
94 | chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
95 | chunk["aux"] = eg["aux"]
96 | chunk["aux_len"] = eg["aux_len"]
97 | chunk["valid_len"] = int(N)
98 | chunks.append(chunk)
99 | else:
100 | # random select start point for training
101 | s = random.randint(0, N % self.least) if self.train else 0
102 | while True:
103 | if s + self.chunk_size > N:
104 | break
105 | chunk = self._make_chunk(eg, s)
106 | chunks.append(chunk)
107 | s += self.least
108 | return chunks
109 |
110 |
111 | class DataLoader(object):
112 | """
113 | Online dataloader for chunk-level PIT
114 | """
115 | def __init__(self,
116 | dataset,
117 | num_workers=4,
118 | chunk_size=80000,
119 | batch_size=4,
120 | train=True):
121 | self.batch_size = batch_size
122 | self.train = train
123 | self.splitter = ChunkSplitter(chunk_size,
124 | train=train,
125 | least=chunk_size // 2)
126 | # just return batch of egs, support multiple workers
127 | self.eg_loader = dat.DataLoader(dataset,
128 | batch_size=batch_size // 2,
129 | num_workers=num_workers,
130 | shuffle=train,
131 | collate_fn=self._collate)
132 |
133 | def _collate(self, batch):
134 | """
135 | Online split utterances
136 | """
137 | chunk = []
138 | for eg in batch:
139 | chunk += self.splitter.split(eg)
140 | return chunk
141 |
142 | def _pad_aux(self, chunk_list):
143 | lens_list = []
144 | for chunk_item in chunk_list:
145 | lens_list.append(chunk_item['aux_len'])
146 | max_len = np.max(lens_list)
147 |
148 | for idx in range(len(chunk_list)):
149 | P = max_len - len(chunk_list[idx]["aux"])
150 | chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
151 |
152 | return chunk_list
153 |
154 | def _merge(self, chunk_list):
155 | """
156 | Merge chunk list into mini-batch
157 | """
158 | N = len(chunk_list)
159 | if self.train:
160 | random.shuffle(chunk_list)
161 | blist = []
162 | for s in range(0, N - self.batch_size + 1, self.batch_size):
163 | batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
164 | blist.append(batch)
165 | rn = N % self.batch_size
166 | return blist, chunk_list[-rn:] if rn else []
167 |
168 | def __iter__(self):
169 | chunk_list = []
170 | for chunks in self.eg_loader:
171 | chunk_list += chunks
172 | batch, chunk_list = self._merge(chunk_list)
173 | for obj in batch:
174 | yield obj
175 |
176 | if __name__=='__main__':
177 | chunk_size=80000
178 | train=True
179 | least=chunk_size // 2
180 | splitter = ChunkSplitter(chunk_size, train, least)
181 | data = Dataset(**train_data)
182 | egs = data[0]
183 | chunk = splitter.split(egs)
184 | dataload = DataLoader(data)
185 | temp = []
186 | for i, obj in enumerate(dataload):
187 | # print('mix...', obj)
188 | #print(i,obj)
189 | temp.append(obj)
190 | # mix,anw = obj[]
191 | # logits = net(mix,anw)
192 | # loss = net.loss(logits,targets)
193 | # loss.backward()
194 |
195 |
196 | # mix = obj[]
197 |
198 |
199 | #if i == 2:
200 | # break
201 | print(len)
202 |
--------------------------------------------------------------------------------
/nnet/libs/metric.py:
--------------------------------------------------------------------------------
1 | # jyhan@2020
2 |
3 | """
4 | Provided measure metircs:
5 | speech separation: (w/ & w/o PIT)
6 | - SDR
7 | - SDRi
8 | - SI-SNR
9 | - SI-SNRi
10 | speech enhancement:
11 | - PESQ
12 | - STOI
13 | """
14 |
15 | import numpy as np
16 |
17 | from pesq import pesq
18 | from pystoi.stoi import stoi
19 |
20 | from itertools import permutations
21 | from mir_eval.separation import bss_eval_sources
22 |
23 | def cal_sisnr(est, ref, remove_dc=True, eps=1e-8):
24 | """
25 | Compute SI-SNR
26 | Arguments:
27 | est: vector, enhanced/separated signal
28 | ref: vector, reference signal(ground truth)
29 | """
30 | assert len(est) == len(ref)
31 | def vec_l2norm(x):
32 | return np.linalg.norm(x, 2)
33 |
34 | # zero mean, seems do not hurt results
35 | if remove_dc:
36 | e_zm = est - np.mean(est)
37 | r_zm = ref - np.mean(ref)
38 | t = np.inner(e_zm, r_zm) * r_zm / (vec_l2norm(r_zm)**2 + eps)
39 | n = e_zm - t
40 | else:
41 | t = np.inner(est, ref) * ref / (vec_l2norm(ref)**2 + eps)
42 | n = est - t
43 | return 20 * np.log10(vec_l2norm(t) / (vec_l2norm(n) + eps))
44 |
45 |
46 | def permute_si_snr(est, ref):
47 | """
48 | Compute SI-SNR between N pairs
49 | Arguments:
50 | est: list[vector], enhanced/separated signal
51 | ref: list[vector], reference signal(ground truth)
52 | Return:
53 | max sisnr and it's permutation
54 | """
55 | assert len(est) == len(ref)
56 | def si_snr_avg(est, ref):
57 | return sum([cal_sisnr(e, r) for e, r in zip(est, ref)]) / len(est)
58 |
59 | N = len(est)
60 | if N != len(est):
61 | raise RuntimeError(
62 | "size do not match between est and ref: {:d} vs {:d}".format(
63 | N, len(ref)))
64 | si_snrs = []
65 | perm = []
66 | for order in permutations(range(N)):
67 | si_snrs.append(si_snr_avg(est, [ref[n] for n in order]))
68 | perm.append(order)
69 |
70 | return max(si_snrs), perm[si_snrs.index(max(si_snrs))]
71 |
72 |
73 | def permute_si_snri(mix, est, ref, both=True):
74 | """
75 | Compute SI-SNR improvement
76 | Arguments:
77 | mix: vector, mixture signal
78 | est: list[vector], enhanced/separated signal
79 | ref: list[vector], reference signal(ground truth)
80 | [spk1, spk2, aux]
81 | """
82 | m_mix = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2])
83 | m_enh, _ = permute_si_snr(est, ref)
84 | if both:
85 | return m_enh, m_enh - m_mix
86 | else:
87 | return m_enh - m_mix
88 |
89 | def pit_rank_sisnr(mix, est, ref):
90 | """
91 | Compute SI-SNR improvement
92 | Arguments:
93 | mix: vector, mixture signal
94 | est: list[vector], enhanced/separated signal
95 | ref: list[vector], reference signal(ground truth)
96 | [spk1, spk2, aux]
97 | """
98 | m_mix1 = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2])
99 | m_mix2 = sum([cal_sisnr(mix, r) for r in est[:2]]) / len(est[:2])
100 | m_mix = (m_mix1 + m_mix2) / 2
101 | m_enh, _ = permute_si_snr(est, ref)
102 |
103 | return m_enh, m_mix
104 |
105 | def pit_rank_sisnr_all(mix, est, ref):
106 | """
107 | Compute SI-SNR improvement
108 | Arguments:
109 | mix: vector, mixture signal
110 | est: list[vector], enhanced/separated signal
111 | ref: list[vector], reference signal(ground truth)
112 | [spk1, spk2, aux]
113 | """
114 | m_mix1 = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2])
115 | m_mix2 = sum([cal_sisnr(mix, r) for r in est[:2]]) / len(est[:2])
116 | m_mix = (m_mix1 + m_mix2) / 2
117 | m_enh, _ = permute_si_snr(est, ref)
118 |
119 | return m_enh, m_mix1, m_mix2, m_mix
120 |
121 |
122 | def reorder_list(slist, perm):
123 | """
124 | Arguments:
125 | slist: list[vector], reference signal
126 | perm: permutation label
127 | Return:
128 | list[vector], reordered reference signal
129 | """
130 | return [slist[p] for p in perm]
131 |
132 |
133 | def cal_SDRi(mix, est, ref):
134 | """Calculate Source-to-Distortion Ratio improvement (SDRi).
135 | NOTE: bss_eval_sources is very very slow.
136 | Args:
137 | mix: numpy.ndarray,
138 | est: [numpy.ndarray, numpy.ndarray] enhanced/separated signal
139 | ref: [numpy.ndarray, numpy.ndarray] , reference signal(ground truth)
140 | Returns:
141 | avg_sdr, sdri
142 | """
143 | mix = np.array(mix)
144 | est = np.array(est)
145 | ref = np.array(ref)
146 |
147 | mix_anchor = np.stack([mix, mix], axis=0)
148 | sdr, sir, sar, popt = bss_eval_sources(ref, est)
149 | sdr0, sir0, sar0, popt0 = bss_eval_sources(ref, mix_anchor)
150 | avg_sdr = (sdr[0] + sdr[1] ) / 2
151 | avg_sdr_m = (sdr0[0] + sdr0[1] ) / 2
152 |
153 | return avg_sdr, avg_sdr - avg_sdr_m
154 |
155 |
156 | def permute_pesq(est, ref, fs=8000, mode='nb'):
157 | """
158 | Evaluate PESQ
159 | Args:
160 | est: [numpy 1D array, numpy 1D array], estimated audio signal
161 | ref: [numpy 1D array, numpy 1D array], reference audio signal
162 | fs: integer, sampling rate
163 | """
164 | assert fs in [8000, 16000]
165 | assert len(est) == len(ref)
166 | mode = 'nb' if fs == 8000 else 'wb'
167 |
168 | def pesq_avg(est, ref):
169 | return sum([pesq(fs, r, e, mode) for e, r in zip(est, ref)]) / len(est)
170 |
171 | N = len(est)
172 | if N != len(est):
173 | raise RuntimeError(
174 | "size do not match between est and ref: {:d} vs {:d}".format(
175 | N, len(ref)))
176 | pesqs = []
177 | for order in permutations(range(N)):
178 | pesqs.append(pesq_avg(est, [ref[n] for n in order]))
179 |
180 | return max(pesqs)
181 |
182 |
183 | def permute_stoi(est, ref, fs=8000):
184 | """
185 | Evaluate STOI
186 | Args:
187 | est: [numpy 1D array, numpy 1D array], estimated audio signal
188 | ref: [numpy 1D array, numpy 1D array], reference audio signal
189 | fs: integer, sampling rate
190 | """
191 | assert len(est) == len(ref)
192 |
193 | def stoi_avg(est, ref):
194 | return sum([stoi(r, e, fs) for e, r in zip(est, ref)]) / len(est)
195 |
196 | N = len(est)
197 | if N != len(est):
198 | raise RuntimeError(
199 | "size do not match between est and ref: {:d} vs {:d}".format(
200 | N, len(ref)))
201 | stois = []
202 | for order in permutations(range(N)):
203 | stois.append(stoi_avg(est, [ref[n] for n in order]))
204 |
205 | return max(stois)
206 |
207 |
208 | def eval_all(mix, est, ref, fs=8000, pesq=False):
209 | """
210 | Arguments:
211 | mix: np.narray
212 | est: list[np.narray, np.narray]
213 | ref: list[np.narray, np.narray]
214 | Evaluate
215 | SISNR/SISNRi;
216 | SDR/SDRi;
217 | PESQ/STOI
218 | """
219 | sisnr, sisnri = permute_si_snri(mix, est, ref, True)
220 | sdr, sdri = cal_SDRi(mix, est, ref)
221 | if pesq:
222 | enh_pesq = permute_pesq(est, ref, fs)
223 | enh_stoi = permute_stoi(est, ref, fs)
224 | return sisnr, sisnri, sdr, sdri, enh_pesq, enh_stoi
225 | else:
226 | return sisnr, sisnri, sdr, sdri
227 |
228 | if __name__ == '__main__':
229 | # np.random.seed(20)
230 | x = np.random.rand(32000)
231 | xlist = [np.random.rand(32000), np.random.rand(32000)]
232 | slist = [np.random.rand(32000), np.random.rand(32000)]
233 | mlist = [np.random.rand(32000), np.random.rand(32000)]
234 | # print(permute_si_snr(xlist, slist))
235 | # print(permute_si_snri(x, xlist, slist))
236 | # print(permute_si_snri(x, xlist, slist, False))
237 | # rlist = reorder_list(slist, [0,1])
238 | # sdr, sir, sar, popt = bss_eval_sou1rces(np.array(slist), np.array(xlist))
239 | # sdr, sdri = cal_SDRi(x, xlist, slist)
240 | # pp = permute_pesq(xlist, slist, fs=8000)
241 | # st = permute_stoi(xlist, xlist, fs=8000)
242 | sisnr, sisnri, sdr, sdri, enh_pesq, enh_stoi = eval_all(x, xlist, slist, 8000)
243 |
244 | # print(sdr)
245 | # print(cal_sdr(np.array(xlist[0]), np.array(slist[0])))
246 | # print(cal_sdr(np.array(xlist[1]), np.array(slist[1])))
247 | # print(cal_sdr(np.array(xlist[0]), np.array(slist[1])))
248 | # print(cal_sdr(np.array(xlist[1]), np.array(slist[0])))
249 | # print(cal_sdr(np.array(xlist), np.array(slist)))
250 |
251 |
252 |
253 |
254 |
255 |
256 |
--------------------------------------------------------------------------------
/nnet/libs/trainer_unet_tse_steplr_clip.py:
--------------------------------------------------------------------------------
1 | # wujian@2018
2 |
3 | import os
4 | import sys
5 | import time
6 |
7 | # from itertools import permutations
8 | from collections import defaultdict
9 |
10 | import torch as th
11 | import torch.nn.functional as F
12 | # from torch.optim.lr_scheduler import ReduceLROnPlateau
13 | from torch.optim.lr_scheduler import StepLR
14 | from torch.nn.utils import clip_grad_norm_
15 |
16 | from .utils import get_logger
17 | # from torch.utils.tensorboard import SummaryWriter
18 |
19 | def load_obj(obj, device):
20 | """
21 | Offload tensor object in obj to cuda device
22 | """
23 |
24 | def cuda(obj):
25 | return obj.to(device) if isinstance(obj, th.Tensor) else obj
26 |
27 | if isinstance(obj, dict):
28 | return {key: load_obj(obj[key], device) for key in obj}
29 | elif isinstance(obj, list):
30 | return [load_obj(val, device) for val in obj]
31 | else:
32 | return cuda(obj)
33 |
34 |
35 | class SimpleTimer(object):
36 | """
37 | A simple timer
38 | """
39 |
40 | def __init__(self):
41 | self.reset()
42 |
43 | def reset(self):
44 | self.start = time.time()
45 |
46 | def elapsed(self):
47 | return (time.time() - self.start) / 60
48 |
49 |
50 | class ProgressReporter(object):
51 | """
52 | A simple progress reporter
53 | """
54 |
55 | def __init__(self, logger, period=100):
56 | self.period = period
57 | self.logger = logger
58 | self.loss = []
59 | self.timer = SimpleTimer()
60 | def add(self, loss):
61 | self.loss.append(loss)
62 | N = len(self.loss)
63 | if not N % self.period:
64 | avg = sum(self.loss[-self.period:]) / self.period
65 | self.logger.info("Processed {:d} batches"
66 | "(loss = {:+.2f})...".format(N, avg))
67 | # self.loss_writer.add_scalar('Loss/train', avg, N)
68 | def report(self, details=False):
69 | N = len(self.loss)
70 | if details:
71 | sstr = ",".join(map(lambda f: "{:.2f}".format(f), self.loss))
72 | self.logger.info("Loss on {:d} batches: {}".format(N, sstr))
73 | return {
74 | "loss": sum(self.loss) / N,
75 | "batches": N,
76 | "cost": self.timer.elapsed()
77 | }
78 |
79 | class Trainer(object):
80 | def __init__(self,
81 | nnet,
82 | checkpoint="checkpoint",
83 | optimizer="adam",
84 | gpuid=0,
85 | optimizer_kwargs=None,
86 | clip_norm=1.0,
87 | min_lr=0,
88 | patience=0,
89 | factor=0.5,
90 | logging_period=100,
91 | resume=None,
92 | no_impr=150):
93 | if not th.cuda.is_available():
94 | raise RuntimeError("CUDA device unavailable...exist")
95 | if not isinstance(gpuid, tuple):
96 | gpuid = (gpuid, )
97 | self.device = th.device("cuda:{}".format(gpuid[0]))
98 | self.gpuid = gpuid
99 | if checkpoint and not os.path.exists(checkpoint):
100 | os.makedirs(checkpoint)
101 | self.checkpoint = checkpoint
102 | self.logger = get_logger(
103 | os.path.join(checkpoint, "trainer.log"), file=True)
104 |
105 | self.clip_norm = clip_norm
106 | self.logging_period = logging_period
107 | self.cur_epoch = 0 # zero based
108 | self.no_impr = no_impr
109 |
110 | if resume:
111 | if not os.path.exists(resume):
112 | raise FileNotFoundError(
113 | "Could not find resume checkpoint: {}".format(resume))
114 | cpt = th.load(resume, map_location="cpu")
115 | self.cur_epoch = cpt["epoch"]
116 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format(
117 | resume, self.cur_epoch))
118 | # load nnet
119 | nnet.load_state_dict(cpt["model_state_dict"])
120 | self.nnet = nnet.to(self.device)
121 | self.optimizer = self.create_optimizer(
122 | optimizer, optimizer_kwargs, state=cpt["optim_state_dict"])
123 | else:
124 | self.nnet = nnet.to(self.device)
125 | self.optimizer = self.create_optimizer(optimizer, optimizer_kwargs)
126 | # self.scheduler = ReduceLROnPlateau(
127 | # self.optimizer,
128 | # mode="min",
129 | # factor=factor,
130 | # patience=patience,
131 | # min_lr=min_lr,
132 | # verbose=True)
133 | self.scheduler1 = StepLR(self.optimizer, step_size=2, gamma=0.98)
134 | self.scheduler2 = StepLR(self.optimizer, step_size=1, gamma=0.9)
135 |
136 | self.num_params = sum(
137 | [param.nelement() for param in nnet.parameters()]) / 10.0**6
138 |
139 | # logging
140 | self.logger.info("Model summary:\n{}".format(nnet))
141 | self.logger.info("Loading model to GPUs:{}, #param: {:.2f}M".format(
142 | gpuid, self.num_params))
143 | if clip_norm > 0:
144 | self.logger.info(
145 | "Gradient clipping by {}, default L2".format(clip_norm))
146 |
147 | def save_checkpoint(self, best=True):
148 | cpt = {
149 | "epoch": self.cur_epoch,
150 | "model_state_dict": self.nnet.state_dict(),
151 | "optim_state_dict": self.optimizer.state_dict()
152 | }
153 | th.save(
154 | cpt,
155 | os.path.join(self.checkpoint,
156 | "{0}.pt.tar".format("best" if best else "last")))
157 |
158 | def save_every_checkpoint(self, idx):
159 | cpt = {
160 | "epoch": self.cur_epoch,
161 | "model_state_dict": self.nnet.state_dict(),
162 | "optim_state_dict": self.optimizer.state_dict()
163 | }
164 | th.save(cpt, os.path.join(self.checkpoint,
165 | "{0}.pt.tar".format(str(idx))))
166 |
167 | def create_optimizer(self, optimizer, kwargs, state=None):
168 | supported_optimizer = {
169 | "sgd": th.optim.SGD, # momentum, weight_decay, lr
170 | "rmsprop": th.optim.RMSprop, # momentum, weight_decay, lr
171 | "adam": th.optim.Adam, # weight_decay, lr
172 | "adadelta": th.optim.Adadelta, # weight_decay, lr
173 | "adagrad": th.optim.Adagrad, # lr, lr_decay, weight_decay
174 | "adamax": th.optim.Adamax # lr, weight_decay
175 | # ...
176 | }
177 | if optimizer not in supported_optimizer:
178 | raise ValueError("Now only support optimizer {}".format(optimizer))
179 | opt = supported_optimizer[optimizer](self.nnet.parameters(), **kwargs)
180 | self.logger.info("Create optimizer {0}: {1}".format(optimizer, kwargs))
181 | if state is not None:
182 | opt.load_state_dict(state)
183 | self.logger.info("Load optimizer state dict from checkpoint")
184 | return opt
185 |
186 | def compute_loss(self, egs):
187 | raise NotImplementedError
188 |
189 | def train(self, data_loader):
190 | self.logger.info("Set train mode...")
191 | self.nnet.train()
192 | reporter = ProgressReporter(self.logger, period=self.logging_period)
193 |
194 | for egs in data_loader:
195 | # load to gpu
196 | egs = load_obj(egs, self.device)
197 |
198 | self.optimizer.zero_grad()
199 | loss = self.compute_loss(egs)
200 | loss.backward()
201 |
202 | if self.clip_norm > 0:
203 | clip_grad_norm_(self.nnet.parameters(), self.clip_norm)
204 | self.optimizer.step()
205 |
206 | reporter.add(loss.item())
207 | return reporter.report()
208 |
209 | def eval(self, data_loader):
210 | self.logger.info("Set eval mode...")
211 | self.nnet.eval()
212 | reporter = ProgressReporter(self.logger, period=self.logging_period)
213 |
214 | with th.no_grad():
215 | for egs in data_loader:
216 | egs = load_obj(egs, self.device)
217 | loss = self.compute_loss(egs)
218 | reporter.add(loss.item())
219 | return reporter.report(details=True)
220 |
221 | def run(self, train_loader, dev_loader, num_epochs=120):
222 | # avoid alloc memory from gpu0
223 | reporter = ProgressReporter(self.logger, period=self.logging_period)
224 | with th.cuda.device(self.gpuid[0]):
225 | stats = dict()
226 | # check if save is OK
227 | self.save_checkpoint(best=False)
228 | cv = self.eval(dev_loader)
229 | best_loss = cv["loss"]
230 | self.logger.info("START FROM EPOCH {:d}, LOSS = {:.4f}".format(
231 | self.cur_epoch, best_loss))
232 | no_impr = 0
233 | # make sure not inf
234 | # self.scheduler.best = best_loss
235 | while self.cur_epoch < num_epochs:
236 | self.cur_epoch += 1
237 | cur_lr = self.optimizer.param_groups[0]["lr"]
238 | stats[
239 | "title"] = "Loss(time/N, lr={:.3e}) - Epoch {:2d}:".format(
240 | cur_lr, self.cur_epoch)
241 | tr = self.train(train_loader)
242 | stats["tr"] = "train = {:+.4f}({:.2f}m/{:d})".format(
243 | tr["loss"], tr["cost"], tr["batches"])
244 | cv = self.eval(dev_loader)
245 | stats["cv"] = "dev = {:+.4f}({:.2f}m/{:d})".format(
246 | cv["loss"], cv["cost"], cv["batches"])
247 | stats["scheduler"] = ""
248 | if cv["loss"] > best_loss:
249 | no_impr += 1
250 | stats["scheduler"] = "| no impr, best = {:.4f}".format(
251 | cv["loss"])
252 | else:
253 | best_loss = cv["loss"]
254 | no_impr = 0
255 | self.save_checkpoint(best=True)
256 | if self.cur_epoch == 90 or self.cur_epoch>= 100:
257 | self.save_every_checkpoint(self.cur_epoch)
258 | self.logger.info(
259 | "{title} {tr} | {cv} {scheduler}".format(**stats))
260 | # schedule here
261 | # self.scheduler.step(cv["loss"])
262 | if self.cur_epoch <= 100:
263 | self.scheduler1.step()
264 | else:
265 | self.scheduler2.step()
266 | # flush scheduler info
267 | sys.stdout.flush()
268 | # save last checkpoint
269 | self.save_checkpoint(best=False)
270 |
271 | if no_impr == self.no_impr:
272 | self.logger.info(
273 | "Stop training cause no impr for {:d} epochs".format(
274 | no_impr))
275 | break
276 |
277 |
278 | self.logger.info("Training for {:d}/{:d} epoches done!".format(
279 | self.cur_epoch, num_epochs))
280 | # reporter.loss_writer.close()
281 |
282 | class SiSnrTrainer(Trainer):
283 | def __init__(self, *args, **kwargs):
284 | super(SiSnrTrainer, self).__init__(*args, **kwargs)
285 |
286 | def sisnr(self, x, s, eps=1e-8):
287 | """
288 | Arguments:
289 | x: separated signal, N x S tensor
290 | s: reference signal, N x S tensor
291 | Return:
292 | sisnr: N tensor
293 | """
294 |
295 | def l2norm(mat, keepdim=False):
296 | return th.norm(mat, dim=-1, keepdim=keepdim)
297 |
298 | if x.shape != s.shape:
299 | raise RuntimeError(
300 | "Dimention mismatch when calculate si-snr, {} vs {}".format(
301 | x.shape, s.shape))
302 | x_zm = x - th.mean(x, dim=-1, keepdim=True)
303 | s_zm = s - th.mean(s, dim=-1, keepdim=True)
304 | t = th.sum(
305 | x_zm * s_zm, dim=-1,
306 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
307 | return 20 * th.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
308 |
309 | def mask_by_length(self, xs, lengths, fill=0):
310 | """
311 | Mask tensor according to length
312 | """
313 | assert xs.size(0) == len(lengths)
314 | ret = xs.data.new(*xs.size()).fill_(fill)
315 | for i, l in enumerate(lengths):
316 | ret[i, :l] = xs[i, :l]
317 | return ret
318 |
319 | def compute_loss(self, egs):
320 | N = egs["mix"].size(0)
321 |
322 | # spks x n x S
323 | nnet_load = th.nn.DataParallel(self.nnet, device_ids=self.gpuid)
324 | ests = nnet_load(egs["mix"], egs["aux"])
325 |
326 | refs = egs['ref']
327 | # N = egs["mix"].size(0)
328 | valid_len = egs["valid_len"]
329 | ests = self.mask_by_length(ests, valid_len)
330 | refs = self.mask_by_length(refs, valid_len)
331 | sisnr_loss = -th.sum(self.sisnr(ests, refs)) / N
332 |
333 | return sisnr_loss
334 |
--------------------------------------------------------------------------------
/nnet/libs/utils.py:
--------------------------------------------------------------------------------
1 | # wujian@2018
2 |
3 | import os
4 | import json
5 | import logging
6 |
7 |
8 | def get_logger(
9 | name,
10 | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
11 | date_format="%Y-%m-%d %H:%M:%S",
12 | file=False):
13 | """
14 | Get python logger instance
15 | """
16 | logger = logging.getLogger(name)
17 | logger.setLevel(logging.INFO)
18 | # file or console
19 | handler = logging.StreamHandler() if not file else logging.FileHandler(
20 | name)
21 | handler.setLevel(logging.INFO)
22 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
23 | handler.setFormatter(formatter)
24 | logger.addHandler(handler)
25 | return logger
26 |
27 |
28 | def dump_json(obj, fdir, name):
29 | """
30 | Dump python object in json
31 | """
32 | if fdir and not os.path.exists(fdir):
33 | os.makedirs(fdir)
34 | with open(os.path.join(fdir, name), "w") as f:
35 | json.dump(obj, f, indent=4, sort_keys=False)
36 |
37 |
38 | def load_json(fdir, name):
39 | """
40 | Load json as python object
41 | """
42 | path = os.path.join(fdir, name)
43 | if not os.path.exists(path):
44 | raise FileNotFoundError("Could not find json file: {}".format(path))
45 | with open(path, "r") as f:
46 | obj = json.load(f)
47 | return obj
--------------------------------------------------------------------------------
/nnet/memonger/__init__.py:
--------------------------------------------------------------------------------
1 | from .memonger import SublinearSequential
--------------------------------------------------------------------------------
/nnet/memonger/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/memonger/__pycache__/checkpoint.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/checkpoint.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/memonger/__pycache__/memonger.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/memonger.cpython-39.pyc
--------------------------------------------------------------------------------
/nnet/memonger/checkpoint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function, unicode_literals
2 | import torch
3 | import warnings
4 |
5 |
6 | def detach_variable(inputs):
7 | if isinstance(inputs, tuple):
8 | out = []
9 | for inp in inputs:
10 | x = inp.detach()
11 | x.requires_grad = inp.requires_grad
12 | out.append(x)
13 | return tuple(out)
14 | else:
15 | raise RuntimeError(
16 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
17 |
18 |
19 | def check_backward_validity(inputs):
20 | if not any(inp.requires_grad for inp in inputs):
21 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
22 |
23 |
24 | # Global switch to toggle whether or not checkpointed passes stash and restore
25 | # the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
26 | # output compared to non-checkpointed passes.
27 | preserve_rng_state = True
28 |
29 |
30 | class CheckpointFunction(torch.autograd.Function):
31 |
32 | @staticmethod
33 | def forward(ctx, run_function, *args):
34 | check_backward_validity(args)
35 | ctx.run_function = run_function
36 | if preserve_rng_state:
37 | # We can't know if the user will transfer some args from the host
38 | # to the device during their run_fn. Therefore, we stash both
39 | # the cpu and cuda rng states unconditionally.
40 | #
41 | # TODO:
42 | # We also can't know if the run_fn will internally move some args to a device
43 | # other than the current device, which would require logic to preserve
44 | # rng states for those devices as well. We could paranoically stash and restore
45 | # ALL the rng states for all visible devices, but that seems very wasteful for
46 | # most cases.
47 | ctx.fwd_cpu_rng_state = torch.get_rng_state()
48 | # Don't eagerly initialize the cuda context by accident.
49 | # (If the user intends that the context is initialized later, within their
50 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
51 | # we have no way to anticipate this will happen before we run the function.)
52 | ctx.had_cuda_in_fwd = False
53 | if torch.cuda._initialized:
54 | ctx.had_cuda_in_fwd = True
55 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
56 | ctx.save_for_backward(*args)
57 | with torch.no_grad():
58 | outputs = run_function(*args)
59 | return outputs
60 |
61 | @staticmethod
62 | def backward(ctx, *args):
63 | if not torch.autograd._is_checkpoint_valid():
64 | raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
65 | inputs = ctx.saved_tensors
66 | # Stash the surrounding rng state, and mimic the state that was
67 | # present at this time during forward. Restore the surrouding state
68 | # when we're done.
69 | rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
70 | with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
71 | if preserve_rng_state:
72 | torch.set_rng_state(ctx.fwd_cpu_rng_state)
73 | if ctx.had_cuda_in_fwd:
74 | torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
75 | detached_inputs = detach_variable(inputs)
76 | with torch.enable_grad():
77 | outputs = ctx.run_function(*detached_inputs)
78 |
79 | if isinstance(outputs, torch.Tensor):
80 | outputs = (outputs,)
81 | torch.autograd.backward(outputs, args)
82 | return (None,) + tuple(inp.grad for inp in detached_inputs)
83 |
84 |
85 | def checkpoint(function, *args):
86 | r"""Checkpoint a model or part of the model
87 |
88 | Checkpointing works by trading compute for memory. Rather than storing all
89 | intermediate activations of the entire computation graph for computing
90 | backward, the checkpointed part does **not** save intermediate activations,
91 | and instead recomputes them in backward pass. It can be applied on any part
92 | of a model.
93 |
94 | Specifically, in the forward pass, :attr:`function` will run in
95 | :func:`torch.no_grad` manner, i.e., not storing the intermediate
96 | activations. Instead, the forward pass saves the inputs tuple and the
97 | :attr:`function` parameter. In the backwards pass, the saved inputs and
98 | :attr:`function` is retreived, and the forward pass is computed on
99 | :attr:`function` again, now tracking the intermediate activations, and then
100 | the gradients are calculated using these activation values.
101 |
102 | .. warning::
103 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
104 | with :func:`torch.autograd.backward`.
105 |
106 | .. warning::
107 | If :attr:`function` invocation during backward does anything different
108 | than the one during forward, e.g., due to some global variable, the
109 | checkpointed version won't be equivalent, and unfortunately it can't be
110 | detected.
111 |
112 | .. warning:
113 | At least one of the inputs needs to have :code:`requires_grad=True` if
114 | grads are needed for model inputs, otherwise the checkpointed part of the
115 | model won't have gradients.
116 |
117 | Args:
118 | function: describes what to run in the forward pass of the model or
119 | part of the model. It should also know how to handle the inputs
120 | passed as the tuple. For example, in LSTM, if user passes
121 | ``(activation, hidden)``, :attr:`function` should correctly use the
122 | first input as ``activation`` and the second input as ``hidden``
123 | args: tuple containing inputs to the :attr:`function`
124 |
125 | Returns:
126 | Output of running :attr:`function` on :attr:`*args`
127 | """
128 | return CheckpointFunction.apply(function, *args)
129 |
130 |
131 | def checkpoint_sequential(functions, segments, *inputs):
132 | r"""A helper function for checkpointing sequential models.
133 |
134 | Sequential models execute a list of modules/functions in order
135 | (sequentially). Therefore, we can divide such a model in various segments
136 | and checkpoint each segment. All segments except the last will run in
137 | :func:`torch.no_grad` manner, i.e., not storing the intermediate
138 | activations. The inputs of each checkpointed segment will be saved for
139 | re-running the segment in the backward pass.
140 |
141 | See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
142 |
143 | .. warning::
144 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only
145 | with :func:`torch.autograd.backward`.
146 |
147 | .. warning:
148 | At least one of the inputs needs to have :code:`requires_grad=True` if
149 | grads are needed for model inputs, otherwise the checkpointed part of the
150 | model won't have gradients.
151 |
152 | Args:
153 | functions: A :class:`torch.nn.Sequential` or the list of modules or
154 | functions (comprising the model) to run sequentially.
155 | segments: Number of chunks to create in the model
156 | inputs: tuple of Tensors that are inputs to :attr:`functions`
157 |
158 | Returns:
159 | Output of running :attr:`functions` sequentially on :attr:`*inputs`
160 |
161 | Example:
162 | >>> model = nn.Sequential(...)
163 | >>> input_var = checkpoint_sequential(model, chunks, input_var)
164 | """
165 |
166 | def run_function(start, end, functions):
167 | def forward(*inputs):
168 | for j in range(start, end + 1):
169 | if isinstance(inputs, tuple):
170 | inputs = functions[j](*inputs)
171 | else:
172 | inputs = functions[j](inputs)
173 | return inputs
174 | return forward
175 |
176 | if isinstance(functions, torch.nn.Sequential):
177 | functions = list(functions.children())
178 |
179 | segment_size = len(functions) // segments
180 | # the last chunk has to be non-volatile
181 | end = -1
182 | for start in range(0, segment_size * (segments - 1), segment_size):
183 | end = start + segment_size - 1
184 | inputs = checkpoint(run_function(start, end, functions), *inputs)
185 | if not isinstance(inputs, tuple):
186 | inputs = (inputs,)
187 | return run_function(end + 1, len(functions) - 1, functions)(*inputs)
188 |
--------------------------------------------------------------------------------
/nnet/memonger/memonger.py:
--------------------------------------------------------------------------------
1 | from math import sqrt, log
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.modules.batchnorm import _BatchNorm
7 |
8 | from .checkpoint import checkpoint
9 |
10 |
11 | def reforwad_momentum_fix(origin_momentum):
12 | return (1 - sqrt(1 - origin_momentum))
13 |
14 |
15 | class SublinearSequential(nn.Sequential):
16 | def __init__(self, *args):
17 | super(SublinearSequential, self).__init__(*args)
18 | self.reforward = False
19 | self.momentum_dict = {}
20 | self.set_reforward(True)
21 |
22 | def set_reforward(self, enabled=True):
23 | if not self.reforward and enabled:
24 | print("Rescale BN Momemtum for re-forwarding purpose")
25 | for n, m in self.named_modules():
26 | if isinstance(m, _BatchNorm):
27 | self.momentum_dict[n] = m.momentum
28 | m.momentum = reforwad_momentum_fix(self.momentum_dict[n])
29 | if self.reforward and not enabled:
30 | print("Re-store BN Momemtum")
31 | for n, m in self.named_modules():
32 | if isinstance(m, _BatchNorm):
33 | m.momentum = self.momentum_dict[n]
34 | self.reforward = enabled
35 |
36 | def forward(self, input):
37 | if self.reforward:
38 | return self.sublinear_forward(input)
39 | else:
40 | return self.normal_forward(input)
41 |
42 | def normal_forward(self, input):
43 | for module in self._modules.values():
44 | input = module(input)
45 | return input
46 |
47 | def sublinear_forward(self, input):
48 | def run_function(start, end, functions):
49 | def forward(*inputs):
50 | input = inputs[0]
51 | for j in range(start, end + 1):
52 | input = functions[j](input)
53 | return input
54 |
55 | return forward
56 |
57 | functions = list(self.children())
58 | segments = int(sqrt(len(functions)))
59 | segment_size = len(functions) // segments
60 | # the last chunk has to be non-volatile
61 | end = -1
62 | if not isinstance(input, tuple):
63 | inputs = (input,)
64 | for start in range(0, segment_size * (segments - 1), segment_size):
65 | end = start + segment_size - 1
66 | inputs = checkpoint(run_function(start, end, functions), *inputs)
67 | if not isinstance(inputs, tuple):
68 | inputs = (inputs,)
69 | # output = run_function(end + 1, len(functions) - 1, functions)(*inputs)
70 | output = checkpoint(run_function(end + 1, len(functions) - 1, functions), *inputs)
71 | return output
72 |
--------------------------------------------------------------------------------
/nnet/memonger/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .memonger import SublinearSequential
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, in_planes, planes, stride=1):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 | self.bn1 = nn.BatchNorm2d(planes)
22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 |
25 | self.shortcut = nn.Sequential()
26 | if stride != 1 or in_planes != self.expansion*planes:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
29 | nn.BatchNorm2d(self.expansion*planes)
30 | )
31 |
32 | def forward(self, x):
33 | out = F.relu(self.bn1(self.conv1(x)))
34 | out = self.bn2(self.conv2(out))
35 | out += self.shortcut(x)
36 | out = F.relu(out)
37 | return out
38 |
39 |
40 | class Bottleneck(nn.Module):
41 | expansion = 4
42 |
43 | def __init__(self, in_planes, planes, stride=1):
44 | super(Bottleneck, self).__init__()
45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
46 | self.bn1 = nn.BatchNorm2d(planes)
47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
51 |
52 | self.shortcut = nn.Sequential()
53 | if stride != 1 or in_planes != self.expansion*planes:
54 | self.shortcut = nn.Sequential(
55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
56 | nn.BatchNorm2d(self.expansion*planes)
57 | )
58 |
59 | def forward(self, x):
60 | out = F.relu(self.bn1(self.conv1(x)))
61 | out = F.relu(self.bn2(self.conv2(out)))
62 | out = self.bn3(self.conv3(out))
63 | out += self.shortcut(x)
64 | out = F.relu(out)
65 | return out
66 |
67 |
68 | class ResNet(nn.Module):
69 | def __init__(self, block, num_blocks, num_classes=100):
70 | super(ResNet, self).__init__()
71 | self.in_planes = 64
72 |
73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
74 | self.bn1 = nn.BatchNorm2d(64)
75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
78 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
79 | self.linear = nn.Linear(512*block.expansion, num_classes)
80 |
81 | def _make_layer(self, block, planes, num_blocks, stride):
82 | strides = [stride] + [1]*(num_blocks-1)
83 | layers = []
84 | for stride in strides:
85 | layers.append(block(self.in_planes, planes, stride))
86 | self.in_planes = planes * block.expansion
87 | return SublinearSequential(*layers)
88 |
89 | def forward(self, x):
90 | out = F.relu(self.bn1(self.conv1(x)))
91 | out = self.layer1(out)
92 | out = self.layer2(out)
93 | out = self.layer3(out)
94 | out = self.layer4(out)
95 | out = F.avg_pool2d(out, 4)
96 | out = out.view(out.size(0), -1)
97 | out = self.linear(out)
98 | return out
99 |
100 |
101 | def ResNet18():
102 | return ResNet(BasicBlock, [2,2,2,2])
103 |
104 | def ResNet34():
105 | return ResNet(BasicBlock, [3,4,6,3])
106 |
107 | def ResNet50():
108 | return ResNet(Bottleneck, [3,4,6,3])
109 |
110 | def ResNet101():
111 | return ResNet(Bottleneck, [3,4,23,3])
112 |
113 | def ResNet152():
114 | return ResNet(Bottleneck, [3,8,36,3])
115 |
116 |
117 | def test():
118 | net = ResNet18()
119 | y = net(torch.randn(1,3,32,32))
120 | print(y.size())
121 |
122 | # test()
123 |
--------------------------------------------------------------------------------
/nnet/separate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import time
5 | import argparse
6 | import torch as th
7 | import numpy as np
8 | from SEF_PNet_pse import SEF_PNet
9 | from libs.utils import load_json, get_logger
10 | from libs.audio import write_wav
11 | from libs.dataset_tse import Dataset
12 |
13 | def run(args):
14 | start = time.time()
15 | logger = get_logger(
16 | os.path.join(args.checkpoint, 'separate.log'), file=True)
17 | dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp, sample_rate=args.fs)
18 |
19 | # Load model
20 | nnet_conf = load_json(args.checkpoint, "mdl.json")
21 | nnet = SEF_PNet(**nnet_conf)
22 | cpt_fname = os.path.join(args.checkpoint, "best.pt.tar")
23 | cpt = th.load(cpt_fname, map_location="cpu")
24 | nnet.load_state_dict(cpt["model_state_dict"])
25 | logger.info("Load checkpoint from {}, epoch {:d}".format(
26 | cpt_fname, cpt["epoch"]))
27 |
28 | device = th.device(
29 | "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu")
30 | nnet = nnet.to(device) if args.gpuid >= 0 else nnet
31 | nnet.eval()
32 |
33 | with th.no_grad():
34 | total_cnt = 0
35 | for i, data in enumerate(dataset):
36 | mix = th.tensor(data['mix'], dtype=th.float32, device=device)
37 | aux = th.tensor(data['aux'], dtype=th.float32, device=device)
38 | key = data['key']
39 | if args.gpuid >= 0:
40 | mix = mix.unsqueeze(0).to(device)
41 | aux = aux.unsqueeze(0).to(device)
42 |
43 | # Forward
44 | ests = nnet(mix, aux)
45 | ests = ests.cpu().numpy()
46 | norm = np.linalg.norm(mix.cpu().numpy(), np.inf)
47 | ests = ests[:mix.shape[-1]]
48 | # for each utts
49 | logger.info("Separate Utt{:d}".format(total_cnt + 1))
50 | # norm
51 | ests = ests*norm/np.max(np.abs(ests))
52 |
53 | fname = key + '.wav'
54 | write_wav(os.path.join(args.dump_dir, fname),
55 | ests, fs=args.fs)
56 | total_cnt += 1
57 |
58 | end = time.time()
59 | logger.info('Utt={:d} | Time Elapsed: {:.1f}s'.format(total_cnt, end-start))
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser('Separating speech...')
63 | parser.add_argument("--checkpoint", type=str, required=True,
64 | help="Directory of checkpoint")
65 | parser.add_argument("--gpuid", type=int, default=-1,
66 | help="GPU device to offload model to, -1 means running on CPU")
67 | parser.add_argument('--mix_scp', type=str, required=True,
68 | help='mix scp')
69 | parser.add_argument('--ref_scp', type=str, required=True,
70 | help='ref scp')
71 | parser.add_argument('--aux_scp', type=str, required=True,
72 | help='aux scp')
73 | parser.add_argument('--fs', type=int, default=8000,
74 | help="Sample rate for mixture input")
75 | parser.add_argument('--dump-dir', type=str, default="/node/hzl/expriment/SEF_PNet_icassp2025_github/results",
76 | help="Directory to dump separated results out")
77 | args = parser.parse_args()
78 | run(args)
79 |
--------------------------------------------------------------------------------
/nnet/train_unet_tse_steplr_clip.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import pprint
4 | import argparse
5 | from libs.trainer_unet_tse_steplr_clip import SiSnrTrainer
6 | from libs.dataset_tse import make_dataloader
7 | from libs.utils import dump_json, get_logger
8 | from SEF_PNet_pse import SEF_PNet
9 | from conf_unet_tse_32ms import trainer_conf, nnet_conf, train_data, dev_data, chunk_size
10 |
11 | logger = get_logger(__name__)
12 |
13 | def run(args):
14 | gpuids = tuple(map(int, args.gpus.split(",")))
15 | nnet = SEF_PNet(**nnet_conf)
16 | trainer = SiSnrTrainer(nnet,
17 | gpuid=gpuids,
18 | checkpoint=args.checkpoint,
19 | resume=args.resume,
20 | **trainer_conf)
21 |
22 | data_conf = {
23 | "train": train_data,
24 | "dev": dev_data,
25 | "chunk_size": chunk_size
26 | }
27 |
28 | for conf, fname in zip([nnet_conf, trainer_conf, data_conf],
29 | ["mdl.json", "trainer.json", "data.json"]):
30 | dump_json(conf, args.checkpoint, fname)
31 |
32 | train_loader = make_dataloader(train=True,
33 | data_kwargs=train_data,
34 | batch_size=args.batch_size,
35 | chunk_size=chunk_size,
36 | num_workers=args.num_workers)
37 | dev_loader = make_dataloader(train=False,
38 | data_kwargs=dev_data,
39 | batch_size=args.batch_size,
40 | chunk_size=chunk_size,
41 | num_workers=args.num_workers)
42 | trainer.run(train_loader, dev_loader, num_epochs=args.epochs)
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser(
46 | description=
47 | "Command to start ConvTasNet training, configured from conf.py",
48 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
49 | parser.add_argument("--gpus",
50 | type=str,
51 | default="0",
52 | help="Training on which GPUs "
53 | "(one or more, egs: 0, \"0,1\")")
54 | parser.add_argument("--epochs",
55 | type=int,
56 | default=200,
57 | # default=500,
58 | help="Number of training epochs")
59 | parser.add_argument("--checkpoint",
60 | type=str,
61 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/demo',
62 | #required=True,
63 | help="Directory to dump models")
64 | parser.add_argument("--resume",
65 | type=str,
66 | default=None,
67 | help="Exist model to resume training from")
68 | parser.add_argument("--batch-size",
69 | type=int,
70 | default=32,
71 | help="Number of utterances in each batch")
72 | parser.add_argument("--num-workers",
73 | type=int,
74 | default=32,
75 | help="Number of workers used in data loader")
76 | args = parser.parse_args()
77 | logger.info("Arguments in command:\n{}".format(pprint.pformat(vars(args))))
78 |
79 | run(args)
80 | print("train Done!")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.3.0
2 | numpy==1.22.4
3 | mir_eval==0.7
4 | pesq==0.0.4
5 | pypesq @ https://github.com/vBaiCai/python-pesq/archive/master.zip#sha256=fba27c3d95e8f72fed7c55f675ce6057a64b26a1a67a2e469df2804cca69b8cc
6 | pystoi==0.3.3
7 | soundfile==0.12.1
8 | librosa==0.10.1
--------------------------------------------------------------------------------
/separate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eu
3 | checkpoint=/node/hzl/expriment/libri2mix_min_wav8k/SEF_PNet
4 | gpuid=0
5 | data_root=/node/hzl/data/data_libri2mix_s1_min_wav8k/test
6 |
7 | mix_scp=$data_root/mix_clean.scp
8 | ref_scp=$data_root/s1.scp
9 | aux_scp=$data_root/auxs1.scp
10 |
11 | fs=8000
12 | dump_dir=/node/hzl/data/enhanced_speech
13 |
14 | ./nnet/separate.py \
15 | --checkpoint $checkpoint \
16 | --gpuid $gpuid \
17 | --mix_scp $mix_scp \
18 | --ref_scp $ref_scp \
19 | --aux_scp $aux_scp \
20 | --fs $fs \
21 | --dump-dir $dump_dir \
22 | > separate.log 2>&1
23 |
24 | echo "Separate done!"
25 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -eu
3 | epochs=200
4 | # constrainted by GPU number & memory
5 | batch_size=32
6 | gpuid=0
7 | num_workers=32
8 | cpt_dir=/node/hzl/expriment/SEF_PNet_icassp2025_github/demo
9 | #resume=
10 | #[ $# -ne 1 ] && echo "Script error: $0 " && exit 1
11 | ./nnet/train_unet_tse_steplr_clip.py \
12 | --gpu $gpuid \
13 | --epochs $epochs \
14 | --batch-size $batch_size \
15 | --num-workers $num_workers \
16 | --checkpoint $cpt_dir \
17 | > train.log 2>&1
18 |
--------------------------------------------------------------------------------