├── README.md ├── data ├── dev │ ├── auxs1.scp │ ├── mix_clean.scp │ └── ref.scp ├── test │ ├── auxs1.scp │ ├── mix_clean.scp │ └── ref.scp └── train │ ├── auxs1.scp │ ├── mix_clean.scp │ └── ref.scp ├── eval.sh ├── nnet ├── SEF_PNet_pse.py ├── __pycache__ │ ├── SEF_PNet_pse.cpython-39.pyc │ └── conf_unet_tse_32ms.cpython-39.pyc ├── conf_unet_tse_32ms.py ├── evaluate.py ├── libs │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── audio.cpython-39.pyc │ │ ├── conv_stft.cpython-39.pyc │ │ ├── dataset_tse.cpython-39.pyc │ │ ├── trainer_unet_tse_steplr_clip.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── audio.py │ ├── conv_stft.py │ ├── dataset_tse.py │ ├── metric.py │ ├── trainer_unet_tse_steplr_clip.py │ └── utils.py ├── memonger │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── checkpoint.cpython-39.pyc │ │ └── memonger.cpython-39.pyc │ ├── checkpoint.py │ ├── memonger.py │ └── resnet.py ├── separate.py └── train_unet_tse_steplr_clip.py ├── requirements.txt ├── separate.sh └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # SEF-PNet 2 | 3 | Official PyTorch implementation of the paper "[SEF-PNet: Speaker Encoder-Free Personalized Speech Enhancement with Local and Global Contexts Aggregation](https://arxiv.org/abs/2501.11274)" in ICASSP 2025. 4 | 5 | ## Dataset 6 | [Libri2Mix](https://github.com/JorisCos/LibriMix) min wav8k dataset. The `Data` folder contains three subfolders: `train`, `dev`, and `test`. Each subfolder includes three files: 7 | - `mix_clean.scp`: Clean mixtures of 2 speakers. 8 | - `ref.scp`: Target speaker’s speech. 9 | - `auxs1.scp`: Enrollment speech from the target speaker, which is different from the target speaker’s speech in the mixture. 10 | 11 | The `mix_clean.scp` corresponds to the **2-speaker** scenario in the results section. 12 | Note that in this dataset, only the first speaker in the mixed speech is considered the target speaker. 13 | Make sure to update the file paths in the `scp` files to match your local data locations. Also, remember to update the data paths in `conf_unet_tse_32ms.py` accordingly. 14 | 15 | ## Training 16 | - **`train.sh`**: Shell script that initiates training by setting parameters (e.g., epochs, batch size, GPU settings) and calling the Python script (`train_unet_tse_steplr_clip.py`). To train the model, run: 17 | ```bash 18 | ./train.sh 19 | 20 | - **`train_unet_tse_steplr_clip.py`**: Main Python script for training. It initializes the model, sets up data loaders, and manages the training loop. 21 | 22 | - **`conf_unet_tse_32ms.py`**: Configuration file containing model architecture, data paths, and training hyperparameters. 23 | 24 | - **`SEF_PNet_pse.py`**: Defines the `SEF_PNet` model, which is used in the training script. 25 | 26 | ## Evaluation 27 | 28 | To evaluate the model, use the provided `eval.sh` script. It sets the necessary parameters (e.g., model checkpoint, GPU ID, data paths) and calls `evaluate.py` for performance evaluation. 29 | 30 | - **`eval.sh`**: Runs the evaluation by setting paths and calling `evaluate.py`. 31 | - Usage: 32 | ```bash 33 | ./eval.sh 34 | ``` 35 | 36 | - **`evaluate.py`**: Evaluates the model on the test set, computing metrics like SDR, SI-SNR, PESQ, and STOI. 37 | 38 | ## Results 39 | 40 | Condition-wise results on three Libri2Mix PSE tasks: 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 |
ConditionMethodMetrics
SI-SDRPESQSTOI
1-speaker+noiseMixture3.271.7579.51
sDPCCN14.493.0492.47
SEF-PNet14.503.0592.47
2-speakerMixture-0.031.6071.38
sDPCCN11.622.7687.19
SEF-PNet13.003.0589.71
2-speaker+noiseMixture-2.031.4364.65
sDPCCN6.932.1279.32
SEF-PNet7.542.1480.58
115 | 116 | ### GPU Setup 117 | This code is designed to run on a single GPU. By default, in the `train.sh` script, the `gpuid` is set to `0`. 118 | 119 | To use multiple GPUs, modify `gpuid=0,1,2,...` in `train.sh`. 120 | 121 | Additionally, for multi-GPU setups, comment out the line: 122 | ```python 123 | from memonger import SublinearSequential 124 | ``` 125 | and replace SublinearSequential with nn.Sequential in SEF_PNet_pse.py to avoid memory issues. 126 | 127 | ### Create SCP 128 | The SCP file I provided is from [DPCCN](https://github.com/jyhan03/icassp22-dataset/tree/main/lst/libri2mix). It only uses the first speaker as the target. To match MC-Spex results for the 2-speaker condition in Libri2Mix, you'll need to use double the data, with two speakers taking turns as the target. This means you’ll need to recreate the SCP files for training, validation, and testing. You can use the script in the link for reference. 129 | 130 | Any problems, contact me at hzlkycg111@163.com, and a reply will be given promptly. 131 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | 4 | checkpoint=/node/hzl/expriment/libri2mix_min_wav8k/SEF_PNet 5 | gpuid=0 6 | 7 | data_root=/node/hzl/data/data_libri2mix_s1_min_wav8k/test 8 | 9 | mix_scp=$data_root/mix_clean.scp 10 | spk1_scp=$data_root/s1.scp 11 | aux_scp=$data_root/auxs1.scp 12 | 13 | cal_sdr=1 14 | 15 | ./nnet/evaluate.py \ 16 | --checkpoint $checkpoint \ 17 | --gpuid $gpuid \ 18 | --mix_scp $mix_scp \ 19 | --ref_scp $spk1_scp \ 20 | --aux_scp $aux_scp \ 21 | --cal_sdr $cal_sdr \ 22 | > eval.log 2>&1 23 | 24 | echo "eval done!" 25 | -------------------------------------------------------------------------------- /nnet/SEF_PNet_pse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun June 2 2024 3 | @author: Ziling Huang 4 | """ 5 | import torch as th 6 | import torch.nn as nn 7 | import torch.nn.functional as nn_f 8 | from typing import Tuple, List 9 | from memonger import SublinearSequential 10 | from libs.conv_stft import ConvSTFT, ConviSTFT 11 | 12 | def param(nnet, Mb=True): 13 | """ 14 | Return number parameters(not bytes) in nnet 15 | """ 16 | neles = sum([param.nelement() for param in nnet.parameters()]) 17 | return neles / 10**6 if Mb else neles 18 | 19 | class Conv1D(nn.Conv1d): 20 | """ 21 | 1D conv in ConvTasNet 22 | """ 23 | 24 | def __init__(self, *args, **kwargs): 25 | super(Conv1D, self).__init__(*args, **kwargs) 26 | 27 | def forward(self, x, squeeze=False): 28 | """ 29 | x: N x L or N x C x L 30 | """ 31 | if x.dim() not in [2, 3]: 32 | raise RuntimeError("{} accept 2/3D tensor as input".format( 33 | self.__name__)) 34 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 35 | if squeeze: 36 | x = th.squeeze(x) 37 | return x 38 | 39 | class ChannelWiseLayerNorm(nn.LayerNorm): 40 | """ 41 | Channel wise layer normalization 42 | """ 43 | 44 | def __init__(self, *args, **kwargs): 45 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 46 | 47 | def forward(self, x): 48 | """ 49 | x: N x C x T 50 | """ 51 | if x.dim() != 3: 52 | raise RuntimeError("{} accept 3D tensor as input".format( 53 | self.__name__)) 54 | # N x C x T => N x T x C 55 | x = th.transpose(x, 1, 2) 56 | # LN 57 | x = super().forward(x) 58 | # N x C x T => N x T x C 59 | x = th.transpose(x, 1, 2) 60 | return x 61 | 62 | class GlobalChannelLayerNorm(nn.Module): 63 | """ 64 | Global channel layer normalization 65 | """ 66 | 67 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 68 | super(GlobalChannelLayerNorm, self).__init__() 69 | self.eps = eps 70 | self.normalized_dim = dim 71 | self.elementwise_affine = elementwise_affine 72 | if elementwise_affine: 73 | self.beta = nn.Parameter(th.zeros(dim, 1)) 74 | self.gamma = nn.Parameter(th.ones(dim, 1)) 75 | else: 76 | self.register_parameter("weight", None) 77 | self.register_parameter("bias", None) 78 | 79 | def forward(self, x): 80 | """ 81 | x: N x C x T 82 | """ 83 | if x.dim() != 3: 84 | raise RuntimeError("{} accept 3D tensor as input".format( 85 | self.__name__)) 86 | # N x 1 x 1 87 | mean = th.mean(x, (1, 2), keepdim=True) 88 | var = th.mean((x - mean)**2, (1, 2), keepdim=True) 89 | # N x T x C 90 | if self.elementwise_affine: 91 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta 92 | else: 93 | x = (x - mean) / th.sqrt(var + self.eps) 94 | return x 95 | 96 | def extra_repr(self): 97 | return "{normalized_dim}, eps={eps}, " \ 98 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 99 | 100 | def build_norm(norm, dim): 101 | """ 102 | Build normalize layer 103 | LN cost more memory than BN 104 | """ 105 | if norm not in ["cLN", "gLN", "BN"]: 106 | raise RuntimeError("Unsupported normalize layer: {}".format(norm)) 107 | if norm == "cLN": 108 | return ChannelWiseLayerNorm(dim, elementwise_affine=True) 109 | elif norm == "BN": 110 | return nn.BatchNorm1d(dim) 111 | else: 112 | return GlobalChannelLayerNorm(dim, elementwise_affine=True) 113 | 114 | class Conv2dBlock(nn.Module): 115 | def __init__(self, 116 | in_dims: int = 16, 117 | out_dims: int = 32, 118 | kernel_size: Tuple[int] = (3, 3), 119 | stride: Tuple[int] = (1, 1), 120 | padding: Tuple[int] = (1, 1)) -> None: 121 | super(Conv2dBlock, self).__init__() 122 | self.conv2d = nn.Conv2d(in_dims, out_dims, kernel_size, stride, padding) 123 | self.elu = nn.ELU() 124 | self.norm = nn.InstanceNorm2d(out_dims) 125 | 126 | def forward(self, x: th.Tensor) -> th.Tensor: 127 | x = self.conv2d(x) 128 | x = self.elu(x) 129 | return self.norm(x) 130 | 131 | class ConvTrans2dBlock(nn.Module): 132 | def __init__(self, 133 | in_dims: int = 32, 134 | out_dims: int = 16, 135 | kernel_size: Tuple[int] = (3, 3), 136 | stride: Tuple[int] = (1, 2), 137 | padding: Tuple[int] = (1, 0), 138 | output_padding: Tuple[int] = (0, 0)) -> None: 139 | super(ConvTrans2dBlock, self).__init__() 140 | self.convtrans2d = nn.ConvTranspose2d(in_dims, out_dims, kernel_size, stride, padding, output_padding) 141 | self.elu = nn.ELU() 142 | self.norm = nn.InstanceNorm2d(out_dims) 143 | 144 | def forward(self, x: th.Tensor) -> th.Tensor: 145 | x = self.convtrans2d(x) 146 | x = self.elu(x) 147 | return self.norm(x) 148 | 149 | class DenseBlock(nn.Module): 150 | def __init__(self, in_dims, out_dims, mode = "enc", **kargs): 151 | super(DenseBlock, self).__init__() 152 | if mode not in ["enc", "dec"]: 153 | raise RuntimeError("The mode option must be 'enc' or 'dec'!") 154 | 155 | n = 1 if mode == "enc" else 2 156 | self.conv1 = Conv2dBlock(in_dims=in_dims*n, out_dims=in_dims, **kargs) 157 | self.conv2 = Conv2dBlock(in_dims=in_dims*(n+1), out_dims=in_dims, **kargs) 158 | self.conv3 = Conv2dBlock(in_dims=in_dims*(n+2), out_dims=in_dims, **kargs) 159 | self.conv4 = Conv2dBlock(in_dims=in_dims*(n+3), out_dims=in_dims, **kargs) 160 | self.conv5 = Conv2dBlock(in_dims=in_dims*(n+4), out_dims=out_dims, **kargs) 161 | 162 | def forward(self, x: th.Tensor) -> th.Tensor: 163 | y1 = self.conv1(x) 164 | y2 = self.conv2(th.cat([x, y1], 1)) 165 | y3 = self.conv3(th.cat([x, y1, y2], 1)) 166 | y4 = self.conv4(th.cat([x, y1, y2, y3], 1)) 167 | y5 = self.conv5(th.cat([x, y1, y2, y3, y4], 1)) 168 | return y5 169 | 170 | class TCNBlock(nn.Module): 171 | """ 172 | TCN block: 173 | IN - ELU - Conv1D - IN - ELU - Conv1D 174 | """ 175 | 176 | def __init__(self, 177 | in_dims: int = 384, 178 | out_dims: int = 384, 179 | kernel_size: int = 3, 180 | stride: int = 1, 181 | paddings: int = 1, 182 | dilation: int = 1, 183 | causal: bool = False) -> None: 184 | super(TCNBlock, self).__init__() 185 | self.norm1 = nn.InstanceNorm1d(in_dims) 186 | self.elu1 = nn.ELU() 187 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 188 | dilation * (kernel_size - 1)) 189 | # dilated conv 190 | self.dconv1 = nn.Conv1d( 191 | in_dims, 192 | out_dims, 193 | kernel_size, 194 | padding=dconv_pad, 195 | dilation=dilation, 196 | groups=in_dims, 197 | bias=True) 198 | 199 | self.norm2 = nn.InstanceNorm1d(in_dims) 200 | self.elu2 = nn.ELU() 201 | self.dconv2 = nn.Conv1d(in_dims, out_dims, 1, bias=True) 202 | 203 | # different padding way 204 | self.causal = causal 205 | self.dconv_pad = dconv_pad 206 | 207 | def forward(self, x: th.Tensor) -> th.Tensor: 208 | y = self.elu1(self.norm1(x)) 209 | y = self.dconv1(y) 210 | if self.causal: 211 | y = y[:, :, :-self.dconv_pad] 212 | y = self.elu2(self.norm2(y)) 213 | y = self.dconv2(y) 214 | x = x + y 215 | 216 | return x 217 | 218 | 219 | class LCA(nn.Module): 220 | def __init__(self, channels=64, r=4): 221 | super(LCA, self).__init__() 222 | inter_channels = int(channels // r) 223 | 224 | self.local_att = nn.Sequential( 225 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 226 | nn.BatchNorm2d(inter_channels), 227 | nn.ReLU(inplace=True), 228 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 229 | nn.BatchNorm2d(channels), 230 | ) 231 | 232 | self.global_att = nn.Sequential( 233 | nn.AdaptiveAvgPool2d(1), 234 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 235 | nn.BatchNorm2d(inter_channels), 236 | nn.ReLU(inplace=True), 237 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 238 | nn.BatchNorm2d(channels), 239 | ) 240 | 241 | self.sigmoid = nn.Sigmoid() 242 | 243 | def forward(self, x): 244 | xl = self.local_att(x) 245 | xg = self.global_att(x) 246 | xlg = xl + xg 247 | wei = self.sigmoid(xlg) 248 | return x * wei 249 | 250 | class IFI(nn.Module): 251 | 252 | def __init__(self, channels=64, r=4): 253 | super(IFI, self).__init__() 254 | inter_channels = int(channels // r) 255 | 256 | self.local_att = nn.Sequential( 257 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 258 | nn.BatchNorm2d(inter_channels), 259 | nn.ReLU(inplace=True), 260 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 261 | nn.BatchNorm2d(channels), 262 | ) 263 | 264 | self.global_att = nn.Sequential( 265 | nn.AdaptiveAvgPool2d(1), 266 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 267 | nn.BatchNorm2d(inter_channels), 268 | nn.ReLU(inplace=True), 269 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 270 | nn.BatchNorm2d(channels), 271 | ) 272 | 273 | self.local_att2 = nn.Sequential( 274 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 275 | nn.BatchNorm2d(inter_channels), 276 | nn.ReLU(inplace=True), 277 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 278 | nn.BatchNorm2d(channels), 279 | ) 280 | 281 | self.global_att2 = nn.Sequential( 282 | nn.AdaptiveAvgPool2d(1), 283 | nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), 284 | nn.BatchNorm2d(inter_channels), 285 | nn.ReLU(inplace=True), 286 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 287 | nn.BatchNorm2d(channels), 288 | ) 289 | 290 | self.sigmoid = nn.Sigmoid() 291 | 292 | def forward(self, x, residual): 293 | xa = x + residual 294 | xl = self.local_att(xa) 295 | xg = self.global_att(xa) 296 | xlg = xl + xg 297 | wei = self.sigmoid(xlg) 298 | xi = x * wei + residual * (1 - wei) 299 | 300 | xl2 = self.local_att2(xi) 301 | xg2 = self.global_att(xi) 302 | xlg2 = xl2 + xg2 303 | wei2 = self.sigmoid(xlg2) 304 | xo = x * wei2 + residual * (1 - wei2) 305 | return xo 306 | 307 | class SEF_PNet(nn.Module): 308 | def __init__(self, 309 | win_len: int = 256, # 32 ms 310 | win_inc: int = 64, # 8 ms 311 | fft_len: int = 256, 312 | win_type: str = "sqrthann", 313 | kernel_size: Tuple[int] = (3, 3), 314 | stride1: Tuple[int] = (1, 1), 315 | stride2: Tuple[int] = (1, 2), 316 | paddings: Tuple[int] = (1, 0), 317 | output_padding: Tuple[int] = (0, 0), 318 | tcn_dims: int = 384, 319 | tcn_blocks: int = 10, 320 | tcn_layers: int = 2, 321 | causal: bool = False, 322 | pool_size: Tuple[int] = (4, 8, 16, 32), 323 | num_spks: int = 1, 324 | L: int = 20) -> None: 325 | super(SEF_PNet, self).__init__() 326 | 327 | self.L = L 328 | self.fft_len = fft_len 329 | self.num_spks = num_spks 330 | self.stft = ConvSTFT(win_len, win_inc, fft_len, win_type, 'complex') 331 | self.softmax = nn.Softmax(dim=-2) 332 | self.ifi = IFI(channels=2, r=1/32) 333 | self.upconv1 = nn.Conv2d(4, 64, 1, 1, 0) 334 | self.lca = LCA(64) 335 | self.conv2d = nn.Conv2d(64, 16, (1, 3), 1, 0) 336 | self.relu = nn.ReLU() 337 | self.encoder = self._build_encoder( 338 | kernel_size=kernel_size, 339 | stride=stride2, 340 | padding=paddings 341 | ) 342 | self.tcn_layers = self._build_tcn_layers( 343 | tcn_layers, 344 | tcn_blocks, 345 | in_dims=tcn_dims, 346 | out_dims=tcn_dims, 347 | causal=causal 348 | ) 349 | self.decoder = self._build_decoder( 350 | kernel_size=kernel_size, 351 | stride=stride2, 352 | padding=paddings, 353 | output_padding=output_padding 354 | ) 355 | self.avg_pool = self._build_avg_pool(pool_size) 356 | self.avg_proj = nn.Conv2d(64, 32, 1, 1) 357 | self.deconv2d = nn.ConvTranspose2d(32, 2*num_spks, kernel_size, stride1, paddings) 358 | self.istft = ConviSTFT(win_len, win_inc, fft_len, win_type, 'complex') 359 | 360 | def _build_encoder(self, **enc_kargs): 361 | """ 362 | Build encoder layers 363 | """ 364 | encoder = nn.ModuleList() 365 | encoder.append(SublinearSequential(DenseBlock(16, 16, "enc"),LCA(16))) 366 | 367 | for i in range(3): 368 | encoder.append( 369 | SublinearSequential( 370 | Conv2dBlock(in_dims=16 if i==0 else 32, 371 | out_dims=32, **enc_kargs), 372 | DenseBlock(32, 32, "enc"), 373 | LCA(32) 374 | ) 375 | ) 376 | encoder.append( 377 | SublinearSequential( 378 | Conv2dBlock(in_dims=32, out_dims=64, **enc_kargs), 379 | LCA(64) 380 | ) 381 | ) 382 | encoder.append( 383 | SublinearSequential( 384 | Conv2dBlock(in_dims=64, out_dims=128, **enc_kargs), 385 | LCA(128) 386 | ) 387 | ) 388 | encoder.append( 389 | SublinearSequential( 390 | Conv2dBlock(in_dims=128, out_dims=384, **enc_kargs), 391 | LCA(384) 392 | ) 393 | ) 394 | 395 | return encoder 396 | 397 | def _build_decoder(self, **dec_kargs): 398 | """ 399 | Build decoder layers 400 | """ 401 | decoder = nn.ModuleList() 402 | decoder.append(ConvTrans2dBlock(in_dims=384*2, out_dims=128, **dec_kargs)) 403 | decoder.append(ConvTrans2dBlock(in_dims=128*2, out_dims=64, **dec_kargs)) 404 | decoder.append(ConvTrans2dBlock(in_dims=64*2, out_dims=32, **dec_kargs)) 405 | for i in range(3): 406 | decoder.append( 407 | SublinearSequential( 408 | DenseBlock(32, 64, "dec"), 409 | ConvTrans2dBlock(in_dims=64, 410 | out_dims=32 if i!=2 else 16, 411 | **dec_kargs) 412 | ) 413 | ) 414 | decoder.append(DenseBlock(16, 32, "dec")) 415 | 416 | return decoder 417 | 418 | def _build_tcn_blocks(self, tcn_blocks, **tcn_kargs): 419 | """ 420 | Build TCN blocks in each repeat (layer) 421 | """ 422 | blocks = [ 423 | TCNBlock(**tcn_kargs, dilation=(2**b)) 424 | for b in range(tcn_blocks) 425 | ] 426 | 427 | return SublinearSequential(*blocks) 428 | 429 | def _build_tcn_layers(self, tcn_layers, tcn_blocks, **tcn_kargs): 430 | """ 431 | Build TCN layers 432 | """ 433 | layers = [ 434 | self._build_tcn_blocks(tcn_blocks, **tcn_kargs) 435 | for _ in range(tcn_layers) 436 | ] 437 | 438 | return SublinearSequential(*layers) 439 | 440 | def _build_avg_pool(self, pool_size): 441 | """ 442 | Build avg pooling layers 443 | """ 444 | avg_pool = nn.ModuleList() 445 | for sz in pool_size: 446 | avg_pool.append( 447 | SublinearSequential( 448 | nn.AvgPool2d(sz), 449 | nn.Conv2d(32, 8, 1, 1) 450 | ) 451 | ) 452 | 453 | return avg_pool 454 | 455 | def wav2spec(self, x: th.Tensor, mags: bool = False) -> th.Tensor: 456 | """ 457 | convert waveform to spectrogram 458 | """ 459 | # print(x.shape) 460 | assert x.dim() == 2 461 | # x = x / th.std(x, -1, keepdims=True) # variance normalization 462 | specs = self.stft(x) 463 | real = specs[:,:self.fft_len//2+1] 464 | imag = specs[:,self.fft_len//2+1:] 465 | spec = th.stack([real,imag], 1) #[B,2,F,T] 466 | # spec = th.einsum("hijk->hikj", spec) # batchsize, 2, T, F 467 | if mags: 468 | return th.sqrt(real**2+imag**2+1e-8) 469 | else: 470 | return spec 471 | 472 | def FeaCompression(self, input, factor=0.5): 473 | input_change = input.float() 474 | complex_spectrum = th.complex(input_change[:, 0, :, :], input_change[:, 1, :, :]) 475 | magnitude = th.abs(complex_spectrum).unsqueeze(1) ** factor 476 | phase = th.angle(complex_spectrum).unsqueeze(1) 477 | 478 | real = magnitude * th.cos(phase) 479 | imag = magnitude * th.sin(phase) 480 | output = th.cat((real, imag), dim=1) 481 | 482 | return output 483 | 484 | def FeaDecompression(self, input, factor=0.5): 485 | input_change = input.float() 486 | complex_spectrum = th.complex(input_change[:, 0, :, :], input_change[:, 1, :, :]) 487 | magnitude = th.abs(complex_spectrum).unsqueeze(1) ** (1 / factor) 488 | phase = th.angle(complex_spectrum).unsqueeze(1) 489 | 490 | real = magnitude * th.cos(phase) 491 | imag = magnitude * th.sin(phase) 492 | output = th.cat((real, imag), dim=1) 493 | 494 | return output 495 | 496 | def ComputeSimilarity(self, input, enrollment): 497 | att = enrollment.transpose(-2, -1) @ input 498 | att = self.softmax(att) 499 | output = enrollment @ att 500 | 501 | return output.unsqueeze(0).unsqueeze(0) 502 | 503 | def sep(self, spec: th.Tensor) -> List[th.Tensor]: 504 | """ 505 | spec: (batchsize, 2, T, F) 506 | return [real, imag] or waveform 507 | """ 508 | # spec = th.einsum("hijk->hikj", spec) # (batchsize, 2, F, T) 509 | B, N, F, T = spec.shape 510 | est = th.chunk(spec, 2, 1) # [(B, 1, F, T), (B, 1, F, T)] 511 | est = th.cat(est, 2).reshape(B, -1, T) # B, 2F, T 512 | return th.squeeze(self.istft(est)) 513 | 514 | def forward(self, 515 | mix: th.Tensor, 516 | enrollment: th.Tensor) -> th.Tensor: 517 | """ 518 | if waveform = True, return both waveform and real & imag parts; 519 | else, only return real & imag parts 520 | """ 521 | batch_size = mix.shape[0] 522 | if mix.dim() == 1: 523 | mix = th.unsqueeze(mix, 0) 524 | aux = th.unsqueeze(aux, 0) 525 | mix_spec = self.wav2spec(mix, False) 526 | mix_spec_change = self.FeaCompression(mix_spec) #[B,2,F,T] 527 | similarity = [] 528 | aux_drc = [] 529 | for i in range(batch_size): 530 | aux = self.wav2spec(enrollment[i].unsqueeze(0), False) 531 | aux_spec_change = self.FeaCompression(aux) 532 | aux_drc.append(aux_spec_change) 533 | similarity.append(th.cat([self.ComputeSimilarity(mix_spec_change[i, 0, ...], aux_spec_change[0, 0, ...]), self.ComputeSimilarity(mix_spec_change[i, 1, ...], aux_spec_change[0, 1, ...])], dim=1)) 534 | similarity = th.cat(similarity, dim=0) 535 | aux_drc = th.cat(aux_drc, dim=0) 536 | aux_drc = th.mean(aux_drc, dim=-1).unsqueeze(-1).repeat(1, 1,1, similarity.shape[-1]) 537 | similarity = self.ifi(similarity, aux_drc) 538 | fus = th.cat((mix_spec_change, similarity), dim=1) #[1,4,129,251] 539 | fus = self.upconv1(fus) 540 | fus = self.lca(fus) 541 | # speech separation 542 | fus = fus.permute(0, 1, 3, 2) 543 | out = self.relu(self.conv2d(fus)) 544 | out_list = [] 545 | out = self.encoder[0](out) 546 | out_list.append(out) 547 | for idx, enc in enumerate(self.encoder[1:]): 548 | out = enc(out) 549 | out_list.append(out) 550 | 551 | B, N, T, F = out.shape 552 | out = out.reshape(B, N, T*F) 553 | out = self.tcn_layers(out) 554 | out = th.unsqueeze(out, -1) 555 | 556 | out_list = out_list[::-1] 557 | for idx, dec in enumerate(self.decoder): 558 | decinput = th.cat([out_list[idx], out], 1) 559 | out = dec(decinput) 560 | 561 | # Pyramidal pooling 562 | B, N, T, F = out.shape 563 | upsample = nn.Upsample(size=(T, F), mode='bilinear') 564 | pool_list = [] 565 | for avg in self.avg_pool: 566 | pool_list.append(upsample(avg(out))) 567 | out = th.cat([out, *pool_list], 1) 568 | out = self.avg_proj(out) 569 | out = self.deconv2d(out) 570 | out = out.permute(0, 1, 3, 2) 571 | out = self.FeaDecompression(out) 572 | out = self.sep(out) 573 | return out 574 | 575 | 576 | def test_covn2d_block(): 577 | x = th.randn(2, 16, 257, 200) 578 | conv = Conv2dBlock() 579 | y = conv(x) 580 | convtrans = ConvTrans2dBlock() 581 | z = convtrans(y) 582 | 583 | def test_dense_block(): 584 | x = th.randn(2, 16, 257, 200) 585 | dense = DenseBlock(16, 32, "enc") 586 | y = dense(x) 587 | 588 | def test_tcn_block(): 589 | x = th.randn(2, 384, 1000) 590 | tcn1 = TCNBlock(dilation=128) 591 | 592 | if __name__ == "__main__": 593 | from thop import profile, clever_format 594 | nnet = SEF_PNet() 595 | mix = th.randn(2, 8000) 596 | aux = th.randn(2, 8000) 597 | est = nnet(mix, aux) 598 | macs, params = profile(nnet, inputs=(mix,aux)) 599 | macs, params = clever_format([macs, params], "%.3f") 600 | print(macs, params) 601 | -------------------------------------------------------------------------------- /nnet/__pycache__/SEF_PNet_pse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/__pycache__/SEF_PNet_pse.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conf_unet_tse_32ms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/__pycache__/conf_unet_tse_32ms.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/conf_unet_tse_32ms.py: -------------------------------------------------------------------------------- 1 | fs = 8000 2 | chunk_len = 4 # (s) 3 | chunk_size = chunk_len * fs 4 | 5 | nnet_conf = { 6 | "win_len": 256, 7 | "win_inc": 64, 8 | "fft_len": 256, 9 | "win_type": "sqrthann", 10 | "kernel_size": (3, 3), 11 | "stride1": (1, 1), 12 | "stride2": (1, 2), 13 | "paddings": (1, 0), 14 | "output_padding": (0, 0), 15 | "tcn_dims": 384, 16 | "tcn_blocks": 10, 17 | "tcn_layers": 2, 18 | "causal": False, 19 | "num_spks": 1 20 | } 21 | 22 | 23 | # data configure: 24 | train_dir = "/node/hzl/expriment/SEF_PNet_icassp2025_github/data/train/" 25 | dev_dir = "/node/hzl/expriment/SEF_PNet_icassp2025_github/data/dev/" 26 | 27 | train_data = { 28 | "mix_scp": train_dir + "mix_clean.scp", 29 | "ref_scp": train_dir + "ref.scp", 30 | "aux_scp": train_dir + "auxs1.scp", 31 | "sample_rate": fs, 32 | } 33 | 34 | dev_data = { 35 | "mix_scp": dev_dir + "mix_clean.scp", 36 | "ref_scp": dev_dir + "ref.scp", 37 | "aux_scp": dev_dir + "auxs1.scp", 38 | "sample_rate": fs, 39 | } 40 | 41 | # trainer config 42 | adam_kwargs = { 43 | "lr": 0.5e-3, 44 | "weight_decay": 1e-5, 45 | } 46 | 47 | trainer_conf = { 48 | "optimizer": "adam", 49 | "optimizer_kwargs": adam_kwargs, 50 | "min_lr": 1e-8, 51 | "patience": 2, 52 | "factor": 0.5, 53 | "logging_period": 200 54 | } 55 | -------------------------------------------------------------------------------- /nnet/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import time 5 | import argparse 6 | import torch as th 7 | import numpy as np 8 | from mir_eval.separation import bss_eval_sources 9 | from pesq import pesq as pesq2 10 | from pypesq import pesq as pesq1 11 | from pystoi.stoi import stoi 12 | from SEF_PNet_pse import SEF_PNet 13 | from libs.utils import load_json, get_logger 14 | from libs.dataset_tse import Dataset 15 | 16 | def evaluate(args, model_file, logger): 17 | start = time.time() 18 | total_SISNR = 0 19 | total_SISNRi = 0 20 | total_PESQ = 0 21 | total_PESQi = 0 22 | total_PESQ2 = 0 23 | total_PESQi2 = 0 24 | total_STOI = 0 25 | total_STOIi = 0 26 | total_SDR = 0 27 | total_cnt = 0 28 | 29 | # Load model 30 | nnet_conf = load_json(args.checkpoint, "mdl.json") 31 | nnet = SEF_PNet(**nnet_conf) 32 | cpt_fname = os.path.join(args.checkpoint, model_file) 33 | cpt = th.load(cpt_fname, map_location="cpu") 34 | nnet.load_state_dict(cpt["model_state_dict"]) 35 | logger.info("Loaded checkpoint from {}, epoch {:d}".format( 36 | cpt_fname, cpt["epoch"])) 37 | 38 | device = th.device( 39 | "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu") 40 | nnet = nnet.to(device) if args.gpuid >= 0 else nnet 41 | nnet.eval() 42 | 43 | # Load data 44 | dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp, sample_rate=8000) 45 | 46 | with th.no_grad(): 47 | for i, data in enumerate(dataset): 48 | mix = th.tensor(data['mix'], dtype=th.float32, device=device) 49 | aux = th.tensor(data['aux'], dtype=th.float32, device=device) 50 | 51 | if args.gpuid >= 0: 52 | mix = mix.unsqueeze(0).to(device) 53 | aux = aux.unsqueeze(0).to(device) 54 | 55 | # Forward 56 | ref = data['ref'] 57 | key = data['key'] 58 | ests = nnet(mix, aux) 59 | ests = ests.cpu().numpy() 60 | mix = mix.squeeze(0).cpu().numpy() 61 | if ests.size != ref.size: 62 | end = min(ests.size, ref.size) 63 | ests = ests[:end] 64 | ref = ref[:end] 65 | mix = mix[:end] 66 | 67 | # Compute metrics 68 | if args.cal_sdr == 1: 69 | SDR, sir, sar, popt = bss_eval_sources(ref, ests) 70 | total_SDR += SDR[0] 71 | SISNR, delta = cal_SISNRi(ests, ref, mix) 72 | PESQ, PESQi, PESQ2, PESQi2 = cal_PESQi(ests, ref, mix) 73 | STOI, STOIi = cal_STOIi(ests, ref, mix) 74 | if args.cal_sdr == 1: 75 | logger.info("Utt={:d} | SDR={:.2f} | SI-SNR={:.2f} | SI-SNRi={:.2f} | PESQ={:.2f} | PESQi={:.2f}| PESQ2={:.2f} | PESQi2={:.2f} | | STOI={:.2f} | STOIi={:.2f}".format( 76 | total_cnt+1, SDR[0], SISNR, delta, PESQ, PESQi, PESQ2, PESQi2, STOI, STOIi)) 77 | else: 78 | logger.info("Utt={:d} | SI-SNR={:.2f} | SI-SNRi={:.2f} | PESQ={:.2f} | PESQi={:.2f} | PESQ2={:.2f} | PESQi2={:.2f} | STOI={:.2f} | STOIi={:.2f}".format( 79 | total_cnt+1, SISNR, delta, PESQ, PESQi, PESQ2, PESQi2, STOI, STOIi)) 80 | total_SISNR += SISNR 81 | total_SISNRi += delta 82 | total_PESQ += PESQ 83 | total_PESQi += PESQi 84 | total_PESQ2 += PESQ2 85 | total_PESQi2 += PESQi2 86 | total_STOI += STOI 87 | total_STOIi += STOIi 88 | total_cnt += 1 89 | end = time.time() 90 | 91 | logger.info('Time Elapsed: {:.1f}s'.format(end-start)) 92 | if args.cal_sdr == 1: 93 | logger.info("Average SDR: {0:.2f}".format(total_SDR / total_cnt)) 94 | logger.info("Average SI-SNR: {:.2f}".format(total_SISNR / total_cnt)) 95 | logger.info("Average SI-SNRi: {:.2f}".format(total_SISNRi / total_cnt)) 96 | logger.info("Average PESQ: {:.2f}".format(total_PESQ / total_cnt)) 97 | logger.info("Average PESQi: {:.2f}".format(total_PESQi / total_cnt)) 98 | logger.info("Average PESQ2: {:.2f}".format(total_PESQ2 / total_cnt)) 99 | logger.info("Average PESQi2: {:.2f}".format(total_PESQi2 / total_cnt)) 100 | logger.info("Average STOI: {:.2f}".format(total_STOI / total_cnt)) 101 | logger.info("Average STOIi: {:.2f}".format(total_STOIi / total_cnt)) 102 | 103 | def cal_SISNR(est, ref, eps=1e-8): 104 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 105 | Args: 106 | est: separated signal, numpy.ndarray, [T] 107 | ref: reference signal, numpy.ndarray, [T] 108 | Returns: 109 | SISNR 110 | """ 111 | assert len(est) == len(ref) 112 | est_zm = est - np.mean(est) 113 | ref_zm = ref - np.mean(ref) 114 | 115 | t = np.sum(est_zm * ref_zm) * ref_zm / (np.linalg.norm(ref_zm)**2 + eps) 116 | 117 | return 20 * np.log10(eps + np.linalg.norm(t) / (np.linalg.norm(est_zm - t) + eps)) 118 | 119 | def cal_SISNRi(est, ref, mix, eps=1e-8): 120 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 121 | Args: 122 | est: separated signal, numpy.ndarray, [T] 123 | ref: reference signal, numpy.ndarray, [T] 124 | Returns: 125 | SISNR 126 | """ 127 | assert len(est) == len(ref) == len(mix) 128 | sisnr1 = cal_SISNR(est, ref) 129 | sisnr2 = cal_SISNR(mix, ref) 130 | 131 | return sisnr1, sisnr1 - sisnr2 132 | 133 | def cal_PESQ(est, ref): 134 | assert len(est) == len(ref) 135 | mode ='nb' 136 | p = pesq1(ref, est,8000) 137 | p2 = pesq2(8000, ref, est, mode) 138 | return p,p2 139 | 140 | def cal_PESQi(est, ref, mix): 141 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 142 | Args: 143 | est: separated signal, numpy.ndarray, [T] 144 | ref: reference signal, numpy.ndarray, [T] 145 | Returns: 146 | SISNR 147 | """ 148 | assert len(est) == len(ref) == len(mix) 149 | pesq1,pesq12 = cal_PESQ(est, ref) 150 | pesq2,pesq22= cal_PESQ(mix, ref) 151 | 152 | return pesq1, pesq1 - pesq2,pesq12,pesq12-pesq22 153 | 154 | def cal_STOI(est, ref): 155 | assert len(est) == len(ref) 156 | p = stoi(ref, est, 8000) 157 | return p 158 | 159 | def cal_STOIi(est, ref, mix): 160 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 161 | Args: 162 | est: separated signal, numpy.ndarray, [T] 163 | ref: reference signal, numpy.ndarray, [T] 164 | Returns: 165 | SISNR 166 | """ 167 | assert len(est) == len(ref) == len(mix) 168 | stoi1 = cal_STOI(est, ref)*100 169 | stoi2 = cal_STOI(mix, ref)*100 170 | 171 | return stoi1, stoi1 - stoi2 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser('Evaluate separation performance using Conv-TasNet') 175 | parser.add_argument('--checkpoint', type=str, 176 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/demo', 177 | help='Path to model directory containing checkpoints') 178 | parser.add_argument('--gpuid', type=int, default=0, 179 | help="GPU device to offload model to, -1 means running on CPU") 180 | parser.add_argument('--mix_scp', type=str, 181 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/mix_clean.scp', 182 | help='mix scp') 183 | parser.add_argument('--ref_scp', type=str, 184 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/ref.scp', 185 | help='ref scp') 186 | parser.add_argument('--aux_scp', type=str, 187 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/data/test/auxs1.scp', 188 | help='aux scp') 189 | parser.add_argument('--cal_sdr', type=int, default=None, 190 | help='Whether calculate SDR, add this option because calculation of SDR is very slow') 191 | 192 | args = parser.parse_args() 193 | 194 | 195 | # eval best.pt.tar 196 | best_model_file = "best.pt.tar" 197 | best_log_file = os.path.join(args.checkpoint, "eval_best.log") 198 | best_logger = get_logger(best_log_file, file=True) 199 | best_logger.info(f"Evaluating model: {best_model_file}") 200 | evaluate(args, best_model_file, best_logger) 201 | 202 | # eval 110-122 epoch.pt.tar 203 | for epoch in range(110, 122): 204 | model_file = f"{epoch}.pt.tar" 205 | log_file = os.path.join(args.checkpoint, f"eval_{epoch}.log") 206 | logger = get_logger(log_file, file=True) 207 | logger.info(f"Evaluating model: {model_file}") 208 | evaluate(args, model_file, logger) 209 | -------------------------------------------------------------------------------- /nnet/libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__init__.py -------------------------------------------------------------------------------- /nnet/libs/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__init__.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/audio.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/audio.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/conv_stft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/conv_stft.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/dataset_tse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/dataset_tse.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/trainer_unet_tse_steplr_clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/trainer_unet_tse_steplr_clip.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/libs/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/libs/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import soundfile as sf 5 | import librosa 6 | import kaldiio 7 | MAX_INT16 = np.iinfo(np.int16).max 8 | 9 | 10 | def write_wav(fname, samps, fs=8000, normalize=True): 11 | """ 12 | Write wav files in int16, support single/multi-channel 13 | """ 14 | #if normalize: 15 | # samps = samps * MAX_INT16 16 | ## scipy.io.wavfile.write could write single/multi-channel files 17 | ## for multi-channel, accept ndarray [Nsamples, Nchannels] 18 | #if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 19 | # samps = np.transpose(samps) 20 | # samps = np.squeeze(samps) 21 | ## same as MATLAB and kaldi 22 | #samps_int16 = samps.astype(np.int16) 23 | #fdir = os.path.dirname(fname) 24 | #if fdir and not os.path.exists(fdir): 25 | # os.makedirs(fdir) 26 | ## NOTE: librosa 0.6.0 seems could not write non-float narray 27 | ## so use scipy.io.wavfile instead 28 | #wf.write(fname, fs, samps_int16) 29 | 30 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16 31 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float 32 | fdir = os.path.dirname(fname) 33 | if fdir and not os.path.exists(fdir): 34 | os.makedirs(fdir) 35 | sf.write(fname, samps, fs, subtype='FLOAT',format='WAV') 36 | 37 | 38 | def read_wav(fname, normalize=True, return_rate=False): 39 | """ 40 | Read wave files using scipy.io.wavfile(support multi-channel) 41 | """ 42 | # samps_int16: N x C or N 43 | # N: number of samples 44 | # C: number of channels 45 | #samp_rate, samps_int16 = wf.read(fname) 46 | ## N x C => C x N 47 | #samps = samps_int16.astype(np.float) 48 | ## tranpose because I used to put channel axis first 49 | #if samps.ndim != 1: 50 | # samps = np.transpose(samps) 51 | ## normalize like MATLAB and librosa 52 | #if normalize: 53 | # samps = samps / MAX_INT16 54 | #if return_rate: 55 | # return samp_rate, samps 56 | #return samps 57 | 58 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16 59 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float 60 | samps, samp_rate = sf.read(fname) 61 | if return_rate: 62 | return samp_rate, samps 63 | return samps 64 | 65 | 66 | def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2): 67 | """ 68 | Parse kaldi's script(.scp) file 69 | If num_tokens >= 2, function will check token number 70 | """ 71 | scp_dict = dict() 72 | line = 0 73 | with open(scp_path, "r") as f: 74 | for raw_line in f: 75 | scp_tokens = raw_line.strip().split() 76 | line += 1 77 | if num_tokens >= 2 and len(scp_tokens) != num_tokens or len( 78 | scp_tokens) < 2: 79 | raise RuntimeError( 80 | "For {}, format error in line[{:d}]: {}".format( 81 | scp_path, line, raw_line)) 82 | if num_tokens == 2: 83 | key, value = scp_tokens 84 | else: 85 | key, value = scp_tokens[0], scp_tokens[1:] 86 | if key in scp_dict: 87 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 88 | key, scp_path)) 89 | scp_dict[key] = value_processor(value) 90 | return scp_dict 91 | 92 | 93 | class Reader(object): 94 | """ 95 | Basic Reader Class 96 | """ 97 | def __init__(self, scp_path, value_processor=lambda x: x): 98 | self.index_dict = parse_scripts( 99 | scp_path, value_processor=value_processor, num_tokens=2) 100 | self.index_keys = list(self.index_dict.keys()) 101 | 102 | def _load(self, key): 103 | # return path 104 | return self.index_dict[key] 105 | 106 | # number of utterance 107 | def __len__(self): 108 | return len(self.index_dict) 109 | 110 | def __contains__(self, key): 111 | return key in self.index_dict 112 | 113 | # sequential index 114 | def __iter__(self): 115 | for key in self.index_keys: 116 | yield key, self._load(key) 117 | 118 | # random index, support str/int as index 119 | def __getitem__(self, index): 120 | if type(index) not in [int, str]: 121 | raise IndexError("Unsupported index type: {}".format(type(index))) 122 | if type(index) == int: 123 | # from int index to key 124 | num_utts = len(self.index_keys) 125 | if index >= num_utts or index < 0: 126 | raise KeyError( 127 | "Interger index out of range, {:d} vs {:d}".format( 128 | index, num_utts)) 129 | index = self.index_keys[index] 130 | if index not in self.index_dict: 131 | raise KeyError("Missing utterance {}!".format(index)) 132 | return self._load(index) 133 | 134 | 135 | class WaveReader(Reader): 136 | """ 137 | Sequential/Random Reader for single channel wave 138 | Format of wav.scp follows Kaldi's definition: 139 | key1 /path/to/wav 140 | ... 141 | """ 142 | def __init__(self, wav_scp, sample_rate=None, normalize=True): 143 | super(WaveReader, self).__init__(wav_scp) 144 | self.samp_rate = sample_rate 145 | self.normalize = normalize 146 | 147 | def _load(self, key): 148 | # return C x N or N 149 | samp_rate, samps = read_wav( 150 | self.index_dict[key], normalize=self.normalize, return_rate=True) 151 | # if given samp_rate, check it 152 | if self.samp_rate is not None and samp_rate != self.samp_rate: 153 | samps = librosa.resample(samps, orig_sr=samp_rate, target_sr=self.samp_rate) 154 | # raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format( 155 | # samp_rate, self.samp_rate)) 156 | return samps 157 | -------------------------------------------------------------------------------- /nnet/libs/conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | """ 10 | Return window coefficient 11 | """ 12 | def sqrthann(win_len): 13 | return get_window("hann", win_len, fftbins=True)**0.5 14 | 15 | if win_type == 'None' or win_type is None: 16 | window = np.ones(win_len) 17 | elif win_type == "sqrthann": 18 | window = sqrthann(win_len) 19 | else: 20 | window = get_window(win_type, win_len, fftbins=True)#**0.5 21 | 22 | N = fft_len 23 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 24 | real_kernel = np.real(fourier_basis) 25 | imag_kernel = np.imag(fourier_basis) 26 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 27 | 28 | if invers : 29 | kernel = np.linalg.pinv(kernel).T 30 | 31 | kernel = kernel*window 32 | kernel = kernel[:, None, :] 33 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 34 | 35 | 36 | class ConvSTFT(nn.Module): 37 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'): 38 | super(ConvSTFT, self).__init__() 39 | 40 | if fft_len == None: 41 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 42 | else: 43 | self.fft_len = fft_len 44 | 45 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 46 | self.register_buffer('weight', kernel) 47 | self.feature_type = feature_type 48 | self.stride = win_inc 49 | self.win_len = win_len 50 | self.dim = self.fft_len 51 | 52 | def forward(self, inputs): 53 | if inputs.dim() == 2: 54 | inputs = torch.unsqueeze(inputs, 1) 55 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 56 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 57 | 58 | if self.feature_type == 'complex': 59 | return outputs 60 | else: 61 | dim = self.dim//2+1 62 | real = outputs[:, :dim, :] 63 | imag = outputs[:, dim:, :] 64 | mags = torch.sqrt(real**2+imag**2) 65 | phase = torch.atan2(imag, real) 66 | return mags, phase 67 | 68 | class ConviSTFT(nn.Module): 69 | 70 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real'): 71 | super(ConviSTFT, self).__init__() 72 | if fft_len == None: 73 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 74 | else: 75 | self.fft_len = fft_len 76 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 77 | self.register_buffer('weight', kernel) 78 | self.feature_type = feature_type 79 | self.win_type = win_type 80 | self.win_len = win_len 81 | self.stride = win_inc 82 | self.stride = win_inc 83 | self.dim = self.fft_len 84 | self.register_buffer('window', window) 85 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 86 | 87 | def forward(self, inputs, phase=None): 88 | """ 89 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 90 | phase: [B, N//2+1, T] (if not none) 91 | """ 92 | 93 | if phase is not None: 94 | real = inputs*torch.cos(phase) 95 | imag = inputs*torch.sin(phase) 96 | inputs = torch.cat([real, imag], 1) 97 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 98 | 99 | 100 | 101 | # this is from torch-stft: https://github.com/pseeth/torch-stft 102 | t = self.window.repeat(1,1,inputs.size(-1))**2 103 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 104 | outputs = outputs/(coff+1e-8) 105 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 106 | outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)] 107 | 108 | return outputs -------------------------------------------------------------------------------- /nnet/libs/dataset_tse.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch as th 3 | import numpy as np 4 | 5 | from torch.utils.data.dataloader import default_collate 6 | import torch.utils.data as dat 7 | from torch.nn.utils.rnn import pad_sequence 8 | from libs.audio import WaveReader 9 | from conf_unet_tse_32ms import train_data, dev_data 10 | 11 | 12 | 13 | def make_dataloader(train=True, 14 | data_kwargs=None, 15 | num_workers=4, 16 | chunk_size=80000, 17 | batch_size=16): 18 | dataset = Dataset(**data_kwargs) 19 | return DataLoader(dataset, 20 | train=train, 21 | chunk_size=chunk_size, 22 | batch_size=batch_size, 23 | num_workers=num_workers) 24 | 25 | def get_spk_ivec(key): 26 | ''' 27 | 409o030h_1.7445_029o0304_-1.7445_409c0211 28 | ''' 29 | spk = key.split('_')[-1][0:3] 30 | print(spk) 31 | 32 | class Dataset(object): 33 | """ 34 | Per Utterance Loader 35 | """ 36 | def __init__(self, mix_scp="", ref_scp=None, aux_scp=None, sample_rate=8000): 37 | self.mix = WaveReader(mix_scp, sample_rate=sample_rate) 38 | self.ref = WaveReader(ref_scp, sample_rate=sample_rate) 39 | self.aux = WaveReader(aux_scp, sample_rate=sample_rate) 40 | self.sample_rate = sample_rate 41 | 42 | def __len__(self): 43 | return len(self.mix) 44 | 45 | def __getitem__(self, index): 46 | key = self.mix.index_keys[index] 47 | mix = self.mix[key] 48 | ref = self.ref[key] 49 | aux = self.aux[key] 50 | 51 | return { 52 | "mix": mix.astype(np.float32), 53 | "ref": ref.astype(np.float32), 54 | "aux": aux.astype(np.float32), 55 | "aux_len": len(aux), 56 | "key": key 57 | } 58 | 59 | 60 | class ChunkSplitter(object): 61 | """ 62 | Split utterance into small chunks 63 | """ 64 | def __init__(self, chunk_size, train=True, least=2000): 65 | self.chunk_size = chunk_size 66 | self.least = least 67 | self.train = train 68 | 69 | def _make_chunk(self, eg, s): 70 | """ 71 | Make a chunk instance, which contains: 72 | "mix": ndarray, 73 | "ref": [ndarray...] 74 | """ 75 | chunk = dict() 76 | chunk["mix"] = eg["mix"][s:s + self.chunk_size] 77 | chunk["ref"] = eg["ref"][s:s + self.chunk_size] 78 | chunk["aux"] = eg["aux"] 79 | chunk["aux_len"] = chunk["aux"].shape[0] 80 | chunk["valid_len"] = int(self.chunk_size) 81 | return chunk 82 | 83 | def split(self, eg): 84 | N = eg["mix"].size 85 | # too short, throw away 86 | if N < self.least: 87 | return [] 88 | chunks = [] 89 | # padding zeros 90 | if N < self.chunk_size: 91 | P = self.chunk_size - N 92 | chunk = dict() 93 | chunk["mix"] = np.pad(eg["mix"], (0, P), "constant") 94 | chunk["ref"] = np.pad(eg["ref"], (0, P), "constant") 95 | chunk["aux"] = eg["aux"] 96 | chunk["aux_len"] = eg["aux_len"] 97 | chunk["valid_len"] = int(N) 98 | chunks.append(chunk) 99 | else: 100 | # random select start point for training 101 | s = random.randint(0, N % self.least) if self.train else 0 102 | while True: 103 | if s + self.chunk_size > N: 104 | break 105 | chunk = self._make_chunk(eg, s) 106 | chunks.append(chunk) 107 | s += self.least 108 | return chunks 109 | 110 | 111 | class DataLoader(object): 112 | """ 113 | Online dataloader for chunk-level PIT 114 | """ 115 | def __init__(self, 116 | dataset, 117 | num_workers=4, 118 | chunk_size=80000, 119 | batch_size=4, 120 | train=True): 121 | self.batch_size = batch_size 122 | self.train = train 123 | self.splitter = ChunkSplitter(chunk_size, 124 | train=train, 125 | least=chunk_size // 2) 126 | # just return batch of egs, support multiple workers 127 | self.eg_loader = dat.DataLoader(dataset, 128 | batch_size=batch_size // 2, 129 | num_workers=num_workers, 130 | shuffle=train, 131 | collate_fn=self._collate) 132 | 133 | def _collate(self, batch): 134 | """ 135 | Online split utterances 136 | """ 137 | chunk = [] 138 | for eg in batch: 139 | chunk += self.splitter.split(eg) 140 | return chunk 141 | 142 | def _pad_aux(self, chunk_list): 143 | lens_list = [] 144 | for chunk_item in chunk_list: 145 | lens_list.append(chunk_item['aux_len']) 146 | max_len = np.max(lens_list) 147 | 148 | for idx in range(len(chunk_list)): 149 | P = max_len - len(chunk_list[idx]["aux"]) 150 | chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant") 151 | 152 | return chunk_list 153 | 154 | def _merge(self, chunk_list): 155 | """ 156 | Merge chunk list into mini-batch 157 | """ 158 | N = len(chunk_list) 159 | if self.train: 160 | random.shuffle(chunk_list) 161 | blist = [] 162 | for s in range(0, N - self.batch_size + 1, self.batch_size): 163 | batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size])) 164 | blist.append(batch) 165 | rn = N % self.batch_size 166 | return blist, chunk_list[-rn:] if rn else [] 167 | 168 | def __iter__(self): 169 | chunk_list = [] 170 | for chunks in self.eg_loader: 171 | chunk_list += chunks 172 | batch, chunk_list = self._merge(chunk_list) 173 | for obj in batch: 174 | yield obj 175 | 176 | if __name__=='__main__': 177 | chunk_size=80000 178 | train=True 179 | least=chunk_size // 2 180 | splitter = ChunkSplitter(chunk_size, train, least) 181 | data = Dataset(**train_data) 182 | egs = data[0] 183 | chunk = splitter.split(egs) 184 | dataload = DataLoader(data) 185 | temp = [] 186 | for i, obj in enumerate(dataload): 187 | # print('mix...', obj) 188 | #print(i,obj) 189 | temp.append(obj) 190 | # mix,anw = obj[] 191 | # logits = net(mix,anw) 192 | # loss = net.loss(logits,targets) 193 | # loss.backward() 194 | 195 | 196 | # mix = obj[] 197 | 198 | 199 | #if i == 2: 200 | # break 201 | print(len) 202 | -------------------------------------------------------------------------------- /nnet/libs/metric.py: -------------------------------------------------------------------------------- 1 | # jyhan@2020 2 | 3 | """ 4 | Provided measure metircs: 5 | speech separation: (w/ & w/o PIT) 6 | - SDR 7 | - SDRi 8 | - SI-SNR 9 | - SI-SNRi 10 | speech enhancement: 11 | - PESQ 12 | - STOI 13 | """ 14 | 15 | import numpy as np 16 | 17 | from pesq import pesq 18 | from pystoi.stoi import stoi 19 | 20 | from itertools import permutations 21 | from mir_eval.separation import bss_eval_sources 22 | 23 | def cal_sisnr(est, ref, remove_dc=True, eps=1e-8): 24 | """ 25 | Compute SI-SNR 26 | Arguments: 27 | est: vector, enhanced/separated signal 28 | ref: vector, reference signal(ground truth) 29 | """ 30 | assert len(est) == len(ref) 31 | def vec_l2norm(x): 32 | return np.linalg.norm(x, 2) 33 | 34 | # zero mean, seems do not hurt results 35 | if remove_dc: 36 | e_zm = est - np.mean(est) 37 | r_zm = ref - np.mean(ref) 38 | t = np.inner(e_zm, r_zm) * r_zm / (vec_l2norm(r_zm)**2 + eps) 39 | n = e_zm - t 40 | else: 41 | t = np.inner(est, ref) * ref / (vec_l2norm(ref)**2 + eps) 42 | n = est - t 43 | return 20 * np.log10(vec_l2norm(t) / (vec_l2norm(n) + eps)) 44 | 45 | 46 | def permute_si_snr(est, ref): 47 | """ 48 | Compute SI-SNR between N pairs 49 | Arguments: 50 | est: list[vector], enhanced/separated signal 51 | ref: list[vector], reference signal(ground truth) 52 | Return: 53 | max sisnr and it's permutation 54 | """ 55 | assert len(est) == len(ref) 56 | def si_snr_avg(est, ref): 57 | return sum([cal_sisnr(e, r) for e, r in zip(est, ref)]) / len(est) 58 | 59 | N = len(est) 60 | if N != len(est): 61 | raise RuntimeError( 62 | "size do not match between est and ref: {:d} vs {:d}".format( 63 | N, len(ref))) 64 | si_snrs = [] 65 | perm = [] 66 | for order in permutations(range(N)): 67 | si_snrs.append(si_snr_avg(est, [ref[n] for n in order])) 68 | perm.append(order) 69 | 70 | return max(si_snrs), perm[si_snrs.index(max(si_snrs))] 71 | 72 | 73 | def permute_si_snri(mix, est, ref, both=True): 74 | """ 75 | Compute SI-SNR improvement 76 | Arguments: 77 | mix: vector, mixture signal 78 | est: list[vector], enhanced/separated signal 79 | ref: list[vector], reference signal(ground truth) 80 | [spk1, spk2, aux] 81 | """ 82 | m_mix = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2]) 83 | m_enh, _ = permute_si_snr(est, ref) 84 | if both: 85 | return m_enh, m_enh - m_mix 86 | else: 87 | return m_enh - m_mix 88 | 89 | def pit_rank_sisnr(mix, est, ref): 90 | """ 91 | Compute SI-SNR improvement 92 | Arguments: 93 | mix: vector, mixture signal 94 | est: list[vector], enhanced/separated signal 95 | ref: list[vector], reference signal(ground truth) 96 | [spk1, spk2, aux] 97 | """ 98 | m_mix1 = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2]) 99 | m_mix2 = sum([cal_sisnr(mix, r) for r in est[:2]]) / len(est[:2]) 100 | m_mix = (m_mix1 + m_mix2) / 2 101 | m_enh, _ = permute_si_snr(est, ref) 102 | 103 | return m_enh, m_mix 104 | 105 | def pit_rank_sisnr_all(mix, est, ref): 106 | """ 107 | Compute SI-SNR improvement 108 | Arguments: 109 | mix: vector, mixture signal 110 | est: list[vector], enhanced/separated signal 111 | ref: list[vector], reference signal(ground truth) 112 | [spk1, spk2, aux] 113 | """ 114 | m_mix1 = sum([cal_sisnr(mix, r) for r in ref[:2]]) / len(ref[:2]) 115 | m_mix2 = sum([cal_sisnr(mix, r) for r in est[:2]]) / len(est[:2]) 116 | m_mix = (m_mix1 + m_mix2) / 2 117 | m_enh, _ = permute_si_snr(est, ref) 118 | 119 | return m_enh, m_mix1, m_mix2, m_mix 120 | 121 | 122 | def reorder_list(slist, perm): 123 | """ 124 | Arguments: 125 | slist: list[vector], reference signal 126 | perm: permutation label 127 | Return: 128 | list[vector], reordered reference signal 129 | """ 130 | return [slist[p] for p in perm] 131 | 132 | 133 | def cal_SDRi(mix, est, ref): 134 | """Calculate Source-to-Distortion Ratio improvement (SDRi). 135 | NOTE: bss_eval_sources is very very slow. 136 | Args: 137 | mix: numpy.ndarray, 138 | est: [numpy.ndarray, numpy.ndarray] enhanced/separated signal 139 | ref: [numpy.ndarray, numpy.ndarray] , reference signal(ground truth) 140 | Returns: 141 | avg_sdr, sdri 142 | """ 143 | mix = np.array(mix) 144 | est = np.array(est) 145 | ref = np.array(ref) 146 | 147 | mix_anchor = np.stack([mix, mix], axis=0) 148 | sdr, sir, sar, popt = bss_eval_sources(ref, est) 149 | sdr0, sir0, sar0, popt0 = bss_eval_sources(ref, mix_anchor) 150 | avg_sdr = (sdr[0] + sdr[1] ) / 2 151 | avg_sdr_m = (sdr0[0] + sdr0[1] ) / 2 152 | 153 | return avg_sdr, avg_sdr - avg_sdr_m 154 | 155 | 156 | def permute_pesq(est, ref, fs=8000, mode='nb'): 157 | """ 158 | Evaluate PESQ 159 | Args: 160 | est: [numpy 1D array, numpy 1D array], estimated audio signal 161 | ref: [numpy 1D array, numpy 1D array], reference audio signal 162 | fs: integer, sampling rate 163 | """ 164 | assert fs in [8000, 16000] 165 | assert len(est) == len(ref) 166 | mode = 'nb' if fs == 8000 else 'wb' 167 | 168 | def pesq_avg(est, ref): 169 | return sum([pesq(fs, r, e, mode) for e, r in zip(est, ref)]) / len(est) 170 | 171 | N = len(est) 172 | if N != len(est): 173 | raise RuntimeError( 174 | "size do not match between est and ref: {:d} vs {:d}".format( 175 | N, len(ref))) 176 | pesqs = [] 177 | for order in permutations(range(N)): 178 | pesqs.append(pesq_avg(est, [ref[n] for n in order])) 179 | 180 | return max(pesqs) 181 | 182 | 183 | def permute_stoi(est, ref, fs=8000): 184 | """ 185 | Evaluate STOI 186 | Args: 187 | est: [numpy 1D array, numpy 1D array], estimated audio signal 188 | ref: [numpy 1D array, numpy 1D array], reference audio signal 189 | fs: integer, sampling rate 190 | """ 191 | assert len(est) == len(ref) 192 | 193 | def stoi_avg(est, ref): 194 | return sum([stoi(r, e, fs) for e, r in zip(est, ref)]) / len(est) 195 | 196 | N = len(est) 197 | if N != len(est): 198 | raise RuntimeError( 199 | "size do not match between est and ref: {:d} vs {:d}".format( 200 | N, len(ref))) 201 | stois = [] 202 | for order in permutations(range(N)): 203 | stois.append(stoi_avg(est, [ref[n] for n in order])) 204 | 205 | return max(stois) 206 | 207 | 208 | def eval_all(mix, est, ref, fs=8000, pesq=False): 209 | """ 210 | Arguments: 211 | mix: np.narray 212 | est: list[np.narray, np.narray] 213 | ref: list[np.narray, np.narray] 214 | Evaluate 215 | SISNR/SISNRi; 216 | SDR/SDRi; 217 | PESQ/STOI 218 | """ 219 | sisnr, sisnri = permute_si_snri(mix, est, ref, True) 220 | sdr, sdri = cal_SDRi(mix, est, ref) 221 | if pesq: 222 | enh_pesq = permute_pesq(est, ref, fs) 223 | enh_stoi = permute_stoi(est, ref, fs) 224 | return sisnr, sisnri, sdr, sdri, enh_pesq, enh_stoi 225 | else: 226 | return sisnr, sisnri, sdr, sdri 227 | 228 | if __name__ == '__main__': 229 | # np.random.seed(20) 230 | x = np.random.rand(32000) 231 | xlist = [np.random.rand(32000), np.random.rand(32000)] 232 | slist = [np.random.rand(32000), np.random.rand(32000)] 233 | mlist = [np.random.rand(32000), np.random.rand(32000)] 234 | # print(permute_si_snr(xlist, slist)) 235 | # print(permute_si_snri(x, xlist, slist)) 236 | # print(permute_si_snri(x, xlist, slist, False)) 237 | # rlist = reorder_list(slist, [0,1]) 238 | # sdr, sir, sar, popt = bss_eval_sou1rces(np.array(slist), np.array(xlist)) 239 | # sdr, sdri = cal_SDRi(x, xlist, slist) 240 | # pp = permute_pesq(xlist, slist, fs=8000) 241 | # st = permute_stoi(xlist, xlist, fs=8000) 242 | sisnr, sisnri, sdr, sdri, enh_pesq, enh_stoi = eval_all(x, xlist, slist, 8000) 243 | 244 | # print(sdr) 245 | # print(cal_sdr(np.array(xlist[0]), np.array(slist[0]))) 246 | # print(cal_sdr(np.array(xlist[1]), np.array(slist[1]))) 247 | # print(cal_sdr(np.array(xlist[0]), np.array(slist[1]))) 248 | # print(cal_sdr(np.array(xlist[1]), np.array(slist[0]))) 249 | # print(cal_sdr(np.array(xlist), np.array(slist))) 250 | 251 | 252 | 253 | 254 | 255 | 256 | -------------------------------------------------------------------------------- /nnet/libs/trainer_unet_tse_steplr_clip.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import os 4 | import sys 5 | import time 6 | 7 | # from itertools import permutations 8 | from collections import defaultdict 9 | 10 | import torch as th 11 | import torch.nn.functional as F 12 | # from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | from torch.optim.lr_scheduler import StepLR 14 | from torch.nn.utils import clip_grad_norm_ 15 | 16 | from .utils import get_logger 17 | # from torch.utils.tensorboard import SummaryWriter 18 | 19 | def load_obj(obj, device): 20 | """ 21 | Offload tensor object in obj to cuda device 22 | """ 23 | 24 | def cuda(obj): 25 | return obj.to(device) if isinstance(obj, th.Tensor) else obj 26 | 27 | if isinstance(obj, dict): 28 | return {key: load_obj(obj[key], device) for key in obj} 29 | elif isinstance(obj, list): 30 | return [load_obj(val, device) for val in obj] 31 | else: 32 | return cuda(obj) 33 | 34 | 35 | class SimpleTimer(object): 36 | """ 37 | A simple timer 38 | """ 39 | 40 | def __init__(self): 41 | self.reset() 42 | 43 | def reset(self): 44 | self.start = time.time() 45 | 46 | def elapsed(self): 47 | return (time.time() - self.start) / 60 48 | 49 | 50 | class ProgressReporter(object): 51 | """ 52 | A simple progress reporter 53 | """ 54 | 55 | def __init__(self, logger, period=100): 56 | self.period = period 57 | self.logger = logger 58 | self.loss = [] 59 | self.timer = SimpleTimer() 60 | def add(self, loss): 61 | self.loss.append(loss) 62 | N = len(self.loss) 63 | if not N % self.period: 64 | avg = sum(self.loss[-self.period:]) / self.period 65 | self.logger.info("Processed {:d} batches" 66 | "(loss = {:+.2f})...".format(N, avg)) 67 | # self.loss_writer.add_scalar('Loss/train', avg, N) 68 | def report(self, details=False): 69 | N = len(self.loss) 70 | if details: 71 | sstr = ",".join(map(lambda f: "{:.2f}".format(f), self.loss)) 72 | self.logger.info("Loss on {:d} batches: {}".format(N, sstr)) 73 | return { 74 | "loss": sum(self.loss) / N, 75 | "batches": N, 76 | "cost": self.timer.elapsed() 77 | } 78 | 79 | class Trainer(object): 80 | def __init__(self, 81 | nnet, 82 | checkpoint="checkpoint", 83 | optimizer="adam", 84 | gpuid=0, 85 | optimizer_kwargs=None, 86 | clip_norm=1.0, 87 | min_lr=0, 88 | patience=0, 89 | factor=0.5, 90 | logging_period=100, 91 | resume=None, 92 | no_impr=150): 93 | if not th.cuda.is_available(): 94 | raise RuntimeError("CUDA device unavailable...exist") 95 | if not isinstance(gpuid, tuple): 96 | gpuid = (gpuid, ) 97 | self.device = th.device("cuda:{}".format(gpuid[0])) 98 | self.gpuid = gpuid 99 | if checkpoint and not os.path.exists(checkpoint): 100 | os.makedirs(checkpoint) 101 | self.checkpoint = checkpoint 102 | self.logger = get_logger( 103 | os.path.join(checkpoint, "trainer.log"), file=True) 104 | 105 | self.clip_norm = clip_norm 106 | self.logging_period = logging_period 107 | self.cur_epoch = 0 # zero based 108 | self.no_impr = no_impr 109 | 110 | if resume: 111 | if not os.path.exists(resume): 112 | raise FileNotFoundError( 113 | "Could not find resume checkpoint: {}".format(resume)) 114 | cpt = th.load(resume, map_location="cpu") 115 | self.cur_epoch = cpt["epoch"] 116 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format( 117 | resume, self.cur_epoch)) 118 | # load nnet 119 | nnet.load_state_dict(cpt["model_state_dict"]) 120 | self.nnet = nnet.to(self.device) 121 | self.optimizer = self.create_optimizer( 122 | optimizer, optimizer_kwargs, state=cpt["optim_state_dict"]) 123 | else: 124 | self.nnet = nnet.to(self.device) 125 | self.optimizer = self.create_optimizer(optimizer, optimizer_kwargs) 126 | # self.scheduler = ReduceLROnPlateau( 127 | # self.optimizer, 128 | # mode="min", 129 | # factor=factor, 130 | # patience=patience, 131 | # min_lr=min_lr, 132 | # verbose=True) 133 | self.scheduler1 = StepLR(self.optimizer, step_size=2, gamma=0.98) 134 | self.scheduler2 = StepLR(self.optimizer, step_size=1, gamma=0.9) 135 | 136 | self.num_params = sum( 137 | [param.nelement() for param in nnet.parameters()]) / 10.0**6 138 | 139 | # logging 140 | self.logger.info("Model summary:\n{}".format(nnet)) 141 | self.logger.info("Loading model to GPUs:{}, #param: {:.2f}M".format( 142 | gpuid, self.num_params)) 143 | if clip_norm > 0: 144 | self.logger.info( 145 | "Gradient clipping by {}, default L2".format(clip_norm)) 146 | 147 | def save_checkpoint(self, best=True): 148 | cpt = { 149 | "epoch": self.cur_epoch, 150 | "model_state_dict": self.nnet.state_dict(), 151 | "optim_state_dict": self.optimizer.state_dict() 152 | } 153 | th.save( 154 | cpt, 155 | os.path.join(self.checkpoint, 156 | "{0}.pt.tar".format("best" if best else "last"))) 157 | 158 | def save_every_checkpoint(self, idx): 159 | cpt = { 160 | "epoch": self.cur_epoch, 161 | "model_state_dict": self.nnet.state_dict(), 162 | "optim_state_dict": self.optimizer.state_dict() 163 | } 164 | th.save(cpt, os.path.join(self.checkpoint, 165 | "{0}.pt.tar".format(str(idx)))) 166 | 167 | def create_optimizer(self, optimizer, kwargs, state=None): 168 | supported_optimizer = { 169 | "sgd": th.optim.SGD, # momentum, weight_decay, lr 170 | "rmsprop": th.optim.RMSprop, # momentum, weight_decay, lr 171 | "adam": th.optim.Adam, # weight_decay, lr 172 | "adadelta": th.optim.Adadelta, # weight_decay, lr 173 | "adagrad": th.optim.Adagrad, # lr, lr_decay, weight_decay 174 | "adamax": th.optim.Adamax # lr, weight_decay 175 | # ... 176 | } 177 | if optimizer not in supported_optimizer: 178 | raise ValueError("Now only support optimizer {}".format(optimizer)) 179 | opt = supported_optimizer[optimizer](self.nnet.parameters(), **kwargs) 180 | self.logger.info("Create optimizer {0}: {1}".format(optimizer, kwargs)) 181 | if state is not None: 182 | opt.load_state_dict(state) 183 | self.logger.info("Load optimizer state dict from checkpoint") 184 | return opt 185 | 186 | def compute_loss(self, egs): 187 | raise NotImplementedError 188 | 189 | def train(self, data_loader): 190 | self.logger.info("Set train mode...") 191 | self.nnet.train() 192 | reporter = ProgressReporter(self.logger, period=self.logging_period) 193 | 194 | for egs in data_loader: 195 | # load to gpu 196 | egs = load_obj(egs, self.device) 197 | 198 | self.optimizer.zero_grad() 199 | loss = self.compute_loss(egs) 200 | loss.backward() 201 | 202 | if self.clip_norm > 0: 203 | clip_grad_norm_(self.nnet.parameters(), self.clip_norm) 204 | self.optimizer.step() 205 | 206 | reporter.add(loss.item()) 207 | return reporter.report() 208 | 209 | def eval(self, data_loader): 210 | self.logger.info("Set eval mode...") 211 | self.nnet.eval() 212 | reporter = ProgressReporter(self.logger, period=self.logging_period) 213 | 214 | with th.no_grad(): 215 | for egs in data_loader: 216 | egs = load_obj(egs, self.device) 217 | loss = self.compute_loss(egs) 218 | reporter.add(loss.item()) 219 | return reporter.report(details=True) 220 | 221 | def run(self, train_loader, dev_loader, num_epochs=120): 222 | # avoid alloc memory from gpu0 223 | reporter = ProgressReporter(self.logger, period=self.logging_period) 224 | with th.cuda.device(self.gpuid[0]): 225 | stats = dict() 226 | # check if save is OK 227 | self.save_checkpoint(best=False) 228 | cv = self.eval(dev_loader) 229 | best_loss = cv["loss"] 230 | self.logger.info("START FROM EPOCH {:d}, LOSS = {:.4f}".format( 231 | self.cur_epoch, best_loss)) 232 | no_impr = 0 233 | # make sure not inf 234 | # self.scheduler.best = best_loss 235 | while self.cur_epoch < num_epochs: 236 | self.cur_epoch += 1 237 | cur_lr = self.optimizer.param_groups[0]["lr"] 238 | stats[ 239 | "title"] = "Loss(time/N, lr={:.3e}) - Epoch {:2d}:".format( 240 | cur_lr, self.cur_epoch) 241 | tr = self.train(train_loader) 242 | stats["tr"] = "train = {:+.4f}({:.2f}m/{:d})".format( 243 | tr["loss"], tr["cost"], tr["batches"]) 244 | cv = self.eval(dev_loader) 245 | stats["cv"] = "dev = {:+.4f}({:.2f}m/{:d})".format( 246 | cv["loss"], cv["cost"], cv["batches"]) 247 | stats["scheduler"] = "" 248 | if cv["loss"] > best_loss: 249 | no_impr += 1 250 | stats["scheduler"] = "| no impr, best = {:.4f}".format( 251 | cv["loss"]) 252 | else: 253 | best_loss = cv["loss"] 254 | no_impr = 0 255 | self.save_checkpoint(best=True) 256 | if self.cur_epoch == 90 or self.cur_epoch>= 100: 257 | self.save_every_checkpoint(self.cur_epoch) 258 | self.logger.info( 259 | "{title} {tr} | {cv} {scheduler}".format(**stats)) 260 | # schedule here 261 | # self.scheduler.step(cv["loss"]) 262 | if self.cur_epoch <= 100: 263 | self.scheduler1.step() 264 | else: 265 | self.scheduler2.step() 266 | # flush scheduler info 267 | sys.stdout.flush() 268 | # save last checkpoint 269 | self.save_checkpoint(best=False) 270 | 271 | if no_impr == self.no_impr: 272 | self.logger.info( 273 | "Stop training cause no impr for {:d} epochs".format( 274 | no_impr)) 275 | break 276 | 277 | 278 | self.logger.info("Training for {:d}/{:d} epoches done!".format( 279 | self.cur_epoch, num_epochs)) 280 | # reporter.loss_writer.close() 281 | 282 | class SiSnrTrainer(Trainer): 283 | def __init__(self, *args, **kwargs): 284 | super(SiSnrTrainer, self).__init__(*args, **kwargs) 285 | 286 | def sisnr(self, x, s, eps=1e-8): 287 | """ 288 | Arguments: 289 | x: separated signal, N x S tensor 290 | s: reference signal, N x S tensor 291 | Return: 292 | sisnr: N tensor 293 | """ 294 | 295 | def l2norm(mat, keepdim=False): 296 | return th.norm(mat, dim=-1, keepdim=keepdim) 297 | 298 | if x.shape != s.shape: 299 | raise RuntimeError( 300 | "Dimention mismatch when calculate si-snr, {} vs {}".format( 301 | x.shape, s.shape)) 302 | x_zm = x - th.mean(x, dim=-1, keepdim=True) 303 | s_zm = s - th.mean(s, dim=-1, keepdim=True) 304 | t = th.sum( 305 | x_zm * s_zm, dim=-1, 306 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 307 | return 20 * th.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 308 | 309 | def mask_by_length(self, xs, lengths, fill=0): 310 | """ 311 | Mask tensor according to length 312 | """ 313 | assert xs.size(0) == len(lengths) 314 | ret = xs.data.new(*xs.size()).fill_(fill) 315 | for i, l in enumerate(lengths): 316 | ret[i, :l] = xs[i, :l] 317 | return ret 318 | 319 | def compute_loss(self, egs): 320 | N = egs["mix"].size(0) 321 | 322 | # spks x n x S 323 | nnet_load = th.nn.DataParallel(self.nnet, device_ids=self.gpuid) 324 | ests = nnet_load(egs["mix"], egs["aux"]) 325 | 326 | refs = egs['ref'] 327 | # N = egs["mix"].size(0) 328 | valid_len = egs["valid_len"] 329 | ests = self.mask_by_length(ests, valid_len) 330 | refs = self.mask_by_length(refs, valid_len) 331 | sisnr_loss = -th.sum(self.sisnr(ests, refs)) / N 332 | 333 | return sisnr_loss 334 | -------------------------------------------------------------------------------- /nnet/libs/utils.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import os 4 | import json 5 | import logging 6 | 7 | 8 | def get_logger( 9 | name, 10 | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", 11 | date_format="%Y-%m-%d %H:%M:%S", 12 | file=False): 13 | """ 14 | Get python logger instance 15 | """ 16 | logger = logging.getLogger(name) 17 | logger.setLevel(logging.INFO) 18 | # file or console 19 | handler = logging.StreamHandler() if not file else logging.FileHandler( 20 | name) 21 | handler.setLevel(logging.INFO) 22 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | return logger 26 | 27 | 28 | def dump_json(obj, fdir, name): 29 | """ 30 | Dump python object in json 31 | """ 32 | if fdir and not os.path.exists(fdir): 33 | os.makedirs(fdir) 34 | with open(os.path.join(fdir, name), "w") as f: 35 | json.dump(obj, f, indent=4, sort_keys=False) 36 | 37 | 38 | def load_json(fdir, name): 39 | """ 40 | Load json as python object 41 | """ 42 | path = os.path.join(fdir, name) 43 | if not os.path.exists(path): 44 | raise FileNotFoundError("Could not find json file: {}".format(path)) 45 | with open(path, "r") as f: 46 | obj = json.load(f) 47 | return obj -------------------------------------------------------------------------------- /nnet/memonger/__init__.py: -------------------------------------------------------------------------------- 1 | from .memonger import SublinearSequential -------------------------------------------------------------------------------- /nnet/memonger/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/memonger/__pycache__/checkpoint.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/checkpoint.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/memonger/__pycache__/memonger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isHuangZiling/SEF-PNet/22fa1c36a762058122d1dd33ead4b42e751daa95/nnet/memonger/__pycache__/memonger.cpython-39.pyc -------------------------------------------------------------------------------- /nnet/memonger/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | import torch 3 | import warnings 4 | 5 | 6 | def detach_variable(inputs): 7 | if isinstance(inputs, tuple): 8 | out = [] 9 | for inp in inputs: 10 | x = inp.detach() 11 | x.requires_grad = inp.requires_grad 12 | out.append(x) 13 | return tuple(out) 14 | else: 15 | raise RuntimeError( 16 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 17 | 18 | 19 | def check_backward_validity(inputs): 20 | if not any(inp.requires_grad for inp in inputs): 21 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 22 | 23 | 24 | # Global switch to toggle whether or not checkpointed passes stash and restore 25 | # the RNG state. If True, any checkpoints making use of RNG should achieve deterministic 26 | # output compared to non-checkpointed passes. 27 | preserve_rng_state = True 28 | 29 | 30 | class CheckpointFunction(torch.autograd.Function): 31 | 32 | @staticmethod 33 | def forward(ctx, run_function, *args): 34 | check_backward_validity(args) 35 | ctx.run_function = run_function 36 | if preserve_rng_state: 37 | # We can't know if the user will transfer some args from the host 38 | # to the device during their run_fn. Therefore, we stash both 39 | # the cpu and cuda rng states unconditionally. 40 | # 41 | # TODO: 42 | # We also can't know if the run_fn will internally move some args to a device 43 | # other than the current device, which would require logic to preserve 44 | # rng states for those devices as well. We could paranoically stash and restore 45 | # ALL the rng states for all visible devices, but that seems very wasteful for 46 | # most cases. 47 | ctx.fwd_cpu_rng_state = torch.get_rng_state() 48 | # Don't eagerly initialize the cuda context by accident. 49 | # (If the user intends that the context is initialized later, within their 50 | # run_function, we SHOULD actually stash the cuda state here. Unfortunately, 51 | # we have no way to anticipate this will happen before we run the function.) 52 | ctx.had_cuda_in_fwd = False 53 | if torch.cuda._initialized: 54 | ctx.had_cuda_in_fwd = True 55 | ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() 56 | ctx.save_for_backward(*args) 57 | with torch.no_grad(): 58 | outputs = run_function(*args) 59 | return outputs 60 | 61 | @staticmethod 62 | def backward(ctx, *args): 63 | if not torch.autograd._is_checkpoint_valid(): 64 | raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") 65 | inputs = ctx.saved_tensors 66 | # Stash the surrounding rng state, and mimic the state that was 67 | # present at this time during forward. Restore the surrouding state 68 | # when we're done. 69 | rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else [] 70 | with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): 71 | if preserve_rng_state: 72 | torch.set_rng_state(ctx.fwd_cpu_rng_state) 73 | if ctx.had_cuda_in_fwd: 74 | torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state) 75 | detached_inputs = detach_variable(inputs) 76 | with torch.enable_grad(): 77 | outputs = ctx.run_function(*detached_inputs) 78 | 79 | if isinstance(outputs, torch.Tensor): 80 | outputs = (outputs,) 81 | torch.autograd.backward(outputs, args) 82 | return (None,) + tuple(inp.grad for inp in detached_inputs) 83 | 84 | 85 | def checkpoint(function, *args): 86 | r"""Checkpoint a model or part of the model 87 | 88 | Checkpointing works by trading compute for memory. Rather than storing all 89 | intermediate activations of the entire computation graph for computing 90 | backward, the checkpointed part does **not** save intermediate activations, 91 | and instead recomputes them in backward pass. It can be applied on any part 92 | of a model. 93 | 94 | Specifically, in the forward pass, :attr:`function` will run in 95 | :func:`torch.no_grad` manner, i.e., not storing the intermediate 96 | activations. Instead, the forward pass saves the inputs tuple and the 97 | :attr:`function` parameter. In the backwards pass, the saved inputs and 98 | :attr:`function` is retreived, and the forward pass is computed on 99 | :attr:`function` again, now tracking the intermediate activations, and then 100 | the gradients are calculated using these activation values. 101 | 102 | .. warning:: 103 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 104 | with :func:`torch.autograd.backward`. 105 | 106 | .. warning:: 107 | If :attr:`function` invocation during backward does anything different 108 | than the one during forward, e.g., due to some global variable, the 109 | checkpointed version won't be equivalent, and unfortunately it can't be 110 | detected. 111 | 112 | .. warning: 113 | At least one of the inputs needs to have :code:`requires_grad=True` if 114 | grads are needed for model inputs, otherwise the checkpointed part of the 115 | model won't have gradients. 116 | 117 | Args: 118 | function: describes what to run in the forward pass of the model or 119 | part of the model. It should also know how to handle the inputs 120 | passed as the tuple. For example, in LSTM, if user passes 121 | ``(activation, hidden)``, :attr:`function` should correctly use the 122 | first input as ``activation`` and the second input as ``hidden`` 123 | args: tuple containing inputs to the :attr:`function` 124 | 125 | Returns: 126 | Output of running :attr:`function` on :attr:`*args` 127 | """ 128 | return CheckpointFunction.apply(function, *args) 129 | 130 | 131 | def checkpoint_sequential(functions, segments, *inputs): 132 | r"""A helper function for checkpointing sequential models. 133 | 134 | Sequential models execute a list of modules/functions in order 135 | (sequentially). Therefore, we can divide such a model in various segments 136 | and checkpoint each segment. All segments except the last will run in 137 | :func:`torch.no_grad` manner, i.e., not storing the intermediate 138 | activations. The inputs of each checkpointed segment will be saved for 139 | re-running the segment in the backward pass. 140 | 141 | See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. 142 | 143 | .. warning:: 144 | Checkpointing doesn't work with :func:`torch.autograd.grad`, but only 145 | with :func:`torch.autograd.backward`. 146 | 147 | .. warning: 148 | At least one of the inputs needs to have :code:`requires_grad=True` if 149 | grads are needed for model inputs, otherwise the checkpointed part of the 150 | model won't have gradients. 151 | 152 | Args: 153 | functions: A :class:`torch.nn.Sequential` or the list of modules or 154 | functions (comprising the model) to run sequentially. 155 | segments: Number of chunks to create in the model 156 | inputs: tuple of Tensors that are inputs to :attr:`functions` 157 | 158 | Returns: 159 | Output of running :attr:`functions` sequentially on :attr:`*inputs` 160 | 161 | Example: 162 | >>> model = nn.Sequential(...) 163 | >>> input_var = checkpoint_sequential(model, chunks, input_var) 164 | """ 165 | 166 | def run_function(start, end, functions): 167 | def forward(*inputs): 168 | for j in range(start, end + 1): 169 | if isinstance(inputs, tuple): 170 | inputs = functions[j](*inputs) 171 | else: 172 | inputs = functions[j](inputs) 173 | return inputs 174 | return forward 175 | 176 | if isinstance(functions, torch.nn.Sequential): 177 | functions = list(functions.children()) 178 | 179 | segment_size = len(functions) // segments 180 | # the last chunk has to be non-volatile 181 | end = -1 182 | for start in range(0, segment_size * (segments - 1), segment_size): 183 | end = start + segment_size - 1 184 | inputs = checkpoint(run_function(start, end, functions), *inputs) 185 | if not isinstance(inputs, tuple): 186 | inputs = (inputs,) 187 | return run_function(end + 1, len(functions) - 1, functions)(*inputs) 188 | -------------------------------------------------------------------------------- /nnet/memonger/memonger.py: -------------------------------------------------------------------------------- 1 | from math import sqrt, log 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | from .checkpoint import checkpoint 9 | 10 | 11 | def reforwad_momentum_fix(origin_momentum): 12 | return (1 - sqrt(1 - origin_momentum)) 13 | 14 | 15 | class SublinearSequential(nn.Sequential): 16 | def __init__(self, *args): 17 | super(SublinearSequential, self).__init__(*args) 18 | self.reforward = False 19 | self.momentum_dict = {} 20 | self.set_reforward(True) 21 | 22 | def set_reforward(self, enabled=True): 23 | if not self.reforward and enabled: 24 | print("Rescale BN Momemtum for re-forwarding purpose") 25 | for n, m in self.named_modules(): 26 | if isinstance(m, _BatchNorm): 27 | self.momentum_dict[n] = m.momentum 28 | m.momentum = reforwad_momentum_fix(self.momentum_dict[n]) 29 | if self.reforward and not enabled: 30 | print("Re-store BN Momemtum") 31 | for n, m in self.named_modules(): 32 | if isinstance(m, _BatchNorm): 33 | m.momentum = self.momentum_dict[n] 34 | self.reforward = enabled 35 | 36 | def forward(self, input): 37 | if self.reforward: 38 | return self.sublinear_forward(input) 39 | else: 40 | return self.normal_forward(input) 41 | 42 | def normal_forward(self, input): 43 | for module in self._modules.values(): 44 | input = module(input) 45 | return input 46 | 47 | def sublinear_forward(self, input): 48 | def run_function(start, end, functions): 49 | def forward(*inputs): 50 | input = inputs[0] 51 | for j in range(start, end + 1): 52 | input = functions[j](input) 53 | return input 54 | 55 | return forward 56 | 57 | functions = list(self.children()) 58 | segments = int(sqrt(len(functions))) 59 | segment_size = len(functions) // segments 60 | # the last chunk has to be non-volatile 61 | end = -1 62 | if not isinstance(input, tuple): 63 | inputs = (input,) 64 | for start in range(0, segment_size * (segments - 1), segment_size): 65 | end = start + segment_size - 1 66 | inputs = checkpoint(run_function(start, end, functions), *inputs) 67 | if not isinstance(inputs, tuple): 68 | inputs = (inputs,) 69 | # output = run_function(end + 1, len(functions) - 1, functions)(*inputs) 70 | output = checkpoint(run_function(end + 1, len(functions) - 1, functions), *inputs) 71 | return output 72 | -------------------------------------------------------------------------------- /nnet/memonger/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .memonger import SublinearSequential 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride != 1 or in_planes != self.expansion*planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(self.expansion*planes) 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = self.bn2(self.conv2(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 51 | 52 | self.shortcut = nn.Sequential() 53 | if stride != 1 or in_planes != self.expansion*planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 56 | nn.BatchNorm2d(self.expansion*planes) 57 | ) 58 | 59 | def forward(self, x): 60 | out = F.relu(self.bn1(self.conv1(x))) 61 | out = F.relu(self.bn2(self.conv2(out))) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class ResNet(nn.Module): 69 | def __init__(self, block, num_blocks, num_classes=100): 70 | super(ResNet, self).__init__() 71 | self.in_planes = 64 72 | 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(64) 75 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 76 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 77 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 78 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 79 | self.linear = nn.Linear(512*block.expansion, num_classes) 80 | 81 | def _make_layer(self, block, planes, num_blocks, stride): 82 | strides = [stride] + [1]*(num_blocks-1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(self.in_planes, planes, stride)) 86 | self.in_planes = planes * block.expansion 87 | return SublinearSequential(*layers) 88 | 89 | def forward(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = self.layer4(out) 95 | out = F.avg_pool2d(out, 4) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def ResNet18(): 102 | return ResNet(BasicBlock, [2,2,2,2]) 103 | 104 | def ResNet34(): 105 | return ResNet(BasicBlock, [3,4,6,3]) 106 | 107 | def ResNet50(): 108 | return ResNet(Bottleneck, [3,4,6,3]) 109 | 110 | def ResNet101(): 111 | return ResNet(Bottleneck, [3,4,23,3]) 112 | 113 | def ResNet152(): 114 | return ResNet(Bottleneck, [3,8,36,3]) 115 | 116 | 117 | def test(): 118 | net = ResNet18() 119 | y = net(torch.randn(1,3,32,32)) 120 | print(y.size()) 121 | 122 | # test() 123 | -------------------------------------------------------------------------------- /nnet/separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import time 5 | import argparse 6 | import torch as th 7 | import numpy as np 8 | from SEF_PNet_pse import SEF_PNet 9 | from libs.utils import load_json, get_logger 10 | from libs.audio import write_wav 11 | from libs.dataset_tse import Dataset 12 | 13 | def run(args): 14 | start = time.time() 15 | logger = get_logger( 16 | os.path.join(args.checkpoint, 'separate.log'), file=True) 17 | dataset = Dataset(mix_scp=args.mix_scp, ref_scp=args.ref_scp, aux_scp=args.aux_scp, sample_rate=args.fs) 18 | 19 | # Load model 20 | nnet_conf = load_json(args.checkpoint, "mdl.json") 21 | nnet = SEF_PNet(**nnet_conf) 22 | cpt_fname = os.path.join(args.checkpoint, "best.pt.tar") 23 | cpt = th.load(cpt_fname, map_location="cpu") 24 | nnet.load_state_dict(cpt["model_state_dict"]) 25 | logger.info("Load checkpoint from {}, epoch {:d}".format( 26 | cpt_fname, cpt["epoch"])) 27 | 28 | device = th.device( 29 | "cuda:{}".format(args.gpuid)) if args.gpuid >= 0 else th.device("cpu") 30 | nnet = nnet.to(device) if args.gpuid >= 0 else nnet 31 | nnet.eval() 32 | 33 | with th.no_grad(): 34 | total_cnt = 0 35 | for i, data in enumerate(dataset): 36 | mix = th.tensor(data['mix'], dtype=th.float32, device=device) 37 | aux = th.tensor(data['aux'], dtype=th.float32, device=device) 38 | key = data['key'] 39 | if args.gpuid >= 0: 40 | mix = mix.unsqueeze(0).to(device) 41 | aux = aux.unsqueeze(0).to(device) 42 | 43 | # Forward 44 | ests = nnet(mix, aux) 45 | ests = ests.cpu().numpy() 46 | norm = np.linalg.norm(mix.cpu().numpy(), np.inf) 47 | ests = ests[:mix.shape[-1]] 48 | # for each utts 49 | logger.info("Separate Utt{:d}".format(total_cnt + 1)) 50 | # norm 51 | ests = ests*norm/np.max(np.abs(ests)) 52 | 53 | fname = key + '.wav' 54 | write_wav(os.path.join(args.dump_dir, fname), 55 | ests, fs=args.fs) 56 | total_cnt += 1 57 | 58 | end = time.time() 59 | logger.info('Utt={:d} | Time Elapsed: {:.1f}s'.format(total_cnt, end-start)) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser('Separating speech...') 63 | parser.add_argument("--checkpoint", type=str, required=True, 64 | help="Directory of checkpoint") 65 | parser.add_argument("--gpuid", type=int, default=-1, 66 | help="GPU device to offload model to, -1 means running on CPU") 67 | parser.add_argument('--mix_scp', type=str, required=True, 68 | help='mix scp') 69 | parser.add_argument('--ref_scp', type=str, required=True, 70 | help='ref scp') 71 | parser.add_argument('--aux_scp', type=str, required=True, 72 | help='aux scp') 73 | parser.add_argument('--fs', type=int, default=8000, 74 | help="Sample rate for mixture input") 75 | parser.add_argument('--dump-dir', type=str, default="/node/hzl/expriment/SEF_PNet_icassp2025_github/results", 76 | help="Directory to dump separated results out") 77 | args = parser.parse_args() 78 | run(args) 79 | -------------------------------------------------------------------------------- /nnet/train_unet_tse_steplr_clip.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pprint 4 | import argparse 5 | from libs.trainer_unet_tse_steplr_clip import SiSnrTrainer 6 | from libs.dataset_tse import make_dataloader 7 | from libs.utils import dump_json, get_logger 8 | from SEF_PNet_pse import SEF_PNet 9 | from conf_unet_tse_32ms import trainer_conf, nnet_conf, train_data, dev_data, chunk_size 10 | 11 | logger = get_logger(__name__) 12 | 13 | def run(args): 14 | gpuids = tuple(map(int, args.gpus.split(","))) 15 | nnet = SEF_PNet(**nnet_conf) 16 | trainer = SiSnrTrainer(nnet, 17 | gpuid=gpuids, 18 | checkpoint=args.checkpoint, 19 | resume=args.resume, 20 | **trainer_conf) 21 | 22 | data_conf = { 23 | "train": train_data, 24 | "dev": dev_data, 25 | "chunk_size": chunk_size 26 | } 27 | 28 | for conf, fname in zip([nnet_conf, trainer_conf, data_conf], 29 | ["mdl.json", "trainer.json", "data.json"]): 30 | dump_json(conf, args.checkpoint, fname) 31 | 32 | train_loader = make_dataloader(train=True, 33 | data_kwargs=train_data, 34 | batch_size=args.batch_size, 35 | chunk_size=chunk_size, 36 | num_workers=args.num_workers) 37 | dev_loader = make_dataloader(train=False, 38 | data_kwargs=dev_data, 39 | batch_size=args.batch_size, 40 | chunk_size=chunk_size, 41 | num_workers=args.num_workers) 42 | trainer.run(train_loader, dev_loader, num_epochs=args.epochs) 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser( 46 | description= 47 | "Command to start ConvTasNet training, configured from conf.py", 48 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | parser.add_argument("--gpus", 50 | type=str, 51 | default="0", 52 | help="Training on which GPUs " 53 | "(one or more, egs: 0, \"0,1\")") 54 | parser.add_argument("--epochs", 55 | type=int, 56 | default=200, 57 | # default=500, 58 | help="Number of training epochs") 59 | parser.add_argument("--checkpoint", 60 | type=str, 61 | default='/node/hzl/expriment/SEF_PNet_icassp2025_github/demo', 62 | #required=True, 63 | help="Directory to dump models") 64 | parser.add_argument("--resume", 65 | type=str, 66 | default=None, 67 | help="Exist model to resume training from") 68 | parser.add_argument("--batch-size", 69 | type=int, 70 | default=32, 71 | help="Number of utterances in each batch") 72 | parser.add_argument("--num-workers", 73 | type=int, 74 | default=32, 75 | help="Number of workers used in data loader") 76 | args = parser.parse_args() 77 | logger.info("Arguments in command:\n{}".format(pprint.pformat(vars(args)))) 78 | 79 | run(args) 80 | print("train Done!") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | numpy==1.22.4 3 | mir_eval==0.7 4 | pesq==0.0.4 5 | pypesq @ https://github.com/vBaiCai/python-pesq/archive/master.zip#sha256=fba27c3d95e8f72fed7c55f675ce6057a64b26a1a67a2e469df2804cca69b8cc 6 | pystoi==0.3.3 7 | soundfile==0.12.1 8 | librosa==0.10.1 -------------------------------------------------------------------------------- /separate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eu 3 | checkpoint=/node/hzl/expriment/libri2mix_min_wav8k/SEF_PNet 4 | gpuid=0 5 | data_root=/node/hzl/data/data_libri2mix_s1_min_wav8k/test 6 | 7 | mix_scp=$data_root/mix_clean.scp 8 | ref_scp=$data_root/s1.scp 9 | aux_scp=$data_root/auxs1.scp 10 | 11 | fs=8000 12 | dump_dir=/node/hzl/data/enhanced_speech 13 | 14 | ./nnet/separate.py \ 15 | --checkpoint $checkpoint \ 16 | --gpuid $gpuid \ 17 | --mix_scp $mix_scp \ 18 | --ref_scp $ref_scp \ 19 | --aux_scp $aux_scp \ 20 | --fs $fs \ 21 | --dump-dir $dump_dir \ 22 | > separate.log 2>&1 23 | 24 | echo "Separate done!" 25 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | epochs=200 4 | # constrainted by GPU number & memory 5 | batch_size=32 6 | gpuid=0 7 | num_workers=32 8 | cpt_dir=/node/hzl/expriment/SEF_PNet_icassp2025_github/demo 9 | #resume= 10 | #[ $# -ne 1 ] && echo "Script error: $0 " && exit 1 11 | ./nnet/train_unet_tse_steplr_clip.py \ 12 | --gpu $gpuid \ 13 | --epochs $epochs \ 14 | --batch-size $batch_size \ 15 | --num-workers $num_workers \ 16 | --checkpoint $cpt_dir \ 17 | > train.log 2>&1 18 | --------------------------------------------------------------------------------