├── FaSNet.py ├── README.md ├── data ├── README.md ├── configs │ ├── MC_Libri_adhoc_test.pkl │ ├── MC_Libri_adhoc_train.pkl │ ├── MC_Libri_adhoc_validation.pkl │ ├── MC_Libri_fixed_test.pkl │ ├── MC_Libri_fixed_train.pkl │ └── MC_Libri_fixed_validation.pkl └── create_dataset.py ├── flowchart.png ├── iFaSNet.py └── utility ├── __init__.py ├── models.py └── sdr.py /FaSNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import os 6 | import numpy as np 7 | 8 | from utility.models import * 9 | 10 | # DPRNN for beamforming filter estimation 11 | class BF_module(DPRNN_base): 12 | def __init__(self, *args, **kwargs): 13 | super(BF_module, self).__init__(*args, **kwargs) 14 | 15 | # gated output layer 16 | self.output = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1), 17 | nn.Tanh() 18 | ) 19 | self.output_gate = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1), 20 | nn.Sigmoid() 21 | ) 22 | 23 | def forward(self, input, num_mic): 24 | 25 | if self.model_type == 'DPRNN': 26 | # input: (B, N, T) 27 | batch_size, N, seq_length = input.shape 28 | ch = 1 29 | elif self.model_type == 'DPRNN_TAC': 30 | # input: (B, ch, N, T) 31 | batch_size, ch, N, seq_length = input.shape 32 | 33 | input = input.view(batch_size*ch, N, seq_length) # B*ch, N, T 34 | enc_feature = self.BN(input) 35 | 36 | # split the encoder output into overlapped, longer segments 37 | enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B*ch, N, L, K 38 | 39 | # pass to DPRNN 40 | if self.model_type == 'DPRNN': 41 | output = self.DPRNN(enc_segments).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K 42 | elif self.model_type == 'DPRNN_TAC': 43 | enc_segments = enc_segments.view(batch_size, ch, -1, enc_segments.shape[2], enc_segments.shape[3]) # B, ch, N, L, K 44 | output = self.DPRNN(enc_segments, num_mic).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K 45 | 46 | # overlap-and-add of the outputs 47 | output = self.merge_feature(output, enc_rest) # B*ch*nspk, N, T 48 | 49 | # gated output layer for filter generation 50 | bf_filter = self.output(output) * self.output_gate(output) # B*ch*nspk, K, T 51 | bf_filter = bf_filter.transpose(1, 2).contiguous().view(batch_size, ch, self.num_spk, -1, self.output_dim) # B, ch, nspk, L, N 52 | 53 | return bf_filter 54 | 55 | 56 | # base module for FaSNet 57 | class FaSNet_base(nn.Module): 58 | def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=50, 59 | nspk=2, win_len=4, context_len=16, sr=16000): 60 | super(FaSNet_base, self).__init__() 61 | 62 | # parameters 63 | self.window = int(sr * win_len / 1000) 64 | self.context = int(sr * context_len / 1000) 65 | self.stride = self.window // 2 66 | 67 | self.filter_dim = self.context*2+1 68 | self.enc_dim = enc_dim 69 | self.feature_dim = feature_dim 70 | self.hidden_dim = hidden_dim 71 | self.segment_size = segment_size 72 | 73 | self.layer = layer 74 | self.num_spk = nspk 75 | self.eps = 1e-8 76 | 77 | # waveform encoder 78 | self.encoder = nn.Conv1d(1, self.enc_dim, self.context*2+self.window, bias=False) 79 | self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8) 80 | 81 | def pad_input(self, input, window): 82 | """ 83 | Zero-padding input according to window/stride size. 84 | """ 85 | batch_size, nmic, nsample = input.shape 86 | stride = window // 2 87 | 88 | # pad the signals at the end for matching the window/stride size 89 | rest = window - (stride + nsample % window) % window 90 | if rest > 0: 91 | pad = torch.zeros(batch_size, nmic, rest).type(input.type()) 92 | input = torch.cat([input, pad], 2) 93 | pad_aux = torch.zeros(batch_size, nmic, stride).type(input.type()) 94 | input = torch.cat([pad_aux, input, pad_aux], 2) 95 | 96 | return input, rest 97 | 98 | 99 | def seg_signal_context(self, x, window, context): 100 | """ 101 | Segmenting the signal into chunks with specific context. 102 | input: 103 | x: size (B, ch, T) 104 | window: int 105 | context: int 106 | 107 | """ 108 | 109 | # pad input accordingly 110 | # first pad according to window size 111 | input, rest = self.pad_input(x, window) 112 | batch_size, nmic, nsample = input.shape 113 | stride = window // 2 114 | 115 | # pad another context size 116 | pad_context = torch.zeros(batch_size, nmic, context).type(input.type()) 117 | input = torch.cat([pad_context, input, pad_context], 2) # B, ch, L 118 | 119 | # calculate index for each chunk 120 | nchunk = 2*nsample // window - 1 121 | begin_idx = np.arange(nchunk)*stride 122 | begin_idx = torch.from_numpy(begin_idx).type(input.type()).long().view(1, 1, -1) # 1, 1, nchunk 123 | begin_idx = begin_idx.expand(batch_size, nmic, nchunk) # B, ch, nchunk 124 | # select entries from index 125 | chunks = [torch.gather(input, 2, begin_idx+i).unsqueeze(3) for i in range(2*context + window)] # B, ch, nchunk, 1 126 | chunks = torch.cat(chunks, 3) # B, ch, nchunk, chunk_size 127 | 128 | # center frame 129 | center_frame = chunks[:,:,:,context:context+window] 130 | 131 | return center_frame, chunks, rest 132 | 133 | def seq_cos_sim(self, ref, target): 134 | """ 135 | Cosine similarity between some reference mics and some target mics 136 | ref: shape (nmic1, L, seg1) 137 | target: shape (nmic2, L, seg2) 138 | """ 139 | 140 | assert ref.size(1) == target.size(1), "Inputs should have same length." 141 | assert ref.size(2) >= target.size(2), "Reference input should be no smaller than the target input." 142 | 143 | seq_length = ref.size(1) 144 | 145 | larger_ch = ref.size(0) 146 | if target.size(0) > ref.size(0): 147 | ref = ref.expand(target.size(0), ref.size(1), ref.size(2)).contiguous() # nmic2, L, seg1 148 | larger_ch = target.size(0) 149 | elif target.size(0) < ref.size(0): 150 | target = target.expand(ref.size(0), target.size(1), target.size(2)).contiguous() # nmic1, L, seg2 151 | 152 | # L2 norms 153 | ref_norm = F.conv1d(ref.view(1, -1, ref.size(2)).pow(2), 154 | torch.ones(ref.size(0)*ref.size(1), 1, target.size(2)).type(ref.type()), 155 | groups=larger_ch*seq_length) # 1, larger_ch*L, seg1-seg2+1 156 | ref_norm = ref_norm.sqrt() + self.eps 157 | target_norm = target.norm(2, dim=2).view(1, -1, 1) + self.eps # 1, larger_ch*L, 1 158 | # cosine similarity 159 | cos_sim = F.conv1d(ref.view(1, -1, ref.size(2)), 160 | target.view(-1, 1, target.size(2)), 161 | groups=larger_ch*seq_length) # 1, larger_ch*L, seg1-seg2+1 162 | cos_sim = cos_sim / (ref_norm * target_norm) 163 | 164 | return cos_sim.view(larger_ch, seq_length, -1) 165 | 166 | def forward(self, input, num_mic): 167 | """ 168 | input: shape (batch, max_num_ch, T) 169 | num_mic: shape (batch, ), the number of channels for each input. Zero for fixed geometry configuration. 170 | """ 171 | pass 172 | 173 | 174 | 175 | # original FaSNet 176 | class FaSNet_origin(FaSNet_base): 177 | def __init__(self, *args, **kwargs): 178 | super(FaSNet_origin, self).__init__(*args, **kwargs) 179 | 180 | # DPRNN for ref mic 181 | self.ref_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim, 182 | self.filter_dim, self.num_spk, self.layer, self.segment_size, model_type='DPRNN') 183 | 184 | # DPRNN for other mics 185 | self.other_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim, 186 | self.filter_dim, 1, self.layer, self.segment_size, model_type='DPRNN') 187 | 188 | 189 | def forward(self, input, num_mic): 190 | 191 | batch_size = input.size(0) 192 | nmic = input.size(1) 193 | 194 | # split input into chunks 195 | all_seg, all_mic_context, rest = self.seg_signal_context(input, self.window, self.context) # B, nmic, L, win/chunk 196 | seq_length = all_seg.size(2) 197 | 198 | # first step: filtering the ref mic to create a clean estimate 199 | # calculate cosine similarity 200 | ref_context = all_mic_context[:,0].contiguous().view(1, -1, self.context*2+self.window) # 1, B*L, 3*win 201 | other_segment = all_seg[:,1:].contiguous().transpose(0, 1).contiguous().view(nmic-1, -1, self.window) # nmic-1, B*L, win 202 | ref_cos_sim = self.seq_cos_sim(ref_context, other_segment) # nmic-1, B*L, 2*win+1 203 | ref_cos_sim = ref_cos_sim.view(nmic-1, batch_size, seq_length, self.filter_dim) # nmic-1, B, L, 2*win+1 204 | if num_mic.max() == 0: 205 | ref_cos_sim = ref_cos_sim.mean(0) # B, L, 2*win+1 206 | ref_cos_sim = ref_cos_sim.transpose(1, 2).contiguous() # B, 2*win+1, L 207 | else: 208 | # consider only the valid channels 209 | ref_cos_sim = [ref_cos_sim[:num_mic[b],b,:].mean(0).unsqueeze(0) for b in range(batch_size)] # 1, L, 2*win+1 210 | ref_cos_sim = torch.cat(ref_cos_sim, 0).transpose(1, 2).contiguous() # B, 2*win+1, L 211 | 212 | 213 | # pass to a DPRNN 214 | ref_feature = all_mic_context[:,0].contiguous().view(batch_size*seq_length, 1, self.context*2+self.window) 215 | ref_feature = self.encoder(ref_feature) # B*L, N, 1 216 | ref_feature = ref_feature.view(batch_size, seq_length, self.enc_dim).transpose(1, 2).contiguous() # B, N, L 217 | ref_filter = self.ref_BF(torch.cat([self.enc_LN(ref_feature), ref_cos_sim], 1), num_mic) # B, 1, nspk, L, 2*win+1 218 | 219 | # convolve with ref mic context segments 220 | ref_context = torch.cat([all_mic_context[:,0].unsqueeze(1)]*self.num_spk, 1) # B, nspk, L, 3*win 221 | ref_output = F.conv1d(ref_context.view(1, -1, self.context*2+self.window), 222 | ref_filter.view(-1, 1, self.filter_dim), 223 | groups=batch_size*self.num_spk*seq_length) # 1, B*nspk*L, win 224 | ref_output = ref_output.view(batch_size*self.num_spk, seq_length, self.window) # B*nspk, L, win 225 | 226 | # second step: use the ref output as the cue, beamform other mics 227 | # calculate cosine similarity 228 | other_context = torch.cat([all_mic_context[:,1:].unsqueeze(1)]*self.num_spk, 1) # B, nspk, nmic-1, L, 3*win 229 | other_context_saved = other_context.view(batch_size*self.num_spk, nmic-1, seq_length, self.context*2+self.window) # B*nspk, nmic-1, L, 3*win 230 | other_context = other_context_saved.transpose(0, 1).contiguous().view(nmic-1, -1, self.context*2+self.window) # nmic-1, B*nspk*L, 3*win 231 | ref_segment = ref_output.view(1, -1, self.window) # 1, B*nspk*L, win 232 | other_cos_sim = self.seq_cos_sim(other_context, ref_segment) # nmic-1, B*nspk*L, 2*win+1 233 | other_cos_sim = other_cos_sim.view(nmic-1, batch_size*self.num_spk, seq_length, self.filter_dim) # nmic-1, B*nspk, L, 2*win+1 234 | other_cos_sim = other_cos_sim.permute(1,0,3,2).contiguous().view(-1, self.filter_dim, seq_length) # B*nspk*(nmic-1), 2*win+1, L 235 | 236 | # pass to another DPRNN 237 | other_feature = self.encoder(other_context_saved.view(-1, 1, self.context*2+self.window)).view(-1, seq_length, self.enc_dim) # B*nspk*(nmic-1), L, N 238 | other_feature = other_feature.transpose(1, 2).contiguous() # B*nspk*(nmic-1), N, L 239 | other_filter = self.other_BF(torch.cat([self.enc_LN(other_feature), other_cos_sim], 1), num_mic) # B*nspk*(nmic-1), 1, 1, L, 2*win+1 240 | 241 | # convolve with other mic context segments 242 | other_output = F.conv1d(other_context_saved.view(1, -1, self.context*2+self.window), 243 | other_filter.view(-1, 1, self.filter_dim), 244 | groups=batch_size*self.num_spk*(nmic-1)*seq_length) # 1, B*nspk*(nmic-1)*L, win 245 | other_output = other_output.view(batch_size*self.num_spk, nmic-1, seq_length, self.window) # B*nspk, nmic-1, L, win 246 | 247 | all_bf_output = torch.cat([ref_output.unsqueeze(1), other_output], 1) # B*nspk, nmic, L, win 248 | 249 | # reshape to utterance 250 | bf_signal = all_bf_output.view(batch_size*self.num_spk*nmic, -1, self.window*2) 251 | bf_signal1 = bf_signal[:,:,:self.window].contiguous().view(batch_size*self.num_spk*nmic, 1, -1)[:,:,self.stride:] 252 | bf_signal2 = bf_signal[:,:,self.window:].contiguous().view(batch_size*self.num_spk*nmic, 1, -1)[:,:,:-self.stride] 253 | bf_signal = bf_signal1 + bf_signal2 # B*nspk*nmic, 1, T 254 | if rest > 0: 255 | bf_signal = bf_signal[:,:,:-rest] 256 | 257 | bf_signal = bf_signal.view(batch_size, self.num_spk, nmic, -1) # B, nspk, nmic, T 258 | # consider only the valid channels 259 | if num_mic.max() == 0: 260 | bf_signal = bf_signal.mean(2) # B, nspk, T 261 | else: 262 | bf_signal = [bf_signal[b,:,:num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size)] # nspk, T 263 | bf_signal = torch.cat(bf_signal, 0) # B, nspk, T 264 | 265 | return bf_signal 266 | 267 | # single-stage FaSNet + TAC 268 | class FaSNet_TAC(FaSNet_base): 269 | def __init__(self, *args, **kwargs): 270 | super(FaSNet_TAC, self).__init__(*args, **kwargs) 271 | 272 | # DPRNN + TAC for estimation 273 | self.all_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim, 274 | self.filter_dim, self.num_spk, self.layer, self.segment_size, model_type='DPRNN_TAC') 275 | 276 | def forward(self, input, num_mic): 277 | 278 | batch_size = input.size(0) 279 | nmic = input.size(1) 280 | 281 | # split input into chunks 282 | all_seg, all_mic_context, rest = self.seg_signal_context(input, self.window, self.context) # B, nmic, L, win/chunk 283 | seq_length = all_seg.size(2) 284 | 285 | # embeddings for all channels 286 | enc_output = self.encoder(all_mic_context.view(-1, 1, self.context*2+self.window)).view(batch_size*nmic, seq_length, self.enc_dim).transpose(1, 2).contiguous() # B*nmic, N, L 287 | enc_output = self.enc_LN(enc_output).view(batch_size, nmic, self.enc_dim, seq_length) # B, nmic, N, L 288 | 289 | # calculate the cosine similarities for ref channel's center frame with all channels' context 290 | 291 | ref_seg = all_seg[:,0].contiguous().view(1, -1, self.window) # 1, B*L, win 292 | all_context = all_mic_context.transpose(0, 1).contiguous().view(nmic, -1, self.context*2+self.window) # 1, B*L, 3*win 293 | all_cos_sim = self.seq_cos_sim(all_context, ref_seg) # nmic, B*L, 2*win+1 294 | all_cos_sim = all_cos_sim.view(nmic, batch_size, seq_length, self.filter_dim).permute(1,0,3,2).contiguous() # B, nmic, 2*win+1, L 295 | 296 | input_feature = torch.cat([enc_output, all_cos_sim], 2) # B, nmic, N+2*win+1, L 297 | 298 | # pass to DPRNN 299 | all_filter = self.all_BF(input_feature, num_mic) # B, ch, nspk, L, 2*win+1 300 | 301 | # convolve with all mic's context 302 | mic_context = torch.cat([all_mic_context.view(batch_size*nmic, 1, seq_length, 303 | self.context*2+self.window)]*self.num_spk, 1) # B*nmic, nspk, L, 3*win 304 | all_bf_output = F.conv1d(mic_context.view(1, -1, self.context*2+self.window), 305 | all_filter.view(-1, 1, self.filter_dim), 306 | groups=batch_size*nmic*self.num_spk*seq_length) # 1, B*nmic*nspk*L, win 307 | all_bf_output = all_bf_output.view(batch_size, nmic, self.num_spk, seq_length, self.window) # B, nmic, nspk, L, win 308 | 309 | # reshape to utterance 310 | bf_signal = all_bf_output.view(batch_size*nmic*self.num_spk, -1, self.window*2) 311 | bf_signal1 = bf_signal[:,:,:self.window].contiguous().view(batch_size*nmic*self.num_spk, 1, -1)[:,:,self.stride:] 312 | bf_signal2 = bf_signal[:,:,self.window:].contiguous().view(batch_size*nmic*self.num_spk, 1, -1)[:,:,:-self.stride] 313 | bf_signal = bf_signal1 + bf_signal2 # B*nmic*nspk, 1, T 314 | if rest > 0: 315 | bf_signal = bf_signal[:,:,:-rest] 316 | 317 | bf_signal = bf_signal.view(batch_size, nmic, self.num_spk, -1) # B, nmic, nspk, T 318 | # consider only the valid channels 319 | if num_mic.max() == 0: 320 | bf_signal = bf_signal.mean(1) # B, nspk, T 321 | else: 322 | bf_signal = [bf_signal[b,:num_mic[b]].mean(0).unsqueeze(0) for b in range(batch_size)] # nspk, T 323 | bf_signal = torch.cat(bf_signal, 0) # B, nspk, T 324 | 325 | return bf_signal 326 | 327 | 328 | def test_model(model): 329 | x = torch.rand(2, 4, 32000) # (batch, num_mic, length) 330 | num_mic = torch.from_numpy(np.array([3, 2])).view(-1,).type(x.type()) # ad-hoc array 331 | none_mic = torch.zeros(1).type(x.type()) # fixed-array 332 | y1 = model(x, num_mic.long()) 333 | y2 = model(x, none_mic.long()) 334 | print(y1.shape, y2.shape) # (batch, nspk, length) 335 | 336 | 337 | if __name__ == "__main__": 338 | model_origin = FaSNet_origin(enc_dim=64, feature_dim=64, hidden_dim=128, layer=6, segment_size=50, 339 | nspk=2, win_len=4, context_len=16, sr=16000) 340 | 341 | model_TAC = FaSNet_TAC(enc_dim=64, feature_dim=64, hidden_dim=128, layer=4, segment_size=50, 342 | nspk=2, win_len=4, context_len=16, sr=16000) 343 | 344 | test_model(model_origin) 345 | test_model(model_TAC) 346 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License. 2 | 3 | # Transform-average-concatenate (TAC) for end-to-end microphone permutation and number invariant multi-channel speech separation 4 | 5 | This repository provides the model implementation and dataset generation scripts for the paper ["End-to-end Microphone Permutation and Number Invariant Multi-channel Speech Separation"](https://arxiv.org/abs/1910.14104) by Yi Luo, Zhuo Chen, Nima Mesgarani and Takuya Yoshioka. The paper introduces ***transform-average-concatenate (TAC)***, a simple module to allow end-to-end multi-channel separation systems to be invariant to microphone permutation (indexing) and number. Although designed for ad-hoc array configuration, TAC also provides significant performance improvement in fixed geometry microphone configuration, showing that it can serve as a general design paradigm for end-to-end multi-channel processing systems. 6 | 7 | ## Model 8 | 9 | We implement TAC in the framework of ***filter-and-sum network (FaSNet)***, a recently proposed multi-channel speech separation model operated in time-domain. FaSNet is a neural beamformer that performs the standard filter-and-sum beamforming in time domain, while the beamforming coefficients are estimated by a neural network in an end-to-end fashion. For details please refer to the original paper: ["FaSNet: Low-latency Adaptive Beamforming for Multi-microphone Audio Processing"](https://arxiv.org/abs/1909.13387). 10 | 11 | In this paper we make two main modifications to the original FaSNet: 12 | 1) Instead of the original two-stage architecture, we change it into a single-stage architecture. 13 | 2) TAC is applied throughout the filter estimation module to synchronize the information in different microphones and allow the model to perform *global* decision while estimating the filter coeffients. 14 | 15 | The figure below shows different designs of FaSNet models. 16 | ![](https://github.com/yluo42/TAC/blob/master/flowchart.png) 17 | 18 | The building blocks for the filter estimation modules are based on ***dual-path RNNs (DPRNNs)***, a simple yet effective method for organizing RNN layers to allow successful modeling of extremely long sequential data. For details about DPRNN please refer to ["Dual-path RNN: efficient long sequence modeling for time-domain single-channel speech separation"](https://arxiv.org/abs/1910.06379). The implementation of DPRNN, as well as the combination of DPRNN and TAC, can be found in [*utility/models*](https://github.com/yluo42/TAC/blob/master/utility/models.py). 19 | 20 | ## Dataset 21 | 22 | The evaluation of the model is on both ad-hoc array and fixed geometry array configurations. We simulate two datasets on the public available Librispeech corpus. For data generation please refer to the *data* folder. 23 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Data generation 3 | 4 | The simulated datasets are based on the [Librispeech](http://www.openslr.org/12) corpus and the [100 Nonspeech Sounds](http://web.cse.ohio-state.edu/pnl/corpus/HuNonspeech/HuCorpus.html) corpus. 5 | 6 | **Update 2021.06.16:** It is observed that several utterances might encounter clipping when saving the waveforms to files. Since this is also a possible case in real-world communications, I do not force then to be clipped but instead add an optional choice to avoid clipping or not in the generation script. 7 | 8 | **Update 2019.11.10:** The configuration files and the data generation script have been updated. Please make sure you are using the most recent files for data generation. 9 | 10 | ## Raw data download 11 | 12 | 1) Download the *train-clean-100*, *dev-clean* and *test-clean* data from Librispeech's website and unzip them into any directory. The absolute path for the directory is denoted as *libri_path*, which should contain 3 subfolders *train-clean-100*, *dev-clean* and *test-clean*. 13 | 2) Download the 100 Nonspeech Sounds data and unzip it into any directory. The absolute path for the directory is denoted as *noise_path*. 14 | 3) Download or clone this repository. 15 | 16 | ## Additional Python packages 17 | 18 | - [soundfile](https://pypi.org/project/SoundFile/)==0.10.0 19 | - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR) (it does not provide version information, but the latest build should be fine) 20 | 21 | ## Dataset generation 22 | 23 | run `python create_dataset.py --output-path=your_output_path --avoid-clipping=0 --dataset='adhoc' --libri-path=libri_path --noise-path=noise_path`, where: 24 | 1) *output_path*: the absolute path for saving the output. Default is empty which uses the current directory as output path. 25 | 2) *avoid_clipping*: whether to avoid clipping when saving the waveforms to files. 0 (default): keep the original scale. 1: avoid clipping. 26 | 3) *dataset*: the dataset to generate. It can only be *'adhoc'* or *'fixed'*. 27 | 4) *libri_path*: the absolute path for Librispeech data. 28 | 5) *noise_path*: the absolute path for noise data. 29 | -------------------------------------------------------------------------------- /data/configs/MC_Libri_adhoc_test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_test.pkl -------------------------------------------------------------------------------- /data/configs/MC_Libri_adhoc_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_train.pkl -------------------------------------------------------------------------------- /data/configs/MC_Libri_adhoc_validation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_validation.pkl -------------------------------------------------------------------------------- /data/configs/MC_Libri_fixed_test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_test.pkl -------------------------------------------------------------------------------- /data/configs/MC_Libri_fixed_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_train.pkl -------------------------------------------------------------------------------- /data/configs/MC_Libri_fixed_validation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_validation.pkl -------------------------------------------------------------------------------- /data/create_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import signal 3 | import os 4 | import soundfile as sf 5 | import pickle 6 | import argparse 7 | import gpuRIR 8 | 9 | # generate audio files 10 | def generate_data(output_path='', avoid_clipping=0, dataset='adhoc', libri_path='/home/yi/data/Librispeech', noise_path='/home/yi/data/Nonspeech'): 11 | assert dataset in ['adhoc', 'fixed'], "dataset can only be adhoc or fixed." 12 | 13 | if output_path == '': 14 | output_path = os.getcwd() 15 | 16 | data_type = ['train', 'validation', 'test'] 17 | for i in range(len(data_type)): 18 | # path for config 19 | config_path = os.path.join('configs', 'MC_Libri_'+dataset+'_'+data_type[i]+'.pkl') 20 | 21 | # load pickle file 22 | with open(config_path, 'rb') as f: 23 | configs = pickle.load(f) 24 | 25 | # sample rate is 16k Hz 26 | sr = 16000 27 | # signal length is 4 sec 28 | sig_len = 4 29 | 30 | for utt in range(len(configs)): 31 | this_config = configs[utt] 32 | 33 | # load audio files 34 | speakers = this_config['speech'] 35 | noise = this_config['noise'] 36 | spk1, _ = sf.read(os.path.join(libri_path, speakers[0])) 37 | spk2, _ = sf.read(os.path.join(libri_path, speakers[1])) 38 | noise, _ = sf.read(os.path.join(noise_path, noise)) 39 | 40 | # calculate signal length according to overlap ratio 41 | overlap_ratio = this_config['overlap_ratio'] 42 | actual_len = int(sig_len / (2 - overlap_ratio) * sr) 43 | overlap = int(actual_len*overlap_ratio) 44 | 45 | # truncate speech according to start and end indexes 46 | start_idx = this_config['start_idx'] 47 | end_idx = start_idx + actual_len 48 | spk1 = spk1[start_idx:end_idx] 49 | spk2 = spk2[start_idx:end_idx] 50 | 51 | # rescaling speaker and noise energy according to relative SNR 52 | spk1 = spk1 / np.sqrt(np.sum(spk1**2)+1e-8) * 1e2 53 | spk2 = spk2 / np.sqrt(np.sum(spk2**2)+1e-8) * 1e2 54 | spk2 = spk2 * np.power(10, this_config['spk_snr']/20.) 55 | # repeat noise if necessary 56 | noise = noise[:int(sig_len*sr)] 57 | if len(noise) < int(sig_len*sr): 58 | num_repeat = int(sig_len*sr) // len(noise) 59 | res = int(sig_len*sr) - num_repeat * len(noise) 60 | noise = np.concatenate([np.concatenate([noise]*num_repeat), noise[:res]]) 61 | # rescale noise energy w.r.t mixture energy 62 | noise = noise / np.sqrt(np.sum(noise**2)+1e-8) * np.sqrt(np.sum((spk1+spk2)**2)+1e-8) 63 | noise = noise / np.power(10, this_config['noise_snr']/20.) 64 | 65 | # load locations and room configs 66 | mic_pos = np.asarray(this_config['mic_pos']) 67 | spk_pos = np.asarray(this_config['spk_pos']) 68 | noise_pos = np.asarray(this_config['noise_pos']) 69 | room_size = np.asarray(this_config['room_size']) 70 | rt60 = this_config['RT60'] 71 | num_mic = len(mic_pos) 72 | 73 | # generate RIR 74 | beta = gpuRIR.beta_SabineEstimation(room_size, rt60) 75 | nb_img = gpuRIR.t2n(rt60, room_size) 76 | spk_rir = gpuRIR.simulateRIR(room_size, beta, spk_pos, mic_pos, nb_img, rt60, sr) 77 | noise_rir = gpuRIR.simulateRIR(room_size, beta, noise_pos, mic_pos, nb_img, rt60, sr) 78 | 79 | # convolve with RIR at different mic 80 | echoic_spk1 = [] 81 | echoic_spk2 = [] 82 | echoic_mixture = [] 83 | 84 | if dataset == 'adhoc': 85 | nmic = this_config['num_mic'] 86 | else: 87 | nmic = 6 88 | for mic in range(nmic): 89 | spk1_echoic_sig = signal.fftconvolve(spk1, spk_rir[0][mic]) 90 | spk2_echoic_sig = signal.fftconvolve(spk2, spk_rir[1][mic]) 91 | noise_echoic_sig = signal.fftconvolve(noise, noise_rir[0][mic]) 92 | 93 | # align the speakers according to overlap ratio 94 | pad_length = int((1 - overlap_ratio) * actual_len) 95 | padding = np.zeros(pad_length) 96 | spk1_echoic_sig = np.concatenate([spk1_echoic_sig, padding]) 97 | spk2_echoic_sig = np.concatenate([padding, spk2_echoic_sig]) 98 | # pad or truncate length to 4s if necessary 99 | def pad_sig(x): 100 | if len(x) < sig_len*sr: 101 | zeros = np.zeros(sig_len * sr - len(x)) 102 | return np.concatenate([x, zeros]) 103 | else: 104 | return x[:sig_len*sr] 105 | 106 | spk1_echoic_sig = pad_sig(spk1_echoic_sig) 107 | spk2_echoic_sig = pad_sig(spk2_echoic_sig) 108 | noise_echoic_sig = pad_sig(noise_echoic_sig) 109 | 110 | # sum up for mixture 111 | mixture = spk1_echoic_sig + spk2_echoic_sig + noise_echoic_sig 112 | 113 | if avoid_clipping: 114 | # avoid clipping 115 | max_scale = np.max([np.max(np.abs(mixture)), np.max(np.abs(spk1_echoic_sig)), np.max(np.abs(spk2_echoic_sig))]) 116 | mixture = mixture / max_scale * 0.9 117 | spk1_echoic_sig = spk1_echoic_sig / max_scale * 0.9 118 | spk2_echoic_sig = spk2_echoic_sig / max_scale * 0.9 119 | 120 | # save waveforms 121 | this_save_dir = os.path.join(output_path, 'MC_Libri_'+dataset, data_type[i], str(num_mic)+'mic', 'sample'+str(utt+1)) 122 | if not os.path.exists(this_save_dir): 123 | os.makedirs(this_save_dir) 124 | sf.write(os.path.join(this_save_dir, 'spk1_mic'+str(mic+1)+'.wav'), spk1_echoic_sig, sr) 125 | sf.write(os.path.join(this_save_dir, 'spk2_mic'+str(mic+1)+'.wav'), spk2_echoic_sig, sr) 126 | sf.write(os.path.join(this_save_dir, 'mixture_mic'+str(mic+1)+'.wav'), mixture, sr) 127 | 128 | # print progress 129 | if (utt+1) % (len(configs) // 5) == 0: 130 | print("{} configuration, {} set, {:d} out of {:d} utterances generated.".format(dataset, data_type[i], 131 | utt+1, len(configs))) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser(description='Generate multi-channel Librispeech data') 136 | parser.add_argument('--output-path', metavar='absolute path', required=False, default='', 137 | help="The path to the output directory. Default is the current directory.") 138 | parser.add_argument('--avoid-clipping', metavar='avoid clipping', required=False, default=0, 139 | help="Whether to avoid clipping when saving the waveforms. 0: no clipping. 1: clipping.") 140 | parser.add_argument('--dataset', metavar='dataset type', required=True, 141 | help="The type of dataset to generate. Can only be 'adhoc' or 'fixed'.") 142 | parser.add_argument('--libri-path', metavar='absolute path', required=True, 143 | help="Absolute path for Librispeech folder containing train-clean-100, dev-clean and test-clean folders.") 144 | parser.add_argument('--noise-path', metavar='absolute path', required=True, 145 | help="Absolute path for the 100 Nonspeech sound folder.") 146 | args = parser.parse_args() 147 | generate_data(output_path=args.output_path, dataset=args.dataset, libri_path=args.libri_path, noise_path=args.noise_path) 148 | -------------------------------------------------------------------------------- /flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/flowchart.png -------------------------------------------------------------------------------- /iFaSNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import os 6 | import numpy as np 7 | 8 | from utility.models import * 9 | 10 | # base module for DPRNN/DPRNN-TAC 11 | class BF_module(DPRNN_base): 12 | def __init__(self, *args, **kwargs): 13 | super(BF_module, self).__init__(*args, **kwargs) 14 | 15 | # output layer 16 | self.output = nn.Conv1d(self.feature_dim, self.output_dim, 1) 17 | 18 | def forward(self, input, num_mic): 19 | 20 | if self.model_type == 'DPRNN': 21 | # input: (B, N, T) 22 | batch_size, N, seq_length = input.shape 23 | ch = 1 24 | elif self.model_type == 'DPRNN_TAC': 25 | # input: (B, ch, N, T) 26 | batch_size, ch, N, seq_length = input.shape 27 | 28 | input = input.view(batch_size*ch, N, seq_length) # B*ch, N, T 29 | enc_feature = self.BN(input) 30 | 31 | # split the encoder output into overlapped, longer segments 32 | enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B*ch, N, L, K 33 | 34 | # pass to DPRNN 35 | if self.model_type == 'DPRNN': 36 | output = self.DPRNN(enc_segments).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K 37 | elif self.model_type == 'DPRNN_TAC': 38 | enc_segments = enc_segments.view(batch_size, ch, -1, enc_segments.shape[2], enc_segments.shape[3]) # B, ch, N, L, K 39 | output = self.DPRNN(enc_segments, num_mic).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K 40 | 41 | # overlap-and-add of the outputs 42 | output = self.merge_feature(output, enc_rest) # B*ch*nspk, N, T 43 | 44 | # output layer 45 | bf_filter = self.output(output) # B*ch*nspk, K, T 46 | bf_filter = bf_filter.view(batch_size, ch, self.num_spk, self.output_dim, -1) # B, ch, nspk, K, L 47 | 48 | return bf_filter 49 | 50 | 51 | # base module for FaSNet 52 | class FaSNet_base(nn.Module): 53 | def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=24, 54 | nspk=2, win_len=16, context_len=16, sr=16000): 55 | super(FaSNet_base, self).__init__() 56 | 57 | # parameters 58 | self.window = int(sr * win_len / 1000) 59 | self.stride = self.window // 2 60 | self.context = context_len*2 // win_len 61 | 62 | self.enc_dim = enc_dim 63 | self.feature_dim = feature_dim 64 | self.hidden_dim = hidden_dim 65 | self.segment_size = segment_size 66 | 67 | self.layer = layer 68 | self.num_spk = nspk 69 | self.eps = 1e-8 70 | 71 | # waveform encoder/decoder 72 | self.encoder = nn.Conv1d(1, self.enc_dim, self.window, stride=self.stride, bias=False) 73 | self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.window, stride=self.stride, bias=False) 74 | self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=self.eps) 75 | 76 | def pad_input(self, input, window, stride): 77 | """ 78 | Zero-padding input according to window/stride size. 79 | """ 80 | batch_size, nmic, nsample = input.shape 81 | 82 | # pad the signals at the end for matching the window/stride size 83 | rest = window - (stride + nsample % window) % window 84 | if rest > 0: 85 | pad = torch.zeros(batch_size, nmic, rest).type(input.type()) 86 | input = torch.cat([input, pad], 2) 87 | pad_aux = torch.zeros(batch_size, nmic, stride).type(input.type()) 88 | input = torch.cat([pad_aux, input, pad_aux], 2) 89 | 90 | return input, rest 91 | 92 | def signal_context(self, x, context): 93 | """ 94 | Segmenting the signal into chunks with specific context. 95 | input: 96 | x: size (B, dim, nframe) 97 | context: int 98 | 99 | """ 100 | 101 | batch_size, dim, nframe = x.shape 102 | 103 | zero_pad = torch.zeros(batch_size, dim, context).type(x.type()) 104 | pad_past = [] 105 | pad_future = [] 106 | for i in range(context): 107 | pad_past.append(torch.cat([zero_pad[:,:,i:], x[:,:,:-context+i]], 2).unsqueeze(2)) 108 | pad_future.append(torch.cat([x[:,:,i+1:], zero_pad[:,:,:i+1]], 2).unsqueeze(2)) 109 | 110 | pad_past = torch.cat(pad_past, 2) # B, D, C, L 111 | pad_future = torch.cat(pad_future, 2) # B, D, C, L 112 | all_context = torch.cat([pad_past, x.unsqueeze(2), pad_future], 2) # B, D, 2*C+1, L 113 | 114 | return all_context 115 | 116 | def forward(self, input, num_mic): 117 | """ 118 | input: shape (batch, max_num_ch, T) 119 | num_mic: shape (batch, ), the number of channels for each input. Zero for fixed geometry configuration. 120 | """ 121 | pass 122 | 123 | 124 | # implicit FaSNet (iFaSNet) 125 | class iFaSNet(FaSNet_base): 126 | def __init__(self, *args, **kwargs): 127 | super(iFaSNet, self).__init__(*args, **kwargs) 128 | 129 | # context compression 130 | self.summ_BN = nn.Linear(self.enc_dim, self.feature_dim) 131 | self.summ_RNN = SingleRNN('LSTM',self.feature_dim, self.hidden_dim, bidirectional=True) 132 | self.summ_LN = nn.GroupNorm(1, self.feature_dim, eps=self.eps) 133 | self.summ_output = nn.Linear(self.feature_dim, self.enc_dim) 134 | 135 | # DPRNN-TAC 136 | self.separator = BF_module(self.enc_dim+(self.context*2+1)**2, 137 | self.feature_dim, self.hidden_dim, 138 | self.enc_dim, self.num_spk, self.layer, 139 | self.segment_size, model_type='DPRNN_TAC') 140 | 141 | # context decompression 142 | self.gen_BN = nn.Conv1d(self.enc_dim*2, self.feature_dim, 1) 143 | self.gen_RNN = SingleRNN('LSTM', self.feature_dim, self.hidden_dim, bidirectional=True) 144 | self.gen_LN = nn.GroupNorm(1, self.feature_dim, eps=self.eps) 145 | self.gen_output = nn.Conv1d(self.feature_dim, self.enc_dim, 1) 146 | 147 | def forward(self, input, num_mic): 148 | 149 | batch_size = input.size(0) 150 | nmic = input.size(1) 151 | 152 | # pad input accordingly 153 | input, rest = self.pad_input(input, self.window, self.stride) 154 | 155 | # encoder on all channels 156 | enc_output = self.encoder(input.view(batch_size*nmic, 1, -1)) # B*nmic, N, L 157 | seq_length = enc_output.shape[-1] 158 | 159 | # calculate the context of the encoder output 160 | # consider both past and future 161 | enc_context = self.signal_context(enc_output, self.context) # B*nmic, N, 2C+1, L 162 | enc_context = enc_context.view(batch_size, nmic, self.enc_dim, -1, seq_length) # B, nmic, N, 2C+1, L 163 | 164 | # NCC feature 165 | ref_enc = enc_context[:,0].contiguous() # B, N, 2C+1, L 166 | ref_enc = ref_enc.permute(0,3,1,2).contiguous().view(batch_size*seq_length, self.enc_dim, -1) # B*L, N, 2C+1 167 | enc_context_copy = enc_context.permute(0,4,1,3,2).contiguous().view(batch_size*seq_length, nmic, 168 | -1, self.enc_dim) # B*L, nmic, 2C+1, N 169 | NCC = torch.cat([enc_context_copy[:,i].bmm(ref_enc).unsqueeze(1) for i in range(nmic)], 1) # B*L, nmic, 2C+1, 2C+1 170 | ref_norm = (ref_enc.pow(2).sum(1).unsqueeze(1) + self.eps).sqrt() # B*L, 1, 2C+1 171 | enc_norm = (enc_context_copy.pow(2).sum(3).unsqueeze(3) + self.eps).sqrt() # B*L, nmic, 2C+1, 1 172 | NCC = NCC / (ref_norm.unsqueeze(1) * enc_norm) # B*L, nmic, 2C+1, 2C+1 173 | NCC = torch.cat([NCC[:,:,i] for i in range(NCC.shape[2])], 2) # B*L, nmic, (2C+1)^2 174 | NCC = NCC.view(batch_size, seq_length, nmic, -1).permute(0,2,3,1).contiguous() # B, nmic, (2C+1)^2, L 175 | 176 | # context compression 177 | norm_output = self.enc_LN(enc_output) # B*nmic, N, L 178 | norm_context = self.signal_context(norm_output, self.context) # B*nmic, N, 2C+1, L 179 | norm_context = norm_context.permute(0,3,2,1).contiguous().view(-1, self.context*2+1, self.enc_dim) 180 | norm_context_BN = self.summ_BN(norm_context.view(-1, self.enc_dim)).view(-1, self.context*2+1, self.feature_dim) 181 | embedding = self.summ_RNN(norm_context_BN).transpose(1, 2).contiguous() # B*nmic*L, N, 2C+1 182 | embedding = norm_context_BN.transpose(1, 2).contiguous() + self.summ_LN(embedding) # B*nmic*L, N, 2C+1 183 | embedding = self.summ_output(embedding.mean(2)).view(batch_size, nmic, seq_length, self.enc_dim) # B, nmic, L, N 184 | embedding = embedding.transpose(2, 3).contiguous() # B, nmic, N, L 185 | 186 | input_feature = torch.cat([embedding, NCC], 2) # B, nmic, N+(2C+1)^2, L 187 | 188 | # pass to DPRNN-TAC 189 | embedding = self.separator(input_feature, num_mic)[:,0].contiguous() # B, nspk, N, L 190 | 191 | # concatenate with encoder outputs and generate masks 192 | # context decompression 193 | norm_context = norm_context.view(batch_size, nmic, seq_length, -1, self.enc_dim) # B, nmic, L, 2C+1, N 194 | norm_context = norm_context.permute(0,1,4,3,2)[:,:1].contiguous() # B, 1, N, 2C+1, L 195 | 196 | embedding = torch.cat([embedding.unsqueeze(3)]*(self.context*2+1), 3) # B, nspk, N, 2C+1, L 197 | norm_context = torch.cat([norm_context]*self.num_spk, 1) # B, nspk, N, 2C+1, L 198 | embedding = torch.cat([norm_context, embedding], 2).permute(0,1,4,2,3).contiguous() # B, nspk, L, 2N, 2C+1 199 | all_filter = self.gen_BN(embedding.view(-1, self.enc_dim*2, self.context*2+1)) # B*nspk*L, N, 2C+1 200 | all_filter = all_filter + self.gen_LN(self.gen_RNN(all_filter.transpose(1, 2)).transpose(1, 2)) # B*nspk*L, N, 2C+1 201 | all_filter = self.gen_output(all_filter) # B*nspk*L, N, 2C+1 202 | all_filter = all_filter.view(batch_size, self.num_spk, seq_length, self.enc_dim, -1) # B, nspk, L, N+1, 2C+1 203 | all_filter = all_filter.permute(0,1,3,4,2).contiguous() # B, nspk, N, 2C+1, L 204 | 205 | # apply to with ref mic's encoder context 206 | output = (enc_context[:,:1] * all_filter).mean(3) # B, nspk, N, L 207 | 208 | # decode 209 | bf_signal = self.decoder(output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*nspk, 1, T 210 | 211 | if rest > 0: 212 | bf_signal = bf_signal[:,:,self.stride:-rest-self.stride] 213 | 214 | bf_signal = bf_signal.view(batch_size, self.num_spk, -1) # B, nspk, T 215 | 216 | return bf_signal 217 | 218 | def test_model(model): 219 | x = torch.rand(2, 4, 32000) # (batch, num_mic, length) 220 | num_mic = torch.from_numpy(np.array([3, 2])).view(-1,).type(x.type()) # ad-hoc array 221 | none_mic = torch.zeros(1).type(x.type()) # fixed-array 222 | y1 = model(x, num_mic.long()) 223 | y2 = model(x, none_mic.long()) 224 | print(y1.shape, y2.shape) # (batch, nspk, length) 225 | 226 | 227 | if __name__ == "__main__": 228 | model_iFaSNet = iFaSNet(enc_dim=64, feature_dim=64, hidden_dim=128, layer=6, segment_size=24, 229 | nspk=2, win_len=16, context_len=16, sr=16000) 230 | 231 | test_model(model_iFaSNet) 232 | -------------------------------------------------------------------------------- /utility/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utility/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | import sys 9 | 10 | 11 | class SingleRNN(nn.Module): 12 | """ 13 | Container module for a single RNN layer. 14 | 15 | args: 16 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 17 | input_size: int, dimension of the input feature. The input should have shape 18 | (batch, seq_len, input_size). 19 | hidden_size: int, dimension of the hidden state. 20 | dropout: float, dropout ratio. Default is 0. 21 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 22 | """ 23 | 24 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False): 25 | super(SingleRNN, self).__init__() 26 | 27 | self.rnn_type = rnn_type 28 | self.input_size = input_size 29 | self.hidden_size = hidden_size 30 | self.num_direction = int(bidirectional) + 1 31 | 32 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, 1, dropout=dropout, batch_first=True, bidirectional=bidirectional) 33 | 34 | # linear projection layer 35 | self.proj = nn.Linear(hidden_size*self.num_direction, input_size) 36 | 37 | def forward(self, input): 38 | # input shape: batch, seq, dim 39 | output = input 40 | rnn_output, _ = self.rnn(output) 41 | rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(output.shape) 42 | return rnn_output 43 | 44 | # dual-path RNN 45 | class DPRNN(nn.Module): 46 | """ 47 | Deep duaL-path RNN. 48 | 49 | args: 50 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 51 | input_size: int, dimension of the input feature. The input should have shape 52 | (batch, seq_len, input_size). 53 | hidden_size: int, dimension of the hidden state. 54 | output_size: int, dimension of the output size. 55 | dropout: float, dropout ratio. Default is 0. 56 | num_layers: int, number of stacked RNN layers. Default is 1. 57 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 58 | """ 59 | def __init__(self, rnn_type, input_size, hidden_size, output_size, 60 | dropout=0, num_layers=1, bidirectional=True): 61 | super(DPRNN, self).__init__() 62 | 63 | self.input_size = input_size 64 | self.output_size = output_size 65 | self.hidden_size = hidden_size 66 | 67 | # dual-path RNN 68 | self.row_rnn = nn.ModuleList([]) 69 | self.col_rnn = nn.ModuleList([]) 70 | self.row_norm = nn.ModuleList([]) 71 | self.col_norm = nn.ModuleList([]) 72 | for i in range(num_layers): 73 | self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal 74 | self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional)) 75 | self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 76 | # default is to use noncausal LayerNorm for inter-chunk RNN. For causal setting change it to causal normalization techniques accordingly. 77 | self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 78 | 79 | # output layer 80 | self.output = nn.Sequential(nn.PReLU(), 81 | nn.Conv2d(input_size, output_size, 1) 82 | ) 83 | 84 | def forward(self, input): 85 | # input shape: batch, N, dim1, dim2 86 | # apply RNN on dim1 first and then dim2 87 | 88 | batch_size, _, dim1, dim2 = input.shape 89 | output = input 90 | for i in range(len(self.row_rnn)): 91 | row_input = output.permute(0,3,2,1).contiguous().view(batch_size*dim2, dim1, -1) # B*dim2, dim1, N 92 | row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H 93 | row_output = row_output.view(batch_size, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B, N, dim1, dim2 94 | row_output = self.row_norm[i](row_output) 95 | output = output + row_output 96 | 97 | col_input = output.permute(0,2,3,1).contiguous().view(batch_size*dim1, dim2, -1) # B*dim1, dim2, N 98 | col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H 99 | col_output = col_output.view(batch_size, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B, N, dim1, dim2 100 | col_output = self.col_norm[i](col_output) 101 | output = output + col_output 102 | 103 | output = self.output(output) 104 | 105 | return output 106 | 107 | 108 | # dual-path RNN with transform-average-concatenate (TAC) 109 | class DPRNN_TAC(nn.Module): 110 | """ 111 | Deep duaL-path RNN with transform-average-concatenate (TAC) applied to each layer/block. 112 | 113 | args: 114 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 115 | input_size: int, dimension of the input feature. The input should have shape 116 | (batch, seq_len, input_size). 117 | hidden_size: int, dimension of the hidden state. 118 | output_size: int, dimension of the output size. 119 | dropout: float, dropout ratio. Default is 0. 120 | num_layers: int, number of stacked RNN layers. Default is 1. 121 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 122 | """ 123 | def __init__(self, rnn_type, input_size, hidden_size, output_size, 124 | dropout=0, num_layers=1, bidirectional=True): 125 | super(DPRNN_TAC, self).__init__() 126 | 127 | self.input_size = input_size 128 | self.output_size = output_size 129 | self.hidden_size = hidden_size 130 | 131 | # DPRNN + TAC for 3D input (ch, N, T) 132 | self.row_rnn = nn.ModuleList([]) 133 | self.col_rnn = nn.ModuleList([]) 134 | self.ch_transform = nn.ModuleList([]) 135 | self.ch_average = nn.ModuleList([]) 136 | self.ch_concat = nn.ModuleList([]) 137 | 138 | self.row_norm = nn.ModuleList([]) 139 | self.col_norm = nn.ModuleList([]) 140 | self.ch_norm = nn.ModuleList([]) 141 | 142 | 143 | for i in range(num_layers): 144 | self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal 145 | self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional)) 146 | self.ch_transform.append(nn.Sequential(nn.Linear(input_size, hidden_size*3), 147 | nn.PReLU() 148 | ) 149 | ) 150 | self.ch_average.append(nn.Sequential(nn.Linear(hidden_size*3, hidden_size*3), 151 | nn.PReLU() 152 | ) 153 | ) 154 | self.ch_concat.append(nn.Sequential(nn.Linear(hidden_size*6, input_size), 155 | nn.PReLU() 156 | ) 157 | ) 158 | 159 | 160 | self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 161 | # default is to use noncausal LayerNorm for inter-chunk RNN and TAC modules. For causal setting change them to causal normalization techniques accordingly. 162 | self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 163 | self.ch_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 164 | 165 | # output layer 166 | self.output = nn.Sequential(nn.PReLU(), 167 | nn.Conv2d(input_size, output_size, 1) 168 | ) 169 | 170 | def forward(self, input, num_mic): 171 | # input shape: batch, ch, N, dim1, dim2 172 | # num_mic shape: batch, 173 | # apply RNN on dim1 first, then dim2, then ch 174 | 175 | batch_size, ch, N, dim1, dim2 = input.shape 176 | output = input 177 | for i in range(len(self.row_rnn)): 178 | # intra-segment RNN 179 | output = output.view(batch_size*ch, N, dim1, dim2) # B*ch, N, dim1, dim2 180 | row_input = output.permute(0,3,2,1).contiguous().view(batch_size*ch*dim2, dim1, -1) # B*ch*dim2, dim1, N 181 | row_output = self.row_rnn[i](row_input) # B*ch*dim2, dim1, N 182 | row_output = row_output.view(batch_size*ch, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B*ch, N, dim1, dim2 183 | row_output = self.row_norm[i](row_output) 184 | output = output + row_output # B*ch, N, dim1, dim2 185 | 186 | # inter-segment RNN 187 | col_input = output.permute(0,2,3,1).contiguous().view(batch_size*ch*dim1, dim2, -1) # B*ch*dim1, dim2, N 188 | col_output = self.col_rnn[i](col_input) # B*dim1, dim2, N 189 | col_output = col_output.view(batch_size*ch, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B*ch, N, dim1, dim2 190 | col_output = self.col_norm[i](col_output) 191 | output = output + col_output # B*ch, N, dim1, dim2 192 | 193 | # TAC for cross-channel communication 194 | ch_input = output.view(input.shape) # B, ch, N, dim1, dim2 195 | ch_input = ch_input.permute(0,3,4,1,2).contiguous().view(-1, N) # B*dim1*dim2*ch, N 196 | ch_output = self.ch_transform[i](ch_input).view(batch_size, dim1*dim2, ch, -1) # B, dim1*dim2, ch, H 197 | # mean pooling across channels 198 | if num_mic.max() == 0: 199 | # fixed geometry array 200 | ch_mean = ch_output.mean(2).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H 201 | else: 202 | # only consider valid channels 203 | ch_mean = [ch_output[b,:,:num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size)] # 1, dim1*dim2, H 204 | ch_mean = torch.cat(ch_mean, 0).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H 205 | ch_output = ch_output.view(batch_size*dim1*dim2, ch, -1) # B*dim1*dim2, ch, H 206 | ch_mean = self.ch_average[i](ch_mean).unsqueeze(1).expand_as(ch_output).contiguous() # B*dim1*dim2, ch, H 207 | ch_output = torch.cat([ch_output, ch_mean], 2) # B*dim1*dim2, ch, 2H 208 | ch_output = self.ch_concat[i](ch_output.view(-1, ch_output.shape[-1])) # B*dim1*dim2*ch, N 209 | ch_output = ch_output.view(batch_size, dim1, dim2, ch, -1).permute(0,3,4,1,2).contiguous() # B, ch, N, dim1, dim2 210 | ch_output = self.ch_norm[i](ch_output.view(batch_size*ch, N, dim1, dim2)) # B*ch, N, dim1, dim2 211 | output = output + ch_output 212 | 213 | output = self.output(output) # B*ch, N, dim1, dim2 214 | 215 | return output 216 | 217 | # base module for deep DPRNN 218 | class DPRNN_base(nn.Module): 219 | def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2, 220 | layer=4, segment_size=100, bidirectional=True, model_type='DPRNN', 221 | rnn_type='LSTM'): 222 | super(DPRNN_base, self).__init__() 223 | 224 | assert model_type in ['DPRNN', 'DPRNN_TAC'], "model_type can only be 'DPRNN' or 'DPRNN_TAC'." 225 | 226 | self.input_dim = input_dim 227 | self.feature_dim = feature_dim 228 | self.hidden_dim = hidden_dim 229 | self.output_dim = output_dim 230 | 231 | self.layer = layer 232 | self.segment_size = segment_size 233 | self.num_spk = num_spk 234 | 235 | self.model_type = model_type 236 | 237 | self.eps = 1e-8 238 | 239 | # bottleneck 240 | self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False) 241 | 242 | # DPRNN model 243 | self.DPRNN = getattr(sys.modules[__name__], model_type)(rnn_type, self.feature_dim, self.hidden_dim, self.feature_dim*self.num_spk, 244 | num_layers=layer, bidirectional=bidirectional) 245 | 246 | def pad_segment(self, input, segment_size): 247 | # input is the features: (B, N, T) 248 | batch_size, dim, seq_len = input.shape 249 | segment_stride = segment_size // 2 250 | 251 | rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size 252 | if rest > 0: 253 | pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) 254 | input = torch.cat([input, pad], 2) 255 | 256 | pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) 257 | input = torch.cat([pad_aux, input, pad_aux], 2) 258 | 259 | return input, rest 260 | 261 | def split_feature(self, input, segment_size): 262 | # split the feature into chunks of segment size 263 | # input is the features: (B, N, T) 264 | 265 | input, rest = self.pad_segment(input, segment_size) 266 | batch_size, dim, seq_len = input.shape 267 | segment_stride = segment_size // 2 268 | 269 | segments1 = input[:,:,:-segment_stride].contiguous().view(batch_size, dim, -1, segment_size) 270 | segments2 = input[:,:,segment_stride:].contiguous().view(batch_size, dim, -1, segment_size) 271 | segments = torch.cat([segments1, segments2], 3).view(batch_size, dim, -1, segment_size).transpose(2, 3) 272 | 273 | return segments.contiguous(), rest 274 | 275 | def merge_feature(self, input, rest): 276 | # merge the splitted features into full utterance 277 | # input is the features: (B, N, L, K) 278 | 279 | batch_size, dim, segment_size, _ = input.shape 280 | segment_stride = segment_size // 2 281 | input = input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size*2) # B, N, K, L 282 | 283 | input1 = input[:,:,:,:segment_size].contiguous().view(batch_size, dim, -1)[:,:,segment_stride:] 284 | input2 = input[:,:,:,segment_size:].contiguous().view(batch_size, dim, -1)[:,:,:-segment_stride] 285 | 286 | output = input1 + input2 287 | if rest > 0: 288 | output = output[:,:,:-rest] 289 | 290 | return output.contiguous() # B, N, T 291 | 292 | def forward(self, input): 293 | pass -------------------------------------------------------------------------------- /utility/sdr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import permutations 3 | from torch.autograd import Variable 4 | 5 | import scipy,time,numpy 6 | 7 | import torch 8 | 9 | # original implementation 10 | def compute_measures(se,s,j): 11 | Rss=s.transpose().dot(s) 12 | this_s=s[:,j] 13 | 14 | a=this_s.transpose().dot(se)/Rss[j,j] 15 | e_true=a*this_s 16 | e_res=se-a*this_s 17 | Sss=np.sum((e_true)**2) 18 | Snn=np.sum((e_res)**2) 19 | 20 | SDR=10*np.log10(Sss/Snn) 21 | 22 | Rsr= s.transpose().dot(e_res) 23 | b=np.linalg.inv(Rss).dot(Rsr) 24 | 25 | e_interf = s.dot(b) 26 | e_artif= e_res-e_interf 27 | 28 | SIR=10*np.log10(Sss/np.sum((e_interf)**2)) 29 | SAR=10*np.log10(Sss/np.sum((e_artif)**2)) 30 | return SDR, SIR, SAR 31 | 32 | def GetSDR(se,s): 33 | se=se-np.mean(se,axis=0) 34 | s=s-np.mean(s,axis=0) 35 | nsampl,nsrc=se.shape 36 | nsampl2,nsrc2=s.shape 37 | assert(nsrc2==nsrc) 38 | assert(nsampl2==nsampl) 39 | 40 | SDR=np.zeros((nsrc,nsrc)) 41 | SIR=SDR.copy() 42 | SAR=SDR.copy() 43 | 44 | for jest in range(nsrc): 45 | for jtrue in range(nsrc): 46 | SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue) 47 | 48 | 49 | perm=list(permutations(np.arange(nsrc))) 50 | nperm=len(perm) 51 | meanSIR=np.zeros((nperm,)) 52 | for p in range(nperm): 53 | tp=SIR.transpose().reshape(nsrc*nsrc) 54 | idx=np.arange(nsrc)*nsrc+list(perm[p]) 55 | meanSIR[p]=np.mean(tp[idx]) 56 | popt=np.argmax(meanSIR) 57 | per=list(perm[popt]) 58 | idx=np.arange(nsrc)*nsrc+per 59 | SDR=SDR.transpose().reshape(nsrc*nsrc)[idx] 60 | SIR=SIR.transpose().reshape(nsrc*nsrc)[idx] 61 | SAR=SAR.transpose().reshape(nsrc*nsrc)[idx] 62 | return SDR, SIR, SAR, per 63 | 64 | # Pytorch implementation with batch processing 65 | def calc_sdr_torch(estimation, origin, mask=None): 66 | """ 67 | batch-wise SDR caculation for one audio file on pytorch Variables. 68 | estimation: (batch, nsample) 69 | origin: (batch, nsample) 70 | mask: an optional mask for sequence masking. This is for cases where zero-padding was applied at the end and should not be consider for SDR calculation. 71 | """ 72 | 73 | if mask is not None: 74 | origin = origin * mask 75 | estimation = estimation * mask 76 | 77 | def calculate(estimation, origin): 78 | origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1) 79 | scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1) 80 | 81 | est_true = scale * origin # (batch, nsample) 82 | est_res = estimation - est_true # (batch, nsample) 83 | 84 | true_power = torch.pow(est_true, 2).sum(1) + 1e-8 85 | res_power = torch.pow(est_res, 2).sum(1) + 1e-8 86 | 87 | return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, ) 88 | 89 | best_sdr = calculate(estimation, origin) 90 | 91 | return best_sdr 92 | 93 | 94 | def batch_SDR_torch(estimation, origin, mask=None, return_perm=False): 95 | """ 96 | batch-wise SDR caculation for multiple audio files. 97 | estimation: (batch, nsource, nsample) 98 | origin: (batch, nsource, nsample) 99 | mask: optional, (batch, nsample), binary 100 | return_perm: bool, whether to return the permutation index. Default is false. 101 | """ 102 | 103 | batch_size_est, nsource_est, nsample_est = estimation.size() 104 | batch_size_ori, nsource_ori, nsample_ori = origin.size() 105 | 106 | assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape." 107 | assert nsource_est == nsource_ori, "Estimation and original sources should have same shape." 108 | assert nsample_est == nsample_ori, "Estimation and original sources should have same shape." 109 | 110 | assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal." 111 | 112 | batch_size = batch_size_est 113 | nsource = nsource_est 114 | 115 | # zero mean signals 116 | estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation) 117 | origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation) 118 | 119 | # SDR for each permutation 120 | SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type()) 121 | for i in range(nsource): 122 | for j in range(nsource): 123 | SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask) 124 | 125 | # choose the best permutation 126 | SDR_max = [] 127 | SDR_perm = [] 128 | perm = sorted(list(set(permutations(np.arange(nsource))))) 129 | for permute in perm: 130 | sdr = [] 131 | for idx in range(len(permute)): 132 | sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1)) 133 | sdr = torch.sum(torch.cat(sdr, 1), 1) 134 | SDR_perm.append(sdr.view(batch_size, 1)) 135 | SDR_perm = torch.cat(SDR_perm, 1) 136 | SDR_max, SDR_idx = torch.max(SDR_perm, dim=1) 137 | 138 | if not return_perm: 139 | return SDR_max / nsource 140 | else: 141 | return SDR_max / nsource, SDR_idx --------------------------------------------------------------------------------