├── DPARNet ├── Beam-Guided-DPARNet.py ├── DPARNet.py ├── conf.yml ├── dataset.py ├── dataset_css.py ├── eval.py ├── feature.py ├── mvdr_util.py ├── pesq_stoi.py ├── run.sh ├── system.py ├── train.py └── utils │ ├── parse_options.sh │ └── prepare_python_env.sh ├── README.md ├── generate_rir.py └── sms_wsj_replace ├── create_rirs.py └── scenario.py /DPARNet/Beam-Guided-DPARNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # yangyi 2022.06 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from fast_transformers.masking import FullMask, LengthMask, TriangularCausalMask 9 | from fast_transformers.attention import SharedLinearAttention, LinearAttention 10 | 11 | from asteroid.engine.optimizers import make_optimizer 12 | from asteroid.losses import PITLossWrapper, pairwise_neg_snr 13 | 14 | from feature import FeatureExtractor, AngleFeature 15 | from mvdr_util import MVDR 16 | 17 | def make_model_and_optimizer(conf): 18 | model = DPARNet1_DPARNet2() 19 | print(model) 20 | optimizer = make_optimizer(model.parameters(), **conf['optim']) 21 | return model, optimizer 22 | 23 | 24 | class DPARNet1_DPARNet2(nn.Module): 25 | def __init__(self): 26 | super(DPARNet1_DPARNet2, self).__init__() 27 | 28 | self.mvdr = MVDR(False) 29 | 30 | self.DPARNet1 = DPARNet1( 31 | use_dense = False, 32 | use_att = [True, True, True, True, False], 33 | use_rnn = [True, True, False, False, False], 34 | width = 64, 35 | num_layers = 4, 36 | dropout_rate = 0.4, 37 | causal_conf = False, 38 | ) 39 | 40 | self.DPARNet2 = DPARNet2( 41 | use_dense = False, 42 | use_att = [True, True, True, True, False], 43 | use_rnn = [True, False, False, False, False], 44 | width = 64, 45 | num_layers = 4, 46 | dropout_rate = 0.4, 47 | causal_conf = False, 48 | ) 49 | 50 | def forward(self, mixture, do_eval = 0): # mixture [b m n] 51 | assert do_eval == 0 or do_eval == 1, "Eval type should be 0 (training) or 1 (mvdr) !" 52 | 53 | out_DPARNet1, pha_src = self.DPARNet1(mixture) # [b s m n] # [b m f t] 54 | bf_DPARNet1 = self.mvdr(mixture, out_DPARNet1) # [b s m n] 55 | out_DPARNet2_iter1 = self.DPARNet2(mixture, bf_DPARNet1) # [b s m n] 56 | 57 | bf_DPARNet2_iter1 = self.mvdr(mixture, out_DPARNet2_iter1) # [b s m n] 58 | out_DPARNet2_iter2 = self.DPARNet2(mixture, bf_DPARNet2_iter1) # [b s m n] 59 | 60 | if do_eval == 0: 61 | return out_DPARNet1, out_DPARNet2_iter1, out_DPARNet2_iter2, pha_src 62 | 63 | else: 64 | bf_DPARNet2_iter1 = self.mvdr(mixture, out_DPARNet2_iter1) # [b s m n] 65 | 66 | out_DPARNet2_iter2 = self.DPARNet2(mixture, bf_DPARNet2_iter1) # [b s m n] 67 | bf_DPARNet2_iter2 = self.mvdr(mixture, out_DPARNet2_iter2) # [b s m n] 68 | 69 | out_DPARNet2_iter3 = self.DPARNet2(mixture, bf_DPARNet2_iter2) # [b s m n] 70 | bf_DPARNet2_iter3 = self.mvdr(mixture, out_DPARNet2_iter3) # [b s m n] 71 | 72 | out_DPARNet2_iter4 = self.DPARNet2(mixture, bf_DPARNet2_iter3) # [b s m n] 73 | 74 | return out_DPARNet2_iter1[:,:,0], out_DPARNet2_iter1 75 | 76 | 77 | def half_reshape(x, inverse): 78 | if(inverse): 79 | x = torch.cat([x[:,0:x.shape[1]//2,],x[:,x.shape[1]//2:x.shape[1]//2*2,]],-1) 80 | x = x[...,:-1] 81 | return x 82 | else: 83 | x = F.pad(x, (0,1,0,0), 'replicate') 84 | x = torch.cat([x[:,:,:,0:x.shape[-1]//2],x[:,:,:,x.shape[-1]//2:x.shape[-1]//2*2]],1) 85 | return x 86 | 87 | class DPARNet1(nn.Module): 88 | def __init__(self, 89 | use_dense, 90 | use_att, 91 | use_rnn, 92 | num_channels=7, 93 | num_spks=2, 94 | frame_len=512, 95 | frame_hop=128, 96 | width=64, 97 | num_layers=3, 98 | dropout_rate=0.4, 99 | causal_conf=False, 100 | ): 101 | super(DPARNet1, self).__init__() 102 | 103 | self.use_dense = use_dense 104 | self.use_att = use_att 105 | self.use_rnn = use_rnn 106 | 107 | self.num_channels = num_channels 108 | self.num_spks = num_spks 109 | self.frame_len = frame_len 110 | self.frame_hop = frame_hop 111 | self.width = width 112 | self.num_layers = num_layers 113 | self.dropout_rate = dropout_rate 114 | self.causal_conf = causal_conf 115 | self.num_bins = self.frame_len // 2 + 1 116 | 117 | self.extractor = FeatureExtractor(frame_len=self.frame_len, frame_hop=self.frame_hop, do_ipd=False) 118 | 119 | self.in_Conv = nn.Sequential( 120 | nn.Conv2d(in_channels=self.num_channels * 2 * 2, out_channels=self.width, kernel_size=(1, 1)), 121 | nn.LayerNorm(self.num_bins // 2 + 1), 122 | nn.PReLU(self.width), 123 | ) 124 | 125 | self.in_Conv_att = nn.Sequential(nn.Conv2d(self.width, self.width // 2, kernel_size=(1, 1)), nn.PReLU()) 126 | 127 | self.dualrnn_attention = nn.ModuleList() 128 | for i in range (self.num_layers): 129 | self.dualrnn_attention.append(DualRNN_Attention(dropout_rate=self.dropout_rate, d_model=self.width//2, use_att=self.use_att[i], use_rnn=self.use_rnn[i])) 130 | 131 | self.out_Conv_att = nn.Sequential(nn.Conv2d(self.width // 2, self.width, kernel_size=(1, 1)), nn.PReLU()) 132 | 133 | self.out_Conv1 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Tanh(),) 134 | self.out_Conv2 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Sigmoid(),) 135 | 136 | self.out_Conv = nn.ConvTranspose2d(in_channels=self.width//2, out_channels=self.num_spks * self.num_channels, kernel_size=(1, 1)) 137 | 138 | if (self.use_dense): 139 | self.in_DenseBlock = DenseBlock(init_ch=self.width, g1=8, g2=self.width) 140 | self.out_DenseBlock = DenseBlock(init_ch=self.width, g1=8, g2=self.width) 141 | else: 142 | self.in_DenseBlock = None 143 | self.out_DenseBlock = None 144 | 145 | self.sigmoid = nn.Sigmoid() 146 | 147 | 148 | def forward(self, mixture): # mixture: [b m n] 149 | 150 | real_spec_src, imag_spec_src = self.extractor.stft(mixture, cplx=True) # [b m f t] 151 | com_spec_src = torch.cat((real_spec_src, imag_spec_src), 1).permute(0,1,3,2) # [b 2m t f=257] 152 | 153 | x = half_reshape(com_spec_src, False) # [b 4m t f=129] 154 | 155 | x = self.in_Conv(x) # [b w=64 t f] 156 | if (self.use_dense): 157 | x = self.in_DenseBlock(x) 158 | x = self.in_Conv_att(x) # [b w=32 t f] 159 | 160 | for i in range (self.num_layers): 161 | x = self.dualrnn_attention[i](x) # [b w=64 t f] 162 | 163 | x = self.out_Conv_att(x) # [b w=64 t f] 164 | 165 | x = self.out_Conv1(x) * self.out_Conv2(x) # [b w=64 t f] 166 | 167 | if (self.use_dense): 168 | x = self.out_DenseBlock(x) 169 | 170 | x = half_reshape(x, True) # [b w=32 t f=257] 171 | 172 | irm_est = self.sigmoid(self.out_Conv(x)).transpose(2,3) # [b sxm f t] 173 | irm_est = irm_est.chunk(self.num_spks, dim=1) # [b m f t] * s 174 | 175 | mag_src, pha_src = self.extractor.stft(mixture, cplx=False) # [b m f t] 176 | 177 | est_sig = [] 178 | for id_spk in range (self.num_spks): 179 | for id_chan in range (self.num_channels): 180 | mag_est = irm_est[id_spk][:,id_chan] * mag_src[:,id_chan] 181 | est_sig.append(self.extractor.istft(mag_est, pha_src[:,id_chan], cplx=False)) 182 | 183 | output = torch.stack(est_sig, 1) # [b mxs n] 184 | output = torch.stack(output.chunk(self.num_spks, dim=1), 1) # [b s m n] 185 | output = torch.nn.functional.pad(output,[0,mixture.shape[-1]-output.shape[-1]]) 186 | 187 | return output, pha_src ## [b s m n] [b m f t] 188 | 189 | 190 | class DPARNet2(nn.Module): 191 | def __init__(self, 192 | use_dense, 193 | use_att, 194 | use_rnn, 195 | num_channels=7, 196 | num_spks=2, 197 | frame_len=512, 198 | frame_hop=128, 199 | width=64, 200 | num_layers=3, 201 | dropout_rate=0.4, 202 | causal_conf=False, 203 | ): 204 | super(DPARNet2, self).__init__() 205 | 206 | self.use_dense = use_dense 207 | self.use_att = use_att 208 | self.use_rnn = use_rnn 209 | 210 | self.num_channels = num_channels 211 | self.num_spks = num_spks 212 | self.frame_len = frame_len 213 | self.frame_hop = frame_hop 214 | self.width = width 215 | self.num_layers = num_layers 216 | self.dropout_rate = dropout_rate 217 | self.causal_conf = causal_conf 218 | self.num_bins = self.frame_len // 2 + 1 219 | 220 | self.extractor = FeatureExtractor(frame_len=self.frame_len, frame_hop=self.frame_hop, do_ipd=False) 221 | 222 | self.in_Conv = nn.Sequential( 223 | nn.Conv2d(in_channels=self.num_channels * 2 * 2 * (self.num_spks+1), out_channels=self.width, kernel_size=(1, 1)), 224 | nn.LayerNorm(self.num_bins // 2 + 1), 225 | nn.PReLU(self.width), 226 | ) 227 | 228 | self.in_Conv_att = nn.Sequential(nn.Conv2d(self.width, self.width // 2, kernel_size=(1, 1)), nn.PReLU()) 229 | 230 | self.dualrnn_attention = nn.ModuleList() 231 | for i in range (self.num_layers): 232 | self.dualrnn_attention.append(DualRNN_Attention(dropout_rate=self.dropout_rate, d_model=self.width//2, use_att=self.use_att[i], use_rnn=self.use_rnn[i])) 233 | 234 | self.out_Conv1 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Tanh(),) 235 | self.out_Conv2 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Sigmoid(),) 236 | 237 | self.out_Conv_att = nn.Sequential(nn.Conv2d(self.width // 2, self.width, kernel_size=(1, 1)), nn.PReLU()) 238 | 239 | self.out_Conv = nn.ConvTranspose2d(in_channels=self.width//2, out_channels=self.num_channels * self.num_spks * 2, kernel_size=(1, 1)) 240 | 241 | if (self.use_dense): 242 | self.in_DenseBlock = DenseBlock(init_ch=self.width, g1=8, g2=self.width) 243 | self.out_DenseBlock = DenseBlock(init_ch=self.width, g1=8, g2=self.width) 244 | else: 245 | self.in_DenseBlock = None 246 | self.out_DenseBlock = None 247 | 248 | 249 | def forward(self, mixture, bf_DPARnet1): # mixture: [b m n] bf_DPARnet1: [b s m n] 250 | 251 | B = mixture.size(0) 252 | real_spec_src, imag_spec_src = self.extractor.stft(mixture, cplx=True) # [b m f t] 253 | real_spec_dparnet1, imag_spec_dparnet1 = self.extractor.stft(bf_DPARnet1.view(B, self.num_spks*self.num_channels, -1), cplx=True) # [b sxm f t] 254 | com_spec_src = torch.cat((real_spec_src, imag_spec_src, real_spec_dparnet1, imag_spec_dparnet1), 1).permute(0,1,3,2) # [b 2xmx(s+1) t f=257] 255 | 256 | x = half_reshape(com_spec_src, False) # [b 4xmx(s+1) t f=129] 257 | 258 | x = self.in_Conv(x) # [b w=64 t f] 259 | if (self.use_dense): 260 | x = self.in_DenseBlock(x) 261 | 262 | x = self.in_Conv_att(x) # [b w=32 t f] 263 | 264 | for i in range (self.num_layers): 265 | x = self.dualrnn_attention[i](x) # [b w t f] 266 | 267 | x = self.out_Conv_att(x) # [b w=64 t f] 268 | 269 | x = self.out_Conv1(x) * self.out_Conv2(x) # [b w=64 t f] 270 | 271 | if (self.use_dense): 272 | x = self.out_DenseBlock(x) 273 | 274 | x = half_reshape(x, True) # [b w=32 t f=257] 275 | 276 | cmask_est = self.out_Conv(x).transpose(2,3) # [b 2xsxm f t] 277 | cmask_est = cmask_est.chunk(self.num_spks, dim=1) # [b 2m f t] * 2 278 | 279 | est_sig = [] 280 | for id_spk in range (self.num_spks): 281 | for id_chan in range (self.num_channels): 282 | rmask_est = cmask_est[id_spk][:,0+2*id_chan] # [b f t] 283 | imask_est = cmask_est[id_spk][:,1+2*id_chan] 284 | 285 | real_spec_est = rmask_est * real_spec_src[:,id_chan] - imask_est * imag_spec_src[:,id_chan] 286 | imag_spec_est = rmask_est * imag_spec_src[:,id_chan] + imask_est * real_spec_src[:,id_chan] 287 | est_sig.append(self.extractor.istft(real_spec_est, imag_spec_est, cplx=True)) # [b n] * m * s 288 | 289 | output = torch.stack(est_sig, 1) # [b mxs n] 290 | output = torch.stack(output.chunk(self.num_spks,1), 1) # [b s m n] 291 | output = torch.nn.functional.pad(output,[0,mixture.shape[-1]-output.shape[-1]]) 292 | 293 | return output 294 | 295 | 296 | # NOTE 297 | class DualRNN_Attention(nn.Module): 298 | 299 | def __init__(self, dropout_rate, d_model, nhead=4, use_att=False, use_rnn=False): 300 | super(DualRNN_Attention,self).__init__() 301 | self.dropout_rate = dropout_rate 302 | self.d_model = d_model 303 | self.nhead = nhead 304 | self.use_att = use_att 305 | self.use_rnn = use_rnn 306 | 307 | self.bn = nn.BatchNorm1d(1) 308 | 309 | if (self.use_att): 310 | self.shared_att1 = SharedLinearAttention(self.d_model) 311 | self.linear_att1 = nn.Linear(self.d_model, self.d_model*3) 312 | self.shared_att2 = SharedLinearAttention(self.d_model) 313 | self.linear_att2 = nn.Linear(self.d_model, self.d_model*3) 314 | self.ln_att = nn.LayerNorm(self.d_model) 315 | else: 316 | self.shared_att1 = None 317 | self.linear_att1 = None 318 | self.shared_att2 = None 319 | self.linear_att2 = None 320 | self.ln_att = None 321 | 322 | if (self.use_rnn): 323 | self.rnn1 = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model * 2, num_layers=1, bias=False, bidirectional=True, batch_first=True) 324 | self.rnn2 = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model * 2, num_layers=1, bias=False, bidirectional=True, batch_first=True) 325 | self.dropout = nn.Dropout(p=self.dropout_rate) 326 | self.linear_rnn1 = nn.Linear(in_features = self.d_model * 4, out_features = self.d_model) 327 | self.linear_rnn2 = nn.Linear(in_features = self.d_model * 4, out_features = self.d_model) 328 | #self.relu = nn.ReLU() 329 | self.ln_rnn = nn.LayerNorm(self.d_model) 330 | else: 331 | self.rnn1 = None 332 | self.rnn2 = None 333 | self.dropout = None 334 | self.linear_rnn1 = None 335 | self.linear_rnn2 = None 336 | #self.relu = None 337 | self.ln_rnn = None 338 | 339 | def forward(self, x): # [b w t f=129] 340 | B, W, T, F = x.size() 341 | att_in1 = x.permute(0,2,3,1).contiguous().view(B*T, F, -1) 342 | 343 | if (not self.use_att): 344 | att_out1 = att_in1 345 | else: 346 | q, k, v = self.linear_att1(att_in1).view(B*T, F, self.nhead, -1).chunk(3,-1) 347 | m1, m2, m3 = FullMask(q.shape[1], k.shape[1], device=x.device),FullMask(q.shape[0], q.shape[1], device=x.device),FullMask(k.shape[0], k.shape[1], device=x.device) 348 | att_out1 = self.shared_att1(q, k, v, m1, m2, m3, causal=False).view(B*T, F, -1) 349 | att_out1 = self.ln_att(att_in1 + att_out1) 350 | 351 | rnn_in1 = att_out1 # [bxt f w] 352 | 353 | if (not self.use_rnn): 354 | rnn_out1 = rnn_in1 355 | else: 356 | rnn_out1, _ = self.rnn1(rnn_in1) 357 | rnn_out1 = self.linear_rnn1(self.dropout(rnn_out1)) 358 | rnn_out1 = self.ln_rnn(att_in1 + rnn_out1) 359 | 360 | rnn_out1 = rnn_out1.view(B, T, F, -1).permute(0,3,1,2) # [b w t f] 361 | 362 | rnn_out1 = (self.bn(rnn_out1.reshape(B,1,-1))).reshape(*rnn_out1.shape) # [b w t f] 363 | rnn_out1 = rnn_out1 + x 364 | 365 | att_in2 = rnn_out1.permute(0,3,2,1).contiguous().view(B*F, T, -1) 366 | 367 | if (not self.use_att): 368 | att_out2 = att_in2 369 | else: 370 | q, k, v = self.linear_att2(att_in2).view(B*F, T, self.nhead, -1).chunk(3,-1) 371 | m1, m2, m3 = FullMask(q.shape[1], k.shape[1], device=x.device),FullMask(q.shape[0], q.shape[1], device=x.device),FullMask(k.shape[0], k.shape[1], device=x.device) 372 | att_out2 = self.shared_att2(q, k, v, m1, m2, m3, causal=False).view(B*F, T, -1) 373 | att_out2 = self.ln_att(att_in2 + att_out2) 374 | 375 | rnn_in2 = att_out2 # [bxf t w] 376 | 377 | if (not self.use_rnn): 378 | rnn_out2 = rnn_in2 379 | else: 380 | rnn_out2, _ = self.rnn2(rnn_in2) 381 | rnn_out2 = self.linear_rnn2(self.dropout(rnn_out2)) 382 | rnn_out2 = self.ln_rnn(att_in2 + rnn_out2) 383 | 384 | rnn_out2 = rnn_out2.view(B, F, T, -1).permute(0,3,2,1) # [b w t f] 385 | 386 | rnn_out2 = (self.bn(rnn_out2.reshape(B,1,-1))).reshape(*rnn_out2.shape) # [b w t f] 387 | rnn_out2 = rnn_out2 + rnn_out1 388 | 389 | return rnn_out2 390 | 391 | 392 | class DenseBlock(nn.Module): 393 | 394 | def __init__(self, init_ch, g1, g2): 395 | super(DenseBlock,self).__init__() 396 | 397 | self.conv1 = nn.Sequential( 398 | nn.Conv2d(init_ch, g1, kernel_size=(3,3),stride=(1,1),padding=(1,1)), 399 | nn.ELU(), 400 | nn.InstanceNorm2d(g1,affine=False) 401 | ) 402 | self.conv2 = nn.Sequential( 403 | nn.Conv2d(init_ch+g1, g1, kernel_size=(3,3),stride=(1,1),padding=(1,1)), 404 | nn.ELU(), 405 | nn.InstanceNorm2d(g1,affine=False) 406 | ) 407 | self.conv3 = nn.Sequential( 408 | nn.Conv2d(init_ch+2*g1, g1, kernel_size=(3,3),stride=(1,1),padding=(1,1)), 409 | nn.ELU(), 410 | nn.InstanceNorm2d(g1,affine=False) 411 | ) 412 | self.conv4 = nn.Sequential( 413 | nn.Conv2d(init_ch+3*g1, g1, kernel_size=(3,3),stride=(1,1),padding=(1,1)), 414 | nn.ELU(), 415 | nn.InstanceNorm2d(g1,affine=False) 416 | ) 417 | self.conv5 = nn.Sequential( 418 | nn.Conv2d(init_ch+4*g1, g2, kernel_size=(3,3),stride=(1,1),padding=(1,1)), 419 | nn.ELU(), 420 | nn.InstanceNorm2d(g2,affine=False) 421 | ) 422 | 423 | def forward(self,x): 424 | y0 = self.conv1(x) 425 | 426 | y0_x = torch.cat((x,y0),dim=1) 427 | y1 = self.conv2(y0_x) 428 | 429 | y1_0_x = torch.cat((x,y0,y1),dim=1) 430 | y2 = self.conv3(y1_0_x) 431 | 432 | y2_1_0_x = torch.cat((x,y0,y1,y2),dim=1) 433 | y3 = self.conv4(y2_1_0_x) 434 | 435 | y3_2_1_0_x = torch.cat((x,y0,y1,y2,y3),dim=1) 436 | y4 = self.conv5(y3_2_1_0_x) 437 | 438 | return y4 439 | 440 | 441 | # NOTE 442 | class com_sisdr_loss1(nn.Module): 443 | def __init__(self, frame_len=512, frame_hop=128, num_channel=7): 444 | super().__init__() 445 | self.extractor = FeatureExtractor(frame_len=frame_len, frame_hop=frame_hop) 446 | self.sig_loss = PITLossWrapper(pairwise_neg_snr, pit_from='pw_mtx') 447 | 448 | # est_targets [b s m n] pha_mix [b m f t] targets [b s m n] SP [b] 449 | #def forward(self, est_targets1, est_targets2_1, pha_mix, targets, SP): 450 | def forward(self, est_targets1, est_targets2_1, est_targets2_2, pha_mix, targets, SP): 451 | B, S, M, N = est_targets1.size() 452 | 453 | mag_src_spk1, pha_src_spk1 = self.extractor.stft(targets[:,0]) # [b m f t] 454 | mag_src_spk2, pha_src_spk2 = self.extractor.stft(targets[:,1]) # [b m f t] 455 | _, _, F, T = mag_src_spk1.size() 456 | mag_src_spk1, pha_src_spk1 = mag_src_spk1.view(-1,F,T), pha_src_spk1.view(-1,F,T) 457 | mag_src_spk2, pha_src_spk2 = mag_src_spk2.view(-1,F,T), pha_src_spk2.view(-1,F,T) 458 | pha_mix = pha_mix.view(-1,F,T) 459 | 460 | targets_recover1 = self.extractor.istft(mag_src_spk1 * torch.cos(pha_mix - pha_src_spk1), pha_src_spk1).view(B,M,N) 461 | targets_recover2 = self.extractor.istft(mag_src_spk2 * torch.cos(pha_mix - pha_src_spk2), pha_src_spk2).view(B,M,N) 462 | 463 | targets_recover = torch.stack((targets_recover1, targets_recover2), 1) # [b s m n] 464 | 465 | est_targets1, targets_recover = est_targets1.permute(0,2,1,3).contiguous().view(B*M, S, N), targets_recover.permute(0,2,1,3).contiguous().view(B*M, S, N) 466 | est_targets2_1, targets = est_targets2_1.permute(0,2,1,3).contiguous().view(B*M, S, N), targets.permute(0,2,1,3).contiguous().view(B*M, S, N) 467 | est_targets2_2 = est_targets2_2.permute(0,2,1,3).contiguous().view(B*M, S, N) 468 | SP = SP.repeat(M,1).transpose(0,1).contiguous().view(B*M) 469 | 470 | if (sum(SP)==0): 471 | snr_loss1 = self.sig_loss(est_targets1, targets_recover) 472 | snr_loss2_1 = self.sig_loss(est_targets2_1, targets) 473 | snr_loss2_2 = self.sig_loss(est_targets2_2, targets) 474 | elif (sum(SP)==SP.shape[0]): 475 | snr_loss1 = 0.05 * self.sig_loss(est_targets1[:,[0]], targets_recover[:,[0]]) 476 | snr_loss2_1 = 0.05 * self.sig_loss(est_targets2_1[:,[0]], targets[:,[0]]) 477 | snr_loss2_2 = 0.05 * self.sig_loss(est_targets2_2[:,[0]], targets[:,[0]]) 478 | else: 479 | snr_loss1 = 0.05 * self.sig_loss(est_targets1[SP == 1][:,[0]], targets_recover[SP == 1][:,[0]]) + self.sig_loss(est_targets1[SP == 0], targets_recover[SP == 0]) 480 | snr_loss2_1 = 0.05 * self.sig_loss(est_targets2_1[SP == 1][:,[0]], targets[SP == 1][:,[0]]) + self.sig_loss(est_targets2_1[SP == 0], targets[SP == 0]) 481 | snr_loss2_2 = 0.05 * self.sig_loss(est_targets2_2[SP == 1][:,[0]], targets[SP == 1][:,[0]]) + self.sig_loss(est_targets2_2[SP == 0], targets[SP == 0]) 482 | 483 | #loss = snr_loss2_1.mean() 484 | #loss_dict = dict(sig_loss=loss, snr_loss1=snr_loss1.mean(), snr_loss2_1=snr_loss2_1.mean()) 485 | loss = (snr_loss2_1.mean() + snr_loss2_2.mean()) / 2 486 | loss_dict = dict(sig_loss=loss, snr_loss1=snr_loss1.mean(), snr_loss2_1=snr_loss2_1.mean(), snr_loss2_2=snr_loss2_2.mean()) 487 | 488 | return loss, loss_dict 489 | -------------------------------------------------------------------------------- /DPARNet/DPARNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # -Dense macs: 2.482 G/s params:147.848 k 4 | # +Dense macs:10.305 G/s params:639.880 k 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from fast_transformers.masking import FullMask, LengthMask, TriangularCausalMask 10 | from fast_transformers.attention import SharedLinearAttention, LinearAttention 11 | 12 | from asteroid.engine.optimizers import make_optimizer 13 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr, pairwise_neg_snr, pairwise_mse 14 | 15 | from feature import FeatureExtractor, AngleFeature 16 | 17 | def make_model_and_optimizer(conf): 18 | model = MISO1_MISO2() 19 | #print(model) 20 | optimizer = make_optimizer(model.parameters(), **conf['optim']) 21 | return model, optimizer 22 | 23 | class MISO1_MISO2(nn.Module): 24 | def __init__(self): 25 | super(MISO1_MISO2, self).__init__() 26 | 27 | self.MISO2 = MISO2( 28 | use_dense = False, 29 | use_att = [True, True, True, True, False], 30 | use_rnn = [True, False, False, False, False], 31 | width = 64, 32 | num_layers = 4, 33 | dropout_rate = 0.4, 34 | causal_conf = False, 35 | ) 36 | 37 | def forward(self, mixture, do_eval = 0): 38 | assert do_eval == 0 or do_eval == 1, "0 - training OR 1 - multi-channel beamforming" 39 | 40 | o_MISO2 = self.MISO2(mixture) 41 | 42 | if do_eval == 0: 43 | return o_MISO2 44 | else: 45 | return o_MISO2[:,:,0], o_MISO2 46 | 47 | class MISO2(nn.Module): 48 | def __init__(self, 49 | use_dense, 50 | use_att, 51 | use_rnn, 52 | num_channels=7, 53 | num_spks=2, 54 | frame_len=512, 55 | frame_hop=128, 56 | width=48, 57 | num_layers=3, 58 | dropout_rate=0.4, 59 | causal_conf=False, 60 | ): 61 | super(MISO2, self).__init__() 62 | 63 | self.use_dense = use_dense 64 | self.use_att = use_att 65 | self.use_rnn = use_rnn 66 | 67 | self.num_channels = num_channels 68 | self.num_spks = num_spks 69 | self.frame_len = frame_len 70 | self.frame_hop = frame_hop 71 | self.width = width 72 | self.num_layers = num_layers 73 | self.dropout_rate = dropout_rate 74 | self.causal_conf = causal_conf 75 | self.num_bins = self.frame_len // 2 + 1 76 | 77 | self.extractor = FeatureExtractor(frame_len=self.frame_len, frame_hop=self.frame_hop, do_ipd=False) 78 | 79 | self.in_Conv = nn.Sequential( 80 | nn.Conv2d(in_channels=self.num_channels * 2 * 2, out_channels=self.width, kernel_size=(1, 1)), 81 | nn.LayerNorm(self.num_bins // 2 + 1), 82 | nn.PReLU(self.width), 83 | ) 84 | 85 | self.in_Conv_att = nn.Sequential(nn.Conv2d(self.width, self.width // 2, kernel_size=(1, 1)), nn.PReLU()) 86 | 87 | self.dualrnn_attention = nn.ModuleList() 88 | for i in range (self.num_layers): 89 | self.dualrnn_attention.append(DualRNN_Attention(dropout_rate=self.dropout_rate, d_model=self.width//2, use_att=self.use_att[i], use_rnn=self.use_rnn[i])) 90 | 91 | self.out_Conv1 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Tanh(),) 92 | self.out_Conv2 = nn.Sequential(nn.Conv2d(in_channels=self.width, out_channels=self.width, kernel_size=(1, 1)), nn.Sigmoid(),) 93 | 94 | self.out_Conv_att = nn.Sequential(nn.Conv2d(self.width // 2, self.width, kernel_size=(1, 1)), nn.PReLU()) 95 | 96 | self.out_Conv = nn.ConvTranspose2d(in_channels=self.width//2, out_channels=self.num_channels * self.num_spks * 2, kernel_size=(1, 1)) 97 | 98 | if (self.use_dense): 99 | self.in_DenseBlock = DenseBlock(init_ch=self.width, g1=64, g2=self.width) 100 | self.out_DenseBlock = DenseBlock(init_ch=self.width, g1=64, g2=self.width) 101 | else: 102 | self.in_DenseBlock = None 103 | self.out_DenseBlock = None 104 | 105 | 106 | def half_reshape(self, x, inverse): 107 | if(inverse): 108 | x = torch.cat([x[:,0:x.shape[1]//2,],x[:,x.shape[1]//2:x.shape[1]//2*2,]],-1) 109 | x = x[...,:-1] 110 | return x 111 | else: 112 | x = F.pad(x, (0,1,0,0), 'replicate') 113 | x = torch.cat([x[:,:,:,0:x.shape[-1]//2],x[:,:,:,x.shape[-1]//2:x.shape[-1]//2*2]],1) 114 | return x 115 | 116 | 117 | def forward(self, mixture): # mixture: [b m n] 118 | 119 | real_spec_src, imag_spec_src = self.extractor.stft(mixture, cplx=True) # [b m f t] 120 | com_spec_src = torch.cat((real_spec_src, imag_spec_src), 1).permute(0,1,3,2) # [b 2m t f=257] 121 | 122 | x = self.half_reshape(com_spec_src, False) # [b 4m t f=129] 123 | 124 | x = self.in_Conv(x) # [b w=64 t f] 125 | if (self.use_dense): 126 | x = self.in_DenseBlock(x) 127 | 128 | x = self.in_Conv_att(x) # [b w=32 t f] 129 | 130 | for i in range (self.num_layers): 131 | x = self.dualrnn_attention[i](x) # [b w t f] 132 | 133 | x = self.out_Conv_att(x) # [b w=64 t f] 134 | 135 | x = self.out_Conv1(x) * self.out_Conv2(x) # [b w=64 t f] 136 | 137 | if (self.use_dense): 138 | x = self.out_DenseBlock(x) 139 | 140 | x = self.half_reshape(x, True) # [b w=32 t f=257] 141 | 142 | cmask_est = self.out_Conv(x).transpose(2,3) # [b 2xsxm f t] 143 | cmask_est = cmask_est.chunk(self.num_spks, dim=1) # [b 2m f t] * 2 144 | 145 | est_sig = [] 146 | for id_spk in range (self.num_spks): 147 | for id_chan in range (self.num_channels): 148 | rmask_est = cmask_est[id_spk][:,0+2*id_chan] # [b f t] 149 | imask_est = cmask_est[id_spk][:,1+2*id_chan] 150 | 151 | real_spec_est = rmask_est * real_spec_src[:,id_chan] - imask_est * imag_spec_src[:,id_chan] 152 | imag_spec_est = rmask_est * imag_spec_src[:,id_chan] + imask_est * real_spec_src[:,id_chan] 153 | est_sig.append(self.extractor.istft(real_spec_est, imag_spec_est, cplx=True)) # [b n] * m * s 154 | 155 | output = torch.stack(est_sig, 1) # [b mxs n] 156 | output = torch.stack(output.chunk(self.num_spks,1), 1) # [b s m n] 157 | output = torch.nn.functional.pad(output,[0,mixture.shape[-1]-output.shape[-1]]) 158 | 159 | return output 160 | 161 | 162 | class DualRNN_Attention(nn.Module): 163 | 164 | def __init__(self, dropout_rate, d_model, nhead=4, use_att=False, use_rnn=False): 165 | super(DualRNN_Attention,self).__init__() 166 | self.dropout_rate = dropout_rate 167 | self.d_model = d_model 168 | self.nhead = nhead 169 | self.use_att = use_att 170 | self.use_rnn = use_rnn 171 | 172 | self.bn = nn.BatchNorm1d(1) 173 | 174 | if (self.use_att): 175 | self.shared_att1 = SharedLinearAttention(self.d_model) 176 | self.linear_att1 = nn.Linear(self.d_model, self.d_model*3) 177 | self.shared_att2 = SharedLinearAttention(self.d_model) 178 | self.linear_att2 = nn.Linear(self.d_model, self.d_model*3) 179 | self.ln_att = nn.LayerNorm(self.d_model) 180 | else: 181 | self.shared_att1 = None 182 | self.linear_att1 = None 183 | self.shared_att2 = None 184 | self.linear_att2 = None 185 | self.ln_att = None 186 | 187 | if (self.use_rnn): 188 | self.rnn1 = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model * 2, num_layers=1, bias=False, bidirectional=True, batch_first=True) 189 | self.rnn2 = nn.LSTM(input_size=self.d_model, hidden_size=self.d_model * 2, num_layers=1, bias=False, bidirectional=True, batch_first=True) 190 | self.dropout = nn.Dropout(p=self.dropout_rate) 191 | self.linear_rnn1 = nn.Linear(in_features = self.d_model * 4, out_features = self.d_model) 192 | self.linear_rnn2 = nn.Linear(in_features = self.d_model * 4, out_features = self.d_model) 193 | #self.relu = nn.ReLU() 194 | self.ln_rnn = nn.LayerNorm(self.d_model) 195 | else: 196 | self.rnn1 = None 197 | self.rnn2 = None 198 | self.dropout = None 199 | self.linear_rnn1 = None 200 | self.linear_rnn2 = None 201 | #self.relu = None 202 | self.ln_rnn = None 203 | 204 | def forward(self,x): # [b w t f=129] 205 | B, W, T, F = x.size() 206 | att_in1 = x.permute(0,2,3,1).contiguous().view(B*T, F, -1) 207 | 208 | if (not self.use_att): 209 | att_out1 = att_in1 210 | else: 211 | q, k, v = self.linear_att1(att_in1).view(B*T, F, self.nhead, -1).chunk(3,-1) 212 | m1, m2, m3 = FullMask(q.shape[1], k.shape[1], device=x.device),FullMask(q.shape[0], q.shape[1], device=x.device),FullMask(k.shape[0], k.shape[1], device=x.device) 213 | att_out1 = self.shared_att1(q, k, v, m1, m2, m3, causal=False).view(B*T, F, -1) 214 | att_out1 = self.ln_att(att_in1 + att_out1) 215 | 216 | rnn_in1 = att_out1 # [bxt f w] 217 | 218 | if (not self.use_rnn): 219 | rnn_out1 = rnn_in1 220 | else: 221 | rnn_out1, _ = self.rnn1(rnn_in1) 222 | rnn_out1 = self.linear_rnn1(self.dropout(rnn_out1)) 223 | rnn_out1 = self.ln_rnn(att_in1 + rnn_out1) 224 | 225 | rnn_out1 = rnn_out1.view(B, T, F, -1).permute(0,3,1,2) # [b w t f] 226 | 227 | rnn_out1 = (self.bn(rnn_out1.reshape(B,1,-1))).reshape(*rnn_out1.shape) # [b w t f] 228 | rnn_out1 = rnn_out1 + x 229 | 230 | att_in2 = rnn_out1.permute(0,3,2,1).contiguous().view(B*F, T, -1) 231 | 232 | if (not self.use_att): 233 | att_out2 = att_in2 234 | else: 235 | q, k, v = self.linear_att2(att_in2).view(B*F, T, self.nhead, -1).chunk(3,-1) 236 | m1, m2, m3 = FullMask(q.shape[1], k.shape[1], device=x.device),FullMask(q.shape[0], q.shape[1], device=x.device),FullMask(k.shape[0], k.shape[1], device=x.device) 237 | att_out2 = self.shared_att2(q, k, v, m1, m2, m3, causal=False).view(B*F, T, -1) 238 | att_out2 = self.ln_att(att_in2 + att_out2) 239 | 240 | rnn_in2 = att_out2 # [bxf t w] 241 | 242 | if (not self.use_rnn): 243 | rnn_out2 = rnn_in2 244 | else: 245 | rnn_out2, _ = self.rnn2(rnn_in2) 246 | rnn_out2 = self.linear_rnn2(self.dropout(rnn_out2)) 247 | rnn_out2 = self.ln_rnn(att_in2 + rnn_out2) 248 | 249 | rnn_out2 = rnn_out2.view(B, F, T, -1).permute(0,3,2,1) # [b w t f] 250 | 251 | rnn_out2 = (self.bn(rnn_out2.reshape(B,1,-1))).reshape(*rnn_out2.shape) # [b w t f] 252 | rnn_out2 = rnn_out2 + rnn_out1 253 | 254 | return rnn_out2 255 | 256 | 257 | class DenseBlock(nn.Module): 258 | 259 | def __init__(self, init_ch, g1, g2): 260 | super(DenseBlock,self).__init__() 261 | 262 | self.conv1 = nn.Sequential( 263 | nn.Conv2d(init_ch, g1, kernel_size=(2,3),stride=(1,1),padding=(1,1)), 264 | nn.ELU(), 265 | nn.InstanceNorm2d(g1,affine=False) 266 | ) 267 | self.conv2 = nn.Sequential( 268 | nn.Conv2d(g1*2, g1, kernel_size=(2,3),stride=(1,1),padding=(1,1)), 269 | nn.ELU(), 270 | nn.InstanceNorm2d(g1,affine=False) 271 | ) 272 | self.conv3 = nn.Sequential( 273 | nn.Conv2d(g1*3, g1, kernel_size=(2,3),stride=(1,1),padding=(1,1)), 274 | nn.ELU(), 275 | nn.InstanceNorm2d(g1,affine=False) 276 | ) 277 | self.conv4 = nn.Sequential( 278 | nn.Conv2d(g1*4, g2, kernel_size=(2,3),stride=(1,1),padding=(1,1)), 279 | nn.ELU(), 280 | nn.InstanceNorm2d(g2,affine=False) 281 | ) 282 | 283 | def forward(self,x): # x [b 64 t f] 284 | 285 | y0 = self.conv1(x)[:,:,:-1] 286 | 287 | y0_x = torch.cat((x,y0),dim=1) 288 | y1 = self.conv2(y0_x)[:,:,:-1] 289 | 290 | y1_0_x = torch.cat((x,y0,y1),dim=1) 291 | y2 = self.conv3(y1_0_x)[:,:,:-1] 292 | 293 | y2_1_0_x = torch.cat((x,y0,y1,y2),dim=1) 294 | y3 = self.conv4(y2_1_0_x)[:,:,:-1] 295 | 296 | return y3 297 | 298 | class com_sisdr_loss2(nn.Module): 299 | def __init__(self): 300 | super().__init__() 301 | self.sig_loss = PITLossWrapper(pairwise_neg_snr, pit_from='pw_mtx') 302 | 303 | def forward(self, est_targets, targets, SP): # est_targets [b s m n] targets [b s m n] SP [b] 304 | B, S, M, N = est_targets.size() 305 | est_targets, targets = est_targets.permute(0,2,1,3).contiguous().view(B*M, S, N), targets.permute(0,2,1,3).contiguous().view(B*M, S, N) 306 | SP = SP.repeat(M,1).transpose(0,1).contiguous().view(B*M) 307 | 308 | if (sum(SP)==0): 309 | sisdr_loss = self.sig_loss(est_targets, targets) 310 | elif (sum(SP)==SP.shape[0]): 311 | sisdr_loss = 0.05 * self.sig_loss(est_targets[:,[0]], targets[:,[0]]) 312 | else: 313 | sisdr_loss = 0.05 * self.sig_loss(est_targets[SP == 1][:,[0]], targets[SP == 1][:,[0]]) + self.sig_loss(est_targets[SP == 0], targets[SP == 0]) 314 | 315 | loss = sisdr_loss.mean() 316 | loss_dict = dict(sig_loss=loss, sisdr_loss=sisdr_loss.mean()) 317 | 318 | return loss, loss_dict 319 | 320 | if __name__ == "__main__": 321 | import torch 322 | from thop import profile 323 | from thop import clever_format 324 | 325 | model = MISO1_MISO2() 326 | #print(model) 327 | mixture = torch.randn(1, 7, 16000) # b m n 328 | macs, params = profile(model, inputs=(mixture)) 329 | macs, params = clever_format([macs, params], "%.3f") 330 | 331 | print('macs:', macs) 332 | print('params:', params) 333 | 334 | -------------------------------------------------------------------------------- /DPARNet/conf.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 50 4 | batch_size: 8 5 | num_workers: 4 6 | half_lr: yes 7 | early_stop: yes 8 | # Optim config 9 | optim: 10 | optimizer: adam 11 | lr: 0.001 12 | weight_decay: 0. 13 | # Data config 14 | data: 15 | train_dir: 16 | valid_dir: 17 | sample_rate: 16000 18 | -------------------------------------------------------------------------------- /DPARNet/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | import os 5 | import soundfile as sf 6 | import math 7 | import random 8 | import shutil 9 | 10 | from sms_wsj.database.create_rirs import config, scenarios, rirs 11 | from sms_wsj.reverb.reverb_utils import convolve 12 | 13 | EPS=1e-8 14 | 15 | def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): 16 | mean = torch.mean(wav_tensor, dim=-1, keepdim=True) 17 | if std is None: 18 | std = wav_tensor.std(-1, keepdim=True) 19 | return (wav_tensor - mean) / (std + eps) 20 | 21 | def rms(y): 22 | return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False)) 23 | 24 | def get_amplitude_scaling_factor(s, n, snr, method='rms'): 25 | original_sn_rms_ratio = rms(s) / rms(n) 26 | target_sn_rms_ratio = 10. ** (float(snr) / 20.) # snr = 20 * lg(rms(s) / rms(n)) 27 | signal_scaling_factor = target_sn_rms_ratio / original_sn_rms_ratio 28 | 29 | class Dataset(data.Dataset): 30 | 31 | def __init__( 32 | self, 33 | reverb_matrixs_dir, 34 | rirNO = 5, 35 | trainingNO = 5000, 36 | segment = 6, 37 | channel = [0,1,2,3,4,5,6], 38 | overlap = [0.1, 0.2, 0.3, 0.4, 0.5], 39 | raw_dir = '/path/to/LibriSpeech/filelist-all/', 40 | noise_dir = '/path/to/noise/', 41 | sample_rate = 16000, 42 | use_aneconic = False, 43 | channel_permute = False, 44 | normalize = False, 45 | ): 46 | super(Dataset, self).__init__() 47 | self.reverb_matrixs_dir = reverb_matrixs_dir 48 | self.rirNO = rirNO 49 | self.trainingNO = trainingNO 50 | self.segment = segment 51 | self.channel = channel 52 | self.overlap = overlap 53 | self.raw_dir = raw_dir 54 | self.noise_list = [os.path.join(noise_dir, f) for f in os.listdir(noise_dir) if '.wav' in f] 55 | self.sample_rate = sample_rate 56 | self.use_aneconic = use_aneconic 57 | self.channel_permute = channel_permute 58 | self.normalize = normalize 59 | 60 | def __len__(self): 61 | return self.trainingNO 62 | 63 | 64 | def add_reverb(self,raw_dir1,raw_dir2,raw_dir3,h_use): 65 | with open(raw_dir1,'r') as fin1: 66 | with open(raw_dir2,'r') as fin2: 67 | with open(raw_dir3,'r') as fin3: 68 | wav1 = fin1.readlines() 69 | wav2 = fin2.readlines() 70 | wav3 = fin3.readlines() 71 | mix_location = np.random.choice(['front','end','both'], size=1, replace=False) 72 | choose_wav = True 73 | while(choose_wav): 74 | i = np.random.randint(0,len(wav1)) 75 | j = np.random.randint(0,len(wav2)) 76 | k = np.random.randint(0,len(wav3)) 77 | w1,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav1[i].rstrip("\n")), dtype="float32") 78 | w2,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav2[j].rstrip("\n")), dtype="float32") 79 | w3,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav3[k].rstrip("\n")), dtype="float32") 80 | 81 | if mix_location == 'front' or mix_location == 'end': 82 | overlap = np.random.choice(self.overlap) 83 | if (overlap == 0.0): 84 | single_speaker = 1 85 | else: 86 | single_speaker = 0 87 | seg_len1 = int(fs * self.segment) 88 | seg_len2 = int(fs * overlap * self.segment) 89 | if (w1.shape[0] > seg_len1 + 1 and w2.shape[0] > seg_len2 + 1): 90 | choose_wav = False 91 | 92 | mix_name = 'overlap' + str(overlap) + '_' + os.path.basename(raw_dir1)[:-4] + '-' + os.path.basename(raw_dir2)[:-4] + '.wav' 93 | 94 | elif mix_location == 'both': 95 | overlap1 = np.random.choice([0.1, 0.2, 0.3, 0.4]) 96 | overlap2 = np.random.choice([0.1, 0.2, 0.3, 0.4]) 97 | seg_len1 = int(fs * self.segment) 98 | seg_len2 = int(fs * overlap1 * self.segment) 99 | seg_len3 = int(fs * overlap2 * self.segment) 100 | single_speaker = 0 101 | if (w1.shape[0] > seg_len1 + 1 and w2.shape[0] > seg_len2 + 1 and w3.shape[0] > seg_len3 + 1): 102 | choose_wav = False 103 | 104 | mix_name='overlap' + str(overlap1) + '_' + str(overlap2) + '_' + os.path.basename(raw_dir1)[:-4] + '-' + os.path.basename(raw_dir2)[:-4] + '.wav' 105 | 106 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len1) 107 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len2) 108 | stop1 = int(rand_start1 + seg_len1) 109 | stop2 = int(rand_start2 + seg_len2) 110 | 111 | if mix_location == 'both': 112 | rand_start3 = np.random.randint(0, w3.shape[0] - seg_len3) 113 | stop3 = int(rand_start3 + seg_len3) 114 | 115 | if (self.use_aneconic): 116 | #print('Using aneconic...') 117 | h_ely = h_use.copy() 118 | for oneid in range(h_ely.shape[0]): 119 | start_inx=(np.abs(h_ely[oneid,0])>(np.abs(h_ely[oneid,0]).max()/10.0)).argmax() 120 | end_inx=start_inx+self.sample_rate//1000*50 121 | h_ely[oneid,:,end_inx:]=0.0 122 | 123 | w1_con = convolve(w1, h_ely[0,:,:]).T 124 | w2_con = convolve(w2, h_ely[1,:,:]).T 125 | w3_con = convolve(w3, h_ely[2,:,:]).T 126 | 127 | else: 128 | w1_con = convolve(w1, h_use[0,:,:]).T 129 | w2_con = convolve(w2, h_use[1,:,:]).T 130 | w3_con = convolve(w3, h_use[2,:,:]).T 131 | 132 | # dynamic SIR 133 | SIR1 = random.uniform(-5,5) 134 | scalar1=get_amplitude_scaling_factor(w1_con, w2_con, snr = SIR1) 135 | w2_con = w2_con / scalar1 136 | 137 | SIR2 = random.uniform(-5,5) 138 | scalar2=get_amplitude_scaling_factor(w1_con, w3_con, snr = SIR2) 139 | w3_con = w3_con / scalar2 140 | 141 | if (mix_location == 'front'): 142 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + seg_len2] + w2_con[rand_start2:stop2], \ 143 | w1_con[rand_start1 + seg_len2:stop1]], axis=0) 144 | 145 | s1_reverb = w1_con[rand_start1:stop1] 146 | s2_reverb = np.concatenate([w2_con[rand_start2:stop2], np.zeros_like(w1_con[rand_start1 + seg_len2:stop1])], axis=0) 147 | 148 | if (mix_location == 'end'): 149 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + seg_len1 - seg_len2], \ 150 | w1_con[rand_start1 + seg_len1 - seg_len2:rand_start1 + seg_len1] + w2_con[rand_start2:stop2]], axis=0) 151 | 152 | s1_reverb = w1_con[rand_start1:stop1] 153 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:rand_start1 + seg_len1 - seg_len2]), w2_con[rand_start2:stop2]], axis=0) 154 | 155 | if (mix_location == 'both'): 156 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + seg_len2] + w2_con[rand_start2:stop2], \ 157 | w1_con[rand_start1 + seg_len2:stop1 - seg_len3], \ 158 | w1_con[stop1 - seg_len3:stop1] + w3_con[rand_start3:stop3]], axis=0) 159 | s1_reverb = w1_con[rand_start1:stop1] 160 | s2_reverb = np.concatenate([w2_con[rand_start2:stop2], \ 161 | np.zeros_like(w1_con[rand_start1 + seg_len2:stop1 - seg_len3]), \ 162 | w3_con[rand_start3:stop3]], axis=0) 163 | 164 | return mix_reverb, s1_reverb, s2_reverb, mix_name, single_speaker 165 | 166 | def add_noise(self, mix_reverb): 167 | # dynamic SNR 168 | SNR = random.uniform(5,20) 169 | if(random.uniform(0,1)<0.1): 170 | w_n = np.random.randn(*mix_reverb.shape) 171 | else: 172 | w_n = sf.read(random.choice(self.noise_list), dtype="float32")[0] 173 | start_inx = random.randint(0,w_n.shape[0]-mix_reverb.shape[0]-1) 174 | w_n = w_n[start_inx:start_inx+mix_reverb.shape[0],0:mix_reverb.shape[-1]] 175 | scalar = get_amplitude_scaling_factor(mix_reverb[:,0], w_n[:,0], snr = SNR) 176 | 177 | mix_noise = mix_reverb + w_n / scalar 178 | return mix_noise 179 | 180 | def __getitem__(self,idx): 181 | raw_list = os.listdir(self.raw_dir) 182 | SpeakerNo = len(raw_list) 183 | 184 | speaker1 = np.random.randint(0,SpeakerNo) 185 | speaker2 = np.random.randint(0,SpeakerNo) 186 | speaker3 = np.random.randint(0,SpeakerNo) 187 | while (speaker1 == speaker2): 188 | speaker2 = np.random.randint(0,SpeakerNo) 189 | while (speaker3 == speaker1 or speaker3 == speaker2): 190 | speaker3 = np.random.randint(0,SpeakerNo) 191 | raw_dir1 = self.raw_dir+raw_list[speaker1] 192 | raw_dir2 = self.raw_dir+raw_list[speaker2] 193 | raw_dir3 = self.raw_dir+raw_list[speaker3] 194 | 195 | choose_rir = np.random.randint(0,self.rirNO) 196 | rand_rir = np.load(self.reverb_matrixs_dir + str(choose_rir).zfill(5) + '.npz') 197 | h_use, _source_positions, _sensor_positions = rand_rir['h'], rand_rir['source_positions'], rand_rir['sensor_positions'] 198 | 199 | # step1:add reverb to utterance 200 | mix_reverb, s1_reverb, s2_reverb, mix_name, single_speaker = self.add_reverb(raw_dir1,raw_dir2,raw_dir3,h_use) 201 | 202 | # step2:add noise 203 | mix_noise = self.add_noise(mix_reverb) 204 | mix_noise = mix_noise.transpose()[self.channel] 205 | 206 | # choose reference channel 207 | source_arrays = [] 208 | if (self.channel_permute): 209 | #print('Using channel permutation...') 210 | ref_channel = np.random.randint(0, len(self.channel)) 211 | # s1_reverb [n c] 212 | source_arrays.append(np.concatenate((s1_reverb.T[ref_channel:], s1_reverb.T[:ref_channel]), axis=0)) 213 | source_arrays.append(np.concatenate((s2_reverb.T[ref_channel:], s2_reverb.T[:ref_channel]), axis=0)) 214 | mixture = np.concatenate((mix_noise[ref_channel:], mix_noise[:ref_channel]), axis=0) 215 | 216 | else: 217 | source_arrays.append(s1_reverb.T[self.channel]) 218 | source_arrays.append(s2_reverb.T[self.channel]) 219 | mixture = mix_noise 220 | 221 | # [s c n] 222 | sources = torch.from_numpy(np.stack(source_arrays, axis=0).astype(np.float32)) 223 | # [c n] 224 | mixture = torch.from_numpy(np.array(mixture).astype(np.float32)) 225 | 226 | # 2022.04.06 227 | # normalization 228 | if (self.normalize): 229 | print('Using normalization...') 230 | # [c n] 231 | m_std = mixture.std(-1, keepdim=True) 232 | # [c n] 233 | mixture = normalize_tensor_wav(mixture, eps=EPS, std=m_std) 234 | # [s n] 235 | sources = normalize_tensor_wav(sources, eps=EPS, std=m_std[[ref_channel]]) 236 | 237 | # mixture [c n] sources [s c n] 238 | return mixture, sources, single_speaker 239 | 240 | 241 | if __name__ == "__main__": 242 | from tqdm import tqdm 243 | 244 | base_dir = 'path/to/testset' 245 | mix_reverb = os.path.join(base_dir, 'mix_reverb') 246 | s1_reverb = os.path.join(base_dir, 's1_reverb') 247 | s2_reverb = os.path.join(base_dir, 's2_reverb') 248 | 249 | dir_list = [mix_reverb, s1_reverb, s2_reverb,] 250 | for item in dir_list: 251 | try: 252 | os.makedirs(item) 253 | except OSError: 254 | pass 255 | 256 | d = Dataset( 257 | reverb_matrixs_dir = '/path/to/reverb-set/', 258 | rirNO = 10000, 259 | trainingNO = 1, 260 | segment = 6, 261 | channel = [0,1,2,3,4,5,6], 262 | ) 263 | 264 | pbar = tqdm(range(10)) 265 | for i in pbar: 266 | mix, src, _, _ = d[i] 267 | sf.write(os.path.join(mix_reverb,'{}.wav'.format(i)),mix[0].numpy(),16000) 268 | sf.write(os.path.join(s1_reverb,'{}.wav'.format(i)),src[0,0,:].numpy().transpose(),16000) 269 | sf.write(os.path.join(s2_reverb,'{}.wav'.format(i)),src[1,0,:].numpy().transpose(),16000) 270 | 271 | print('Done.') 272 | -------------------------------------------------------------------------------- /DPARNet/dataset_css.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import json 4 | import os 5 | import numpy as np 6 | import soundfile as sf 7 | 8 | EPS = 1e-8 9 | 10 | DATASET = "xiandao2020" 11 | sep_clean = {"mixture": "mix_noise", "infos": [], "default_nsrc": 2} 12 | 13 | xiandao2020_TASKS = {"sep_clean": sep_clean,} 14 | 15 | def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): 16 | mean = wav_tensor.mean(-1, keepdim=True) 17 | if std is None: 18 | std = wav_tensor.std(-1, keepdim=True) 19 | #return (wav_tensor - mean) / (std + eps) 20 | wav_tensor =wav_tensor - mean 21 | return wav_tensor / (wav_tensor.max() + EPS) 22 | 23 | 24 | class XiandaoDataset(data.Dataset): 25 | 26 | dataset_name = "xiandao2020" 27 | 28 | def __init__( 29 | self, 30 | json_dir, 31 | task, 32 | sample_rate=16000, 33 | segment=5, 34 | nondefault_nsrc=None, 35 | normalize_audio=False, 36 | channel = [0] 37 | ): 38 | super(XiandaoDataset, self).__init__() 39 | if task not in xiandao2020_TASKS.keys(): 40 | raise ValueError( 41 | "Unexpected task {}, expected one of " "{}".format(task, xiandao2020_TASKS.keys()) 42 | ) 43 | # Task setting 44 | self.json_dir = json_dir 45 | self.task = task 46 | self.task_dict = xiandao2020_TASKS[task] 47 | self.sample_rate = sample_rate 48 | self.normalize_audio = normalize_audio 49 | self.seg_len = None if segment is None else int(segment * sample_rate) 50 | self.channel = channel 51 | if not nondefault_nsrc: 52 | self.n_src = self.task_dict["default_nsrc"] 53 | else: 54 | assert nondefault_nsrc >= self.task_dict["default_nsrc"] 55 | self.n_src = nondefault_nsrc 56 | self.like_test = self.seg_len is None 57 | # Load json files 58 | mix_json = os.path.join(json_dir, self.task_dict["mixture"] + ".json") 59 | 60 | with open(mix_json, "r") as f: 61 | mix_infos = json.load(f) 62 | 63 | # Filter out short utterances only when segment is specified 64 | orig_len = len(mix_infos) 65 | drop_utt, drop_len = 0, 0 66 | if not self.like_test: 67 | for i in range(len(mix_infos) - 1, -1, -1): # Go backward 68 | if mix_infos[i][1] < self.seg_len: 69 | drop_utt += 1 70 | drop_len += mix_infos[i][1] 71 | del mix_infos[i] 72 | 73 | print( 74 | "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format( 75 | drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len 76 | ) 77 | ) 78 | self.mix = mix_infos 79 | # Handle the case n_src > default_nsrc 80 | 81 | def __add__(self, xiandao): 82 | if self.n_src != xiandao.n_src: 83 | raise ValueError( 84 | "Only datasets having the same number of sources" 85 | "can be added together. Received " 86 | "{} and {}".format(self.n_src, xiandao.n_src) 87 | ) 88 | if self.seg_len != xiandao.seg_len: 89 | self.seg_len = min(self.seg_len, xiandao.seg_len) 90 | print( 91 | "Segment length mismatched between the two Dataset" 92 | "passed one the smallest to the sum." 93 | ) 94 | 95 | self.mix = self.mix + xiandao.mix 96 | 97 | def __len__(self): 98 | return len(self.mix) 99 | 100 | def __getitem__(self, idx): 101 | 102 | # Random start 103 | if self.mix[idx][1] == self.seg_len or self.like_test: 104 | rand_start = 0 105 | else: 106 | rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len) 107 | if self.like_test: 108 | stop = None 109 | else: 110 | stop = rand_start + self.seg_len 111 | 112 | # Load mixture 113 | x, _ = sf.read(self.mix[idx][0], start=rand_start, stop=stop, dtype="float32") 114 | x = x.transpose()[self.channel] 115 | 116 | base_name = os.path.basename(self.mix[idx][0]) 117 | dir_name = os.path.basename(os.path.dirname(self.mix[idx][0])) 118 | name = dir_name + '/' + base_name 119 | seg_len = torch.as_tensor([len(x)]) 120 | 121 | mixture = torch.from_numpy(np.array(x).astype(np.float32)).permute(1,0) 122 | 123 | if self.normalize_audio: 124 | m_std = mixture.std(-1, keepdim=True) 125 | mixture = normalize_tensor_wav(mixture, eps=EPS, std=m_std) 126 | 127 | return mixture, name 128 | -------------------------------------------------------------------------------- /DPARNet/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import soundfile as sf 4 | import torch 5 | import yaml 6 | import json 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | import math 11 | from tqdm import tqdm 12 | from pprint import pprint 13 | 14 | from asteroid import torch_utils 15 | from asteroid.metrics import get_metrics 16 | from asteroid.losses import pairwise_neg_sisdr 17 | from asteroid.losses.pit_wrapper import PITLossWrapper 18 | from DPARNet import make_model_and_optimizer 19 | from asteroid.utils import tensors_to_device 20 | from mvdr_util import MVDR 21 | 22 | from dataset_css import XiandaoDataset 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--normalize", type=int, required=True, help="") 26 | parser.add_argument("--test_dir_simu", type=str, required=True, help="Test directory") 27 | parser.add_argument("--test_dir_css", type=str, required=True, help="Test directory") 28 | parser.add_argument("--save_wav_simu", type=int, default=0, help="Whether to save wav files") 29 | parser.add_argument("--save_wav_css", type=int, default=0, help="Whether to save wav files") 30 | parser.add_argument("--save_dir_simu", type=str, required=True, help="Output directory") 31 | parser.add_argument("--save_dir_css", type=str, required=True, help="Output directory") 32 | parser.add_argument("--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution") 33 | parser.add_argument("--do_mvdr", type=int, default=0, help="Whether to use mvdr") 34 | parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") 35 | 36 | compute_metrics = ["si_sdr"] 37 | 38 | def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None): 39 | mean = torch.mean(wav_tensor, dim=-1, keepdim=True) 40 | if std is None: 41 | std = wav_tensor.std(-1, keepdim=True) 42 | return (wav_tensor - mean) / (std + eps) 43 | 44 | 45 | def load_best_model(model, exp_dir): 46 | # Create the model from recipe-local function 47 | try: 48 | # Last best model summary 49 | with open(os.path.join(exp_dir, 'best_k_models.json'), "r") as f: 50 | best_k = json.load(f) 51 | best_model_path = min(best_k, key=best_k.get) 52 | except FileNotFoundError: 53 | # Get last checkpoint 54 | all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/')) 55 | all_ckpt=[(ckpt,int("".join(filter(str.isdigit,ckpt)))) for ckpt in all_ckpt] 56 | all_ckpt.sort(key=lambda x:x[1]) 57 | best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1][0]) 58 | print( 'LOADING from ',best_model_path) 59 | # Load checkpoint 60 | checkpoint = torch.load(best_model_path, map_location='cpu') 61 | for k in list(checkpoint['state_dict'].keys()): 62 | if('loss_func' in k): 63 | del checkpoint['state_dict'][k] 64 | # Load state_dict into model. 65 | model = torch_utils.load_state_dict_in(checkpoint['state_dict'], model) 66 | model = model.eval() 67 | return model 68 | 69 | class sisdr_loss(torch.nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | self.sig_loss = PITLossWrapper(pairwise_neg_sisdr, pit_from='pw_mtx') 73 | 74 | def forward(self, est_targets, targets): 75 | sig_loss, reordered_sources = self.sig_loss(est_targets, targets, return_est=True) 76 | 77 | return sig_loss.mean(), reordered_sources 78 | 79 | def main1(model, conf): 80 | causal = False 81 | mvdr = MVDR(causal) 82 | if conf["use_gpu"]: 83 | model.cuda() 84 | mvdr.cuda() 85 | model_device = next(model.parameters()).device 86 | 87 | normalize = conf['normalize'] 88 | test_dir_simu = conf['test_dir_simu'] 89 | save_dir_simu = conf['save_dir_simu'] 90 | dlist = os.listdir(test_dir_simu) 91 | pbar = tqdm(range(len(dlist))) 92 | series_list = [] 93 | torch.no_grad().__enter__() 94 | for idx in pbar: 95 | test_wav = np.load(test_dir_simu + dlist[idx]) 96 | mix, sources, name, single_speaker = tensors_to_device([torch.from_numpy(test_wav['mix']), torch.from_numpy(test_wav['src']), \ 97 | str(test_wav['n']), test_wav['single_speaker']], device=model_device) 98 | 99 | mix = mix.permute(1,0) # [m n] 100 | 101 | if (normalize): 102 | m_std = mix.std(1, keepdim=True) 103 | mix = normalize_tensor_wav(mix, eps=1e-8, std=m_std) 104 | sources = normalize_tensor_wav(sources, eps=1e-8, std=m_std[[0]]) # [s n] 105 | 106 | est_sources_7ch = model(mix[None]) # [b s m n] 107 | est_sources = est_sources_7ch[:,:,0] # [b s n] 108 | 109 | loss, reordered_sources = sisdr_loss()(est_sources, sources[None]) 110 | 111 | sources_np = sources.cpu().data.numpy() 112 | 113 | if conf["do_mvdr"]: 114 | est_sources = mvdr(mix[None], est_sources_7ch) # b s m n 115 | est_sources_np = est_sources.squeeze(0).cpu().data.numpy() # s m n 116 | est_sources_np = est_sources_np[:,0] 117 | else: 118 | est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy() 119 | 120 | if (single_speaker == 1): 121 | sources_np = sources_np[[0],:] 122 | est_sources_np = est_sources_np[[0],:] 123 | 124 | mix = mix[0,:] 125 | mix_np = mix[None].cpu().data.numpy() 126 | 127 | # save wave 128 | if not os.path.exists(os.path.join(save_dir_simu, name.split('_')[0])): 129 | os.makedirs(os.path.join(save_dir_simu, name.split('_')[0])) 130 | if idx<1000: 131 | est_s1 = est_sources_np[0] 132 | est_s1 *= np.max(np.abs(mix_np.squeeze()))/np.max(np.abs(est_s1)) 133 | est_sources_np *= np.max(np.abs(mix_np.squeeze()))/np.max(np.abs(est_sources_np)) 134 | sf.write(os.path.join(save_dir_simu, name.split('_')[0], name[:-4]+'_0.wav'), est_s1.squeeze() / est_s1.max(), conf["sample_rate"]) 135 | est_s2 = est_sources_np[1] 136 | est_s2 *= np.max(np.abs(mix_np.squeeze()))/np.max(np.abs(est_s2)) 137 | sf.write(os.path.join(save_dir_simu, name.split('_')[0], name[:-4]+'_1.wav'), est_s2.squeeze() / est_s2.max(), conf["sample_rate"]) 138 | 139 | utt_metrics = get_metrics( 140 | mix_np, 141 | sources_np, 142 | est_sources_np, 143 | sample_rate=conf["sample_rate"], 144 | metrics_list=compute_metrics, 145 | ) 146 | utt_metrics["mix_path"] = name 147 | series_list.append(pd.Series(utt_metrics)) 148 | pbar.set_description("si_sdr : {}".format(pd.DataFrame(series_list)['si_sdr'].mean())) 149 | # Save all metrics to the experiment folder. 150 | all_metrics_df = pd.DataFrame(series_list) 151 | all_metrics_df.to_csv(os.path.join(conf["exp_dir"], "all_metrics_permute.csv")) 152 | 153 | # Print and save summary metrics 154 | final_results = {} 155 | for metric_name in compute_metrics: 156 | input_metric_name = "input_" + metric_name 157 | ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] 158 | final_results[metric_name] = all_metrics_df[metric_name].mean() 159 | final_results[metric_name + "_imp"] = ldf.mean() 160 | print("Overall metrics :") 161 | pprint(final_results) 162 | 163 | 164 | def main2(model, conf): 165 | causal = False 166 | mvdr = MVDR(causal) 167 | if conf["use_gpu"]: 168 | model.cuda() 169 | mvdr.cuda() 170 | model_device = next(model.parameters()).device 171 | normalize = conf['normalize'] 172 | test_dir_css = conf['test_dir_css'] 173 | save_dir_css = conf['save_dir_css'] 174 | test_set = XiandaoDataset( 175 | conf["test_dir_css"], 176 | 'sep_clean', 177 | sample_rate=conf["sample_rate"], 178 | nondefault_nsrc=2, 179 | segment=None, 180 | channel=[0,1,2,3,4,5,6] 181 | ) # Uses all segment length 182 | 183 | series_list = [] 184 | torch.no_grad().__enter__() 185 | for idx in tqdm(range(len(test_set))): 186 | # Forward the network on the mixture. 187 | mix, name = tensors_to_device(test_set[idx], device=model_device) 188 | name = name[:-4] 189 | 190 | # normalization 191 | mix = mix.permute(1,0) # [m n] 192 | if (normalize): 193 | m_std = mix.std(1, keepdim=True) 194 | mix = normalize_tensor_wav(mix, eps=1e-8, std=m_std) 195 | 196 | est_sources, est_sources_7ch = model(mix[None], do_eval=1) 197 | 198 | mix_7ch = mix 199 | if (normalize): 200 | mix = mix[0,:] 201 | else: 202 | mix = mix[:,0] 203 | 204 | mix_np = mix[None].cpu().data.numpy() 205 | 206 | if conf["do_mvdr"]: 207 | est_sources = mvdr(mix_7ch[None], est_sources_7ch) # b s c n 208 | est_sources_np = est_sources.squeeze(0).cpu().data.numpy() 209 | est_sources_s1_np = est_sources_np[0,0] 210 | est_sources_s2_np = est_sources_np[1,0] 211 | else: 212 | est_sources_np = est_sources.squeeze(0).cpu().data.numpy() 213 | est_sources_s1_np = est_sources_np[0] 214 | est_sources_s2_np = est_sources_np[1] 215 | 216 | # save wave 217 | est_sources_np = est_sources.squeeze(0).cpu().data.numpy() 218 | est_wav_s1 = est_sources_s1_np * np.max(np.abs(mix_np.squeeze()))/np.max(np.abs(est_sources_s1_np)) 219 | est_wav_s2 = est_sources_s2_np * np.max(np.abs(mix_np.squeeze()))/np.max(np.abs(est_sources_s2_np)) 220 | if not os.path.exists(save_dir_css + os.path.dirname(name)): 221 | os.makedirs(save_dir_css + os.path.dirname(name)) 222 | sf.write(save_dir_css+name+'_0.wav', est_wav_s1.squeeze() / est_wav_s1.max(), conf["sample_rate"]) 223 | sf.write(save_dir_css+name+'_1.wav', est_wav_s2.squeeze() / est_wav_s2.max(), conf["sample_rate"]) 224 | 225 | 226 | def main(conf): 227 | model, _ = make_model_and_optimizer(train_conf) 228 | model = load_best_model(model, conf['exp_dir']) 229 | save_dir_simu = conf['save_dir_simu'] 230 | save_dir_css = conf['save_dir_css'] 231 | 232 | if (conf['save_wav_simu']): 233 | main1(model, conf) 234 | if (conf['save_wav_css']): 235 | main2(model, conf) 236 | 237 | 238 | if __name__ == "__main__": 239 | args = parser.parse_args() 240 | arg_dic = dict(vars(args)) 241 | 242 | # Load training config 243 | conf_path = os.path.join(args.exp_dir, "conf.yml") 244 | with open(conf_path) as f: 245 | train_conf = yaml.safe_load(f) 246 | arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] 247 | arg_dic["train_conf"] = train_conf 248 | 249 | main(arg_dic) 250 | -------------------------------------------------------------------------------- /DPARNet/feature.py: -------------------------------------------------------------------------------- 1 | # This code use for reference https://github.com/Sanyuan-Chen/CSS_with_Conformer/blob/master/executor/feature.py 2 | # Function: Implementation of front-end feature via PyTorch 3 | 4 | import torch 5 | import torch as th 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import math 10 | from collections.abc import Sequence 11 | MATH_PI = math.pi 12 | 13 | class FeatureExtractor(nn.Module): 14 | """ 15 | A PyTorch module to handle spectral & spatial features 16 | """ 17 | def __init__(self, 18 | frame_len=512, 19 | frame_hop=128, 20 | normalize=True, 21 | round_pow_of_two=True, 22 | window="sqrt_hann", 23 | log_spectrogram=True, 24 | mvn_spectrogram=True, 25 | ipd_mean_normalize=True, 26 | ipd_mean_normalize_version=2, 27 | ipd_cos=True, 28 | ipd_sin=False, 29 | ipd_index="1,0;2,0;3,0;4,0;5,0;6,0", 30 | ang_index="1,0;2,0;3,0;4,0;5,0;6,0", 31 | do_ipd=False, 32 | ): 33 | super(FeatureExtractor, self).__init__() 34 | # forward STFT 35 | self.forward_stft = STFT(frame_len, frame_hop, normalize=normalize, window=window, round_pow_of_two=round_pow_of_two) 36 | self.inverse_stft = iSTFT(frame_len, frame_hop, normalize=normalize, round_pow_of_two=round_pow_of_two) 37 | # BN or not 38 | self.mvn_mag = mvn_spectrogram 39 | # apply log or not 40 | self.log_mag = log_spectrogram 41 | 42 | # IPD or not 43 | self.do_ipd = do_ipd 44 | self.ipd_extractor = IPDFeature(ipd_index, cos=ipd_cos, sin=ipd_sin, ipd_mean_normalize_version=ipd_mean_normalize_version, ipd_mean_normalize=ipd_mean_normalize) 45 | self.ang_extractor = AngleFeature(num_bins=257, num_doas=1, af_index=ang_index) 46 | 47 | def stft(self, x, cplx=False): 48 | return self.forward_stft(x, cplx=cplx) 49 | 50 | def istft(self, m, p, cplx=False): 51 | return self.inverse_stft(m, p, cplx=cplx) 52 | 53 | def compute_spectra(self, x): 54 | """ 55 | Compute spectra features 56 | args 57 | x: b x c x n (multi-channel) or b x 1 x n (single channel) 58 | return: 59 | mag & pha: b x f x t or b x c x f x t 60 | feature: b x * x t 61 | """ 62 | mag, pha = self.forward_stft(x) 63 | # ch0: N x F x T 64 | if mag.dim() == 4: 65 | f = th.clamp(mag[:, 0], min=1e-8) 66 | else: 67 | f = th.clamp(mag, min=1e-8) 68 | # log 69 | if self.log_mag: 70 | f = th.log(f) 71 | # mvn 72 | if self.mvn_mag: 73 | f = (f - f.mean(-1, keepdim=True)) / (f.std(-1, keepdim=True) + 1e-8) 74 | 75 | if mag.dim() == 4: 76 | f = f.reshape(mag.size()[0], -1, mag.size()[-1]) 77 | 78 | return mag, pha, f 79 | 80 | 81 | def compute_spatial(self, pha, doa=None): 82 | """ 83 | Compute spatial features 84 | args 85 | pha: b x c x f x t 86 | return 87 | feature: b x * x t 88 | """ 89 | ipd = self.ipd_extractor(pha) 90 | if (doa.size() != torch.Size([])): 91 | doa = self.ang_extractor(pha,doa) 92 | return torch.cat((ipd, doa), 1) 93 | else: 94 | return ipd 95 | 96 | def forward(self, x, doa=None): 97 | """ 98 | args 99 | x: b x c x n (multi-channel) or b x 1 x n (single channel) 100 | return: 101 | mag & pha: b x f x t (if ref_channel is not None), b x c x f x t 102 | feature: b x * x t 103 | """ 104 | mag, pha, f = self.compute_spectra(x) 105 | feature = [f] 106 | if (self.do_ipd): 107 | spatial = self.compute_spatial(pha=pha, doa=doa) 108 | feature.append(spatial) 109 | # b x * x t 110 | feature = th.cat(feature, 1) 111 | 112 | return mag, pha, feature 113 | 114 | def init_kernel(frame_len, 115 | frame_hop, 116 | normalize=True, 117 | round_pow_of_two=True, 118 | window="sqrt_hann"): 119 | if window != "sqrt_hann" and window != "hann": 120 | raise RuntimeError("Now only support sqrt hanning window or hann window") 121 | # FFT points 122 | N = 2**math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len 123 | # window 124 | W = th.hann_window(frame_len) 125 | if window == "sqrt_hann": 126 | W = W**0.5 127 | # scale factor to make same magnitude after iSTFT 128 | if window == "sqrt_hann" and normalize: 129 | S = 0.5 * (N * N / frame_hop)**0.5 130 | else: 131 | S = 1 132 | # F x N/2+1 x 2 133 | K = th.rfft(th.eye(N) / S, 1)[:frame_len] 134 | # 2 x N/2+1 x F 135 | K = th.transpose(K, 0, 2) * W 136 | # N+2 x 1 x F 137 | K = th.reshape(K, (N + 2, 1, frame_len)) 138 | return K 139 | 140 | class STFTBase(nn.Module): 141 | """ 142 | Base layer for (i)STFT 143 | NOTE: 144 | 1) Recommend sqrt_hann window with 2**N frame length, because it 145 | could achieve perfect reconstruction after overlap-add 146 | 2) Now haven't consider padding problems yet 147 | """ 148 | def __init__(self, 149 | frame_len, 150 | frame_hop, 151 | window="sqrt_hann", 152 | normalize=True, 153 | round_pow_of_two=True): 154 | super(STFTBase, self).__init__() 155 | K = init_kernel(frame_len, 156 | frame_hop, 157 | round_pow_of_two=round_pow_of_two, 158 | window=window) 159 | self.K = nn.Parameter(K, requires_grad=False) 160 | self.stride = frame_hop 161 | self.window = window 162 | self.normalize = normalize 163 | self.num_bins = self.K.shape[0] // 2 164 | if window == "hann": 165 | self.conjugate = True 166 | else: 167 | self.conjugate = False 168 | 169 | def extra_repr(self): 170 | return (f"window={self.window}, stride={self.stride}, " + 171 | f"kernel_size={self.K.shape[0]}x{self.K.shape[2]}, " + 172 | f"normalize={self.normalize}") 173 | 174 | class STFT(STFTBase): 175 | """ 176 | Short-time Fourier Transform as a Layer 177 | """ 178 | def __init__(self, *args, **kwargs): 179 | super(STFT, self).__init__(*args, **kwargs) 180 | 181 | def forward(self, x, cplx=False): 182 | """ 183 | Accept (single or multiple channel) raw waveform and output magnitude and phase 184 | args 185 | x: input signal, N x C x S or N x S 186 | return 187 | m: magnitude, N x C x F x T or N x F x T 188 | p: phase, N x C x F x T or N x F x T 189 | """ 190 | if x.dim() not in [2, 3]: 191 | raise RuntimeError( 192 | "{} expect 2D/3D tensor, but got {:d}D signal".format( 193 | self.__name__, x.dim())) 194 | # if N x S, reshape N x 1 x S 195 | if x.dim() == 2: 196 | x = th.unsqueeze(x, 1) 197 | # N x 2F x T 198 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 199 | # N x F x T 200 | r, i = th.chunk(c, 2, dim=1) 201 | if self.conjugate: 202 | # to match with science pipeline, we need to do conjugate 203 | i = -i 204 | # else reshape NC x 1 x S 205 | else: 206 | N, C, S = x.shape 207 | x = x.contiguous().view(N * C, 1, S) 208 | # NC x 2F x T 209 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 210 | # N x C x 2F x T 211 | c = c.view(N, C, -1, c.shape[-1]) 212 | # N x C x F x T 213 | r, i = th.chunk(c, 2, dim=2) 214 | if self.conjugate: 215 | # to match with science pipeline, we need to do conjugate 216 | i = -i 217 | if cplx: 218 | return r, i 219 | m = (r**2 + i**2)**0.5 220 | p = th.atan2(i, r) 221 | return m, p 222 | 223 | 224 | class iSTFT(STFTBase): 225 | """ 226 | Inverse Short-time Fourier Transform as a Layer 227 | """ 228 | def __init__(self, *args, **kwargs): 229 | super(iSTFT, self).__init__(*args, **kwargs) 230 | 231 | def forward(self, m, p, cplx=False, squeeze=False): 232 | """ 233 | Accept phase & magnitude and output raw waveform 234 | args 235 | m, p: N x F x T 236 | return 237 | s: N x S 238 | """ 239 | if p.dim() != m.dim() or p.dim() not in [2, 3]: 240 | raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format( 241 | p.dim())) 242 | # if F x T, reshape 1 x F x T 243 | if p.dim() == 2: 244 | p = th.unsqueeze(p, 0) 245 | m = th.unsqueeze(m, 0) 246 | if cplx: 247 | # N x 2F x T 248 | c = th.cat([m, p], dim=1) 249 | else: 250 | r = m * th.cos(p) 251 | i = m * th.sin(p) 252 | # N x 2F x T 253 | c = th.cat([r, i], dim=1) 254 | # N x 2F x T 255 | s = F.conv_transpose1d(c, self.K, stride=self.stride, padding=0) 256 | # N x S 257 | s = s.squeeze(1) 258 | if squeeze: 259 | s = th.squeeze(s) 260 | return s 261 | 262 | 263 | class IPDFeature(nn.Module): 264 | """ 265 | Compute inter-channel phase difference 266 | """ 267 | def __init__(self, 268 | ipd_index="1,0;2,0;3,0;4,0;5,0;6,0", 269 | cos=True, 270 | sin=False, 271 | ipd_mean_normalize_version=2, 272 | ipd_mean_normalize=True): 273 | super(IPDFeature, self).__init__() 274 | split_index = lambda sstr: [ 275 | tuple(map(int, p.split(","))) for p in sstr.split(";") 276 | ] 277 | # ipd index 278 | pair = split_index(ipd_index) 279 | self.index_l = [t[0] for t in pair] 280 | self.index_r = [t[1] for t in pair] 281 | self.ipd_index = ipd_index 282 | self.cos = cos 283 | self.sin = sin 284 | self.ipd_mean_normalize=ipd_mean_normalize 285 | self.ipd_mean_normalize_version=ipd_mean_normalize_version 286 | self.num_pairs = len(pair) * 2 if cos and sin else len(pair) 287 | 288 | def extra_repr(self): 289 | return f"ipd_index={self.ipd_index}, cos={self.cos}, sin={self.sin}" 290 | 291 | def forward(self, p): 292 | """ 293 | Accept multi-channel phase and output inter-channel phase difference 294 | args 295 | p: phase matrix, N x C x F x T 296 | return 297 | ipd: N x MF x T 298 | """ 299 | if p.dim() not in [3, 4]: 300 | raise RuntimeError( 301 | "{} expect 3/4D tensor, but got {:d} instead".format( 302 | self.__name__, p.dim())) 303 | # C x F x T => 1 x C x F x T 304 | if p.dim() == 3: 305 | p = p.unsqueeze(0) 306 | N, _, _, T = p.shape 307 | pha_dif = p[:, self.index_l] - p[:, self.index_r] 308 | if self.ipd_mean_normalize: 309 | yr = th.cos(pha_dif) 310 | yi = th.sin(pha_dif) 311 | yrm = yr.mean(-1, keepdim=True) 312 | yim = yi.mean(-1, keepdim=True) 313 | if self.ipd_mean_normalize_version == 1: 314 | pha_dif = th.atan2(yi - yim, yr - yrm) 315 | elif self.ipd_mean_normalize_version == 2: 316 | pha_dif_mean = th.atan2(yim, yrm) 317 | pha_dif -= pha_dif_mean 318 | elif self.ipd_mean_normalize_version == 3: 319 | pha_dif_mean = pha_dif.mean(-1, keepdim=True) 320 | pha_dif -= pha_dif_mean 321 | else: 322 | # we only support version 1, 2 and 3 323 | raise RuntimeError( 324 | "{} expect ipd_mean_normalization version 1 or version 2, but got {:d} instead".format( 325 | self.__name__, self.ipd_mean_normalize_version)) 326 | 327 | if self.cos: 328 | # N x M x F x T 329 | ipd = th.cos(pha_dif) 330 | if self.sin: 331 | # N x M x 2F x T, along frequency axis 332 | ipd = th.cat([ipd, th.sin(pha_dif)], 2) 333 | else: 334 | # th.fmod behaves differently from np.mod for the input that is less than -math.pi 335 | # i believe it is a bug 336 | # so we need to ensure it is larger than -math.pi by adding an extra 6 * math.pi 337 | #ipd = th.fmod(pha_dif + math.pi, 2 * math.pi) - math.pi 338 | ipd = pha_dif 339 | # N x MF x T 340 | ipd = ipd.view(N, -1, T) 341 | # N x MF x T 342 | return ipd 343 | 344 | 345 | class AngleFeature(nn.Module): 346 | """ 347 | Compute angle/directional feature 348 | 1) num_doas == 1: we known the DoA of the target speaker 349 | 2) num_doas != 1: we do not have that prior, so we sampled #num_doas DoAs 350 | and compute on each directions 351 | """ 352 | def __init__(self, 353 | geometric="princeton", 354 | sr=16000, 355 | velocity=340, 356 | num_bins=257, 357 | num_doas=1, 358 | af_index="1,0;2,0;3,0;4,0;5,0;6,0"): 359 | super(AngleFeature, self).__init__() 360 | if geometric not in ["princeton"]: 361 | raise RuntimeError( 362 | "Unsupported array geometric: {}".format(geometric)) 363 | self.geometric = geometric 364 | self.sr = sr 365 | self.num_bins = num_bins 366 | self.num_doas = num_doas 367 | self.velocity = velocity 368 | split_index = lambda sstr: [ 369 | tuple(map(int, p.split(","))) for p in sstr.split(";") 370 | ] 371 | # ipd index 372 | pair = split_index(af_index) 373 | self.index_l = [t[0] for t in pair] 374 | self.index_r = [t[1] for t in pair] 375 | self.af_index = af_index 376 | omega = th.tensor( 377 | [math.pi * sr * f / (num_bins - 1) for f in range(num_bins)]) 378 | # 1 x F 379 | self.omega = nn.Parameter(omega[None, :], requires_grad=False) 380 | 381 | def _oracle_phase_delay(self, doa): 382 | """ 383 | Compute oracle phase delay given DoA 384 | args 385 | doa: N 386 | return 387 | phi: N x C x F or N x D x C x F 388 | """ 389 | device = doa.device 390 | if self.num_doas != 1: 391 | # doa is a unused, fake parameter 392 | N = doa.shape[0] 393 | # N x D 394 | doa = th.linspace(0, MATH_PI * 2, self.num_doas + 1, 395 | device=device)[:-1].repeat(N, 1) 396 | # for princeton 397 | # M = 7, R = 0.0425, treat M_0 as (0, 0) 398 | # *3 *2 399 | # 400 | # *4 *0 *1 401 | # 402 | # *5 *6 403 | if self.geometric == "princeton": 404 | R = 0.0425 405 | zero = th.zeros_like(doa) 406 | # N x 7 or N x D x 7 407 | tau = R * th.stack([ 408 | zero, -th.cos(doa), -th.cos(MATH_PI / 3 - doa), 409 | -th.cos(2 * MATH_PI / 3 - doa), 410 | th.cos(doa), 411 | th.cos(MATH_PI / 3 - doa), 412 | th.cos(2 * MATH_PI / 3 - doa) 413 | ], 414 | dim=-1) / self.velocity 415 | # (Nx7x1) x (1xF) => Nx7xF or (NxDx7x1) x (1xF) => NxDx7xF 416 | phi = th.matmul(tau.unsqueeze(-1), -self.omega) 417 | return phi 418 | else: 419 | return None 420 | 421 | def extra_repr(self): 422 | return ( 423 | f"geometric={self.geometric}, af_index={self.af_index}, " + 424 | f"sr={self.sr}, num_bins={self.num_bins}, velocity={self.velocity}, " 425 | + f"known_doa={self.num_doas == 1}") 426 | 427 | def _compute_af(self, ipd, doa): 428 | """ 429 | Compute angle feature 430 | args 431 | ipd: N x C x F x T 432 | doa: DoA of the target speaker (if we known that), N 433 | or N x D (we do not known that, sampling D DoAs instead) 434 | return 435 | af: N x F x T or N x D x F x T 436 | """ 437 | # N x C x F or N x D x C x F 438 | d = self._oracle_phase_delay(doa) 439 | d = d.unsqueeze(-1) 440 | if self.num_doas == 1: 441 | dif = d[:, self.index_l] - d[:, self.index_r] 442 | # N x C x F x T 443 | af = th.cos(ipd - dif) 444 | # on channel dimention (mean or sum) 445 | af = th.mean(af, dim=1) 446 | else: 447 | # N x D x C x F x 1 448 | dif = d[:, :, self.index_l] - d[:, :, self.index_r] 449 | # N x D x C x F x T 450 | af = th.cos(ipd.unsqueeze(1) - dif) 451 | # N x D x F x T 452 | af = th.mean(af, dim=2) 453 | return af 454 | 455 | def forward(self, p, doa): 456 | """ 457 | Accept doa of the speaker & multi-channel phase, output angle feature 458 | args 459 | doa: DoA of target/each speaker, N or [N, ...] 460 | p: phase matrix, N x C x F x T 461 | return 462 | af: angle feature, N x F* x T or N x D x F x T (known_doa=False) 463 | """ 464 | if p.dim() not in [3, 4]: 465 | raise RuntimeError( 466 | "{} expect 3/4D tensor, but got {:d} instead".format( 467 | self.__name__, p.dim())) 468 | # C x F x T => 1 x C x F x T 469 | if p.dim() == 3: 470 | p = p.unsqueeze(0) 471 | ipd = p[:, self.index_l] - p[:, self.index_r] 472 | 473 | if isinstance(doa, Sequence): 474 | if self.num_doas != 1: 475 | raise RuntimeError("known_doa=False, no need to pass " 476 | "doa as a Sequence object") 477 | # [N x F x T or N x D x F x T, ...] 478 | af = [self._compute_af(ipd, spk_doa) for spk_doa in doa] 479 | # N x F x T => N x F* x T 480 | af = th.cat(af, 1) 481 | else: 482 | # N x F x T or N x D x F x T 483 | af = self._compute_af(ipd, doa) 484 | return af 485 | -------------------------------------------------------------------------------- /DPARNet/mvdr_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | from torch import nn 5 | from sklearn.cluster import KMeans 6 | from asteroid import torch_utils 7 | import asteroid.filterbanks as fb 8 | from asteroid.engine.optimizers import make_optimizer 9 | from asteroid.masknn import norms, activations 10 | from asteroid.utils.torch_utils import pad_x_to_y 11 | from asteroid.losses import PITLossWrapper, pairwise_neg_snr 12 | from torch_complex.tensor import ComplexTensor 13 | from torch_complex import functional as FC 14 | from distutils.version import LooseVersion 15 | 16 | stft_dict={ 17 | 'n_filters': 4096, 18 | 'kernel_size': 4096, 19 | 'stride': 1024, 20 | } 21 | 22 | class STFT(nn.Module): 23 | def __init__(self,stft_dict=stft_dict): 24 | super().__init__() 25 | self.stft_dict=stft_dict 26 | enc, dec = fb.make_enc_dec('stft', **stft_dict) 27 | self.enc = enc 28 | self.dec = dec 29 | 30 | def stft(self,x): 31 | # x should be ... , t 32 | tf = self.enc(x.contiguous()) 33 | # ..., F, T 34 | return tf 35 | 36 | def istft(self,x,y=None): 37 | # x ...,f,t 38 | x=self.dec(x) 39 | if(y is not None): 40 | x=torch_utils.pad_x_to_y(x,y) 41 | return x 42 | 43 | def get_causal_power_spectral_density_matrix(observation, normalize=False, causal=False): 44 | ''' 45 | psd = np.einsum('...dft,...eft->...deft', observation, observation.conj()) # (..., sensors, sensors, freq, frames) 46 | if normalize: 47 | psd = np.cumsum(psd, axis=-1)/np.arange(1,psd.shape[-1]+1,dtype=np.complex64) 48 | if(psd.shape[-1]%causal_step==0): 49 | return psd[...,causal_step-1::causal_step] 50 | else: 51 | return np.concatenate([psd[...,causal_step-1::causal_step], psd[...,[-1]]],-1) 52 | ''' 53 | obsr, obsi = observation.chunk(2,-2) # S C F T 54 | psdr = torch.einsum('saft,sbft->sabft',obsr,obsr) + torch.einsum('saft,sbft->sabft',obsi,obsi) 55 | psdi = -torch.einsum('saft,sbft->sabft',obsr,obsi) + torch.einsum('saft,sbft->sbaft',obsr,obsi) 56 | if causal: 57 | psd = torch.cat([psdr,psdi],-2).cumsum(-1) # S C C F T 58 | if(normalize): 59 | psd = psd/torch.arange(1,psd.shape[-1]+1,1,dtype=psd.dtype, device=psd.device)[None,None,None,None,:] 60 | else: 61 | psd = torch.cat([psdr,psdi],-2).sum(-1,keepdim=True) # S C C F T 62 | if(normalize): 63 | pad = psd/psdr.shape[-1] 64 | return psd 65 | 66 | def get_mvdr_vector( 67 | psd_s: ComplexTensor, 68 | psd_n: ComplexTensor, 69 | reference_vector = 0, 70 | use_torch_solver: bool = True, 71 | diagonal_loading: bool = True, 72 | diag_eps: float = 1e-7, 73 | eps: float = 1e-8, 74 | ) -> ComplexTensor: 75 | """Return the MVDR (Minimum Variance Distortionless Response) vector: 76 | 77 | h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u 78 | 79 | Reference: 80 | On optimal frequency-domain multichannel linear filtering 81 | for noise reduction; M. Souden et al., 2010; 82 | https://ieeexplore.ieee.org/document/5089420 83 | 84 | Args: 85 | psd_s (ComplexTensor): speech covariance matrix (..., F, C, C) 86 | psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) 87 | reference_vector (torch.Tensor): (..., C) 88 | use_torch_solver (bool): Whether to use `solve` instead of `inverse` 89 | diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n 90 | diag_eps (float): 91 | eps (float): 92 | Returns: 93 | beamform_vector (ComplexTensor): (..., F, C) 94 | """ # noqa: D400 95 | if diagonal_loading: 96 | psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) 97 | ''' 98 | if use_torch_solver and is_torch_1_1_plus: 99 | # torch.solve is required, which is only available after pytorch 1.1.0+ 100 | numerator = FC.solve(psd_s, psd_n)[0] 101 | else: 102 | ''' 103 | numerator = FC.matmul(psd_n.inverse2(), psd_s) 104 | # ws: (..., C, C) / (...,) -> (..., C, C) 105 | ws = numerator / (FC.trace(numerator)[..., None, None] + eps) 106 | # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) 107 | beamform_vector = ws 108 | return beamform_vector 109 | 110 | def tik_reg(mat: ComplexTensor, reg: float = 1e-8, eps: float = 1e-8) -> ComplexTensor: 111 | """Perform Tikhonov regularization (only modifying real part). 112 | 113 | Args: 114 | mat (ComplexTensor): input matrix (..., C, C) 115 | reg (float): regularization factor 116 | eps (float) 117 | Returns: 118 | ret (ComplexTensor): regularized matrix (..., C, C) 119 | """ 120 | # Add eps 121 | C = mat.size(-1) 122 | eye = torch.eye(C, dtype=mat.dtype, device=mat.device) 123 | shape = [1 for _ in range(mat.dim() - 2)] + [C, C] 124 | eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1) 125 | with torch.no_grad(): 126 | epsilon = FC.trace(mat).real[..., None, None] * reg 127 | # in case that correlation_matrix is all-zero 128 | epsilon = epsilon + eps 129 | mat = mat + epsilon * eye 130 | return mat 131 | 132 | class MVDR(nn.Module): 133 | def __init__(self, causal): 134 | super().__init__() 135 | self.stft_model = STFT() 136 | self.causal = causal 137 | self.permute = PITLossWrapper(pairwise_neg_snr, pit_from='pw_mtx') 138 | 139 | def forward(self, x, s, do_permute=True): 140 | """ 141 | x: mix b x c x n 142 | s: est b x s x c x n 143 | """ 144 | n_batch, n_src, n_chan, n_samp = s.shape 145 | if (do_permute): 146 | s = self.permute_sig(s) 147 | x = x.unsqueeze(1).repeat(1,n_src,1,1).view(n_batch*n_src, n_chan, n_samp) 148 | s = s.view(n_batch*n_src, n_chan, n_samp) 149 | 150 | X = self.stft_model.stft(x) # B*S C F T 151 | S = self.stft_model.stft(s) # B*S C F T 152 | N = X - S 153 | 154 | n_freq, n_frame = S.shape[-2:] 155 | 156 | # print('N ', N.shape) 157 | Sscm = get_causal_power_spectral_density_matrix(S, normalize=True, causal=self.causal) # B*S C C F T 158 | Nscm = get_causal_power_spectral_density_matrix(N, normalize=True, causal=self.causal) # B*S C C F T 159 | 160 | # print('N maxtrix ', N.shape) 161 | Sscm = ComplexTensor(*Sscm.chunk(2,-2)).permute(0,4,3,1,2) # B*S T F C C 162 | Nscm = ComplexTensor(*Nscm.chunk(2,-2)).permute(0,4,3,1,2) 163 | est_filt = get_mvdr_vector(Sscm, Nscm) # B*S T F C C 164 | est_filt = torch.cat([est_filt.real,est_filt.imag],2) # B*S T F C C 165 | est_filt = est_filt.permute(0,3,4,2,1) # B*S C C F T 166 | # print('est_filt ', est_filt.shape) 167 | 168 | est_S = self.apply_bf(est_filt,X) # B*S C F T 169 | est_s = torch_utils.pad_x_to_y(self.stft_model.istft(est_S), s).view(n_batch, n_src, n_chan, n_samp) # b*s c t 170 | s = s.view(n_batch, n_src, n_chan, n_samp) 171 | 172 | return est_s 173 | 174 | def apply_bf(self,f,X): 175 | ''' 176 | f B C C F T 177 | X B C F T 178 | ''' 179 | X_real, X_imag = X.unsqueeze(2).chunk(2,-2) # B C 1 F T 180 | f_real, f_imag = f.chunk(2,-2) 181 | f_imag = -1.0 * f_imag 182 | # enhX_real = (X_real * (f_real + torch.ones_like(f_real))).sum(1) - (X_imag * f_imag).sum(1) # B C F T 183 | # enhX_imag = (X_real * f_imag).sum(1) + (X_imag * (f_real + torch.ones_like(f_real))).sum(1) 184 | enhX_real = (X_real * f_real).sum(1) - (X_imag * f_imag).sum(1) # B C F T 185 | enhX_imag = (X_real * f_imag).sum(1) + (X_imag * f_real).sum(1) 186 | enhX = torch.cat([enhX_real, enhX_imag],2) 187 | return enhX 188 | 189 | def permute_sig(self, est_sources): 190 | # b s c t 191 | reest_sources = [est_sources[:,:,0,:],] 192 | for chan in range(1,est_sources.shape[2]): 193 | if(self.causal): 194 | est_sources_rest = torch.zeros_like(est_sources[:,:,chan,:]) 195 | if(est_sources.shape[-1] None: 66 | if self.scheduler is not None: 67 | if not isinstance(self.scheduler, (list, tuple)): 68 | self.scheduler = [self.scheduler] # support multiple schedulers 69 | for sched in self.scheduler: 70 | if isinstance(sched, dict) and sched["interval"] == "batch": 71 | sched["scheduler"].step() # call step on each batch scheduler 72 | super().optimizer_step(*args, **kwargs) 73 | def configure_optimizers(self): 74 | """ Required by pytorch-lightning. """ 75 | 76 | if self.scheduler is not None: 77 | if not isinstance(self.scheduler, (list, tuple)): 78 | self.scheduler = [self.scheduler] # support multiple schedulers 79 | epoch_schedulers = [] 80 | for sched in self.scheduler: 81 | if not isinstance(sched, dict): 82 | epoch_schedulers.append(sched) 83 | else: 84 | assert sched["interval"] in [ 85 | "batch", 86 | "epoch", 87 | ], "Scheduler interval should be either batch or epoch" 88 | if sched["interval"] == "epoch": 89 | epoch_schedulers.append(sched) 90 | return [self.optimizer], epoch_schedulers 91 | return self.optimizer 92 | 93 | def train_dataloader(self): 94 | return self.train_loader 95 | 96 | def val_dataloader(self): 97 | return self.val_loader 98 | 99 | def on_save_checkpoint(self, checkpoint): 100 | """ Overwrite if you want to save more things in the checkpoint.""" 101 | checkpoint["training_config"] = self.config 102 | return checkpoint 103 | 104 | @staticmethod 105 | def config_to_hparams(dic): 106 | dic = flatten_dict(dic) 107 | for k, v in dic.items(): 108 | if v is None: 109 | dic[k] = str(v) 110 | elif isinstance(v, (list, tuple)): 111 | dic[k] = torch.Tensor(v) 112 | return dic 113 | 114 | -------------------------------------------------------------------------------- /DPARNet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.data import DataLoader 11 | import pytorch_lightning as pl 12 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 13 | 14 | from dataset import Dataset 15 | from asteroid.engine.optimizers import make_optimizer 16 | from system import System 17 | from DPARNet import make_model_and_optimizer, com_sisdr_loss1 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--use_aneconic", type=int, required=True) 21 | parser.add_argument("--channel_permute", type=int, required=True) 22 | parser.add_argument("--normalize", type=int, required=True) 23 | parser.add_argument("--train_dirs", type=str, required=True) 24 | parser.add_argument("--val_dirs", type=str, required=True) 25 | parser.add_argument("--exp_dir", default="exp/tmp") 26 | 27 | def _worker_init_fn_(worker_id): 28 | torch_seed = torch.initial_seed() 29 | 30 | random.seed(torch_seed + worker_id) 31 | if torch_seed >= 2**32: 32 | torch_seed = torch_seed % 2**32 33 | np.random.seed(torch_seed + worker_id) 34 | 35 | def main(conf): 36 | 37 | use_aneconic = conf["main_args"]['use_aneconic'] 38 | channel_permute = conf["main_args"]['channel_permute'] 39 | normalize = conf["main_args"]['normalize'] 40 | train_dir = conf["main_args"]['train_dirs'] 41 | val_dir = conf["main_args"]['val_dirs'] 42 | 43 | rirNO_train = len(os.listdir(train_dir)) 44 | rirNO_val = len(os.listdir(val_dir)) 45 | 46 | train_set = Dataset( 47 | train_dir, 48 | rirNO_train, 49 | trainingNO = 8000, 50 | segment=6, 51 | channel=[0,1,2,3,4,5,6], 52 | overlap = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 53 | use_aneconic = use_aneconic, 54 | channel_permute = channel_permute, 55 | normalize = normalize, 56 | ) 57 | 58 | val_set = Dataset( 59 | val_dir, 60 | rirNO_val, 61 | trainingNO = 1000, 62 | segment=6, 63 | channel=[0,1,2,3,4,5,6], 64 | overlap = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], 65 | use_aneconic = use_aneconic, 66 | channel_permute = channel_permute, 67 | normalize = normalize, 68 | ) 69 | 70 | train_loader = DataLoader( 71 | train_set, 72 | shuffle=True, 73 | batch_size=conf["training"]["batch_size"], 74 | num_workers=conf["training"]["num_workers"], 75 | drop_last=True, 76 | worker_init_fn=_worker_init_fn_ 77 | ) 78 | 79 | val_loader = DataLoader( 80 | val_set, 81 | shuffle=False, 82 | batch_size=conf["training"]["batch_size"], 83 | num_workers=conf["training"]["num_workers"], 84 | drop_last=True, 85 | worker_init_fn=_worker_init_fn_ 86 | ) 87 | 88 | # Define model and optimizer 89 | model, optimizer = make_model_and_optimizer(conf) 90 | loss_func = com_sisdr_loss1() 91 | 92 | # Define scheduler 93 | scheduler = None 94 | if conf["training"]["half_lr"]: 95 | scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) 96 | 97 | # Just after instantiating, save the args. Easy loading in the future. 98 | exp_dir = conf["main_args"]["exp_dir"] 99 | os.makedirs(exp_dir, exist_ok=True) 100 | conf_path = os.path.join(exp_dir, "conf.yml") 101 | with open(conf_path, "w") as outfile: 102 | yaml.safe_dump(conf, outfile) 103 | 104 | system = System( 105 | model=model, 106 | optimizer=optimizer, 107 | loss_func=loss_func, 108 | train_loader=train_loader, 109 | val_loader=val_loader, 110 | scheduler=scheduler, 111 | config=conf, 112 | ) 113 | 114 | # Define callbacks 115 | checkpoint_dir = os.path.join(exp_dir, "checkpoints/") 116 | checkpoint = ModelCheckpoint( 117 | checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True 118 | ) 119 | early_stopping = False 120 | if conf["training"]["early_stop"]: 121 | early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True) 122 | 123 | # Don't ask GPU if they are not available. 124 | gpus = -1 if torch.cuda.is_available() else None 125 | trainer = pl.Trainer( 126 | max_epochs=conf["training"]["epochs"], 127 | checkpoint_callback=checkpoint, 128 | #resume_from_checkpoint='', 129 | early_stop_callback=early_stopping, 130 | default_root_dir=exp_dir, 131 | gpus=gpus, 132 | distributed_backend="dp", 133 | train_percent_check=1.0, # Useful for fast experiment 134 | gradient_clip_val=5.0, 135 | ) 136 | trainer.fit(system) 137 | 138 | best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} 139 | with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: 140 | json.dump(best_k, f, indent=0) 141 | 142 | 143 | if __name__ == "__main__": 144 | import yaml 145 | from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict 146 | 147 | with open("conf.yml") as f: 148 | def_conf = yaml.safe_load(f) 149 | parser = prepare_parser_from_dict(def_conf, parser=parser) 150 | 151 | arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) 152 | main(arg_dic) 153 | 154 | 155 | -------------------------------------------------------------------------------- /DPARNet/utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /DPARNet/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DPARNet 2 | **Light-weight speech separation based on dual-path attention and recurrent neural network** 3 | 4 | **基于双路注意力循环网络的轻量化语音分离** 5 | 6 | **This work has been published on 声学学报 *(Chinese Journal of Acoustics)*. The paper is available [here][Paper].** 7 | 8 | ## Contents 9 | * **[DPARNet](#dparnet)** 10 | * **[Contents](#contents)** 11 | * **[Introduction](#introduction)** 12 | * **[Dataset](#dataset)** 13 | * **[Requirement](#requirement)** 14 | * **[Train](#train)** 15 | * **[Test](#test)** 16 | * **[Results](#results)** 17 | * **[Citation](#citation)** 18 | * **[References](#references)** 19 | 20 | ## Introduction 21 | **DPARNet, which is an improvement of DPTFSNet [1], is composed of encoder, separation network and decoder. To alleviate the computation burden, sub-band processing approach is leveraged in the encoder. Dual-path attention mechanism and recurrent network structure are introduced in the separation network to model the speech signals in each sub-band, which facilitate extraction of deep feature information and rich spectrum details.** 22 | 23 | **The parameters and computation cost of DPARNet model is only 0.15M and 15.2G/6s.** 24 | 25 | **Inspired by [2], we also introduce Beam-Guided DPARNet, which makes full use of spatial information.** 26 | 27 | ## Dataset 28 | **We use [sms_wsj][sms_wsj] to generate room impulse responses (RIRs) set. ```sms_wsj/reverb/scenario.py``` and ```sms_wsj/database/create_rirs.py``` should be replaced by scripts in 'sms_wsj_replace' folder.** 29 | 30 | **use ```python generate_rir.py``` to generate training and valadation data** 31 | 32 | **We use [LibriCSS][libricss] dataset as test set.** 33 | 34 | ## Requirement 35 | **Our script use [asteroid][asteroid] toolkit as the basic environment.** 36 | 37 | ## Train 38 | **We recommend running to train end-to-end :** 39 | 40 | **```./run.sh --id 0,1,2,3```** 41 | 42 | **or :** 43 | 44 | **```./run.sh --id 0,1,2,3 --stage 1```** 45 | 46 | ## Test 47 | **```./run.sh --id 0 --stage 2```** 48 | 49 | ## Results 50 | **WER (%) on LibriCSS, model parameters (MiB) and computation (G/6s speech)** 51 | 52 | |**Model** |**Year**|**0S** |**0L** |**OV10**|**OV20**|**OV30**|**OV40**|**parameters**|**computation**| 53 | | :----- | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 54 | |**Raw[3]** |2020 |11.8 |11.7 |18.8 |27.2 |35.6 |43.3 | - | - | 55 | |**BLSTM[4]** |2021 |**7.0** |7.5 |10.8 |13.4 |16.5 |18.8 |21.8 |17.1 | 56 | |**PW-NBDF[5]** |2021 |7.3 |7.3 | 8.3 |10.6 |13.4 |15.8 |18.9 |20.1 | 57 | |**Conformer-large[4]** |2021 |7.2 |7.5 |9.6 |11.3 |13.7 |15.1 |58.7 |43.6 | 58 | |**DPT-FSNet[1]** |2022 |7.1 |7.3 |7.6 |8.9 |10.8 |11.3 |0.50 |49.1 | 59 | |**Beam-Guided DPT-FSNet[2]**|2022 |7.1 |7.1 |**7.1** |8.0 |9.2 |9.7 |1.0 |50.1 | 60 | |**Proposed DPARNet** |- |7.2 |7.2 |7.4 |8.6 |10.3 |10.9 |0.15 |15.2 | 61 | |**Beam-Guided DPARNet** |- |7.3 |**6.9** |7.2 |**7.7** |**9.0** |**9.4** |0.41 |41.1 | 62 | 63 | 64 | ## Citation 65 | **Cite our paper by:** 66 | 67 | **@article{XIBA202305016,** 68 | 69 | **title={双路注意力循环网络的轻量化语音分离},** 70 | 71 | **author={杨弋 and 胡琦 and 张鹏远},** 72 | 73 | **journal={声学学报},** 74 | 75 | **volume={48},** 76 | 77 | **number={05},** 78 | 79 | **pages={1060-1069},** 80 | 81 | **year={2023},** 82 | 83 | **doi={10.15949/j.cnki.0371-0025.2023.05.013}** 84 | 85 | **}** 86 | 87 | ## Referenecs 88 | 89 | **[1] Dang F, Chen H T, Zhang P Y. DPT-FSNet: Dual-path Transformer Based Full-band and Sub-band Fusion Network for Speech Enhancement. Proc. IEEE 90 | Int. Conf. Acoust. Speech Signal Process., 2022: 6857—6861** 91 | 92 | **[2] Chen H T, Zhang P Y. Beam-Guided TasNet: An Iterative Speech Separation Framework with Multi-Channel Output, 2021: arXiv preprint arXiv: 93 | 2102.02998** 94 | 95 | **[3] Chen Z, Yoshioka T, Lu L et al. Continuous speech separation: dataset and analysis. Proc. IEEE Int. Conf. Acoust. Speech Signal Process., 2020: 96 | 7284—7288** 97 | 98 | **[4] Chen S Y, Wu Y, Chen Z et al. Continuous Speech Separation with Conformer. Proc. IEEE Int. Conf. Acoust. Speech Signal Process., 2021; 5749—5753** 99 | 100 | **[5] Zhang S Y, Li X F. Microphone Array Generalization for Multichannel Narrowband Deep Speech Enhancement. Proc. Interspeech, 2021: 666—670** 101 | 102 | **Please feel free to contact us if you have any questions.** 103 | 104 | [Paper]: https://kns.cnki.net/kcms2/article/abstract?v=8oX70opUeL_csRlg3CyHljRzX8_P7Bdf9bdDPZNM78u11yizB0xSj5i-PufxXLHIMnJFouMxrqzguCdi9UGX_E043B1UU4db444UqiIsQNMeEz9kNSMelZInznKw-fNdUkUA0G3tvXUBujQEyyNh8-C1qsGAhlXbn3NgZ6a3-wIVp3jhC1eHlRQuBs61r7rsnBT1-dlIXK47IfxkBBoBA4FDbV8uN9Qm&uniplatform=NZKPT&language=CHS 105 | [libricss]: https://github.com/chenzhuo1011/libri_css 106 | [asteroid]: https://github.com/asteroid-team/asteroid 107 | [sms_wsj]: https://github.com/fgnt/sms_wsj 108 | -------------------------------------------------------------------------------- /generate_rir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | from multiprocessing import Pool 7 | import torch 8 | import random 9 | 10 | import time 11 | from sms_wsj.database.create_rirs import config, scenarios, rirs 12 | from sms_wsj.reverb.reverb_utils import convolve 13 | 14 | 15 | def generate_rir(i): 16 | _worker_init_fn_(i) 17 | reverb_matrixs_dir = '/path/to/reverb-set/' 18 | geometry, sound_decay_time_range, sample_rate, filter_length = config() 19 | room_dimensions, source_positions, sensor_positions, sound_decay_time = scenarios(geometry, sound_decay_time_range,) 20 | h = rirs(sample_rate, filter_length, room_dimensions, source_positions, sensor_positions, sound_decay_time) 21 | np.savez(reverb_matrixs_dir + str(i).zfill(5) + '.npz', h=h, source_positions=source_positions, sensor_positions=sensor_positions,) 22 | 23 | 24 | def main(conf): 25 | reverb_matrixs_dir = '/path/to/reverb-set/' 26 | generate_NO = 10000 27 | 28 | # generate new rirs 29 | if not os.path.exists(reverb_matrixs_dir): 30 | os.makedirs(reverb_matrixs_dir) 31 | else: 32 | if (input('target dir already esists, continue? [y/n] ') == 'n'): 33 | print('Exit. Nothing happends.') 34 | sys.exit() 35 | print('Generating reverb matrixs into ', reverb_matrixs_dir, '......') 36 | ''' 37 | # single process 38 | pbar = tqdm(range(generate_NO)) 39 | for i in pbar: 40 | generate_rir(i, reverb_matrixs_dir) 41 | ''' 42 | # multi process 43 | time_start=time.time() 44 | pool = Pool(processes=32) 45 | args = [] 46 | for i in range (generate_NO): 47 | args.append(i) 48 | pool.map(generate_rir, args) 49 | pool.close() 50 | pool.join() 51 | time_end=time.time() 52 | print('totally cost ', round((time_end-time_start)/60), 'minutes') 53 | -------------------------------------------------------------------------------- /sms_wsj_replace/create_rirs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sms_wsj.reverb.reverb_utils import generate_rir 3 | from sms_wsj.reverb.scenario import generate_random_source_positions 4 | from sms_wsj.reverb.scenario import generate_sensor_positions 5 | from sms_wsj.reverb.scenario import sample_from_random_box 6 | 7 | def config(): 8 | # Either set it to zero or above 0.15 s. Otherwise, RIR contains NaN. 9 | sound_decay_time_range = dict(low=0.15, high=0.6) 10 | 11 | geometry = dict( 12 | number_of_sources=3, 13 | number_of_sensors=7, 14 | sensor_shape="circular_center", 15 | center=[[3.5], [3.], [1.5]], # m 16 | scale=0.0425, # m 17 | room=[[7.], [6.], [3.]], # m 18 | random_box=[[4.], [2.], [0.4]], # m 19 | ) 20 | 21 | sample_rate = 16000 22 | filter_length = 2 ** 14 # 1.024 seconds when sample_rate == 16000 23 | 24 | return geometry, sound_decay_time_range, sample_rate, filter_length 25 | 26 | def scenarios(geometry,sound_decay_time_range,): 27 | room_dimensions = sample_from_random_box(geometry["room"], geometry["random_box"]) 28 | center = sample_from_random_box(geometry["center"], geometry["random_box"]) 29 | source_positions = generate_random_source_positions(center=center,sources=geometry["number_of_sources"], dims=2) 30 | 31 | sensor_positions = generate_sensor_positions( 32 | shape=geometry["sensor_shape"], 33 | center=center, 34 | room_dimensions = room_dimensions, 35 | scale=geometry["scale"], 36 | number_of_sensors=geometry["number_of_sensors"], 37 | rotate_x=np.random.uniform(0, 0.01 * 2 * np.pi), 38 | rotate_y=np.random.uniform(0, 0.01 * 2 * np.pi), 39 | rotate_z=np.random.uniform(0, 2 * np.pi), 40 | ) 41 | sound_decay_time = np.random.uniform(**sound_decay_time_range) 42 | 43 | return room_dimensions, source_positions, sensor_positions, sound_decay_time 44 | 45 | def rirs(sample_rate, filter_length, room_dimensions, source_positions, sensor_positions, sound_decay_time): 46 | h = generate_rir( 47 | room_dimensions=room_dimensions, 48 | source_positions=source_positions, 49 | sensor_positions=sensor_positions, 50 | sound_decay_time=sound_decay_time, 51 | sample_rate=sample_rate, 52 | filter_length=filter_length, 53 | sensor_orientations=None, 54 | sensor_directivity=None, 55 | sound_velocity=343 56 | ) 57 | 58 | return h 59 | -------------------------------------------------------------------------------- /sms_wsj_replace/scenario.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helps to quickly create source and sensor positions. 3 | Try it with the following code: 4 | 5 | >>> import numpy as np 6 | >>> import sms_wsj.reverb.scenario as scenario 7 | >>> src = scenario.generate_random_source_positions(dims=2, sources=1000) 8 | >>> src[1, :] = np.abs(src[1, :]) 9 | >>> mic = scenario.generate_sensor_positions(shape='linear', scale=0.1, number_of_sensors=6) 10 | """ 11 | 12 | import numpy as np 13 | from sms_wsj.reverb.rotation import rot_x, rot_y, rot_z 14 | 15 | 16 | def sample_from_random_box(center, edge_lengths, rng=np.random): 17 | """ Sample from a random box to get somewhat random locations. 18 | 19 | >>> points = np.asarray([sample_from_random_box( 20 | ... [[10], [20], [30]], [[1], [2], [3]] 21 | ... ) for _ in range(1000)]) 22 | >>> import matplotlib.pyplot as plt 23 | >>> from mpl_toolkits.mplot3d import Axes3D 24 | >>> fig = plt.figure() 25 | >>> ax = fig.add_subplot(111, projection='3d') 26 | >>> _ = ax.scatter(points[:, 0, 0], points[:, 1, 0], points[:, 2, 0]) 27 | >>> _ = plt.show() 28 | 29 | Args: 30 | center: Original center (mean). 31 | edge_lengths: Edge length of the box to be sampled from. 32 | 33 | Returns: 34 | 35 | """ 36 | center = np.asarray(center) 37 | edge_lengths = np.asarray(edge_lengths) 38 | return center + rng.uniform( 39 | low=-edge_lengths / 2, 40 | high=edge_lengths / 2 41 | ) 42 | 43 | def generate_sensor_positions( 44 | shape='cube', 45 | center=np.zeros((3, 1), dtype=np.float), 46 | room_dimensions = [[6], [4], [3]], 47 | scale=0.01, 48 | number_of_sensors=None, 49 | jitter=None, 50 | rng=np.random, 51 | rotate_x=0, rotate_y=0, rotate_z=0 52 | ): 53 | """ Generate different sensor configurations. 54 | 55 | Sensors are index counter-clockwise starting with the 0th sensor below 56 | the x axis. This is done, such that the first two sensors point towards 57 | the x axis. 58 | 59 | :param shape: A shape, i.e. 'cube', 'triangle', 'linear' or 'circular'. 60 | :param center: Numpy array with shape (3, 1) 61 | which holds coordinates x, y and z. 62 | :param scale: Scalar responsible for scale of the array. See individual 63 | implementations, if it is used as radius or edge length. 64 | :param jitter: Add random Gaussian noise with standard deviation ``jitter`` 65 | to sensor positions. 66 | :return: Numpy array with shape (3, number_of_sensors). 67 | """ 68 | 69 | center = np.array(center) 70 | if center.ndim == 1: 71 | center = center[:, None] 72 | 73 | if shape == 'cube': 74 | b = scale / 2 75 | sensor_positions = np.array([ 76 | [-b, -b, -b], 77 | [-b, -b, b], 78 | [-b, b, -b], 79 | [-b, b, b], 80 | [b, -b, -b], 81 | [b, -b, b], 82 | [b, b, -b], 83 | [b, b, b] 84 | ]).T 85 | 86 | elif shape == 'triangle': 87 | assert number_of_sensors == 3, ( 88 | "triangle is only defined for 3 sensors", 89 | number_of_sensors) 90 | sensor_positions = generate_sensor_positions( 91 | shape='circular', scale=scale, number_of_sensors=3, rng=rng 92 | ) 93 | 94 | elif shape == 'linear': 95 | sensor_positions = np.zeros((3, number_of_sensors), dtype=np.float) 96 | sensor_positions[1, :] = scale * np.arange(number_of_sensors) 97 | sensor_positions -= np.mean(sensor_positions, keepdims=True, axis=1) 98 | 99 | elif shape == 'circular': 100 | if number_of_sensors == 1: 101 | sensor_positions = np.zeros((3, 1), dtype=np.float) 102 | else: 103 | radius = scale 104 | delta_phi = 2 * np.pi / number_of_sensors 105 | phi_0 = delta_phi / 2 106 | phi = np.arange(0, number_of_sensors) * delta_phi - phi_0 107 | sensor_positions = np.asarray([ 108 | radius * np.cos(phi), 109 | radius * np.sin(phi), 110 | np.zeros(phi.shape) 111 | ]) 112 | 113 | elif shape == 'circular_center': 114 | radius = scale 115 | delta_phi = 2 * np.pi / (number_of_sensors - 1) 116 | phi_0 = delta_phi / 2 117 | phi = np.arange(0, number_of_sensors-1) * delta_phi - phi_0 118 | sensor_positions_cir = np.asarray([ 119 | radius * np.cos(phi), 120 | radius * np.sin(phi), 121 | np.zeros(phi.shape) 122 | ]) 123 | sensor_positions_cen = np.asarray([ 124 | [0], 125 | [0], 126 | [0] 127 | ]) 128 | sensor_positions = np.hstack([sensor_positions_cen, sensor_positions_cir]) 129 | 130 | elif shape == 'chime3': 131 | assert scale is None, scale 132 | assert ( 133 | number_of_sensors is None or number_of_sensors == 6 134 | ), number_of_sensors 135 | 136 | sensor_positions = np.asarray( 137 | [ 138 | [-0.1, 0, 0.1, -0.1, 0, 0.1], 139 | [0.095, 0.095, 0.095, -0.095, -0.095, -0.095], 140 | [0, -0.02, 0, 0, 0, 0] 141 | ] 142 | ) 143 | 144 | else: 145 | raise NotImplementedError('Given shape is not implemented.') 146 | 147 | # NOTE rotation 148 | #sensor_positions = rot_x(rotate_x) @ sensor_positions 149 | #sensor_positions = rot_y(rotate_y) @ sensor_positions 150 | #sensor_positions = rot_z(rotate_z) @ sensor_positions 151 | 152 | if jitter is not None: 153 | sensor_positions += rng.normal( 154 | 0., jitter, size=sensor_positions.shape 155 | ) 156 | 157 | return np.asarray(sensor_positions + center) 158 | 159 | def generate_random_source_positions( 160 | center=np.zeros((3, 1)), 161 | sources=1, 162 | distance_interval=(0.3, 2.5), 163 | dims=2, 164 | minimum_angular_distance=None, 165 | maximum_angular_distance=None, 166 | rng=np.random 167 | ): 168 | """ Generates random positions on a hollow sphere or circle. 169 | 170 | Samples are drawn from a uniform distribution on a hollow sphere with 171 | inner and outer radius according to distance_interval. 172 | 173 | The idea is to sample from an angular centric Gaussian distribution. 174 | 175 | Params: 176 | center 177 | sources 178 | distance_interval 179 | dims 180 | minimum_angular_distance: In randiant or None. 181 | maximum_angular_distance: In randiant or None. 182 | rng: Random number generator, if you need to set the seed. 183 | """ 184 | enforce_angular_constrains = ( 185 | minimum_angular_distance is not None or 186 | maximum_angular_distance is not None 187 | ) 188 | 189 | if not dims == 2 and enforce_angular_constrains: 190 | raise NotImplementedError( 191 | 'Only implemented distance constraints for 2D.' 192 | ) 193 | 194 | accept = False 195 | while not accept: 196 | x = rng.normal(size=(3, sources)) 197 | if dims == 2: 198 | x[2, :] = 0 199 | 200 | if enforce_angular_constrains: 201 | if not sources == 2: 202 | raise NotImplementedError 203 | angle = np.arctan2(x[1, :], x[0, :]) 204 | difference = np.angle( 205 | np.exp(1j * (angle[None, :], angle[:, None]))) 206 | difference = difference[np.triu_indices_from(difference, k=1)] 207 | distance = np.abs(difference) 208 | if ( 209 | minimum_angular_distance is not None and 210 | minimum_angular_distance > np.min(distance) 211 | ): 212 | continue 213 | if ( 214 | maximum_angular_distance is not None and 215 | maximum_angular_distance < np.max(distance) 216 | ): 217 | continue 218 | accept = True 219 | 220 | x /= np.linalg.norm(x, axis=0) # 单位方向向量 221 | 222 | radius = rng.uniform( 223 | distance_interval[0] ** dims, 224 | distance_interval[1] ** dims, 225 | size=(1, sources) 226 | ) ** (1 / dims) 227 | 228 | x *= radius 229 | 230 | return np.asarray(x + center) 231 | --------------------------------------------------------------------------------