├── FaSNet.py
├── README.md
├── data
├── README.md
├── configs
│ ├── MC_Libri_adhoc_test.pkl
│ ├── MC_Libri_adhoc_train.pkl
│ ├── MC_Libri_adhoc_validation.pkl
│ ├── MC_Libri_fixed_test.pkl
│ ├── MC_Libri_fixed_train.pkl
│ └── MC_Libri_fixed_validation.pkl
└── create_dataset.py
├── flowchart.png
├── iFaSNet.py
└── utility
├── __init__.py
├── models.py
└── sdr.py
/FaSNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import os
6 | import numpy as np
7 |
8 | from utility.models import *
9 |
10 | # DPRNN for beamforming filter estimation
11 | class BF_module(DPRNN_base):
12 | def __init__(self, *args, **kwargs):
13 | super(BF_module, self).__init__(*args, **kwargs)
14 |
15 | # gated output layer
16 | self.output = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1),
17 | nn.Tanh()
18 | )
19 | self.output_gate = nn.Sequential(nn.Conv1d(self.feature_dim, self.output_dim, 1),
20 | nn.Sigmoid()
21 | )
22 |
23 | def forward(self, input, num_mic):
24 |
25 | if self.model_type == 'DPRNN':
26 | # input: (B, N, T)
27 | batch_size, N, seq_length = input.shape
28 | ch = 1
29 | elif self.model_type == 'DPRNN_TAC':
30 | # input: (B, ch, N, T)
31 | batch_size, ch, N, seq_length = input.shape
32 |
33 | input = input.view(batch_size*ch, N, seq_length) # B*ch, N, T
34 | enc_feature = self.BN(input)
35 |
36 | # split the encoder output into overlapped, longer segments
37 | enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B*ch, N, L, K
38 |
39 | # pass to DPRNN
40 | if self.model_type == 'DPRNN':
41 | output = self.DPRNN(enc_segments).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K
42 | elif self.model_type == 'DPRNN_TAC':
43 | enc_segments = enc_segments.view(batch_size, ch, -1, enc_segments.shape[2], enc_segments.shape[3]) # B, ch, N, L, K
44 | output = self.DPRNN(enc_segments, num_mic).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K
45 |
46 | # overlap-and-add of the outputs
47 | output = self.merge_feature(output, enc_rest) # B*ch*nspk, N, T
48 |
49 | # gated output layer for filter generation
50 | bf_filter = self.output(output) * self.output_gate(output) # B*ch*nspk, K, T
51 | bf_filter = bf_filter.transpose(1, 2).contiguous().view(batch_size, ch, self.num_spk, -1, self.output_dim) # B, ch, nspk, L, N
52 |
53 | return bf_filter
54 |
55 |
56 | # base module for FaSNet
57 | class FaSNet_base(nn.Module):
58 | def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=50,
59 | nspk=2, win_len=4, context_len=16, sr=16000):
60 | super(FaSNet_base, self).__init__()
61 |
62 | # parameters
63 | self.window = int(sr * win_len / 1000)
64 | self.context = int(sr * context_len / 1000)
65 | self.stride = self.window // 2
66 |
67 | self.filter_dim = self.context*2+1
68 | self.enc_dim = enc_dim
69 | self.feature_dim = feature_dim
70 | self.hidden_dim = hidden_dim
71 | self.segment_size = segment_size
72 |
73 | self.layer = layer
74 | self.num_spk = nspk
75 | self.eps = 1e-8
76 |
77 | # waveform encoder
78 | self.encoder = nn.Conv1d(1, self.enc_dim, self.context*2+self.window, bias=False)
79 | self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=1e-8)
80 |
81 | def pad_input(self, input, window):
82 | """
83 | Zero-padding input according to window/stride size.
84 | """
85 | batch_size, nmic, nsample = input.shape
86 | stride = window // 2
87 |
88 | # pad the signals at the end for matching the window/stride size
89 | rest = window - (stride + nsample % window) % window
90 | if rest > 0:
91 | pad = torch.zeros(batch_size, nmic, rest).type(input.type())
92 | input = torch.cat([input, pad], 2)
93 | pad_aux = torch.zeros(batch_size, nmic, stride).type(input.type())
94 | input = torch.cat([pad_aux, input, pad_aux], 2)
95 |
96 | return input, rest
97 |
98 |
99 | def seg_signal_context(self, x, window, context):
100 | """
101 | Segmenting the signal into chunks with specific context.
102 | input:
103 | x: size (B, ch, T)
104 | window: int
105 | context: int
106 |
107 | """
108 |
109 | # pad input accordingly
110 | # first pad according to window size
111 | input, rest = self.pad_input(x, window)
112 | batch_size, nmic, nsample = input.shape
113 | stride = window // 2
114 |
115 | # pad another context size
116 | pad_context = torch.zeros(batch_size, nmic, context).type(input.type())
117 | input = torch.cat([pad_context, input, pad_context], 2) # B, ch, L
118 |
119 | # calculate index for each chunk
120 | nchunk = 2*nsample // window - 1
121 | begin_idx = np.arange(nchunk)*stride
122 | begin_idx = torch.from_numpy(begin_idx).type(input.type()).long().view(1, 1, -1) # 1, 1, nchunk
123 | begin_idx = begin_idx.expand(batch_size, nmic, nchunk) # B, ch, nchunk
124 | # select entries from index
125 | chunks = [torch.gather(input, 2, begin_idx+i).unsqueeze(3) for i in range(2*context + window)] # B, ch, nchunk, 1
126 | chunks = torch.cat(chunks, 3) # B, ch, nchunk, chunk_size
127 |
128 | # center frame
129 | center_frame = chunks[:,:,:,context:context+window]
130 |
131 | return center_frame, chunks, rest
132 |
133 | def seq_cos_sim(self, ref, target):
134 | """
135 | Cosine similarity between some reference mics and some target mics
136 | ref: shape (nmic1, L, seg1)
137 | target: shape (nmic2, L, seg2)
138 | """
139 |
140 | assert ref.size(1) == target.size(1), "Inputs should have same length."
141 | assert ref.size(2) >= target.size(2), "Reference input should be no smaller than the target input."
142 |
143 | seq_length = ref.size(1)
144 |
145 | larger_ch = ref.size(0)
146 | if target.size(0) > ref.size(0):
147 | ref = ref.expand(target.size(0), ref.size(1), ref.size(2)).contiguous() # nmic2, L, seg1
148 | larger_ch = target.size(0)
149 | elif target.size(0) < ref.size(0):
150 | target = target.expand(ref.size(0), target.size(1), target.size(2)).contiguous() # nmic1, L, seg2
151 |
152 | # L2 norms
153 | ref_norm = F.conv1d(ref.view(1, -1, ref.size(2)).pow(2),
154 | torch.ones(ref.size(0)*ref.size(1), 1, target.size(2)).type(ref.type()),
155 | groups=larger_ch*seq_length) # 1, larger_ch*L, seg1-seg2+1
156 | ref_norm = ref_norm.sqrt() + self.eps
157 | target_norm = target.norm(2, dim=2).view(1, -1, 1) + self.eps # 1, larger_ch*L, 1
158 | # cosine similarity
159 | cos_sim = F.conv1d(ref.view(1, -1, ref.size(2)),
160 | target.view(-1, 1, target.size(2)),
161 | groups=larger_ch*seq_length) # 1, larger_ch*L, seg1-seg2+1
162 | cos_sim = cos_sim / (ref_norm * target_norm)
163 |
164 | return cos_sim.view(larger_ch, seq_length, -1)
165 |
166 | def forward(self, input, num_mic):
167 | """
168 | input: shape (batch, max_num_ch, T)
169 | num_mic: shape (batch, ), the number of channels for each input. Zero for fixed geometry configuration.
170 | """
171 | pass
172 |
173 |
174 |
175 | # original FaSNet
176 | class FaSNet_origin(FaSNet_base):
177 | def __init__(self, *args, **kwargs):
178 | super(FaSNet_origin, self).__init__(*args, **kwargs)
179 |
180 | # DPRNN for ref mic
181 | self.ref_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim,
182 | self.filter_dim, self.num_spk, self.layer, self.segment_size, model_type='DPRNN')
183 |
184 | # DPRNN for other mics
185 | self.other_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim,
186 | self.filter_dim, 1, self.layer, self.segment_size, model_type='DPRNN')
187 |
188 |
189 | def forward(self, input, num_mic):
190 |
191 | batch_size = input.size(0)
192 | nmic = input.size(1)
193 |
194 | # split input into chunks
195 | all_seg, all_mic_context, rest = self.seg_signal_context(input, self.window, self.context) # B, nmic, L, win/chunk
196 | seq_length = all_seg.size(2)
197 |
198 | # first step: filtering the ref mic to create a clean estimate
199 | # calculate cosine similarity
200 | ref_context = all_mic_context[:,0].contiguous().view(1, -1, self.context*2+self.window) # 1, B*L, 3*win
201 | other_segment = all_seg[:,1:].contiguous().transpose(0, 1).contiguous().view(nmic-1, -1, self.window) # nmic-1, B*L, win
202 | ref_cos_sim = self.seq_cos_sim(ref_context, other_segment) # nmic-1, B*L, 2*win+1
203 | ref_cos_sim = ref_cos_sim.view(nmic-1, batch_size, seq_length, self.filter_dim) # nmic-1, B, L, 2*win+1
204 | if num_mic.max() == 0:
205 | ref_cos_sim = ref_cos_sim.mean(0) # B, L, 2*win+1
206 | ref_cos_sim = ref_cos_sim.transpose(1, 2).contiguous() # B, 2*win+1, L
207 | else:
208 | # consider only the valid channels
209 | ref_cos_sim = [ref_cos_sim[:num_mic[b],b,:].mean(0).unsqueeze(0) for b in range(batch_size)] # 1, L, 2*win+1
210 | ref_cos_sim = torch.cat(ref_cos_sim, 0).transpose(1, 2).contiguous() # B, 2*win+1, L
211 |
212 |
213 | # pass to a DPRNN
214 | ref_feature = all_mic_context[:,0].contiguous().view(batch_size*seq_length, 1, self.context*2+self.window)
215 | ref_feature = self.encoder(ref_feature) # B*L, N, 1
216 | ref_feature = ref_feature.view(batch_size, seq_length, self.enc_dim).transpose(1, 2).contiguous() # B, N, L
217 | ref_filter = self.ref_BF(torch.cat([self.enc_LN(ref_feature), ref_cos_sim], 1), num_mic) # B, 1, nspk, L, 2*win+1
218 |
219 | # convolve with ref mic context segments
220 | ref_context = torch.cat([all_mic_context[:,0].unsqueeze(1)]*self.num_spk, 1) # B, nspk, L, 3*win
221 | ref_output = F.conv1d(ref_context.view(1, -1, self.context*2+self.window),
222 | ref_filter.view(-1, 1, self.filter_dim),
223 | groups=batch_size*self.num_spk*seq_length) # 1, B*nspk*L, win
224 | ref_output = ref_output.view(batch_size*self.num_spk, seq_length, self.window) # B*nspk, L, win
225 |
226 | # second step: use the ref output as the cue, beamform other mics
227 | # calculate cosine similarity
228 | other_context = torch.cat([all_mic_context[:,1:].unsqueeze(1)]*self.num_spk, 1) # B, nspk, nmic-1, L, 3*win
229 | other_context_saved = other_context.view(batch_size*self.num_spk, nmic-1, seq_length, self.context*2+self.window) # B*nspk, nmic-1, L, 3*win
230 | other_context = other_context_saved.transpose(0, 1).contiguous().view(nmic-1, -1, self.context*2+self.window) # nmic-1, B*nspk*L, 3*win
231 | ref_segment = ref_output.view(1, -1, self.window) # 1, B*nspk*L, win
232 | other_cos_sim = self.seq_cos_sim(other_context, ref_segment) # nmic-1, B*nspk*L, 2*win+1
233 | other_cos_sim = other_cos_sim.view(nmic-1, batch_size*self.num_spk, seq_length, self.filter_dim) # nmic-1, B*nspk, L, 2*win+1
234 | other_cos_sim = other_cos_sim.permute(1,0,3,2).contiguous().view(-1, self.filter_dim, seq_length) # B*nspk*(nmic-1), 2*win+1, L
235 |
236 | # pass to another DPRNN
237 | other_feature = self.encoder(other_context_saved.view(-1, 1, self.context*2+self.window)).view(-1, seq_length, self.enc_dim) # B*nspk*(nmic-1), L, N
238 | other_feature = other_feature.transpose(1, 2).contiguous() # B*nspk*(nmic-1), N, L
239 | other_filter = self.other_BF(torch.cat([self.enc_LN(other_feature), other_cos_sim], 1), num_mic) # B*nspk*(nmic-1), 1, 1, L, 2*win+1
240 |
241 | # convolve with other mic context segments
242 | other_output = F.conv1d(other_context_saved.view(1, -1, self.context*2+self.window),
243 | other_filter.view(-1, 1, self.filter_dim),
244 | groups=batch_size*self.num_spk*(nmic-1)*seq_length) # 1, B*nspk*(nmic-1)*L, win
245 | other_output = other_output.view(batch_size*self.num_spk, nmic-1, seq_length, self.window) # B*nspk, nmic-1, L, win
246 |
247 | all_bf_output = torch.cat([ref_output.unsqueeze(1), other_output], 1) # B*nspk, nmic, L, win
248 |
249 | # reshape to utterance
250 | bf_signal = all_bf_output.view(batch_size*self.num_spk*nmic, -1, self.window*2)
251 | bf_signal1 = bf_signal[:,:,:self.window].contiguous().view(batch_size*self.num_spk*nmic, 1, -1)[:,:,self.stride:]
252 | bf_signal2 = bf_signal[:,:,self.window:].contiguous().view(batch_size*self.num_spk*nmic, 1, -1)[:,:,:-self.stride]
253 | bf_signal = bf_signal1 + bf_signal2 # B*nspk*nmic, 1, T
254 | if rest > 0:
255 | bf_signal = bf_signal[:,:,:-rest]
256 |
257 | bf_signal = bf_signal.view(batch_size, self.num_spk, nmic, -1) # B, nspk, nmic, T
258 | # consider only the valid channels
259 | if num_mic.max() == 0:
260 | bf_signal = bf_signal.mean(2) # B, nspk, T
261 | else:
262 | bf_signal = [bf_signal[b,:,:num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size)] # nspk, T
263 | bf_signal = torch.cat(bf_signal, 0) # B, nspk, T
264 |
265 | return bf_signal
266 |
267 | # single-stage FaSNet + TAC
268 | class FaSNet_TAC(FaSNet_base):
269 | def __init__(self, *args, **kwargs):
270 | super(FaSNet_TAC, self).__init__(*args, **kwargs)
271 |
272 | # DPRNN + TAC for estimation
273 | self.all_BF = BF_module(self.filter_dim+self.enc_dim, self.feature_dim, self.hidden_dim,
274 | self.filter_dim, self.num_spk, self.layer, self.segment_size, model_type='DPRNN_TAC')
275 |
276 | def forward(self, input, num_mic):
277 |
278 | batch_size = input.size(0)
279 | nmic = input.size(1)
280 |
281 | # split input into chunks
282 | all_seg, all_mic_context, rest = self.seg_signal_context(input, self.window, self.context) # B, nmic, L, win/chunk
283 | seq_length = all_seg.size(2)
284 |
285 | # embeddings for all channels
286 | enc_output = self.encoder(all_mic_context.view(-1, 1, self.context*2+self.window)).view(batch_size*nmic, seq_length, self.enc_dim).transpose(1, 2).contiguous() # B*nmic, N, L
287 | enc_output = self.enc_LN(enc_output).view(batch_size, nmic, self.enc_dim, seq_length) # B, nmic, N, L
288 |
289 | # calculate the cosine similarities for ref channel's center frame with all channels' context
290 |
291 | ref_seg = all_seg[:,0].contiguous().view(1, -1, self.window) # 1, B*L, win
292 | all_context = all_mic_context.transpose(0, 1).contiguous().view(nmic, -1, self.context*2+self.window) # 1, B*L, 3*win
293 | all_cos_sim = self.seq_cos_sim(all_context, ref_seg) # nmic, B*L, 2*win+1
294 | all_cos_sim = all_cos_sim.view(nmic, batch_size, seq_length, self.filter_dim).permute(1,0,3,2).contiguous() # B, nmic, 2*win+1, L
295 |
296 | input_feature = torch.cat([enc_output, all_cos_sim], 2) # B, nmic, N+2*win+1, L
297 |
298 | # pass to DPRNN
299 | all_filter = self.all_BF(input_feature, num_mic) # B, ch, nspk, L, 2*win+1
300 |
301 | # convolve with all mic's context
302 | mic_context = torch.cat([all_mic_context.view(batch_size*nmic, 1, seq_length,
303 | self.context*2+self.window)]*self.num_spk, 1) # B*nmic, nspk, L, 3*win
304 | all_bf_output = F.conv1d(mic_context.view(1, -1, self.context*2+self.window),
305 | all_filter.view(-1, 1, self.filter_dim),
306 | groups=batch_size*nmic*self.num_spk*seq_length) # 1, B*nmic*nspk*L, win
307 | all_bf_output = all_bf_output.view(batch_size, nmic, self.num_spk, seq_length, self.window) # B, nmic, nspk, L, win
308 |
309 | # reshape to utterance
310 | bf_signal = all_bf_output.view(batch_size*nmic*self.num_spk, -1, self.window*2)
311 | bf_signal1 = bf_signal[:,:,:self.window].contiguous().view(batch_size*nmic*self.num_spk, 1, -1)[:,:,self.stride:]
312 | bf_signal2 = bf_signal[:,:,self.window:].contiguous().view(batch_size*nmic*self.num_spk, 1, -1)[:,:,:-self.stride]
313 | bf_signal = bf_signal1 + bf_signal2 # B*nmic*nspk, 1, T
314 | if rest > 0:
315 | bf_signal = bf_signal[:,:,:-rest]
316 |
317 | bf_signal = bf_signal.view(batch_size, nmic, self.num_spk, -1) # B, nmic, nspk, T
318 | # consider only the valid channels
319 | if num_mic.max() == 0:
320 | bf_signal = bf_signal.mean(1) # B, nspk, T
321 | else:
322 | bf_signal = [bf_signal[b,:num_mic[b]].mean(0).unsqueeze(0) for b in range(batch_size)] # nspk, T
323 | bf_signal = torch.cat(bf_signal, 0) # B, nspk, T
324 |
325 | return bf_signal
326 |
327 |
328 | def test_model(model):
329 | x = torch.rand(2, 4, 32000) # (batch, num_mic, length)
330 | num_mic = torch.from_numpy(np.array([3, 2])).view(-1,).type(x.type()) # ad-hoc array
331 | none_mic = torch.zeros(1).type(x.type()) # fixed-array
332 | y1 = model(x, num_mic.long())
333 | y2 = model(x, none_mic.long())
334 | print(y1.shape, y2.shape) # (batch, nspk, length)
335 |
336 |
337 | if __name__ == "__main__":
338 | model_origin = FaSNet_origin(enc_dim=64, feature_dim=64, hidden_dim=128, layer=6, segment_size=50,
339 | nspk=2, win_len=4, context_len=16, sr=16000)
340 |
341 | model_TAC = FaSNet_TAC(enc_dim=64, feature_dim=64, hidden_dim=128, layer=4, segment_size=50,
342 | nspk=2, win_len=4, context_len=16, sr=16000)
343 |
344 | test_model(model_origin)
345 | test_model(model_TAC)
346 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 3.0 United States License.
2 |
3 | # Transform-average-concatenate (TAC) for end-to-end microphone permutation and number invariant multi-channel speech separation
4 |
5 | This repository provides the model implementation and dataset generation scripts for the paper ["End-to-end Microphone Permutation and Number Invariant Multi-channel Speech Separation"](https://arxiv.org/abs/1910.14104) by Yi Luo, Zhuo Chen, Nima Mesgarani and Takuya Yoshioka. The paper introduces ***transform-average-concatenate (TAC)***, a simple module to allow end-to-end multi-channel separation systems to be invariant to microphone permutation (indexing) and number. Although designed for ad-hoc array configuration, TAC also provides significant performance improvement in fixed geometry microphone configuration, showing that it can serve as a general design paradigm for end-to-end multi-channel processing systems.
6 |
7 | ## Model
8 |
9 | We implement TAC in the framework of ***filter-and-sum network (FaSNet)***, a recently proposed multi-channel speech separation model operated in time-domain. FaSNet is a neural beamformer that performs the standard filter-and-sum beamforming in time domain, while the beamforming coefficients are estimated by a neural network in an end-to-end fashion. For details please refer to the original paper: ["FaSNet: Low-latency Adaptive Beamforming for Multi-microphone Audio Processing"](https://arxiv.org/abs/1909.13387).
10 |
11 | In this paper we make two main modifications to the original FaSNet:
12 | 1) Instead of the original two-stage architecture, we change it into a single-stage architecture.
13 | 2) TAC is applied throughout the filter estimation module to synchronize the information in different microphones and allow the model to perform *global* decision while estimating the filter coeffients.
14 |
15 | The figure below shows different designs of FaSNet models.
16 | 
17 |
18 | The building blocks for the filter estimation modules are based on ***dual-path RNNs (DPRNNs)***, a simple yet effective method for organizing RNN layers to allow successful modeling of extremely long sequential data. For details about DPRNN please refer to ["Dual-path RNN: efficient long sequence modeling for time-domain single-channel speech separation"](https://arxiv.org/abs/1910.06379). The implementation of DPRNN, as well as the combination of DPRNN and TAC, can be found in [*utility/models*](https://github.com/yluo42/TAC/blob/master/utility/models.py).
19 |
20 | ## Dataset
21 |
22 | The evaluation of the model is on both ad-hoc array and fixed geometry array configurations. We simulate two datasets on the public available Librispeech corpus. For data generation please refer to the *data* folder.
23 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Data generation
3 |
4 | The simulated datasets are based on the [Librispeech](http://www.openslr.org/12) corpus and the [100 Nonspeech Sounds](http://web.cse.ohio-state.edu/pnl/corpus/HuNonspeech/HuCorpus.html) corpus.
5 |
6 | **Update 2021.06.16:** It is observed that several utterances might encounter clipping when saving the waveforms to files. Since this is also a possible case in real-world communications, I do not force then to be clipped but instead add an optional choice to avoid clipping or not in the generation script.
7 |
8 | **Update 2019.11.10:** The configuration files and the data generation script have been updated. Please make sure you are using the most recent files for data generation.
9 |
10 | ## Raw data download
11 |
12 | 1) Download the *train-clean-100*, *dev-clean* and *test-clean* data from Librispeech's website and unzip them into any directory. The absolute path for the directory is denoted as *libri_path*, which should contain 3 subfolders *train-clean-100*, *dev-clean* and *test-clean*.
13 | 2) Download the 100 Nonspeech Sounds data and unzip it into any directory. The absolute path for the directory is denoted as *noise_path*.
14 | 3) Download or clone this repository.
15 |
16 | ## Additional Python packages
17 |
18 | - [soundfile](https://pypi.org/project/SoundFile/)==0.10.0
19 | - [gpuRIR](https://github.com/DavidDiazGuerra/gpuRIR) (it does not provide version information, but the latest build should be fine)
20 |
21 | ## Dataset generation
22 |
23 | run `python create_dataset.py --output-path=your_output_path --avoid-clipping=0 --dataset='adhoc' --libri-path=libri_path --noise-path=noise_path`, where:
24 | 1) *output_path*: the absolute path for saving the output. Default is empty which uses the current directory as output path.
25 | 2) *avoid_clipping*: whether to avoid clipping when saving the waveforms to files. 0 (default): keep the original scale. 1: avoid clipping.
26 | 3) *dataset*: the dataset to generate. It can only be *'adhoc'* or *'fixed'*.
27 | 4) *libri_path*: the absolute path for Librispeech data.
28 | 5) *noise_path*: the absolute path for noise data.
29 |
--------------------------------------------------------------------------------
/data/configs/MC_Libri_adhoc_test.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_test.pkl
--------------------------------------------------------------------------------
/data/configs/MC_Libri_adhoc_train.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_train.pkl
--------------------------------------------------------------------------------
/data/configs/MC_Libri_adhoc_validation.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_adhoc_validation.pkl
--------------------------------------------------------------------------------
/data/configs/MC_Libri_fixed_test.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_test.pkl
--------------------------------------------------------------------------------
/data/configs/MC_Libri_fixed_train.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_train.pkl
--------------------------------------------------------------------------------
/data/configs/MC_Libri_fixed_validation.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/data/configs/MC_Libri_fixed_validation.pkl
--------------------------------------------------------------------------------
/data/create_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import signal
3 | import os
4 | import soundfile as sf
5 | import pickle
6 | import argparse
7 | import gpuRIR
8 |
9 | # generate audio files
10 | def generate_data(output_path='', avoid_clipping=0, dataset='adhoc', libri_path='/home/yi/data/Librispeech', noise_path='/home/yi/data/Nonspeech'):
11 | assert dataset in ['adhoc', 'fixed'], "dataset can only be adhoc or fixed."
12 |
13 | if output_path == '':
14 | output_path = os.getcwd()
15 |
16 | data_type = ['train', 'validation', 'test']
17 | for i in range(len(data_type)):
18 | # path for config
19 | config_path = os.path.join('configs', 'MC_Libri_'+dataset+'_'+data_type[i]+'.pkl')
20 |
21 | # load pickle file
22 | with open(config_path, 'rb') as f:
23 | configs = pickle.load(f)
24 |
25 | # sample rate is 16k Hz
26 | sr = 16000
27 | # signal length is 4 sec
28 | sig_len = 4
29 |
30 | for utt in range(len(configs)):
31 | this_config = configs[utt]
32 |
33 | # load audio files
34 | speakers = this_config['speech']
35 | noise = this_config['noise']
36 | spk1, _ = sf.read(os.path.join(libri_path, speakers[0]))
37 | spk2, _ = sf.read(os.path.join(libri_path, speakers[1]))
38 | noise, _ = sf.read(os.path.join(noise_path, noise))
39 |
40 | # calculate signal length according to overlap ratio
41 | overlap_ratio = this_config['overlap_ratio']
42 | actual_len = int(sig_len / (2 - overlap_ratio) * sr)
43 | overlap = int(actual_len*overlap_ratio)
44 |
45 | # truncate speech according to start and end indexes
46 | start_idx = this_config['start_idx']
47 | end_idx = start_idx + actual_len
48 | spk1 = spk1[start_idx:end_idx]
49 | spk2 = spk2[start_idx:end_idx]
50 |
51 | # rescaling speaker and noise energy according to relative SNR
52 | spk1 = spk1 / np.sqrt(np.sum(spk1**2)+1e-8) * 1e2
53 | spk2 = spk2 / np.sqrt(np.sum(spk2**2)+1e-8) * 1e2
54 | spk2 = spk2 * np.power(10, this_config['spk_snr']/20.)
55 | # repeat noise if necessary
56 | noise = noise[:int(sig_len*sr)]
57 | if len(noise) < int(sig_len*sr):
58 | num_repeat = int(sig_len*sr) // len(noise)
59 | res = int(sig_len*sr) - num_repeat * len(noise)
60 | noise = np.concatenate([np.concatenate([noise]*num_repeat), noise[:res]])
61 | # rescale noise energy w.r.t mixture energy
62 | noise = noise / np.sqrt(np.sum(noise**2)+1e-8) * np.sqrt(np.sum((spk1+spk2)**2)+1e-8)
63 | noise = noise / np.power(10, this_config['noise_snr']/20.)
64 |
65 | # load locations and room configs
66 | mic_pos = np.asarray(this_config['mic_pos'])
67 | spk_pos = np.asarray(this_config['spk_pos'])
68 | noise_pos = np.asarray(this_config['noise_pos'])
69 | room_size = np.asarray(this_config['room_size'])
70 | rt60 = this_config['RT60']
71 | num_mic = len(mic_pos)
72 |
73 | # generate RIR
74 | beta = gpuRIR.beta_SabineEstimation(room_size, rt60)
75 | nb_img = gpuRIR.t2n(rt60, room_size)
76 | spk_rir = gpuRIR.simulateRIR(room_size, beta, spk_pos, mic_pos, nb_img, rt60, sr)
77 | noise_rir = gpuRIR.simulateRIR(room_size, beta, noise_pos, mic_pos, nb_img, rt60, sr)
78 |
79 | # convolve with RIR at different mic
80 | echoic_spk1 = []
81 | echoic_spk2 = []
82 | echoic_mixture = []
83 |
84 | if dataset == 'adhoc':
85 | nmic = this_config['num_mic']
86 | else:
87 | nmic = 6
88 | for mic in range(nmic):
89 | spk1_echoic_sig = signal.fftconvolve(spk1, spk_rir[0][mic])
90 | spk2_echoic_sig = signal.fftconvolve(spk2, spk_rir[1][mic])
91 | noise_echoic_sig = signal.fftconvolve(noise, noise_rir[0][mic])
92 |
93 | # align the speakers according to overlap ratio
94 | pad_length = int((1 - overlap_ratio) * actual_len)
95 | padding = np.zeros(pad_length)
96 | spk1_echoic_sig = np.concatenate([spk1_echoic_sig, padding])
97 | spk2_echoic_sig = np.concatenate([padding, spk2_echoic_sig])
98 | # pad or truncate length to 4s if necessary
99 | def pad_sig(x):
100 | if len(x) < sig_len*sr:
101 | zeros = np.zeros(sig_len * sr - len(x))
102 | return np.concatenate([x, zeros])
103 | else:
104 | return x[:sig_len*sr]
105 |
106 | spk1_echoic_sig = pad_sig(spk1_echoic_sig)
107 | spk2_echoic_sig = pad_sig(spk2_echoic_sig)
108 | noise_echoic_sig = pad_sig(noise_echoic_sig)
109 |
110 | # sum up for mixture
111 | mixture = spk1_echoic_sig + spk2_echoic_sig + noise_echoic_sig
112 |
113 | if avoid_clipping:
114 | # avoid clipping
115 | max_scale = np.max([np.max(np.abs(mixture)), np.max(np.abs(spk1_echoic_sig)), np.max(np.abs(spk2_echoic_sig))])
116 | mixture = mixture / max_scale * 0.9
117 | spk1_echoic_sig = spk1_echoic_sig / max_scale * 0.9
118 | spk2_echoic_sig = spk2_echoic_sig / max_scale * 0.9
119 |
120 | # save waveforms
121 | this_save_dir = os.path.join(output_path, 'MC_Libri_'+dataset, data_type[i], str(num_mic)+'mic', 'sample'+str(utt+1))
122 | if not os.path.exists(this_save_dir):
123 | os.makedirs(this_save_dir)
124 | sf.write(os.path.join(this_save_dir, 'spk1_mic'+str(mic+1)+'.wav'), spk1_echoic_sig, sr)
125 | sf.write(os.path.join(this_save_dir, 'spk2_mic'+str(mic+1)+'.wav'), spk2_echoic_sig, sr)
126 | sf.write(os.path.join(this_save_dir, 'mixture_mic'+str(mic+1)+'.wav'), mixture, sr)
127 |
128 | # print progress
129 | if (utt+1) % (len(configs) // 5) == 0:
130 | print("{} configuration, {} set, {:d} out of {:d} utterances generated.".format(dataset, data_type[i],
131 | utt+1, len(configs)))
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser(description='Generate multi-channel Librispeech data')
136 | parser.add_argument('--output-path', metavar='absolute path', required=False, default='',
137 | help="The path to the output directory. Default is the current directory.")
138 | parser.add_argument('--avoid-clipping', metavar='avoid clipping', required=False, default=0,
139 | help="Whether to avoid clipping when saving the waveforms. 0: no clipping. 1: clipping.")
140 | parser.add_argument('--dataset', metavar='dataset type', required=True,
141 | help="The type of dataset to generate. Can only be 'adhoc' or 'fixed'.")
142 | parser.add_argument('--libri-path', metavar='absolute path', required=True,
143 | help="Absolute path for Librispeech folder containing train-clean-100, dev-clean and test-clean folders.")
144 | parser.add_argument('--noise-path', metavar='absolute path', required=True,
145 | help="Absolute path for the 100 Nonspeech sound folder.")
146 | args = parser.parse_args()
147 | generate_data(output_path=args.output_path, dataset=args.dataset, libri_path=args.libri_path, noise_path=args.noise_path)
148 |
--------------------------------------------------------------------------------
/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yluo42/TAC/e3373b73358a96af6f64fdbe25327def8d6bd973/flowchart.png
--------------------------------------------------------------------------------
/iFaSNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import os
6 | import numpy as np
7 |
8 | from utility.models import *
9 |
10 | # base module for DPRNN/DPRNN-TAC
11 | class BF_module(DPRNN_base):
12 | def __init__(self, *args, **kwargs):
13 | super(BF_module, self).__init__(*args, **kwargs)
14 |
15 | # output layer
16 | self.output = nn.Conv1d(self.feature_dim, self.output_dim, 1)
17 |
18 | def forward(self, input, num_mic):
19 |
20 | if self.model_type == 'DPRNN':
21 | # input: (B, N, T)
22 | batch_size, N, seq_length = input.shape
23 | ch = 1
24 | elif self.model_type == 'DPRNN_TAC':
25 | # input: (B, ch, N, T)
26 | batch_size, ch, N, seq_length = input.shape
27 |
28 | input = input.view(batch_size*ch, N, seq_length) # B*ch, N, T
29 | enc_feature = self.BN(input)
30 |
31 | # split the encoder output into overlapped, longer segments
32 | enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B*ch, N, L, K
33 |
34 | # pass to DPRNN
35 | if self.model_type == 'DPRNN':
36 | output = self.DPRNN(enc_segments).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K
37 | elif self.model_type == 'DPRNN_TAC':
38 | enc_segments = enc_segments.view(batch_size, ch, -1, enc_segments.shape[2], enc_segments.shape[3]) # B, ch, N, L, K
39 | output = self.DPRNN(enc_segments, num_mic).view(batch_size*ch*self.num_spk, self.feature_dim, self.segment_size, -1) # B*ch*nspk, N, L, K
40 |
41 | # overlap-and-add of the outputs
42 | output = self.merge_feature(output, enc_rest) # B*ch*nspk, N, T
43 |
44 | # output layer
45 | bf_filter = self.output(output) # B*ch*nspk, K, T
46 | bf_filter = bf_filter.view(batch_size, ch, self.num_spk, self.output_dim, -1) # B, ch, nspk, K, L
47 |
48 | return bf_filter
49 |
50 |
51 | # base module for FaSNet
52 | class FaSNet_base(nn.Module):
53 | def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=24,
54 | nspk=2, win_len=16, context_len=16, sr=16000):
55 | super(FaSNet_base, self).__init__()
56 |
57 | # parameters
58 | self.window = int(sr * win_len / 1000)
59 | self.stride = self.window // 2
60 | self.context = context_len*2 // win_len
61 |
62 | self.enc_dim = enc_dim
63 | self.feature_dim = feature_dim
64 | self.hidden_dim = hidden_dim
65 | self.segment_size = segment_size
66 |
67 | self.layer = layer
68 | self.num_spk = nspk
69 | self.eps = 1e-8
70 |
71 | # waveform encoder/decoder
72 | self.encoder = nn.Conv1d(1, self.enc_dim, self.window, stride=self.stride, bias=False)
73 | self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.window, stride=self.stride, bias=False)
74 | self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=self.eps)
75 |
76 | def pad_input(self, input, window, stride):
77 | """
78 | Zero-padding input according to window/stride size.
79 | """
80 | batch_size, nmic, nsample = input.shape
81 |
82 | # pad the signals at the end for matching the window/stride size
83 | rest = window - (stride + nsample % window) % window
84 | if rest > 0:
85 | pad = torch.zeros(batch_size, nmic, rest).type(input.type())
86 | input = torch.cat([input, pad], 2)
87 | pad_aux = torch.zeros(batch_size, nmic, stride).type(input.type())
88 | input = torch.cat([pad_aux, input, pad_aux], 2)
89 |
90 | return input, rest
91 |
92 | def signal_context(self, x, context):
93 | """
94 | Segmenting the signal into chunks with specific context.
95 | input:
96 | x: size (B, dim, nframe)
97 | context: int
98 |
99 | """
100 |
101 | batch_size, dim, nframe = x.shape
102 |
103 | zero_pad = torch.zeros(batch_size, dim, context).type(x.type())
104 | pad_past = []
105 | pad_future = []
106 | for i in range(context):
107 | pad_past.append(torch.cat([zero_pad[:,:,i:], x[:,:,:-context+i]], 2).unsqueeze(2))
108 | pad_future.append(torch.cat([x[:,:,i+1:], zero_pad[:,:,:i+1]], 2).unsqueeze(2))
109 |
110 | pad_past = torch.cat(pad_past, 2) # B, D, C, L
111 | pad_future = torch.cat(pad_future, 2) # B, D, C, L
112 | all_context = torch.cat([pad_past, x.unsqueeze(2), pad_future], 2) # B, D, 2*C+1, L
113 |
114 | return all_context
115 |
116 | def forward(self, input, num_mic):
117 | """
118 | input: shape (batch, max_num_ch, T)
119 | num_mic: shape (batch, ), the number of channels for each input. Zero for fixed geometry configuration.
120 | """
121 | pass
122 |
123 |
124 | # implicit FaSNet (iFaSNet)
125 | class iFaSNet(FaSNet_base):
126 | def __init__(self, *args, **kwargs):
127 | super(iFaSNet, self).__init__(*args, **kwargs)
128 |
129 | # context compression
130 | self.summ_BN = nn.Linear(self.enc_dim, self.feature_dim)
131 | self.summ_RNN = SingleRNN('LSTM',self.feature_dim, self.hidden_dim, bidirectional=True)
132 | self.summ_LN = nn.GroupNorm(1, self.feature_dim, eps=self.eps)
133 | self.summ_output = nn.Linear(self.feature_dim, self.enc_dim)
134 |
135 | # DPRNN-TAC
136 | self.separator = BF_module(self.enc_dim+(self.context*2+1)**2,
137 | self.feature_dim, self.hidden_dim,
138 | self.enc_dim, self.num_spk, self.layer,
139 | self.segment_size, model_type='DPRNN_TAC')
140 |
141 | # context decompression
142 | self.gen_BN = nn.Conv1d(self.enc_dim*2, self.feature_dim, 1)
143 | self.gen_RNN = SingleRNN('LSTM', self.feature_dim, self.hidden_dim, bidirectional=True)
144 | self.gen_LN = nn.GroupNorm(1, self.feature_dim, eps=self.eps)
145 | self.gen_output = nn.Conv1d(self.feature_dim, self.enc_dim, 1)
146 |
147 | def forward(self, input, num_mic):
148 |
149 | batch_size = input.size(0)
150 | nmic = input.size(1)
151 |
152 | # pad input accordingly
153 | input, rest = self.pad_input(input, self.window, self.stride)
154 |
155 | # encoder on all channels
156 | enc_output = self.encoder(input.view(batch_size*nmic, 1, -1)) # B*nmic, N, L
157 | seq_length = enc_output.shape[-1]
158 |
159 | # calculate the context of the encoder output
160 | # consider both past and future
161 | enc_context = self.signal_context(enc_output, self.context) # B*nmic, N, 2C+1, L
162 | enc_context = enc_context.view(batch_size, nmic, self.enc_dim, -1, seq_length) # B, nmic, N, 2C+1, L
163 |
164 | # NCC feature
165 | ref_enc = enc_context[:,0].contiguous() # B, N, 2C+1, L
166 | ref_enc = ref_enc.permute(0,3,1,2).contiguous().view(batch_size*seq_length, self.enc_dim, -1) # B*L, N, 2C+1
167 | enc_context_copy = enc_context.permute(0,4,1,3,2).contiguous().view(batch_size*seq_length, nmic,
168 | -1, self.enc_dim) # B*L, nmic, 2C+1, N
169 | NCC = torch.cat([enc_context_copy[:,i].bmm(ref_enc).unsqueeze(1) for i in range(nmic)], 1) # B*L, nmic, 2C+1, 2C+1
170 | ref_norm = (ref_enc.pow(2).sum(1).unsqueeze(1) + self.eps).sqrt() # B*L, 1, 2C+1
171 | enc_norm = (enc_context_copy.pow(2).sum(3).unsqueeze(3) + self.eps).sqrt() # B*L, nmic, 2C+1, 1
172 | NCC = NCC / (ref_norm.unsqueeze(1) * enc_norm) # B*L, nmic, 2C+1, 2C+1
173 | NCC = torch.cat([NCC[:,:,i] for i in range(NCC.shape[2])], 2) # B*L, nmic, (2C+1)^2
174 | NCC = NCC.view(batch_size, seq_length, nmic, -1).permute(0,2,3,1).contiguous() # B, nmic, (2C+1)^2, L
175 |
176 | # context compression
177 | norm_output = self.enc_LN(enc_output) # B*nmic, N, L
178 | norm_context = self.signal_context(norm_output, self.context) # B*nmic, N, 2C+1, L
179 | norm_context = norm_context.permute(0,3,2,1).contiguous().view(-1, self.context*2+1, self.enc_dim)
180 | norm_context_BN = self.summ_BN(norm_context.view(-1, self.enc_dim)).view(-1, self.context*2+1, self.feature_dim)
181 | embedding = self.summ_RNN(norm_context_BN).transpose(1, 2).contiguous() # B*nmic*L, N, 2C+1
182 | embedding = norm_context_BN.transpose(1, 2).contiguous() + self.summ_LN(embedding) # B*nmic*L, N, 2C+1
183 | embedding = self.summ_output(embedding.mean(2)).view(batch_size, nmic, seq_length, self.enc_dim) # B, nmic, L, N
184 | embedding = embedding.transpose(2, 3).contiguous() # B, nmic, N, L
185 |
186 | input_feature = torch.cat([embedding, NCC], 2) # B, nmic, N+(2C+1)^2, L
187 |
188 | # pass to DPRNN-TAC
189 | embedding = self.separator(input_feature, num_mic)[:,0].contiguous() # B, nspk, N, L
190 |
191 | # concatenate with encoder outputs and generate masks
192 | # context decompression
193 | norm_context = norm_context.view(batch_size, nmic, seq_length, -1, self.enc_dim) # B, nmic, L, 2C+1, N
194 | norm_context = norm_context.permute(0,1,4,3,2)[:,:1].contiguous() # B, 1, N, 2C+1, L
195 |
196 | embedding = torch.cat([embedding.unsqueeze(3)]*(self.context*2+1), 3) # B, nspk, N, 2C+1, L
197 | norm_context = torch.cat([norm_context]*self.num_spk, 1) # B, nspk, N, 2C+1, L
198 | embedding = torch.cat([norm_context, embedding], 2).permute(0,1,4,2,3).contiguous() # B, nspk, L, 2N, 2C+1
199 | all_filter = self.gen_BN(embedding.view(-1, self.enc_dim*2, self.context*2+1)) # B*nspk*L, N, 2C+1
200 | all_filter = all_filter + self.gen_LN(self.gen_RNN(all_filter.transpose(1, 2)).transpose(1, 2)) # B*nspk*L, N, 2C+1
201 | all_filter = self.gen_output(all_filter) # B*nspk*L, N, 2C+1
202 | all_filter = all_filter.view(batch_size, self.num_spk, seq_length, self.enc_dim, -1) # B, nspk, L, N+1, 2C+1
203 | all_filter = all_filter.permute(0,1,3,4,2).contiguous() # B, nspk, N, 2C+1, L
204 |
205 | # apply to with ref mic's encoder context
206 | output = (enc_context[:,:1] * all_filter).mean(3) # B, nspk, N, L
207 |
208 | # decode
209 | bf_signal = self.decoder(output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*nspk, 1, T
210 |
211 | if rest > 0:
212 | bf_signal = bf_signal[:,:,self.stride:-rest-self.stride]
213 |
214 | bf_signal = bf_signal.view(batch_size, self.num_spk, -1) # B, nspk, T
215 |
216 | return bf_signal
217 |
218 | def test_model(model):
219 | x = torch.rand(2, 4, 32000) # (batch, num_mic, length)
220 | num_mic = torch.from_numpy(np.array([3, 2])).view(-1,).type(x.type()) # ad-hoc array
221 | none_mic = torch.zeros(1).type(x.type()) # fixed-array
222 | y1 = model(x, num_mic.long())
223 | y2 = model(x, none_mic.long())
224 | print(y1.shape, y2.shape) # (batch, nspk, length)
225 |
226 |
227 | if __name__ == "__main__":
228 | model_iFaSNet = iFaSNet(enc_dim=64, feature_dim=64, hidden_dim=128, layer=6, segment_size=24,
229 | nspk=2, win_len=16, context_len=16, sr=16000)
230 |
231 | test_model(model_iFaSNet)
232 |
--------------------------------------------------------------------------------
/utility/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/utility/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 | import sys
9 |
10 |
11 | class SingleRNN(nn.Module):
12 | """
13 | Container module for a single RNN layer.
14 |
15 | args:
16 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
17 | input_size: int, dimension of the input feature. The input should have shape
18 | (batch, seq_len, input_size).
19 | hidden_size: int, dimension of the hidden state.
20 | dropout: float, dropout ratio. Default is 0.
21 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
22 | """
23 |
24 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False):
25 | super(SingleRNN, self).__init__()
26 |
27 | self.rnn_type = rnn_type
28 | self.input_size = input_size
29 | self.hidden_size = hidden_size
30 | self.num_direction = int(bidirectional) + 1
31 |
32 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, 1, dropout=dropout, batch_first=True, bidirectional=bidirectional)
33 |
34 | # linear projection layer
35 | self.proj = nn.Linear(hidden_size*self.num_direction, input_size)
36 |
37 | def forward(self, input):
38 | # input shape: batch, seq, dim
39 | output = input
40 | rnn_output, _ = self.rnn(output)
41 | rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(output.shape)
42 | return rnn_output
43 |
44 | # dual-path RNN
45 | class DPRNN(nn.Module):
46 | """
47 | Deep duaL-path RNN.
48 |
49 | args:
50 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
51 | input_size: int, dimension of the input feature. The input should have shape
52 | (batch, seq_len, input_size).
53 | hidden_size: int, dimension of the hidden state.
54 | output_size: int, dimension of the output size.
55 | dropout: float, dropout ratio. Default is 0.
56 | num_layers: int, number of stacked RNN layers. Default is 1.
57 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
58 | """
59 | def __init__(self, rnn_type, input_size, hidden_size, output_size,
60 | dropout=0, num_layers=1, bidirectional=True):
61 | super(DPRNN, self).__init__()
62 |
63 | self.input_size = input_size
64 | self.output_size = output_size
65 | self.hidden_size = hidden_size
66 |
67 | # dual-path RNN
68 | self.row_rnn = nn.ModuleList([])
69 | self.col_rnn = nn.ModuleList([])
70 | self.row_norm = nn.ModuleList([])
71 | self.col_norm = nn.ModuleList([])
72 | for i in range(num_layers):
73 | self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal
74 | self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional))
75 | self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
76 | # default is to use noncausal LayerNorm for inter-chunk RNN. For causal setting change it to causal normalization techniques accordingly.
77 | self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
78 |
79 | # output layer
80 | self.output = nn.Sequential(nn.PReLU(),
81 | nn.Conv2d(input_size, output_size, 1)
82 | )
83 |
84 | def forward(self, input):
85 | # input shape: batch, N, dim1, dim2
86 | # apply RNN on dim1 first and then dim2
87 |
88 | batch_size, _, dim1, dim2 = input.shape
89 | output = input
90 | for i in range(len(self.row_rnn)):
91 | row_input = output.permute(0,3,2,1).contiguous().view(batch_size*dim2, dim1, -1) # B*dim2, dim1, N
92 | row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H
93 | row_output = row_output.view(batch_size, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B, N, dim1, dim2
94 | row_output = self.row_norm[i](row_output)
95 | output = output + row_output
96 |
97 | col_input = output.permute(0,2,3,1).contiguous().view(batch_size*dim1, dim2, -1) # B*dim1, dim2, N
98 | col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H
99 | col_output = col_output.view(batch_size, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B, N, dim1, dim2
100 | col_output = self.col_norm[i](col_output)
101 | output = output + col_output
102 |
103 | output = self.output(output)
104 |
105 | return output
106 |
107 |
108 | # dual-path RNN with transform-average-concatenate (TAC)
109 | class DPRNN_TAC(nn.Module):
110 | """
111 | Deep duaL-path RNN with transform-average-concatenate (TAC) applied to each layer/block.
112 |
113 | args:
114 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
115 | input_size: int, dimension of the input feature. The input should have shape
116 | (batch, seq_len, input_size).
117 | hidden_size: int, dimension of the hidden state.
118 | output_size: int, dimension of the output size.
119 | dropout: float, dropout ratio. Default is 0.
120 | num_layers: int, number of stacked RNN layers. Default is 1.
121 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False.
122 | """
123 | def __init__(self, rnn_type, input_size, hidden_size, output_size,
124 | dropout=0, num_layers=1, bidirectional=True):
125 | super(DPRNN_TAC, self).__init__()
126 |
127 | self.input_size = input_size
128 | self.output_size = output_size
129 | self.hidden_size = hidden_size
130 |
131 | # DPRNN + TAC for 3D input (ch, N, T)
132 | self.row_rnn = nn.ModuleList([])
133 | self.col_rnn = nn.ModuleList([])
134 | self.ch_transform = nn.ModuleList([])
135 | self.ch_average = nn.ModuleList([])
136 | self.ch_concat = nn.ModuleList([])
137 |
138 | self.row_norm = nn.ModuleList([])
139 | self.col_norm = nn.ModuleList([])
140 | self.ch_norm = nn.ModuleList([])
141 |
142 |
143 | for i in range(num_layers):
144 | self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=True)) # intra-segment RNN is always noncausal
145 | self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional))
146 | self.ch_transform.append(nn.Sequential(nn.Linear(input_size, hidden_size*3),
147 | nn.PReLU()
148 | )
149 | )
150 | self.ch_average.append(nn.Sequential(nn.Linear(hidden_size*3, hidden_size*3),
151 | nn.PReLU()
152 | )
153 | )
154 | self.ch_concat.append(nn.Sequential(nn.Linear(hidden_size*6, input_size),
155 | nn.PReLU()
156 | )
157 | )
158 |
159 |
160 | self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
161 | # default is to use noncausal LayerNorm for inter-chunk RNN and TAC modules. For causal setting change them to causal normalization techniques accordingly.
162 | self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
163 | self.ch_norm.append(nn.GroupNorm(1, input_size, eps=1e-8))
164 |
165 | # output layer
166 | self.output = nn.Sequential(nn.PReLU(),
167 | nn.Conv2d(input_size, output_size, 1)
168 | )
169 |
170 | def forward(self, input, num_mic):
171 | # input shape: batch, ch, N, dim1, dim2
172 | # num_mic shape: batch,
173 | # apply RNN on dim1 first, then dim2, then ch
174 |
175 | batch_size, ch, N, dim1, dim2 = input.shape
176 | output = input
177 | for i in range(len(self.row_rnn)):
178 | # intra-segment RNN
179 | output = output.view(batch_size*ch, N, dim1, dim2) # B*ch, N, dim1, dim2
180 | row_input = output.permute(0,3,2,1).contiguous().view(batch_size*ch*dim2, dim1, -1) # B*ch*dim2, dim1, N
181 | row_output = self.row_rnn[i](row_input) # B*ch*dim2, dim1, N
182 | row_output = row_output.view(batch_size*ch, dim2, dim1, -1).permute(0,3,2,1).contiguous() # B*ch, N, dim1, dim2
183 | row_output = self.row_norm[i](row_output)
184 | output = output + row_output # B*ch, N, dim1, dim2
185 |
186 | # inter-segment RNN
187 | col_input = output.permute(0,2,3,1).contiguous().view(batch_size*ch*dim1, dim2, -1) # B*ch*dim1, dim2, N
188 | col_output = self.col_rnn[i](col_input) # B*dim1, dim2, N
189 | col_output = col_output.view(batch_size*ch, dim1, dim2, -1).permute(0,3,1,2).contiguous() # B*ch, N, dim1, dim2
190 | col_output = self.col_norm[i](col_output)
191 | output = output + col_output # B*ch, N, dim1, dim2
192 |
193 | # TAC for cross-channel communication
194 | ch_input = output.view(input.shape) # B, ch, N, dim1, dim2
195 | ch_input = ch_input.permute(0,3,4,1,2).contiguous().view(-1, N) # B*dim1*dim2*ch, N
196 | ch_output = self.ch_transform[i](ch_input).view(batch_size, dim1*dim2, ch, -1) # B, dim1*dim2, ch, H
197 | # mean pooling across channels
198 | if num_mic.max() == 0:
199 | # fixed geometry array
200 | ch_mean = ch_output.mean(2).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H
201 | else:
202 | # only consider valid channels
203 | ch_mean = [ch_output[b,:,:num_mic[b]].mean(1).unsqueeze(0) for b in range(batch_size)] # 1, dim1*dim2, H
204 | ch_mean = torch.cat(ch_mean, 0).view(batch_size*dim1*dim2, -1) # B*dim1*dim2, H
205 | ch_output = ch_output.view(batch_size*dim1*dim2, ch, -1) # B*dim1*dim2, ch, H
206 | ch_mean = self.ch_average[i](ch_mean).unsqueeze(1).expand_as(ch_output).contiguous() # B*dim1*dim2, ch, H
207 | ch_output = torch.cat([ch_output, ch_mean], 2) # B*dim1*dim2, ch, 2H
208 | ch_output = self.ch_concat[i](ch_output.view(-1, ch_output.shape[-1])) # B*dim1*dim2*ch, N
209 | ch_output = ch_output.view(batch_size, dim1, dim2, ch, -1).permute(0,3,4,1,2).contiguous() # B, ch, N, dim1, dim2
210 | ch_output = self.ch_norm[i](ch_output.view(batch_size*ch, N, dim1, dim2)) # B*ch, N, dim1, dim2
211 | output = output + ch_output
212 |
213 | output = self.output(output) # B*ch, N, dim1, dim2
214 |
215 | return output
216 |
217 | # base module for deep DPRNN
218 | class DPRNN_base(nn.Module):
219 | def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
220 | layer=4, segment_size=100, bidirectional=True, model_type='DPRNN',
221 | rnn_type='LSTM'):
222 | super(DPRNN_base, self).__init__()
223 |
224 | assert model_type in ['DPRNN', 'DPRNN_TAC'], "model_type can only be 'DPRNN' or 'DPRNN_TAC'."
225 |
226 | self.input_dim = input_dim
227 | self.feature_dim = feature_dim
228 | self.hidden_dim = hidden_dim
229 | self.output_dim = output_dim
230 |
231 | self.layer = layer
232 | self.segment_size = segment_size
233 | self.num_spk = num_spk
234 |
235 | self.model_type = model_type
236 |
237 | self.eps = 1e-8
238 |
239 | # bottleneck
240 | self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, bias=False)
241 |
242 | # DPRNN model
243 | self.DPRNN = getattr(sys.modules[__name__], model_type)(rnn_type, self.feature_dim, self.hidden_dim, self.feature_dim*self.num_spk,
244 | num_layers=layer, bidirectional=bidirectional)
245 |
246 | def pad_segment(self, input, segment_size):
247 | # input is the features: (B, N, T)
248 | batch_size, dim, seq_len = input.shape
249 | segment_stride = segment_size // 2
250 |
251 | rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
252 | if rest > 0:
253 | pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
254 | input = torch.cat([input, pad], 2)
255 |
256 | pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
257 | input = torch.cat([pad_aux, input, pad_aux], 2)
258 |
259 | return input, rest
260 |
261 | def split_feature(self, input, segment_size):
262 | # split the feature into chunks of segment size
263 | # input is the features: (B, N, T)
264 |
265 | input, rest = self.pad_segment(input, segment_size)
266 | batch_size, dim, seq_len = input.shape
267 | segment_stride = segment_size // 2
268 |
269 | segments1 = input[:,:,:-segment_stride].contiguous().view(batch_size, dim, -1, segment_size)
270 | segments2 = input[:,:,segment_stride:].contiguous().view(batch_size, dim, -1, segment_size)
271 | segments = torch.cat([segments1, segments2], 3).view(batch_size, dim, -1, segment_size).transpose(2, 3)
272 |
273 | return segments.contiguous(), rest
274 |
275 | def merge_feature(self, input, rest):
276 | # merge the splitted features into full utterance
277 | # input is the features: (B, N, L, K)
278 |
279 | batch_size, dim, segment_size, _ = input.shape
280 | segment_stride = segment_size // 2
281 | input = input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size*2) # B, N, K, L
282 |
283 | input1 = input[:,:,:,:segment_size].contiguous().view(batch_size, dim, -1)[:,:,segment_stride:]
284 | input2 = input[:,:,:,segment_size:].contiguous().view(batch_size, dim, -1)[:,:,:-segment_stride]
285 |
286 | output = input1 + input2
287 | if rest > 0:
288 | output = output[:,:,:-rest]
289 |
290 | return output.contiguous() # B, N, T
291 |
292 | def forward(self, input):
293 | pass
--------------------------------------------------------------------------------
/utility/sdr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from itertools import permutations
3 | from torch.autograd import Variable
4 |
5 | import scipy,time,numpy
6 |
7 | import torch
8 |
9 | # original implementation
10 | def compute_measures(se,s,j):
11 | Rss=s.transpose().dot(s)
12 | this_s=s[:,j]
13 |
14 | a=this_s.transpose().dot(se)/Rss[j,j]
15 | e_true=a*this_s
16 | e_res=se-a*this_s
17 | Sss=np.sum((e_true)**2)
18 | Snn=np.sum((e_res)**2)
19 |
20 | SDR=10*np.log10(Sss/Snn)
21 |
22 | Rsr= s.transpose().dot(e_res)
23 | b=np.linalg.inv(Rss).dot(Rsr)
24 |
25 | e_interf = s.dot(b)
26 | e_artif= e_res-e_interf
27 |
28 | SIR=10*np.log10(Sss/np.sum((e_interf)**2))
29 | SAR=10*np.log10(Sss/np.sum((e_artif)**2))
30 | return SDR, SIR, SAR
31 |
32 | def GetSDR(se,s):
33 | se=se-np.mean(se,axis=0)
34 | s=s-np.mean(s,axis=0)
35 | nsampl,nsrc=se.shape
36 | nsampl2,nsrc2=s.shape
37 | assert(nsrc2==nsrc)
38 | assert(nsampl2==nsampl)
39 |
40 | SDR=np.zeros((nsrc,nsrc))
41 | SIR=SDR.copy()
42 | SAR=SDR.copy()
43 |
44 | for jest in range(nsrc):
45 | for jtrue in range(nsrc):
46 | SDR[jest,jtrue],SIR[jest,jtrue],SAR[jest,jtrue]=compute_measures(se[:,jest],s,jtrue)
47 |
48 |
49 | perm=list(permutations(np.arange(nsrc)))
50 | nperm=len(perm)
51 | meanSIR=np.zeros((nperm,))
52 | for p in range(nperm):
53 | tp=SIR.transpose().reshape(nsrc*nsrc)
54 | idx=np.arange(nsrc)*nsrc+list(perm[p])
55 | meanSIR[p]=np.mean(tp[idx])
56 | popt=np.argmax(meanSIR)
57 | per=list(perm[popt])
58 | idx=np.arange(nsrc)*nsrc+per
59 | SDR=SDR.transpose().reshape(nsrc*nsrc)[idx]
60 | SIR=SIR.transpose().reshape(nsrc*nsrc)[idx]
61 | SAR=SAR.transpose().reshape(nsrc*nsrc)[idx]
62 | return SDR, SIR, SAR, per
63 |
64 | # Pytorch implementation with batch processing
65 | def calc_sdr_torch(estimation, origin, mask=None):
66 | """
67 | batch-wise SDR caculation for one audio file on pytorch Variables.
68 | estimation: (batch, nsample)
69 | origin: (batch, nsample)
70 | mask: an optional mask for sequence masking. This is for cases where zero-padding was applied at the end and should not be consider for SDR calculation.
71 | """
72 |
73 | if mask is not None:
74 | origin = origin * mask
75 | estimation = estimation * mask
76 |
77 | def calculate(estimation, origin):
78 | origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1)
79 | scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1)
80 |
81 | est_true = scale * origin # (batch, nsample)
82 | est_res = estimation - est_true # (batch, nsample)
83 |
84 | true_power = torch.pow(est_true, 2).sum(1) + 1e-8
85 | res_power = torch.pow(est_res, 2).sum(1) + 1e-8
86 |
87 | return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, )
88 |
89 | best_sdr = calculate(estimation, origin)
90 |
91 | return best_sdr
92 |
93 |
94 | def batch_SDR_torch(estimation, origin, mask=None, return_perm=False):
95 | """
96 | batch-wise SDR caculation for multiple audio files.
97 | estimation: (batch, nsource, nsample)
98 | origin: (batch, nsource, nsample)
99 | mask: optional, (batch, nsample), binary
100 | return_perm: bool, whether to return the permutation index. Default is false.
101 | """
102 |
103 | batch_size_est, nsource_est, nsample_est = estimation.size()
104 | batch_size_ori, nsource_ori, nsample_ori = origin.size()
105 |
106 | assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape."
107 | assert nsource_est == nsource_ori, "Estimation and original sources should have same shape."
108 | assert nsample_est == nsample_ori, "Estimation and original sources should have same shape."
109 |
110 | assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal."
111 |
112 | batch_size = batch_size_est
113 | nsource = nsource_est
114 |
115 | # zero mean signals
116 | estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation)
117 | origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation)
118 |
119 | # SDR for each permutation
120 | SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type())
121 | for i in range(nsource):
122 | for j in range(nsource):
123 | SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask)
124 |
125 | # choose the best permutation
126 | SDR_max = []
127 | SDR_perm = []
128 | perm = sorted(list(set(permutations(np.arange(nsource)))))
129 | for permute in perm:
130 | sdr = []
131 | for idx in range(len(permute)):
132 | sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1))
133 | sdr = torch.sum(torch.cat(sdr, 1), 1)
134 | SDR_perm.append(sdr.view(batch_size, 1))
135 | SDR_perm = torch.cat(SDR_perm, 1)
136 | SDR_max, SDR_idx = torch.max(SDR_perm, dim=1)
137 |
138 | if not return_perm:
139 | return SDR_max / nsource
140 | else:
141 | return SDR_max / nsource, SDR_idx
--------------------------------------------------------------------------------