├── LICENSE
├── PESQ.so
├── README.md
├── SE_tutorials.ipynb
├── composite.m
├── config.py
├── dataloader.py
├── estimation
└── check_object_metrics.py
├── generate_noisy_data.py
├── models.py
├── tools_for_estimate.py
├── tools_for_loss.py
├── tools_for_model.py
├── train_interface.py
├── trainer.py
└── write_on_tensorboard.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Seo-Rim Hwang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PESQ.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seorim0/DNN-based-Speech-Enhancement-in-the-frequency-domain/ed54e8c0eaea1f063c4db8e7a475ea3eb6e2f836/PESQ.so
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DNN-based Speech Enhancement in the frequency domain
2 | You can do DNN-based speech enhancement(SE) in the frequency domain using various method through this repository.
3 | First, you have to make noisy data by mixing clean speech and noise. The dataset is used for deep learning training.
4 | And, you can adjust the type of the network and configuration in various ways, as shown below.
5 | The results of the network can be evaluated through various objective metrics (PESQ, STOI, CSIG, CBAK, COVL).
6 |
7 |
8 |
9 |
10 |
11 | You can change
12 |
13 |
14 | Networks
15 |
16 |
17 | Learning methods
18 |
19 | Loss functions
20 |
21 |
22 |
23 |
24 |
25 | ## Requirements
26 | > This repository is tested on Ubuntu 20.04, and
27 | * Python 3.7
28 | * Cuda 11.1
29 | * CuDNN 8.0.5
30 | * Pytorch 1.9.0
31 |
32 |
33 | ## Getting Started
34 | 1. Install the necessary libraries
35 | 2. Make a dataset for train and validation
36 | ```sh
37 | # The shape of the dataset
38 | [data_num, 2 (inputs and targets), sampling_frequency * data_length]
39 |
40 | # For example, if you want to use 1,000 3-second data sets with a sampling frequency of 16k, the shape is,
41 | [1000, 2, 48000]
42 | ```
43 | 4. Set [dataloader.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/dataloader.py)
44 | ```sh
45 | self.input_path = "DATASET_FILE_PATH"
46 | ```
47 | 5. Set [config.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/config.py)
48 | ```sh
49 | # If you need to adjust any settings, simply change this file.
50 | # When you run this project for the first time, you need to set the path where the model and logs will be saved.
51 | ```
52 | 6. Run [train_interface.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/train_interface.py)
53 |
54 |
55 | ## Tutorials
56 | ['SE_tutorials.ipynb'](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/SE_tutorials.ipynb) was made for tutorial.
57 | You can simply train the CRN with the colab file without any preparation .
58 |
59 |
60 |
61 | ## Networks
62 | > You can find a list that you can adjust in various ways at config.py, and they are:
63 | * Real network
64 | - convolutional recurrent network (CRN)
65 | it is a real version of DCCRN
66 | - FullSubNet [[1]](https://arxiv.org/abs/2010.15508)
67 | * Complex network
68 | - deep complex convolutional recurrent network (DCCRN) [[2]](https://arxiv.org/abs/2008.00264)
69 |
70 |
71 | ## Learning Methods
72 | * T-F masking
73 | * Spectral mapping
74 |
75 |
76 | ## Loss Functions
77 | * MSE
78 | * SDR
79 | * SI-SNR
80 | * SI-SDR
81 |
82 | > and you can join the loss functions with perceptual loss.
83 | * LMS
84 | * PMSQE
85 |
86 |
87 | ## Tensorboard
88 | > As shown below, you can check whether the network is being trained well in real time through ['write_on_tensorboard.py'](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/write_on_tensorboard.py).
89 |
90 | 
91 | * loss
92 | * pesq, stoi
93 | * spectrogram
94 |
95 |
96 | ## Reference
97 | **FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement**
98 | Xiang Hao, Xiangdong Su, Radu Horaud, Xiaofei Li
99 | [[arXiv]](https://arxiv.org/abs/2010.15508) [[code]](https://github.com/haoxiangsnr/FullSubNet)
100 | **DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement**
101 | Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang, Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie
102 | [[arXiv]](https://arxiv.org/abs/2008.00264) [[code]](https://github.com/huyanxin/DeepComplexCRN)
103 | **Other tools**
104 | https://github.com/usimarit/semetrics
105 | https://ecs.utdallas.edu/loizou/speech/software.htm
106 |
107 |
--------------------------------------------------------------------------------
/composite.m:
--------------------------------------------------------------------------------
1 | function [Csig,Cbak,Covl,segSNR]= composite(cleanFile, enhancedFile);
2 |
3 | % --------- composite objective measure ----------------------
4 | %
5 | % Center for Robust Speech Systems
6 | % University of Texas-Dallas
7 | % Copyright (c) 2006
8 | % All Rights Reserved.
9 | %
10 | % Description:
11 | %
12 | % This function implements the composite objective measure
13 | % proposed in [1]. It returns three values: The predicted rating of
14 | % overall quality (Covl), the rating of speech distortion (Csig) and
15 | % the rating of background distortion (Cbak). The ratings are based on the 1-5 MOS scale.
16 | % In addition, it returns the values of the SNRseg, log-likelihood ratio (LLR), PESQ
17 | % and weighted spectral slope (WSS) objective measures.
18 | %
19 | % References:
20 | % [1] Hu, Y. and Loizou, P. (2006). �Evaluation of objective measures for speech enhancement,
21 | % Proceedings of INTERSPEECH-2006, Philadelphia, PA, September 2006.
22 | %
23 | %
24 | % Authors:
25 | % Philipos C. Loizou and Yi Hu
26 | % Bryan L. Pellom and John H. L. Hansen (for the implementation of
27 | % the WSS, LLR and SnrSeg measures)
28 | %
29 | %----------------------------------------------------------
30 |
31 | if nargin<2
32 | fprintf('Usage: [Csig,Cbak,Covl]=composite(cleanfile.wav,enhanced.wav)\n');
33 | fprintf('where ''Csig'' is the predicted rating of speech distortion\n');
34 | fprintf(' ''Cbak'' is the predicted rating of background distortion\n');
35 | fprintf(' ''Covl'' is the predicted rating of overall quality.\n\n');
36 | return;
37 | end
38 |
39 |
40 | alpha= 0.95;
41 |
42 | [data1, Srate1]= audioread(cleanFile);
43 | [data2, Srate2]= audioread(enhancedFile);
44 | info1 = audioinfo(cleanFile);
45 | info2 = audioinfo(enhancedFile);
46 | Nbits1 = info1.BitsPerSample;
47 | Nbits2 = info2.BitsPerSample;
48 | if ( Srate1~= Srate2) | ( Nbits1~= Nbits2)
49 | error( 'The two files do not match!\n');
50 | end
51 |
52 | len= min( length( data1), length( data2));
53 | data1= data1( 1: len)+eps;
54 | data2= data2( 1: len)+eps;
55 |
56 |
57 | % -- compute the WSS measure ---
58 | %
59 | wss_dist_vec= wss( data1, data2,Srate1);
60 | wss_dist_vec= sort( wss_dist_vec);
61 | wss_dist= mean( wss_dist_vec( 1: round( length( wss_dist_vec)*alpha)));
62 |
63 | % --- compute the LLR measure ---------
64 | %
65 | LLR_dist= llr( data1, data2,Srate1);
66 | LLRs= sort(LLR_dist);
67 | LLR_len= round( length(LLR_dist)* alpha);
68 | llr_mean= mean( LLRs( 1: LLR_len));
69 |
70 | % --- compute the SNRseg ----------------
71 | %
72 | [snr_dist, segsnr_dist]= snr( data1, data2,Srate1);
73 | snr_mean= snr_dist;
74 | segSNR= mean( segsnr_dist);
75 |
76 |
77 | % -- compute the pesq ----
78 | %[pesq_mos]= pesq(Srate1,cleanFile, enhancedFile);
79 | pesq_mos = 0;
80 |
81 |
82 | % --- now compute the composite measures ------------------
83 | %
84 | Csig = 3.093 - 1.029*llr_mean + 0.603*pesq_mos-0.009*wss_dist;
85 | Csig = max(1, Csig); Csig = min(5, Csig); %% adding for fitting range 1 to 5
86 | Cbak = 1.634 + 0.478 *pesq_mos - 0.007*wss_dist + 0.063*segSNR;
87 | Cbak = max(1, Cbak); Cbak = min(5, Cbak); %% adding for fitting range 1 to 5
88 | Covl = 1.594 + 0.805*pesq_mos - 0.512*llr_mean - 0.007*wss_dist;
89 | Covl = max(1, Covl); Covl = min(5, Covl); %% adding for fitting range 1 to 5
90 |
91 | %fprintf('\n LLR=%f SNRseg=%f WSS=%f PESQ=%f\n',llr_mean,segSNR,wss_dist,pesq_mos);
92 |
93 | return;
94 |
95 | % ----------------------------------------------------------------------
96 | %
97 | % Weighted Spectral Slope (WSS) Objective Speech Quality Measure
98 | %
99 | % Center for Robust Speech Systems
100 | % University of Texas-Dallas
101 | % Copyright (c) 1998-2006
102 | % All Rights Reserved.
103 | %
104 | % Description:
105 | %
106 | % This function implements the Weighted Spectral Slope (WSS)
107 | % distance measure originally proposed in [1]. The algorithm
108 | % works by first decomposing the speech signal into a set of
109 | % frequency bands (this is done for both the test and reference
110 | % frame). The intensities within each critical band are
111 | % measured. Then, a weighted distances between the measured
112 | % slopes of the log-critical band spectra are computed.
113 | % This measure is also described in Section 2.2.9 (pages 56-58)
114 | % of [2].
115 | %
116 | % Whereas Klatt's original measure used 36 critical-band
117 | % filters to estimate the smoothed short-time spectrum, this
118 | % implementation considers a bank of 25 filters spanning
119 | % the 4 kHz bandwidth.
120 | %
121 | % Input/Output:
122 | %
123 | % The input is a reference 8kHz sampled speech, and processed
124 | % speech (could be noisy or enhanced).
125 | %
126 | % The function returns the numerical distance between each
127 | % frame of the two input files (one distance per frame).
128 | %
129 | % References:
130 | %
131 | % [1] D. H. Klatt, "Prediction of Perceived Phonetic Distance
132 | % from Critical-Band Spectra: A First Step", Proc. IEEE
133 | % ICASSP'82, Volume 2, pp. 1278-1281, May, 1982.
134 | %
135 | % [2] S. R. Quackenbush, T. P. Barnwell, and M. A. Clements,
136 | % Objective Measures of Speech Quality. Prentice Hall
137 | % Advanced Reference Series, Englewood Cliffs, NJ, 1988,
138 | % ISBN: 0-13-629056-6.
139 | %
140 | % Authors:
141 | %
142 | % Bryan L. Pellom and John H. L. Hansen
143 | %
144 | %
145 | % Last Modified:
146 | %
147 | % July 22, 1998
148 | % September 12, 2006 by Philipos Loizou
149 | % ----------------------------------------------------------------------
150 |
151 | function distortion = wss(clean_speech, processed_speech,sample_rate)
152 |
153 |
154 | % ----------------------------------------------------------------------
155 | % Check the length of the clean and processed speech. Must be the same.
156 | % ----------------------------------------------------------------------
157 |
158 | clean_length = length(clean_speech);
159 | processed_length = length(processed_speech);
160 |
161 | if (clean_length ~= processed_length)
162 | disp('Error: Files musthave same length.');
163 | return
164 | end
165 |
166 |
167 |
168 | % ----------------------------------------------------------------------
169 | % Global Variables
170 | % ----------------------------------------------------------------------
171 |
172 | % sample_rate = 8000; % default sample rate
173 | % winlength = 240; % window length in samples
174 | % skiprate = 60; % window skip in samples
175 | winlength = round(30*sample_rate/1000); %240; % window length in samples
176 | skiprate = floor(winlength/4); % window skip in samples
177 | max_freq = sample_rate/2; % maximum bandwidth
178 | num_crit = 25; % number of critical bands
179 |
180 | USE_FFT_SPECTRUM = 1; % defaults to 10th order LP spectrum
181 | %n_fft = 512; % FFT size
182 | n_fft = 2^nextpow2(2*winlength);
183 | n_fftby2 = n_fft/2; % FFT size/2
184 | Kmax = 20; % value suggested by Klatt, pg 1280
185 | Klocmax = 1; % value suggested by Klatt, pg 1280
186 |
187 | % ----------------------------------------------------------------------
188 | % Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz)
189 | % ----------------------------------------------------------------------
190 |
191 | cent_freq(1) = 50.0000; bandwidth(1) = 70.0000;
192 | cent_freq(2) = 120.000; bandwidth(2) = 70.0000;
193 | cent_freq(3) = 190.000; bandwidth(3) = 70.0000;
194 | cent_freq(4) = 260.000; bandwidth(4) = 70.0000;
195 | cent_freq(5) = 330.000; bandwidth(5) = 70.0000;
196 | cent_freq(6) = 400.000; bandwidth(6) = 70.0000;
197 | cent_freq(7) = 470.000; bandwidth(7) = 70.0000;
198 | cent_freq(8) = 540.000; bandwidth(8) = 77.3724;
199 | cent_freq(9) = 617.372; bandwidth(9) = 86.0056;
200 | cent_freq(10) = 703.378; bandwidth(10) = 95.3398;
201 | cent_freq(11) = 798.717; bandwidth(11) = 105.411;
202 | cent_freq(12) = 904.128; bandwidth(12) = 116.256;
203 | cent_freq(13) = 1020.38; bandwidth(13) = 127.914;
204 | cent_freq(14) = 1148.30; bandwidth(14) = 140.423;
205 | cent_freq(15) = 1288.72; bandwidth(15) = 153.823;
206 | cent_freq(16) = 1442.54; bandwidth(16) = 168.154;
207 | cent_freq(17) = 1610.70; bandwidth(17) = 183.457;
208 | cent_freq(18) = 1794.16; bandwidth(18) = 199.776;
209 | cent_freq(19) = 1993.93; bandwidth(19) = 217.153;
210 | cent_freq(20) = 2211.08; bandwidth(20) = 235.631;
211 | cent_freq(21) = 2446.71; bandwidth(21) = 255.255;
212 | cent_freq(22) = 2701.97; bandwidth(22) = 276.072;
213 | cent_freq(23) = 2978.04; bandwidth(23) = 298.126;
214 | cent_freq(24) = 3276.17; bandwidth(24) = 321.465;
215 | cent_freq(25) = 3597.63; bandwidth(25) = 346.136;
216 |
217 | bw_min = bandwidth (1); % minimum critical bandwidth
218 |
219 | % ----------------------------------------------------------------------
220 | % Set up the critical band filters. Note here that Gaussianly shaped
221 | % filters are used. Also, the sum of the filter weights are equivalent
222 | % for each critical band filter. Filter less than -30 dB and set to
223 | % zero.
224 | % ----------------------------------------------------------------------
225 |
226 | min_factor = exp (-30.0 / (2.0 * 2.303)); % -30 dB point of filter
227 |
228 | for i = 1:num_crit
229 | f0 = (cent_freq (i) / max_freq) * (n_fftby2);
230 | all_f0(i) = floor(f0);
231 | bw = (bandwidth (i) / max_freq) * (n_fftby2);
232 | norm_factor = log(bw_min) - log(bandwidth(i));
233 | j = 0:1:n_fftby2-1;
234 | crit_filter(i,:) = exp (-11 *(((j - floor(f0)) ./bw).^2) + norm_factor);
235 | crit_filter(i,:) = crit_filter(i,:).*(crit_filter(i,:) > min_factor);
236 | end
237 |
238 | % ----------------------------------------------------------------------
239 | % For each frame of input speech, calculate the Weighted Spectral
240 | % Slope Measure
241 | % ----------------------------------------------------------------------
242 |
243 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames
244 | start = 1; % starting sample
245 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1)));
246 |
247 | for frame_count = 1:num_frames
248 |
249 | % ----------------------------------------------------------
250 | % (1) Get the Frames for the test and reference speech.
251 | % Multiply by Hanning Window.
252 | % ----------------------------------------------------------
253 |
254 | clean_frame = clean_speech(start:start+winlength-1);
255 | processed_frame = processed_speech(start:start+winlength-1);
256 | clean_frame = clean_frame.*window;
257 | processed_frame = processed_frame.*window;
258 |
259 | % ----------------------------------------------------------
260 | % (2) Compute the Power Spectrum of Clean and Processed
261 | % ----------------------------------------------------------
262 |
263 | if (USE_FFT_SPECTRUM)
264 | clean_spec = (abs(fft(clean_frame,n_fft)).^2);
265 | processed_spec = (abs(fft(processed_frame,n_fft)).^2);
266 | else
267 | a_vec = zeros(1,n_fft);
268 | a_vec(1:11) = lpc(clean_frame,10);
269 | clean_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)';
270 |
271 | a_vec = zeros(1,n_fft);
272 | a_vec(1:11) = lpc(processed_frame,10);
273 | processed_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)';
274 | end
275 |
276 | % ----------------------------------------------------------
277 | % (3) Compute Filterbank Output Energies (in dB scale)
278 | % ----------------------------------------------------------
279 |
280 | for i = 1:num_crit
281 | clean_energy(i) = sum(clean_spec(1:n_fftby2) ...
282 | .*crit_filter(i,:)');
283 | processed_energy(i) = sum(processed_spec(1:n_fftby2) ...
284 | .*crit_filter(i,:)');
285 | end
286 | clean_energy = 10*log10(max(clean_energy,1E-10));
287 | processed_energy = 10*log10(max(processed_energy,1E-10));
288 |
289 | % ----------------------------------------------------------
290 | % (4) Compute Spectral Slope (dB[i+1]-dB[i])
291 | % ----------------------------------------------------------
292 |
293 | clean_slope = clean_energy(2:num_crit) - ...
294 | clean_energy(1:num_crit-1);
295 | processed_slope = processed_energy(2:num_crit) - ...
296 | processed_energy(1:num_crit-1);
297 |
298 | % ----------------------------------------------------------
299 | % (5) Find the nearest peak locations in the spectra to
300 | % each critical band. If the slope is negative, we
301 | % search to the left. If positive, we search to the
302 | % right.
303 | % ----------------------------------------------------------
304 |
305 | for i = 1:num_crit-1
306 |
307 | % find the peaks in the clean speech signal
308 |
309 | if (clean_slope(i)>0) % search to the right
310 | n = i;
311 | while ((n 0))
312 | n = n+1;
313 | end
314 | clean_loc_peak(i) = clean_energy(n-1);
315 | else % search to the left
316 | n = i;
317 | while ((n>0) & (clean_slope(n) <= 0))
318 | n = n-1;
319 | end
320 | clean_loc_peak(i) = clean_energy(n+1);
321 | end
322 |
323 | % find the peaks in the processed speech signal
324 |
325 | if (processed_slope(i)>0) % search to the right
326 | n = i;
327 | while ((n 0))
328 | n = n+1;
329 | end
330 | processed_loc_peak(i) = processed_energy(n-1);
331 | else % search to the left
332 | n = i;
333 | while ((n>0) & (processed_slope(n) <= 0))
334 | n = n-1;
335 | end
336 | processed_loc_peak(i) = processed_energy(n+1);
337 | end
338 |
339 | end
340 |
341 | % ----------------------------------------------------------
342 | % (6) Compute the WSS Measure for this frame. This
343 | % includes determination of the weighting function.
344 | % ----------------------------------------------------------
345 |
346 | dBMax_clean = max(clean_energy);
347 | dBMax_processed = max(processed_energy);
348 |
349 | % The weights are calculated by averaging individual
350 | % weighting factors from the clean and processed frame.
351 | % These weights W_clean and W_processed should range
352 | % from 0 to 1 and place more emphasis on spectral
353 | % peaks and less emphasis on slope differences in spectral
354 | % valleys. This procedure is described on page 1280 of
355 | % Klatt's 1982 ICASSP paper.
356 |
357 | Wmax_clean = Kmax ./ (Kmax + dBMax_clean - ...
358 | clean_energy(1:num_crit-1));
359 | Wlocmax_clean = Klocmax ./ ( Klocmax + clean_loc_peak - ...
360 | clean_energy(1:num_crit-1));
361 | W_clean = Wmax_clean .* Wlocmax_clean;
362 |
363 | Wmax_processed = Kmax ./ (Kmax + dBMax_processed - ...
364 | processed_energy(1:num_crit-1));
365 | Wlocmax_processed = Klocmax ./ ( Klocmax + processed_loc_peak - ...
366 | processed_energy(1:num_crit-1));
367 | W_processed = Wmax_processed .* Wlocmax_processed;
368 |
369 | W = (W_clean + W_processed)./2.0;
370 |
371 | distortion(frame_count) = sum(W.*(clean_slope(1:num_crit-1) - ...
372 | processed_slope(1:num_crit-1)).^2);
373 |
374 | % this normalization is not part of Klatt's paper, but helps
375 | % to normalize the measure. Here we scale the measure by the
376 | % sum of the weights.
377 |
378 | distortion(frame_count) = distortion(frame_count)/sum(W);
379 |
380 | start = start + skiprate;
381 |
382 | end
383 |
384 | %-----------------------------------------------
385 | function distortion = llr(clean_speech, processed_speech,sample_rate)
386 |
387 |
388 | % ----------------------------------------------------------------------
389 | % Check the length of the clean and processed speech. Must be the same.
390 | % ----------------------------------------------------------------------
391 |
392 | clean_length = length(clean_speech);
393 | processed_length = length(processed_speech);
394 |
395 | if (clean_length ~= processed_length)
396 | disp('Error: Both Speech Files must be same length.');
397 | return
398 | end
399 |
400 | % ----------------------------------------------------------------------
401 | % Global Variables
402 | % ----------------------------------------------------------------------
403 |
404 | % sample_rate = 8000; % default sample rate
405 | % winlength = 240; % window length in samples
406 | % skiprate = 60; % window skip in samples
407 | % P = 10; % LPC Analysis Order
408 | winlength = round(30*sample_rate/1000); % window length in samples
409 | skiprate = floor(winlength/4); % window skip in samples
410 | if sample_rate<10000
411 | P = 10; % LPC Analysis Order
412 | else
413 | P=16; % this could vary depending on sampling frequency.
414 | end
415 |
416 | % ----------------------------------------------------------------------
417 | % For each frame of input speech, calculate the Log Likelihood Ratio
418 | % ----------------------------------------------------------------------
419 |
420 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames
421 | start = 1; % starting sample
422 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1)));
423 |
424 | for frame_count = 1:num_frames
425 |
426 | % ----------------------------------------------------------
427 | % (1) Get the Frames for the test and reference speech.
428 | % Multiply by Hanning Window.
429 | % ----------------------------------------------------------
430 |
431 | clean_frame = clean_speech(start:start+winlength-1);
432 | processed_frame = processed_speech(start:start+winlength-1);
433 | clean_frame = clean_frame.*window;
434 | processed_frame = processed_frame.*window;
435 |
436 | % ----------------------------------------------------------
437 | % (2) Get the autocorrelation lags and LPC parameters used
438 | % to compute the LLR measure.
439 | % ----------------------------------------------------------
440 |
441 | [R_clean, Ref_clean, A_clean] = ...
442 | lpcoeff(clean_frame, P);
443 | [R_processed, Ref_processed, A_processed] = ...
444 | lpcoeff(processed_frame, P);
445 |
446 | % ----------------------------------------------------------
447 | % (3) Compute the LLR measure
448 | % ----------------------------------------------------------
449 |
450 | numerator = A_processed*toeplitz(R_clean)*A_processed';
451 | denominator = A_clean*toeplitz(R_clean)*A_clean';
452 | distortion(frame_count) = log(numerator/denominator);
453 | start = start + skiprate;
454 |
455 | end
456 |
457 | %---------------------------------------------
458 | function [acorr, refcoeff, lpparams] = lpcoeff(speech_frame, model_order)
459 |
460 | % ----------------------------------------------------------
461 | % (1) Compute Autocorrelation Lags
462 | % ----------------------------------------------------------
463 |
464 | winlength = max(size(speech_frame));
465 | for k=1:model_order+1
466 | R(k) = sum(speech_frame(1:winlength-k+1) ...
467 | .*speech_frame(k:winlength));
468 | end
469 |
470 | % ----------------------------------------------------------
471 | % (2) Levinson-Durbin
472 | % ----------------------------------------------------------
473 |
474 | a = ones(1,model_order);
475 | E(1)=R(1);
476 | for i=1:model_order
477 | a_past(1:i-1) = a(1:i-1);
478 | sum_term = sum(a_past(1:i-1).*R(i:-1:2));
479 | rcoeff(i)=(R(i+1) - sum_term) / E(i);
480 | a(i)=rcoeff(i);
481 | a(1:i-1) = a_past(1:i-1) - rcoeff(i).*a_past(i-1:-1:1);
482 | E(i+1)=(1-rcoeff(i)*rcoeff(i))*E(i);
483 | end
484 |
485 | acorr = R;
486 | refcoeff = rcoeff;
487 | lpparams = [1 -a];
488 |
489 |
490 | % ----------------------------------------------------------------------
491 |
492 | function [overall_snr, segmental_snr] = snr(clean_speech, processed_speech,sample_rate)
493 |
494 | % ----------------------------------------------------------------------
495 | % Check the length of the clean and processed speech. Must be the same.
496 | % ----------------------------------------------------------------------
497 |
498 | clean_length = length(clean_speech);
499 | processed_length = length(processed_speech);
500 |
501 | if (clean_length ~= processed_length)
502 | disp('Error: Both Speech Files must be same length.');
503 | return
504 | end
505 |
506 | % ----------------------------------------------------------------------
507 | % Scale both clean speech and processed speech to have same dynamic
508 | % range. Also remove DC component from each signal
509 | % ----------------------------------------------------------------------
510 |
511 | %clean_speech = clean_speech - mean(clean_speech);
512 | %processed_speech = processed_speech - mean(processed_speech);
513 |
514 | %processed_speech = processed_speech.*(max(abs(clean_speech))/ max(abs(processed_speech)));
515 |
516 | overall_snr = 10* log10( sum(clean_speech.^2)/sum((clean_speech-processed_speech).^2));
517 |
518 | % ----------------------------------------------------------------------
519 | % Global Variables
520 | % ----------------------------------------------------------------------
521 |
522 | % sample_rate = 8000; % default sample rate
523 | % winlength = 240; % window length in samples
524 | % skiprate = 60; % window skip in samples
525 | winlength = round(30*sample_rate/1000); %240; % window length in samples
526 | skiprate = floor(winlength/4); % window skip in samples
527 | MIN_SNR = -10; % minimum SNR in dB
528 | MAX_SNR = 35; % maximum SNR in dB
529 |
530 | % ----------------------------------------------------------------------
531 | % For each frame of input speech, calculate the Segmental SNR
532 | % ----------------------------------------------------------------------
533 |
534 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames
535 | start = 1; % starting sample
536 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1)));
537 |
538 | for frame_count = 1: num_frames
539 |
540 | % ----------------------------------------------------------
541 | % (1) Get the Frames for the test and reference speech.
542 | % Multiply by Hanning Window.
543 | % ----------------------------------------------------------
544 |
545 | clean_frame = clean_speech(start:start+winlength-1);
546 | processed_frame = processed_speech(start:start+winlength-1);
547 | clean_frame = clean_frame.*window;
548 | processed_frame = processed_frame.*window;
549 |
550 | % ----------------------------------------------------------
551 | % (2) Compute the Segmental SNR
552 | % ----------------------------------------------------------
553 |
554 | signal_energy = sum(clean_frame.^2);
555 | noise_energy = sum((clean_frame-processed_frame).^2);
556 | segmental_snr(frame_count) = 10*log10(signal_energy/(noise_energy+eps)+eps);
557 | segmental_snr(frame_count) = max(segmental_snr(frame_count),MIN_SNR);
558 | segmental_snr(frame_count) = min(segmental_snr(frame_count),MAX_SNR);
559 |
560 | start = start + skiprate;
561 |
562 | end
563 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration for train_interface
3 |
4 | You can check the essential information,
5 | and if you want to change model structure or training method,
6 | you have to change this file.
7 | """
8 | #######################################################################
9 | # path #
10 | #######################################################################
11 | job_dir = './models/'
12 | logs_dir = './logs/'
13 | chkpt_model = None # 'FILE PATH (if you have pretrained model..)'
14 | chkpt = str("EPOCH")
15 | if chkpt_model is not None:
16 | chkpt_path = job_dir + chkpt_model + '/chkpt_' + chkpt + '.pt'
17 |
18 | #######################################################################
19 | # possible setting #
20 | #######################################################################
21 | # the list you can do
22 | model_list = ['DCCRN', 'CRN', 'FullSubNet']
23 | loss_list = ['MSE', 'SDR', 'SI-SNR', 'SI-SDR']
24 | perceptual_list = [False, 'LMS', 'PMSQE']
25 | lstm_type = ['real', 'complex']
26 | main_net = ['LSTM', 'GRU']
27 | mask_type = ['Direct(None make)', 'E', 'C', 'R']
28 |
29 | # experiment number setting
30 | expr_num = 'EXPERIMENT_NUMBER'
31 | DEVICE = 'cuda' # if you want to run the code with 'cpu', change 'cpu'
32 | #######################################################################
33 | # current setting #
34 | #######################################################################
35 | model = model_list[0]
36 | loss = loss_list[1]
37 | perceptual = perceptual_list[0]
38 | lstm = lstm_type[1]
39 | sequence_model = main_net[0]
40 |
41 | masking_mode = mask_type[1]
42 | skip_type = True # False, if you want to remove 'skip connection'
43 |
44 | # hyper-parameters
45 | max_epochs = 100
46 | learning_rate = 0.001
47 | batch = 10
48 |
49 | # kernel size
50 | dccrn_kernel_num = [32, 64, 128, 256, 256, 256]
51 | #######################################################################
52 | # model information #
53 | #######################################################################
54 | fs = 16000
55 | win_len = 400
56 | win_inc = 100
57 | ola_ratio = 0.75
58 | fft_len = 512
59 | sam_sec = fft_len / fs
60 | frm_samp = fs * (fft_len / fs)
61 | window = 'hanning'
62 |
63 | # for DCCRN
64 | rnn_layers = 2
65 | rnn_units = 256
66 |
67 | # for CRN
68 | rnn_input_size = 512
69 |
70 | # for FullSubNet
71 | sb_num_neighbors = 15
72 | fb_num_neighbors = 0
73 | num_freqs = fft_len // 2 + 1
74 | look_ahead = 2
75 | fb_output_activate_function = "ReLU"
76 | sb_output_activate_function = None
77 | fb_model_hidden_size = 512
78 | sb_model_hidden_size = 384
79 | weight_init = False
80 | norm_type = "offline_laplace_norm"
81 | num_groups_in_drop_band = 2
82 | #######################################################################
83 | # setting error check #
84 | #######################################################################
85 | # if the setting is wrong, print error message
86 | assert not (masking_mode == 'Direct(None make)' and perceptual is not False), \
87 | "This setting is not created "
88 | assert not (model == 'FullSubNet' and perceptual is not False), \
89 | "This setting is not created "
90 |
91 | #######################################################################
92 | # print setting #
93 | #######################################################################
94 | print('-------------------- C O N F I G ----------------------')
95 | print('--------------------------------------------------------------')
96 | print('MODEL INFO : {}'.format(model))
97 | print('LOSS INFO : {}, perceptual : {}'.format(loss, perceptual))
98 | if model != 'FullSubNet':
99 | print('LSTM : {}'.format(lstm))
100 | print('SKIP : {}'.format(skip_type))
101 | print('MASKING INFO : {}'.format(masking_mode))
102 | else:
103 | print('Main network : {}'.format(sequence_model))
104 | print('\nBATCH : {}'.format(batch))
105 | print('LEARNING RATE : {}'.format(learning_rate))
106 | print('--------------------------------------------------------------')
107 | print('--------------------------------------------------------------\n')
108 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset, DataLoader
4 | import config as cfg
5 |
6 | # # If you don't set the data type to object when saving the data...
7 | # np_load_old = np.load
8 | # np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k)
9 |
10 |
11 | def create_dataloader(mode, type=0, snr=0):
12 | if mode == 'train':
13 | return DataLoader(
14 | dataset=Wave_Dataset(mode, type, snr),
15 | batch_size=cfg.batch,
16 | shuffle=True,
17 | num_workers=0,
18 | pin_memory=True,
19 | drop_last=True,
20 | sampler=None
21 | )
22 | elif mode == 'valid':
23 | return DataLoader(
24 | dataset=Wave_Dataset(mode, type, snr),
25 | batch_size=cfg.batch, shuffle=False, num_workers=0
26 | )
27 | elif mode == 'test':
28 | return DataLoader(
29 | dataset=Wave_Dataset(mode, type, snr),
30 | batch_size=cfg.batch, shuffle=False, num_workers=0
31 | )
32 |
33 |
34 | class Wave_Dataset(Dataset):
35 | def __init__(self, mode, type, snr):
36 | # load data
37 | if mode == 'train':
38 | self.mode = 'train'
39 | print('')
40 | print('Load the data...')
41 | self.input_path = "DATASET_FILE_PATH"
42 | self.input = np.load(self.input_path)
43 | elif mode == 'valid':
44 | self.mode = 'valid'
45 | print('')
46 | print('Load the data...')
47 | self.input_path = "DATASET_FILE_PATH"
48 | self.input = np.load(self.input_path)
49 | # # if you want to use a part of the dataset
50 | # self.input = self.input[:500]
51 | elif mode == 'test':
52 | self.mode = 'test'
53 | print('')
54 | print('Load the data...')
55 | self.input_path = "DATASET_FILE_PATH"
56 |
57 | self.input = np.load(self.input_path)
58 | self.input = self.input[type][snr]
59 |
60 | def __len__(self):
61 | return len(self.input)
62 |
63 | def __getitem__(self, idx):
64 | inputs = self.input[idx][0]
65 | targets = self.input[idx][1]
66 |
67 | # transform to torch from numpy
68 | inputs = torch.from_numpy(inputs)
69 | targets = torch.from_numpy(targets)
70 |
71 | return inputs, targets
72 |
--------------------------------------------------------------------------------
/estimation/check_object_metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | for checking speech quality with some metrics.
3 |
4 | 1. PESQ
5 | 2. STOI
6 | 3. CSIG, CBAK, COVL
7 | """
8 | import os
9 | from tools_for_estimate import cal_pesq, cal_stoi, composite
10 | from pathlib import Path
11 |
12 | # number of files we want to check
13 | flie_num = 1
14 |
15 | target_wav = ['.wav']
16 | estimated_wav = ['.wav']
17 |
18 | file_directory = '/'
19 |
20 | if flie_num == 1:
21 | pesq = cal_pesq(estimated_wav, target_wav)
22 | stoi = cal_stoi(estimated_wav, target_wav)
23 | CSIG, CBAK, CVOL, _ = composite(target_wav[0], estimated_wav[0])
24 |
25 | print('{} is ...'.format(estimated_wav[0]))
26 | print('PESQ {:.4} | STOI {:.4} | CSIG {:.4} | CBAK {:.4} | CVOL {:.4}'
27 | .format(pesq, stoi, CSIG, CBAK, CVOL))
28 | else:
29 | # the list of files in file directory
30 | if os.path.isdir(file_directory) is False:
31 | print("[Error] There is no directory '%s'." % file_directory)
32 | exit()
33 | else:
34 | print("Scanning a directory %s " % file_directory)
35 |
36 | # pick target wav from the directory
37 | target_addr = []
38 | for path, dir, files in os.walk(file_directory):
39 | for file in files:
40 | if file in 'target':
41 | filepath = Path(path) / file
42 | target_addr.append(filepath)
43 |
44 | for addr in target_addr:
45 | estimated_addr = str(addr).replace('target', 'estimated')
46 |
47 | pesq = cal_pesq([estimated_addr], [addr])
48 | stoi = cal_stoi([estimated_addr], [addr])
49 | CSIG, CBAK, CVOL, _ = composite(addr, estimated_addr)
50 |
51 | print('{} is ...'.format(estimated_addr))
52 | print('PESQ {:.4} | STOI {:.4} | CSIG {:.4} | CBAK {:.4} | CVOL {:.4}'
53 | .format(pesq, stoi, CSIG, CBAK, CVOL))
54 |
--------------------------------------------------------------------------------
/generate_noisy_data.py:
--------------------------------------------------------------------------------
1 | """
2 | generate noisy data with various noise files
3 | """
4 | import os
5 | import sys
6 | import numpy as np
7 | import scipy.io.wavfile as wav
8 | import librosa
9 | from pathlib import Path
10 | import soundfile
11 |
12 | #######################################################################
13 | # data info setting #
14 | #######################################################################
15 | # USE THIS, OR SYS.ARGVS
16 | # mode = 'train' # train / validation / test
17 | # snr_set = [0, 5]
18 | # fs = 16000
19 |
20 | #######################################################################
21 | # main #
22 | #######################################################################
23 | def scan_directory(dir_name):
24 | """Scan directory and save address of clean/noisy wav data.
25 | Args:
26 | dir_name: directroy name to scan
27 | Returns:
28 | addr: all address list of clean/noisy wave data in subdirectory
29 | """
30 | if os.path.isdir(dir_name) is False:
31 | print("[Error] There is no directory '%s'." % dir_name)
32 | exit()
33 | else:
34 | print("Scanning a directory %s " % dir_name)
35 |
36 | addr = []
37 | for subdir, dirs, files in os.walk(dir_name):
38 | for file in files:
39 | if file.endswith(".wav"):
40 | filepath = Path(subdir) / file
41 | addr.append(filepath)
42 | return addr
43 |
44 |
45 | # Generate noisy data given speech, noise, and target SNR.
46 | def generate_noisy_wav(wav_speech, wav_noise, snr):
47 | # Obtain the length of speech and noise components.
48 | len_speech = len(wav_speech)
49 | len_noise = len(wav_noise)
50 |
51 | # Select noise segment randomly to have same length with speech signal.
52 | st = np.random.randint(0, len_noise - len_speech)
53 | ed = st + len_speech
54 | wav_noise = wav_noise[st:ed]
55 |
56 | # Compute the power of speech and noise after removing DC bias.
57 | dc_speech = np.mean(wav_speech)
58 | dc_noise = np.mean(wav_noise)
59 | pow_speech = np.mean(np.power(wav_speech - dc_speech, 2.0))
60 | pow_noise = np.mean(np.power(wav_noise - dc_noise, 2.0))
61 |
62 | # Compute the scale factor of noise component depending on the target SNR.
63 | alpha = np.sqrt(10.0 ** (float(-snr) / 10.0) * pow_speech / (pow_noise + 1e-6))
64 | noisy_wav = (wav_speech + alpha * wav_noise) * 32768
65 | noisy_wav = noisy_wav.astype(np.int16)
66 |
67 | return noisy_wav
68 |
69 |
70 | def main():
71 | argvs = sys.argv[1:]
72 | if len(argvs) != 3:
73 | print('Error: Invalid input arguments')
74 | print('\t Usage: python generate_noisy_data.py [mode] [snr] [fs]')
75 | print("\t\t [mode]: 'train', 'validation'")
76 | print("\t\t [snr]: '0', '0, 5', ...'")
77 | print("\t\t [fs]: '16000', ...")
78 | exit()
79 | mode = argvs[0]
80 | snr_set = argvs[1].split(',')
81 | fs = int(argvs[2])
82 |
83 | # Set speech and noise directory.
84 | speech_dir = Path("./")
85 |
86 | # Make a speech file list.
87 | speech_mode_clean_dir = speech_dir / mode / 'clean'
88 | speech_mode_noisy_dir = speech_dir / mode / 'noisy'
89 | list_speech_files = scan_directory(speech_mode_clean_dir)
90 |
91 | # Make directories of the mode and noisy data.
92 | if os.path.isdir(speech_mode_clean_dir) is False:
93 | os.system('mkdir ' + str(speech_mode_clean_dir))
94 |
95 | if os.path.isdir(speech_mode_noisy_dir) is False:
96 | os.system('mkdir ' + str(speech_mode_noisy_dir))
97 |
98 | # Define a log file name.
99 | log_file_name = Path("./log_generate_data_" + mode + ".txt")
100 | f = open(log_file_name, 'w')
101 |
102 | if mode == 'train':
103 | # Make a noise file list
104 | noise_subset_dir = speech_dir / 'train' / 'noise'
105 | list_noise_files = scan_directory(noise_subset_dir)
106 | for snr_in_db in snr_set:
107 | for addr_speech in list_speech_files:
108 | # Load speech waveform and its sampling frequency.
109 | wav_speech, read_fs = soundfile.read(addr_speech)
110 | if read_fs != fs:
111 | wav_speech = librosa.resample(wav_speech, read_fs, fs)
112 |
113 | # Select a noise component randomly, and read it.
114 | nidx = np.random.randint(0, len(list_noise_files))
115 | addr_noise = list_noise_files[nidx]
116 | wav_noise, read_fs = soundfile.read(addr_noise)
117 | if wav_noise.ndim > 1:
118 | wav_noise = wav_noise.mean(axis=1)
119 | if read_fs != fs:
120 | wav_noise = librosa.resample(wav_noise, read_fs, fs)
121 |
122 | # Generate noisy speech by mixing speech and noise components.
123 | wav_noisy = generate_noisy_wav(wav_speech, wav_noise, int(snr_in_db))
124 | noisy_name = Path(addr_speech).name[:-4] +'_' + Path(addr_noise).name[:-4] + '_' + str(
125 | int(snr_in_db)) + '.wav'
126 | addr_noisy = speech_mode_noisy_dir / noisy_name
127 | wav.write(addr_noisy, fs, wav_noisy)
128 |
129 | # Display progress.
130 | print('%s > %s' % (addr_speech, addr_noisy))
131 | f.write('%s\t%s\t%s\t%d dB\n' % (addr_noisy, addr_speech, addr_noise, int(snr_in_db)))
132 |
133 | elif mode == 'validation':
134 | # Make a noise file list for validation.
135 | noise_subset_dir = speech_dir / 'train' / 'noise'
136 | list_noise_files = scan_directory(noise_subset_dir)
137 |
138 | for addr_speech in list_speech_files:
139 | # Load speech waveform and its sampling frequency.
140 | wav_speech, read_fs = soundfile.read(addr_speech)
141 | if read_fs != fs:
142 | wav_speech = librosa.resample(wav_speech, read_fs, fs)
143 |
144 | # Select a noise component randomly, and read it.
145 | nidx = np.random.randint(0, len(list_noise_files))
146 | addr_noise = list_noise_files[nidx]
147 | wav_noise, read_fs = soundfile.read(addr_noise)
148 | if wav_noise.ndim > 1:
149 | wav_noise = wav_noise.mean(axis=1)
150 | if read_fs != fs:
151 | wav_noise = librosa.resample(wav_noise, read_fs, fs)
152 |
153 | # Select an SNR randomly.
154 | ridx_snr = np.random.randint(0, len(snr_set))
155 | snr_in_db = int(snr_set[ridx_snr])
156 |
157 | # Generate noisy speech by mixing speech and noise components.
158 | wav_noisy = generate_noisy_wav(wav_speech, wav_noise, snr_in_db)
159 |
160 | # Write the generated noisy speech into a file.
161 | noisy_name = Path(addr_speech).name[:-4] + '_' + Path(addr_noise).name[:-4] + '_' + str(
162 | snr_in_db) + '.wav'
163 | addr_noisy = speech_mode_noisy_dir / noisy_name
164 | wav.write(addr_noisy, fs, wav_noisy)
165 |
166 | # Display progress.
167 | print('%s > %s' % (addr_speech, addr_noisy))
168 | f.write('%s\t%s\t%s\t%d dB\n' % (addr_noisy, addr_speech, addr_noise, snr_in_db))
169 | f.close()
170 |
171 |
172 | if __name__ == '__main__':
173 | main()
174 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from tools_for_model import ConvSTFT, ConviSTFT, \
5 | ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm, \
6 | RealConv2d, RealConvTranspose2d, \
7 | BaseModel, SequenceModel
8 | import config as cfg
9 | from tools_for_loss import sdr, si_sdr, si_snr, get_array_lms_loss, get_array_pmsqe_loss
10 |
11 |
12 | #######################################################################
13 | # complex network #
14 | #######################################################################
15 | class DCCRN(nn.Module):
16 |
17 | def __init__(
18 | self,
19 | rnn_layers=cfg.rnn_layers,
20 | rnn_units=cfg.rnn_units,
21 | win_len=cfg.win_len,
22 | win_inc=cfg.win_inc,
23 | fft_len=cfg.fft_len,
24 | win_type=cfg.window,
25 | masking_mode=cfg.masking_mode,
26 | use_cbn=False,
27 | kernel_size=5
28 | ):
29 | '''
30 | rnn_layers: the number of lstm layers in the crn,
31 | rnn_units: for clstm, rnn_units = real+imag
32 | '''
33 |
34 | super(DCCRN, self).__init__()
35 |
36 | # for fft
37 | self.win_len = win_len
38 | self.win_inc = win_inc
39 | self.fft_len = fft_len
40 | self.win_type = win_type
41 |
42 | input_dim = win_len
43 | output_dim = win_len
44 |
45 | self.rnn_units = rnn_units
46 | self.input_dim = input_dim
47 | self.output_dim = output_dim
48 | self.hidden_layers = rnn_layers
49 | self.kernel_size = kernel_size
50 | kernel_num = cfg.dccrn_kernel_num
51 | self.kernel_num = [2] + kernel_num
52 | self.masking_mode = masking_mode
53 |
54 | # bidirectional=True
55 | bidirectional = False
56 | fac = 2 if bidirectional else 1
57 |
58 | fix = True
59 | self.fix = fix
60 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
61 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix)
62 |
63 | self.encoder = nn.ModuleList()
64 | self.decoder = nn.ModuleList()
65 | for idx in range(len(self.kernel_num) - 1):
66 | self.encoder.append(
67 | nn.Sequential(
68 | # nn.ConstantPad2d([0, 0, 0, 0], 0),
69 | ComplexConv2d(
70 | self.kernel_num[idx],
71 | self.kernel_num[idx + 1],
72 | kernel_size=(self.kernel_size, 2),
73 | stride=(2, 1),
74 | padding=(2, 1)
75 | ),
76 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm(
77 | self.kernel_num[idx + 1]),
78 | nn.PReLU()
79 | )
80 | )
81 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num)))
82 |
83 | if cfg.lstm == 'complex':
84 | rnns = []
85 | for idx in range(rnn_layers):
86 | rnns.append(
87 | NavieComplexLSTM(
88 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units,
89 | hidden_size=self.rnn_units,
90 | bidirectional=bidirectional,
91 | batch_first=False,
92 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None,
93 | )
94 | )
95 | self.enhance = nn.Sequential(*rnns)
96 | else:
97 | self.enhance = nn.LSTM(
98 | input_size=hidden_dim * self.kernel_num[-1],
99 | hidden_size=self.rnn_units,
100 | num_layers=2,
101 | dropout=0.0,
102 | bidirectional=bidirectional,
103 | batch_first=False
104 | )
105 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1])
106 |
107 | if cfg.skip_type:
108 | for idx in range(len(self.kernel_num) - 1, 0, -1):
109 | if idx != 1:
110 | self.decoder.append(
111 | nn.Sequential(
112 | ComplexConvTranspose2d(
113 | self.kernel_num[idx] * 2,
114 | self.kernel_num[idx - 1],
115 | kernel_size=(self.kernel_size, 2),
116 | stride=(2, 1),
117 | padding=(2, 0),
118 | output_padding=(1, 0)
119 | ),
120 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm(
121 | self.kernel_num[idx - 1]),
122 | nn.PReLU()
123 | )
124 | )
125 | else:
126 | self.decoder.append(
127 | nn.Sequential(
128 | ComplexConvTranspose2d(
129 | self.kernel_num[idx] * 2,
130 | self.kernel_num[idx - 1],
131 | kernel_size=(self.kernel_size, 2),
132 | stride=(2, 1),
133 | padding=(2, 0),
134 | output_padding=(1, 0)
135 | ),
136 | )
137 | )
138 | else: # you can erase the skip connection
139 | for idx in range(len(self.kernel_num) - 1, 0, -1):
140 | if idx != 1:
141 | self.decoder.append(
142 | nn.Sequential(
143 | ComplexConvTranspose2d(
144 | self.kernel_num[idx],
145 | self.kernel_num[idx - 1],
146 | kernel_size=(self.kernel_size, 2),
147 | stride=(2, 1),
148 | padding=(2, 0),
149 | output_padding=(1, 0)
150 | ),
151 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm(
152 | self.kernel_num[idx - 1]),
153 | # nn.ELU()
154 | nn.PReLU()
155 | )
156 | )
157 | else:
158 | self.decoder.append(
159 | nn.Sequential(
160 | ComplexConvTranspose2d(
161 | self.kernel_num[idx],
162 | self.kernel_num[idx - 1],
163 | kernel_size=(self.kernel_size, 2),
164 | stride=(2, 1),
165 | padding=(2, 0),
166 | output_padding=(1, 0)
167 | ),
168 | )
169 | )
170 | self.flatten_parameters()
171 |
172 | def flatten_parameters(self):
173 | if isinstance(self.enhance, nn.LSTM):
174 | self.enhance.flatten_parameters()
175 |
176 | def forward(self, inputs, targets=0):
177 | specs = self.stft(inputs)
178 | real = specs[:, :self.fft_len // 2 + 1]
179 | imag = specs[:, self.fft_len // 2 + 1:]
180 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8)
181 |
182 | spec_phase = torch.atan2(imag, real)
183 | cspecs = torch.stack([real, imag], 1)
184 | cspecs = cspecs[:, :, 1:]
185 | '''
186 | means = torch.mean(cspecs, [1,2,3], keepdim=True)
187 | std = torch.std(cspecs, [1,2,3], keepdim=True )
188 | normed_cspecs = (cspecs-means)/(std+1e-8)
189 | out = normed_cspecs
190 | '''
191 |
192 | out = cspecs
193 | encoder_out = []
194 |
195 | for idx, layer in enumerate(self.encoder):
196 | out = layer(out)
197 | # print('encoder', out.size())
198 | encoder_out.append(out)
199 |
200 | batch_size, channels, dims, lengths = out.size()
201 | out = out.permute(3, 0, 1, 2)
202 | if cfg.lstm == 'complex':
203 | r_rnn_in = out[:, :, :channels // 2]
204 | i_rnn_in = out[:, :, channels // 2:]
205 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims])
206 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims])
207 |
208 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in])
209 |
210 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims])
211 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims])
212 | out = torch.cat([r_rnn_in, i_rnn_in], 2)
213 | else:
214 | # to [L, B, C, D]
215 | out = torch.reshape(out, [lengths, batch_size, channels * dims])
216 | out, _ = self.enhance(out)
217 | out = self.tranform(out)
218 | out = torch.reshape(out, [lengths, batch_size, channels, dims])
219 |
220 | out = out.permute(1, 2, 3, 0)
221 |
222 | if cfg.skip_type: # use skip connection
223 | for idx in range(len(self.decoder)):
224 | out = complex_cat([out, encoder_out[-1 - idx]], 1)
225 | out = self.decoder[idx](out)
226 | out = out[..., 1:] #
227 | else:
228 | for idx in range(len(self.decoder)):
229 | out = self.decoder[idx](out)
230 | out = out[..., 1:]
231 |
232 | if self.masking_mode == 'Direct(None make)':
233 | # for loss calculation
234 | target_specs = self.stft(targets)
235 | target_real = target_specs[:, :self.fft_len // 2 + 1]
236 | target_imag = target_specs[:, self.fft_len // 2 + 1:]
237 |
238 | # spectral mapping
239 | out_real = out[:, 0]
240 | out_imag = out[:, 1]
241 | out_real = F.pad(out_real, [0, 0, 1, 0])
242 | out_imag = F.pad(out_imag, [0, 0, 1, 0])
243 |
244 | out_spec = torch.cat([out_real, out_imag], 1)
245 |
246 | out_wav = self.istft(out_spec)
247 | out_wav = torch.squeeze(out_wav, 1)
248 | out_wav = torch.clamp_(out_wav, -1, 1)
249 |
250 | return out_real, target_real, out_imag, target_imag, out_wav
251 | else:
252 | # print('decoder', out.size())
253 | mask_real = out[:, 0]
254 | mask_imag = out[:, 1]
255 | mask_real = F.pad(mask_real, [0, 0, 1, 0])
256 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
257 |
258 | if self.masking_mode == 'E':
259 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5
260 | real_phase = mask_real / (mask_mags + 1e-8)
261 | imag_phase = mask_imag / (mask_mags + 1e-8)
262 | mask_phase = torch.atan2(
263 | imag_phase,
264 | real_phase
265 | )
266 |
267 | # mask_mags = torch.clamp_(mask_mags,0,100)
268 | mask_mags = torch.tanh(mask_mags)
269 | est_mags = mask_mags * spec_mags
270 | est_phase = spec_phase + mask_phase
271 | out_real = est_mags * torch.cos(est_phase)
272 | out_imag = est_mags * torch.sin(est_phase)
273 | elif self.masking_mode == 'C':
274 | out_real, out_imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real
275 | elif self.masking_mode == 'R':
276 | out_real, out_imag = real * mask_real, imag * mask_imag
277 |
278 | out_spec = torch.cat([out_real, out_imag], 1)
279 |
280 | out_wav = self.istft(out_spec)
281 | out_wav = torch.squeeze(out_wav, 1)
282 | out_wav = torch.clamp_(out_wav, -1, 1)
283 |
284 | return out_real, out_imag, out_wav
285 |
286 | def get_params(self, weight_decay=0.0):
287 | # add L2 penalty
288 | weights, biases = [], []
289 | for name, param in self.named_parameters():
290 | if 'bias' in name:
291 | biases += [param]
292 | else:
293 | weights += [param]
294 | params = [{
295 | 'params': weights,
296 | 'weight_decay': weight_decay,
297 | }, {
298 | 'params': biases,
299 | 'weight_decay': 0.0,
300 | }]
301 | return params
302 |
303 | def loss(self, estimated, target, real_spec=0, img_spec=0, perceptual=False):
304 | if perceptual:
305 | if cfg.perceptual == 'LMS':
306 | clean_specs = self.stft(target)
307 | clean_real = clean_specs[:, :self.fft_len // 2 + 1]
308 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:]
309 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7)
310 |
311 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7)
312 | return get_array_lms_loss(clean_mags, est_clean_mags)
313 | elif cfg.perceptual == 'PMSQE':
314 | return get_array_pmsqe_loss(target, estimated)
315 | else:
316 | if cfg.loss == 'MSE':
317 | return F.mse_loss(estimated, target, reduction='mean')
318 | elif cfg.loss == 'SDR':
319 | return -sdr(target, estimated)
320 | elif cfg.loss == 'SI-SNR':
321 | return -(si_snr(estimated, target))
322 | elif cfg.loss == 'SI-SDR':
323 | return -(si_sdr(target, estimated))
324 |
325 |
326 | #######################################################################
327 | # real network #
328 | #######################################################################
329 | class CRN(nn.Module):
330 | def __init__(
331 | self,
332 | rnn_layers=cfg.rnn_layers,
333 | rnn_input_size=cfg.rnn_input_size,
334 | rnn_units=cfg.rnn_units,
335 | win_len=cfg.win_len,
336 | win_inc=cfg.win_inc,
337 | fft_len=cfg.fft_len,
338 | win_type=cfg.window,
339 | masking_mode=cfg.masking_mode,
340 | kernel_size=5
341 | ):
342 | '''
343 | rnn_layers: the number of lstm layers in the crn
344 | '''
345 |
346 | super(CRN, self).__init__()
347 |
348 | # for fft
349 | self.win_len = win_len
350 | self.win_inc = win_inc
351 | self.fft_len = fft_len
352 | self.win_type = win_type
353 |
354 | input_dim = win_len
355 | output_dim = win_len
356 |
357 | self.rnn_input_size = rnn_input_size
358 | self.rnn_units = rnn_units//2
359 | self.input_dim = input_dim
360 | self.output_dim = output_dim
361 | self.hidden_layers = rnn_layers
362 | self.kernel_size = kernel_size
363 | kernel_num = cfg.dccrn_kernel_num
364 | self.kernel_num = [2] + kernel_num
365 | self.masking_mode = masking_mode
366 |
367 | # bidirectional=True
368 | bidirectional = False
369 |
370 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'real')
371 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex')
372 |
373 | self.encoder = nn.ModuleList()
374 | self.decoder = nn.ModuleList()
375 | for idx in range(len(self.kernel_num) - 1):
376 | self.encoder.append(
377 | nn.Sequential(
378 | RealConv2d(
379 | self.kernel_num[idx] // 2,
380 | self.kernel_num[idx + 1] // 2,
381 | kernel_size=(self.kernel_size, 2),
382 | stride=(2, 1),
383 | padding=(2, 1)
384 | ),
385 | nn.BatchNorm2d(self.kernel_num[idx + 1] // 2),
386 | nn.PReLU()
387 | )
388 | )
389 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num)))
390 |
391 | self.enhance = nn.LSTM(
392 | input_size=self.rnn_input_size,
393 | hidden_size=self.rnn_units,
394 | dropout=0.0,
395 | bidirectional=bidirectional,
396 | batch_first=False
397 | )
398 | self.tranform = nn.Linear(self.rnn_units, self.rnn_input_size)
399 |
400 | if cfg.skip_type:
401 | for idx in range(len(self.kernel_num) - 1, 0, -1):
402 | if idx != 1:
403 | self.decoder.append(
404 | nn.Sequential(
405 | RealConvTranspose2d(
406 | self.kernel_num[idx],
407 | self.kernel_num[idx - 1] // 2,
408 | kernel_size=(self.kernel_size, 2),
409 | stride=(2, 1),
410 | padding=(2, 0),
411 | output_padding=(1, 0)
412 | ),
413 | nn.BatchNorm2d(self.kernel_num[idx - 1] // 2),
414 | nn.PReLU()
415 | )
416 | )
417 | else:
418 | self.decoder.append(
419 | nn.Sequential(
420 | RealConvTranspose2d(
421 | self.kernel_num[idx],
422 | self.kernel_num[idx - 1] // 2,
423 | kernel_size=(self.kernel_size, 2),
424 | stride=(2, 1),
425 | padding=(2, 0),
426 | output_padding=(1, 0)
427 | ),
428 | )
429 | )
430 | else:
431 | for idx in range(len(self.kernel_num) - 1, 0, -1):
432 | if idx != 1:
433 | self.decoder.append(
434 | nn.Sequential(
435 | nn.ConvTranspose2d(
436 | self.kernel_num[idx],
437 | self.kernel_num[idx - 1],
438 | kernel_size=(self.kernel_size, 2),
439 | stride=(2, 1),
440 | padding=(2, 0),
441 | output_padding=(1, 0)
442 | ),
443 | nn.BatchNorm2d(self.kernel_num[idx - 1]),
444 | # nn.ELU()
445 | nn.PReLU()
446 | )
447 | )
448 | else:
449 | self.decoder.append(
450 | nn.Sequential(
451 | nn.ConvTranspose2d(
452 | self.kernel_num[idx],
453 | self.kernel_num[idx - 1],
454 | kernel_size=(self.kernel_size, 2),
455 | stride=(2, 1),
456 | padding=(2, 0),
457 | output_padding=(1, 0)
458 | ),
459 | )
460 | )
461 | self.flatten_parameters()
462 |
463 | def flatten_parameters(self):
464 | if isinstance(self.enhance, nn.LSTM):
465 | self.enhance.flatten_parameters()
466 |
467 | def forward(self, inputs, targets=0):
468 | mags, phase = self.stft(inputs)
469 |
470 | out = mags
471 | out = out.unsqueeze(1)
472 | out = out[:, :, 1:]
473 | encoder_out = []
474 |
475 | for idx, layer in enumerate(self.encoder):
476 | out = layer(out)
477 | # print('encoder', out.size())
478 | encoder_out.append(out)
479 |
480 | batch_size, channels, dims, lengths = out.size()
481 | out = out.permute(3, 0, 1, 2)
482 |
483 | rnn_in = torch.reshape(out, [lengths, batch_size, channels * dims])
484 | out, _ = self.enhance(rnn_in)
485 | out = self.tranform(out)
486 | out = torch.reshape(out, [lengths, batch_size, channels, dims])
487 |
488 | out = out.permute(1, 2, 3, 0)
489 |
490 | if cfg.skip_type: # use skip connection
491 | for idx in range(len(self.decoder)):
492 | out = torch.cat([out, encoder_out[-1 - idx]], 1)
493 | out = self.decoder[idx](out)
494 | out = out[..., 1:] #
495 | else:
496 | for idx in range(len(self.decoder)):
497 | out = self.decoder[idx](out)
498 | out = out[..., 1:]
499 |
500 | # mask_mags = F.pad(out, [0, 0, 1, 0])
501 | out = out.squeeze(1)
502 | out = F.pad(out, [0, 0, 1, 0])
503 |
504 | # for loss calculation
505 | target_mags, _ = self.stft(targets)
506 |
507 | if self.masking_mode == 'Direct(None make)': # spectral mapping
508 | out_real = out * torch.cos(phase)
509 | out_imag = out * torch.sin(phase)
510 |
511 | out_spec = torch.cat([out_real, out_imag], 1)
512 |
513 | out_wav = self.istft(out_spec)
514 | out_wav = torch.squeeze(out_wav, 1)
515 | out_wav = torch.clamp_(out_wav, -1, 1)
516 |
517 | return out, target_mags, out_wav
518 | else: # T-F masking
519 | # mask_mags = torch.clamp_(mask_mags,0,100)
520 | # out = F.pad(out, [0, 0, 1, 0])
521 | mask_mags = torch.tanh(out)
522 | est_mags = mask_mags * mags
523 | out_real = est_mags * torch.cos(phase)
524 | out_imag = est_mags * torch.sin(phase)
525 |
526 | out_spec = torch.cat([out_real, out_imag], 1)
527 |
528 | out_wav = self.istft(out_spec)
529 | out_wav = torch.squeeze(out_wav, 1)
530 | out_wav = torch.clamp_(out_wav, -1, 1)
531 |
532 | return est_mags, target_mags, out_wav
533 |
534 | def get_params(self, weight_decay=0.0):
535 | # add L2 penalty
536 | weights, biases = [], []
537 | for name, param in self.named_parameters():
538 | if 'bias' in name:
539 | biases += [param]
540 | else:
541 | weights += [param]
542 | params = [{
543 | 'params': weights,
544 | 'weight_decay': weight_decay,
545 | }, {
546 | 'params': biases,
547 | 'weight_decay': 0.0,
548 | }]
549 | return params
550 |
551 | def loss(self, estimated, target, out_mags=0, target_mags=0, perceptual=False):
552 | if perceptual:
553 | if cfg.perceptual == 'LMS':
554 | return get_array_lms_loss(target_mags, out_mags)
555 | elif cfg.perceptual == 'PMSQE':
556 | return get_array_pmsqe_loss(target, estimated)
557 | else:
558 | if cfg.loss == 'MSE':
559 | return F.mse_loss(estimated, target, reduction='mean')
560 | elif cfg.loss == 'SDR':
561 | return -sdr(target, estimated)
562 | elif cfg.loss == 'SI-SNR':
563 | return -(si_snr(estimated, target))
564 | elif cfg.loss == 'SI-SDR':
565 | return -(si_sdr(target, estimated))
566 |
567 |
568 | class FullSubNet(BaseModel):
569 | def __init__(self,
570 | sb_num_neighbors=cfg.sb_num_neighbors,
571 | fb_num_neighbors=cfg.fb_num_neighbors,
572 | num_freqs=cfg.num_freqs,
573 | look_ahead=cfg.look_ahead,
574 | sequence_model=cfg.sequence_model,
575 | fb_output_activate_function=cfg.fb_output_activate_function,
576 | sb_output_activate_function=cfg.sb_output_activate_function,
577 | fb_model_hidden_size=cfg.fb_model_hidden_size,
578 | sb_model_hidden_size=cfg.sb_model_hidden_size,
579 | weight_init=cfg.weight_init,
580 | norm_type=cfg.norm_type,
581 | ):
582 | """
583 | FullSubNet model (cIRM mask)
584 |
585 | Args:
586 | num_freqs: Frequency dim of the input
587 | look_ahead: Number of use of the future frames
588 | fb_num_neighbors: How much neighbor frequencies at each side from fullband model's output
589 | sb_num_neighbors: How much neighbor frequencies at each side from noisy spectrogram
590 | sequence_model: Chose one sequence model as the basic model e.g., GRU, LSTM
591 | fb_output_activate_function: fullband model's activation function
592 | sb_output_activate_function: subband model's activation function
593 | norm_type: type of normalization, see more details in "BaseModel" class
594 | """
595 | super().__init__()
596 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM."
597 |
598 | self.fb_model = SequenceModel(
599 | input_size=num_freqs,
600 | output_size=num_freqs,
601 | hidden_size=fb_model_hidden_size,
602 | num_layers=2,
603 | bidirectional=False,
604 | sequence_model=sequence_model,
605 | output_activate_function=fb_output_activate_function
606 | )
607 |
608 | self.sb_model = SequenceModel(
609 | input_size=(sb_num_neighbors * 2 + 1) + (fb_num_neighbors * 2 + 1),
610 | output_size=2,
611 | hidden_size=sb_model_hidden_size,
612 | num_layers=2,
613 | bidirectional=False,
614 | sequence_model=sequence_model,
615 | output_activate_function=sb_output_activate_function
616 | )
617 |
618 | self.sb_num_neighbors = sb_num_neighbors
619 | self.fb_num_neighbors = fb_num_neighbors
620 | self.look_ahead = look_ahead
621 | self.norm = self.norm_wrapper(norm_type)
622 |
623 | if weight_init:
624 | self.apply(self.weight_init)
625 |
626 | def forward(self, noisy_mag):
627 | """
628 | Args:
629 | noisy_mag: noisy magnitude spectrogram
630 |
631 | Returns:
632 | The real part and imag part of the enhanced spectrogram
633 |
634 | Shapes:
635 | noisy_mag: [B, 1, F, T]
636 | return: [B, 2, F, T]
637 | """
638 | if not noisy_mag.dim() == 4:
639 | noisy_mag = noisy_mag.unsqueeze(1)
640 | noisy_mag = F.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
641 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
642 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs."
643 |
644 | # Fullband model
645 | fb_input = self.norm(noisy_mag).reshape(batch_size, num_channels * num_freqs, num_frames)
646 | fb_output = self.fb_model(fb_input).reshape(batch_size, 1, num_freqs, num_frames)
647 |
648 | # Unfold fullband model's output, [B, N=F, C, F_f, T]. N is the number of sub-band units
649 | fb_output_unfolded = self.unfold(fb_output, num_neighbor=self.fb_num_neighbors)
650 | fb_output_unfolded = fb_output_unfolded.reshape(batch_size, num_freqs, self.fb_num_neighbors * 2 + 1, num_frames)
651 |
652 | # Unfold noisy spectrogram, [B, N=F, C, F_s, T]
653 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors)
654 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, num_frames)
655 |
656 | # Concatenation, [B, F, (F_s + F_f), T]
657 | sb_input = torch.cat([noisy_mag_unfolded, fb_output_unfolded], dim=2)
658 | sb_input = self.norm(sb_input)
659 |
660 | sb_input = sb_input.reshape(
661 | batch_size * num_freqs,
662 | (self.sb_num_neighbors * 2 + 1) + (self.fb_num_neighbors * 2 + 1),
663 | num_frames
664 | )
665 |
666 | # [B * F, (F_s + F_f), T] => [B * F, 2, T] => [B, F, 2, T]
667 | sb_mask = self.sb_model(sb_input)
668 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous()
669 |
670 | output = sb_mask[:, :, :, self.look_ahead:]
671 | output = output.permute(0, 2, 3, 1)
672 | return output
673 |
674 | def loss(self, estimated, target):
675 | if cfg.loss == 'MSE':
676 | return F.mse_loss(estimated, target, reduction='mean')
677 | elif cfg.loss == 'SDR':
678 | return -sdr(target, estimated)
679 | elif cfg.loss == 'SI-SNR':
680 | return -(si_snr(estimated, target))
681 | elif cfg.loss == 'SI-SDR':
682 | return -(si_sdr(target, estimated))
683 |
684 |
--------------------------------------------------------------------------------
/tools_for_estimate.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | from pesq import pesq
4 | import numpy as np
5 | import ctypes
6 | import logging
7 | import oct2py
8 | from scipy.io import wavfile
9 | from pystoi import stoi
10 | import config as cfg
11 |
12 |
13 | ############################################################################
14 | # MOS #
15 | ############################################################################
16 | # Reference
17 | # https://github.com/usimarit/semetrics # https://ecs.utdallas.edu/loizou/speech/software.htm
18 | logging.basicConfig(level=logging.ERROR)
19 | oc = oct2py.Oct2Py(logger=logging.getLogger())
20 |
21 | COMPOSITE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "composite.m")
22 |
23 |
24 | def composite(clean: str, enhanced: str):
25 | pesq_score = pesq_mos(clean, enhanced)
26 | csig, cbak, covl, ssnr = oc.feval(COMPOSITE, clean, enhanced, nout=4)
27 | csig += 0.603 * pesq_score
28 | cbak += 0.478 * pesq_score
29 | covl += 0.805 * pesq_score
30 | return csig, cbak, covl, ssnr
31 |
32 |
33 | ############################################################################
34 | # PESQ #
35 | ############################################################################
36 | # Reference
37 | # https://github.com/usimarit/semetrics
38 | # https://ecs.utdallas.edu/loizou/speech/software.htm
39 |
40 | def pesq_mos(clean: str, enhanced: str):
41 | sr1, clean_wav = wavfile.read(clean)
42 | sr2, enhanced_wav = wavfile.read(enhanced)
43 | assert sr1 == sr2
44 | mode = "nb" if sr1 < 16000 else "wb"
45 | return pesq(sr1, clean_wav, enhanced_wav, mode)
46 |
47 |
48 | ###############################################################################
49 | # PESQ (another ref) #
50 | ###############################################################################
51 | pesq_dll = ctypes.CDLL('./PESQ.so')
52 | pesq_dll.pesq.restype = ctypes.c_double
53 |
54 |
55 | # interface to PESQ evaluation, taking in two filenames as input
56 | def run_pesq_filenames(clean, to_eval):
57 | pesq_regex = re.compile("\(MOS-LQO\): = ([0-9]+\.[0-9]+)")
58 |
59 | pesq_out = os.popen("./PESQ" + cfg.fs + "wb " + clean + " " + to_eval).read()
60 | regex_result = pesq_regex.search(pesq_out)
61 |
62 | if (regex_result is None):
63 | return 0.0
64 | else:
65 | return float(regex_result.group(1))
66 |
67 |
68 | def run_pesq_waveforms(dirty_wav, clean_wav):
69 | clean_wav = clean_wav.astype(np.double)
70 | dirty_wav = dirty_wav.astype(np.double)
71 | # return pesq(clean_wav, dirty_wav, fs=8000)
72 | return pesq_dll.pesq(ctypes.c_void_p(clean_wav.ctypes.data),
73 | ctypes.c_void_p(dirty_wav.ctypes.data),
74 | len(clean_wav),
75 | len(dirty_wav))
76 |
77 |
78 | # interface to PESQ evaluation, taking in two waveforms as input
79 | def cal_pesq(dirty_wavs, clean_wavs):
80 | scores = []
81 | for i in range(len(dirty_wavs)):
82 | pesq = run_pesq_waveforms(dirty_wavs[i], clean_wavs[i])
83 | scores.append(pesq)
84 | return scores
85 |
86 |
87 | ###############################################################################
88 | # STOI #
89 | ###############################################################################
90 | def cal_stoi(estimated_speechs, clean_speechs):
91 | stoi_scores = []
92 | for i in range(len(estimated_speechs)):
93 | stoi_score = stoi(clean_speechs[i], estimated_speechs[i], cfg.fs, extended=False)
94 | stoi_scores.append(stoi_score)
95 | return stoi_scores
96 |
97 |
98 | ###############################################################################
99 | # SNR #
100 | ###############################################################################
101 | def cal_snr(s1, s2, eps=1e-8):
102 | signal = s2
103 | mean_signal = np.mean(signal)
104 | signal_diff = signal - mean_signal
105 | var_signal = np.sum(np.mean(signal_diff ** 2)) # # variance of orignal data
106 |
107 | noisy_signal = s1
108 | noise = noisy_signal - signal
109 | mean_noise = np.mean(noise)
110 | noise_diff = noise - mean_noise
111 | var_noise = np.sum(np.mean(noise_diff ** 2)) # # variance of noise
112 |
113 | if var_noise == 0:
114 | snr_score = 100 # # clean
115 | else:
116 | snr_score = (np.log10(var_signal/var_noise + eps))*10
117 | return snr_score
118 |
119 |
120 | def cal_snr_array(estimated_speechs, clean_speechs):
121 | snr_score = []
122 | for i in range(len(estimated_speechs)):
123 | snr = cal_snr(estimated_speechs[i], clean_speechs[i])
124 | snr_score.append(snr)
125 | return snr_score
126 |
--------------------------------------------------------------------------------
/tools_for_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | import config as cfg
5 | from asteroid.losses import SingleSrcPMSQE, PITLossWrapper
6 | from asteroid_filterbanks import STFTFB, Encoder, transforms
7 |
8 | ############################################################################
9 | # for model structure & loss function #
10 | ############################################################################
11 | def remove_dc(data):
12 | mean = torch.mean(data, -1, keepdim=True)
13 | data = data - mean
14 | return data
15 |
16 |
17 | def l2_norm(s1, s2):
18 | norm = torch.sum(s1 * s2, -1, keepdim=True)
19 | return norm
20 |
21 |
22 | def sdr_linear(s1, s2, eps=1e-8):
23 | sn = l2_norm(s1, s1)
24 | sn_m_shn = l2_norm(s1 - s2, s1 - s2)
25 | sdr_loss = sn**2 / (sn_m_shn**2 + eps)
26 | return torch.mean(sdr_loss)
27 |
28 |
29 | def sdr(s1, s2, eps=1e-8):
30 | sn = l2_norm(s1, s1)
31 | sn_m_shn = l2_norm(s1 - s2, s1 - s2)
32 | sdr_loss = 10 * torch.log10(sn**2 / (sn_m_shn**2 + eps))
33 | return torch.mean(sdr_loss)
34 |
35 |
36 | def si_snr(s1, s2, eps=1e-8):
37 | s1_s2_norm = l2_norm(s1, s2)
38 | s2_s2_norm = l2_norm(s2, s2)
39 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
40 | e_nosie = s1 - s_target
41 | target_norm = l2_norm(s_target, s_target)
42 | noise_norm = l2_norm(e_nosie, e_nosie)
43 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
44 | return torch.mean(snr)
45 |
46 |
47 | def si_sdr(reference, estimation, eps=1e-8):
48 | """
49 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
50 | Args:
51 | reference: numpy.ndarray, [..., T]
52 | estimation: numpy.ndarray, [..., T]
53 | Returns:
54 | SI-SDR
55 | [1] SDR– Half- Baked or Well Done?
56 | http://www.merl.com/publications/docs/TR2019-013.pdf
57 | >>> np.random.seed(0)
58 | >>> reference = np.random.randn(100)
59 | >>> si_sdr(reference, reference)
60 | inf
61 | >>> si_sdr(reference, reference * 2)
62 | inf
63 | >>> si_sdr(reference, np.flip(reference))
64 | -25.127672346460717
65 | >>> si_sdr(reference, reference + np.flip(reference))
66 | 0.481070445785553
67 | >>> si_sdr(reference, reference + 0.5)
68 | 6.3704606032577304
69 | >>> si_sdr(reference, reference * 2 + 1)
70 | 6.3704606032577304
71 | >>> si_sdr([1., 0], [0., 0]) # never predict only zeros
72 | nan
73 | >>> si_sdr([reference, reference], [reference * 2 + 1, reference * 1 + 0.5])
74 | array([6.3704606, 6.3704606])
75 | :param reference:
76 | :param estimation:
77 | :param eps:
78 | """
79 |
80 | reference_energy = torch.sum(reference ** 2, axis=-1, keepdims=True)
81 |
82 | # This is $\alpha$ after Equation (3) in [1].
83 | optimal_scaling = torch.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy + eps
84 |
85 | # This is $e_{\text{target}}$ in Equation (4) in [1].
86 | projection = optimal_scaling * reference
87 |
88 | # This is $e_{\text{res}}$ in Equation (4) in [1].
89 | noise = estimation - projection
90 |
91 | ratio = torch.sum(projection ** 2, axis=-1) / torch.sum(noise ** 2, axis=-1) + eps
92 |
93 | ratio = torch.mean(ratio)
94 | return 10 * torch.log10(ratio + eps)
95 |
96 |
97 | ############################################################################
98 | # for LMS loss function #
99 | ############################################################################
100 | # MFCC (Mel Frequency Cepstral Coefficients)
101 |
102 | # based on a combination of this article:
103 | # http://practicalcryptography.com/miscellaneous/machine-learning/...
104 | # guide-mel-frequency-cepstral-coefficients-mfccs/
105 | # and some of this code:
106 | # http://stackoverflow.com/questions/5835568/...
107 | # how-to-get-mfcc-from-an-fft-on-a-signal
108 | # Set device
109 | DEVICE = torch.device(cfg.DEVICE)
110 |
111 | FFT_SIZE = cfg.fft_len
112 |
113 | # multi-scale MFCC distance
114 | if cfg.perceptual == 'LMS':
115 | MEL_SCALES = [16, 32, 64]
116 | elif cfg.perceptual == 'PAM':
117 | MEL_SCALES = [32, 64]
118 |
119 |
120 | class rmse(torch.nn.Module):
121 | def __init__(self):
122 | super(rmse, self).__init__()
123 |
124 | def forward(self, y_true, y_pred):
125 | mse = torch.mean((y_pred - y_true) ** 2, axis=-1)
126 | rmse = torch.sqrt(mse + 1e-7)
127 |
128 | return torch.mean(rmse)
129 |
130 |
131 | # conversions between Mel scale and regular frequency scale
132 | def freqToMel(freq):
133 | return 1127.01048 * math.log(1 + freq / 700.0)
134 |
135 |
136 | def melToFreq(mel):
137 | return 700 * (math.exp(mel / 1127.01048) - 1)
138 |
139 | # generate Mel filter bank
140 | def melFilterBank(numCoeffs, fftSize=None):
141 | minHz = 0
142 | maxHz = cfg.fs / 2 # max Hz by Nyquist theorem
143 | if (fftSize is None):
144 | numFFTBins = cfg.win_len
145 | else:
146 | numFFTBins = int(fftSize / 2) + 1
147 |
148 | maxMel = freqToMel(maxHz)
149 | minMel = freqToMel(minHz)
150 |
151 | # we need (numCoeffs + 2) points to create (numCoeffs) filterbanks
152 | melRange = np.array(range(numCoeffs + 2))
153 | melRange = melRange.astype(np.float32)
154 |
155 | # create (numCoeffs + 2) points evenly spaced between minMel and maxMel
156 | melCenterFilters = melRange * (maxMel - minMel) / (numCoeffs + 1) + minMel
157 |
158 | for i in range(numCoeffs + 2):
159 | # mel domain => frequency domain
160 | melCenterFilters[i] = melToFreq(melCenterFilters[i])
161 |
162 | # frequency domain => FFT bins
163 | melCenterFilters[i] = math.floor(numFFTBins * melCenterFilters[i] / maxHz)
164 |
165 | # create matrix of filters (one row is one filter)
166 | filterMat = np.zeros((numCoeffs, numFFTBins))
167 |
168 | # generate triangular filters (in frequency domain)
169 | for i in range(1, numCoeffs + 1):
170 | filter = np.zeros(numFFTBins)
171 |
172 | startRange = int(melCenterFilters[i - 1])
173 | midRange = int(melCenterFilters[i])
174 | endRange = int(melCenterFilters[i + 1])
175 |
176 | for j in range(startRange, midRange):
177 | filter[j] = (float(j) - startRange) / (midRange - startRange)
178 | for j in range(midRange, endRange):
179 | filter[j] = 1 - ((float(j) - midRange) / (endRange - midRange))
180 |
181 | filterMat[i - 1] = filter
182 |
183 | # return filterbank as matrix
184 | return filterMat
185 |
186 |
187 | # Finally: a perceptual loss function (based on Mel scale)
188 |
189 | # given a (symbolic Theano) array of size M x WINDOW_SIZE
190 | # this returns an array M x N where each window has been replaced
191 | # by some perceptual transform (in this case, MFCC coeffs)
192 | def perceptual_transform(x):
193 | # precompute Mel filterbank: [FFT_SIZE x NUM_MFCC_COEFFS]
194 | MEL_FILTERBANKS = []
195 | for scale in MEL_SCALES:
196 | filterbank_npy = melFilterBank(scale, FFT_SIZE).transpose()
197 | torch_filterbank_npy = torch.from_numpy(filterbank_npy).type(torch.FloatTensor)
198 | MEL_FILTERBANKS.append(torch_filterbank_npy.to(DEVICE))
199 |
200 | transforms = []
201 | # powerSpectrum = torch_dft_mag(x, DFT_REAL, DFT_IMAG)**2
202 |
203 | powerSpectrum = x.view(-1, FFT_SIZE // 2 + 1)
204 | powerSpectrum = 1.0 / FFT_SIZE * powerSpectrum
205 |
206 | for filterbank in MEL_FILTERBANKS:
207 | filteredSpectrum = torch.mm(powerSpectrum, filterbank)
208 | filteredSpectrum = torch.log(filteredSpectrum + 1e-7)
209 | transforms.append(filteredSpectrum)
210 |
211 | return transforms
212 |
213 |
214 | # perceptual loss function
215 | class perceptual_distance(torch.nn.Module):
216 |
217 | def __init__(self):
218 | super(perceptual_distance, self).__init__()
219 |
220 | def forward(self, y_true, y_pred):
221 | rmse_loss = rmse()
222 | # y_true = torch.reshape(y_true, (-1, WINDOW_SIZE))
223 | # y_pred = torch.reshape(y_pred, (-1, WINDOW_SIZE))
224 |
225 | pvec_true = perceptual_transform(y_true)
226 | pvec_pred = perceptual_transform(y_pred)
227 |
228 | distances = []
229 | for i in range(0, len(pvec_true)):
230 | error = rmse_loss(pvec_pred[i], pvec_true[i])
231 | error = error.unsqueeze(dim=-1)
232 | distances.append(error)
233 | distances = torch.cat(distances, axis=-1)
234 |
235 | loss = torch.mean(distances, axis=-1)
236 | return torch.mean(loss)
237 |
238 |
239 | get_mel_loss = perceptual_distance()
240 |
241 |
242 | def get_array_lms_loss(clean_array, est_array):
243 | array_mel_loss = 0
244 | for i in range(len(clean_array)):
245 | mel_loss = get_mel_loss(clean_array[i], est_array[i])
246 | array_mel_loss += mel_loss
247 |
248 | avg_mel_loss = array_mel_loss / len(clean_array)
249 | return avg_mel_loss
250 |
251 |
252 | ############################################################################
253 | # for pmsqe loss function #
254 | ############################################################################
255 | pmsqe_stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)).to(DEVICE)
256 | pmsqe_loss = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt').to(DEVICE)
257 |
258 |
259 | def get_array_pmsqe_loss(clean_array, est_array):
260 | if clean_array.dim() == 2:
261 | clean_wav = torch.unsqueeze(clean_array, 1)
262 | est_wav = torch.unsqueeze(est_array, 1)
263 | N, C, H = clean_wav.size()
264 | clean_wav = clean_wav.contiguous().view(N, -1, cfg.fs)
265 | est_wav = est_wav.contiguous().view(N, -1, cfg.fs)
266 |
267 | clean_spec = transforms.mag(pmsqe_stft(clean_wav))
268 | est_spec = transforms.mag(pmsqe_stft(est_wav))
269 | return pmsqe_loss(est_spec, clean_spec)
270 |
--------------------------------------------------------------------------------
/tools_for_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import time
5 | import torch.nn.functional as F
6 | import torch.nn.init as init
7 | from scipy.signal import get_window
8 | import matplotlib.pylab as plt
9 | import config as cfg
10 |
11 |
12 | ############################################################################
13 | # for convolutional STFT #
14 | ############################################################################
15 | # this is from conv_stft https://github.com/huyanxin/DeepComplexCRN
16 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
17 | if win_type == 'None' or win_type is None:
18 | window = np.ones(win_len)
19 | else:
20 | window = get_window(win_type, win_len, fftbins=True) # **0.5
21 |
22 | N = fft_len
23 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
24 | real_kernel = np.real(fourier_basis)
25 | imag_kernel = np.imag(fourier_basis)
26 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T
27 |
28 | if invers:
29 | kernel = np.linalg.pinv(kernel).T
30 |
31 | kernel = kernel * window
32 | kernel = kernel[:, None, :]
33 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32))
34 |
35 |
36 | class ConvSTFT(nn.Module):
37 |
38 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
39 | super(ConvSTFT, self).__init__()
40 |
41 | if fft_len == None:
42 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
43 | else:
44 | self.fft_len = fft_len
45 |
46 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
47 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
48 | self.register_buffer('weight', kernel)
49 | self.feature_type = feature_type
50 | self.stride = win_inc
51 | self.win_len = win_len
52 | self.dim = self.fft_len
53 |
54 | def forward(self, inputs):
55 | if inputs.dim() == 2:
56 | inputs = torch.unsqueeze(inputs, 1)
57 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride])
58 | outputs = F.conv1d(inputs, self.weight, stride=self.stride)
59 |
60 | if self.feature_type == 'complex':
61 | return outputs
62 | else:
63 | dim = self.dim // 2 + 1
64 | real = outputs[:, :dim, :]
65 | imag = outputs[:, dim:, :]
66 | mags = torch.sqrt(real ** 2 + imag ** 2)
67 | phase = torch.atan2(imag, real)
68 | return mags, phase
69 |
70 |
71 | class ConviSTFT(nn.Module):
72 |
73 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True):
74 | super(ConviSTFT, self).__init__()
75 | if fft_len == None:
76 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len)))
77 | else:
78 | self.fft_len = fft_len
79 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True)
80 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix))
81 | self.register_buffer('weight', kernel)
82 | self.feature_type = feature_type
83 | self.win_type = win_type
84 | self.win_len = win_len
85 | self.stride = win_inc
86 | self.dim = self.fft_len
87 | self.register_buffer('window', window)
88 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
89 |
90 | def forward(self, inputs, phase=None):
91 | """
92 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
93 | phase: [B, N//2+1, T] (if not none)
94 | """
95 |
96 | if phase is not None:
97 | real = inputs * torch.cos(phase)
98 | imag = inputs * torch.sin(phase)
99 | inputs = torch.cat([real, imag], 1)
100 |
101 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
102 |
103 | # this is from torch-stft: https://github.com/pseeth/torch-stft
104 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2
105 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
106 |
107 | outputs = outputs / (coff + 1e-8)
108 |
109 | # # outputs = torch.where(coff == 0, outputs, outputs/coff)
110 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)]
111 |
112 | return outputs
113 |
114 |
115 | ############################################################################
116 | # for complex rnn #
117 | ############################################################################
118 | def get_casual_padding1d():
119 | pass
120 |
121 |
122 | def get_casual_padding2d():
123 | pass
124 |
125 |
126 | class cPReLU(nn.Module):
127 |
128 | def __init__(self, complex_axis=1):
129 | super(cPReLU, self).__init__()
130 | self.r_prelu = nn.PReLU()
131 | self.i_prelu = nn.PReLU()
132 | self.complex_axis = complex_axis
133 |
134 | def forward(self, inputs):
135 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
136 | real = self.r_prelu(real)
137 | imag = self.i_prelu(imag)
138 | return torch.cat([real, imag], self.complex_axis)
139 |
140 |
141 | class NavieComplexLSTM(nn.Module):
142 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False):
143 | super(NavieComplexLSTM, self).__init__()
144 |
145 | self.input_dim = input_size // 2
146 | self.rnn_units = hidden_size // 2
147 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
148 | batch_first=False)
149 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional,
150 | batch_first=False)
151 | if bidirectional:
152 | bidirectional = 2
153 | else:
154 | bidirectional = 1
155 | if projection_dim is not None:
156 | self.projection_dim = projection_dim // 2
157 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
158 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim)
159 | else:
160 | self.projection_dim = None
161 |
162 | def forward(self, inputs):
163 | if isinstance(inputs, list):
164 | real, imag = inputs
165 | elif isinstance(inputs, torch.Tensor):
166 | real, imag = torch.chunk(inputs, -1)
167 | r2r_out = self.real_lstm(real)[0]
168 | r2i_out = self.imag_lstm(real)[0]
169 | i2r_out = self.real_lstm(imag)[0]
170 | i2i_out = self.imag_lstm(imag)[0]
171 | real_out = r2r_out - i2i_out
172 | imag_out = i2r_out + r2i_out
173 | if self.projection_dim is not None:
174 | real_out = self.r_trans(real_out)
175 | imag_out = self.i_trans(imag_out)
176 | # print(real_out.shape,imag_out.shape)
177 | return [real_out, imag_out]
178 |
179 | def flatten_parameters(self):
180 | self.imag_lstm.flatten_parameters()
181 | self.real_lstm.flatten_parameters()
182 |
183 |
184 | def complex_cat(inputs, axis):
185 | real, imag = [], []
186 | for idx, data in enumerate(inputs):
187 | r, i = torch.chunk(data, 2, axis)
188 | real.append(r)
189 | imag.append(i)
190 | real = torch.cat(real, axis)
191 | imag = torch.cat(imag, axis)
192 | outputs = torch.cat([real, imag], axis)
193 | return outputs
194 |
195 |
196 | ############################################################################
197 | # for convolutional layer #
198 | ############################################################################
199 | class ComplexConv2d(nn.Module):
200 |
201 | def __init__(
202 | self,
203 | in_channels,
204 | out_channels,
205 | kernel_size=(1, 1),
206 | stride=(1, 1),
207 | padding=(0, 0),
208 | dilation=1,
209 | groups=1,
210 | causal=True,
211 | complex_axis=1,
212 | ):
213 | '''
214 | in_channels: real+imag
215 | out_channels: real+imag
216 | kernel_size : input [B,C,D,T] kernel size in [D,T]
217 | padding : input [B,C,D,T] padding in [D,T]
218 | causal: if causal, will padding time dimension's left side,
219 | otherwise both
220 |
221 | '''
222 | super(ComplexConv2d, self).__init__()
223 | self.in_channels = in_channels // 2
224 | self.out_channels = out_channels // 2
225 | self.kernel_size = kernel_size
226 | self.stride = stride
227 | self.padding = padding
228 | self.causal = causal
229 | self.groups = groups
230 | self.dilation = dilation
231 | self.complex_axis = complex_axis
232 |
233 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
234 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
235 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
236 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
237 |
238 | nn.init.normal_(self.real_conv.weight.data, std=0.05)
239 | nn.init.normal_(self.imag_conv.weight.data, std=0.05)
240 | nn.init.constant_(self.real_conv.bias, 0.)
241 | nn.init.constant_(self.imag_conv.bias, 0.)
242 |
243 | def forward(self, inputs):
244 | if self.padding[1] != 0 and self.causal:
245 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) # # [width left, width right, height left, height right]
246 | else:
247 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
248 |
249 | if self.complex_axis == 0:
250 | real = self.real_conv(inputs)
251 | imag = self.imag_conv(inputs)
252 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
253 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
254 |
255 | else:
256 | if isinstance(inputs, torch.Tensor):
257 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
258 |
259 | real2real = self.real_conv(real, )
260 | imag2imag = self.imag_conv(imag, )
261 |
262 | real2imag = self.imag_conv(real)
263 | imag2real = self.real_conv(imag)
264 |
265 | real = real2real - imag2imag
266 | imag = real2imag + imag2real
267 | out = torch.cat([real, imag], self.complex_axis)
268 |
269 | return out
270 |
271 |
272 | class ComplexConvTranspose2d(nn.Module):
273 |
274 | def __init__(
275 | self,
276 | in_channels,
277 | out_channels,
278 | kernel_size=(1, 1),
279 | stride=(1, 1),
280 | padding=(0, 0),
281 | output_padding=(0, 0),
282 | causal=False,
283 | complex_axis=1,
284 | groups=1
285 | ):
286 | '''
287 | in_channels: real+imag
288 | out_channels: real+imag
289 | '''
290 | super(ComplexConvTranspose2d, self).__init__()
291 | self.in_channels = in_channels // 2
292 | self.out_channels = out_channels // 2
293 | self.kernel_size = kernel_size
294 | self.stride = stride
295 | self.padding = padding
296 | self.output_padding = output_padding
297 | self.groups = groups
298 |
299 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
300 | padding=self.padding, output_padding=output_padding, groups=self.groups)
301 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
302 | padding=self.padding, output_padding=output_padding, groups=self.groups)
303 |
304 | self.complex_axis = complex_axis
305 |
306 | nn.init.normal_(self.real_conv.weight.data, std=0.05)
307 | nn.init.normal_(self.imag_conv.weight.data, std=0.05)
308 | nn.init.constant_(self.real_conv.bias, 0.)
309 | nn.init.constant_(self.imag_conv.bias, 0.)
310 |
311 | def forward(self, inputs):
312 |
313 | if isinstance(inputs, torch.Tensor):
314 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
315 | elif isinstance(inputs, tuple) or isinstance(inputs, list):
316 | real = inputs[0]
317 | imag = inputs[1]
318 | if self.complex_axis == 0:
319 | real = self.real_conv(inputs)
320 | imag = self.imag_conv(inputs)
321 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis)
322 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis)
323 |
324 | else:
325 | if isinstance(inputs, torch.Tensor):
326 | real, imag = torch.chunk(inputs, 2, self.complex_axis)
327 |
328 | real2real = self.real_conv(real, )
329 | imag2imag = self.imag_conv(imag, )
330 |
331 | real2imag = self.imag_conv(real)
332 | imag2real = self.real_conv(imag)
333 |
334 | real = real2real - imag2imag
335 | imag = real2imag + imag2real
336 | out = torch.cat([real, imag], self.complex_axis)
337 |
338 | return out
339 |
340 |
341 | class RealConv2d(nn.Module):
342 |
343 | def __init__(
344 | self,
345 | in_channels,
346 | out_channels,
347 | kernel_size=(1, 1),
348 | stride=(1, 1),
349 | padding=(0, 0),
350 | dilation=1,
351 | groups=1,
352 | causal=True,
353 | complex_axis=1,
354 | ):
355 | '''
356 | in_channels: real+imag
357 | out_channels: real+imag
358 | kernel_size : input [B,C,D,T] kernel size in [D,T]
359 | padding : input [B,C,D,T] padding in [D,T]
360 | causal: if causal, will padding time dimension's left side,
361 | otherwise both
362 |
363 | '''
364 | super(RealConv2d, self).__init__()
365 | self.in_channels = in_channels
366 | self.out_channels = out_channels
367 | self.kernel_size = kernel_size
368 | self.stride = stride
369 | self.padding = padding
370 | self.causal = causal
371 | self.groups = groups
372 | self.dilation = dilation
373 |
374 | self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,
375 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups)
376 |
377 | nn.init.normal_(self.conv.weight.data, std=0.05)
378 | nn.init.constant_(self.conv.bias, 0.)
379 |
380 | def forward(self, inputs):
381 | if self.padding[1] != 0 and self.causal:
382 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) ## [width left, width right, height left, height right]
383 | else:
384 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0])
385 |
386 | out = self.conv(inputs)
387 |
388 | return out
389 |
390 |
391 | class RealConvTranspose2d(nn.Module):
392 |
393 | def __init__(
394 | self,
395 | in_channels,
396 | out_channels,
397 | kernel_size=(1, 1),
398 | stride=(1, 1),
399 | padding=(0, 0),
400 | output_padding=(0, 0),
401 | groups=1
402 | ):
403 | '''
404 | in_channels: real+imag
405 | out_channels: real+imag
406 | '''
407 | super(RealConvTranspose2d, self).__init__()
408 | self.in_channels = in_channels
409 | self.out_channels = out_channels
410 | self.kernel_size = kernel_size
411 | self.stride = stride
412 | self.padding = padding
413 | self.output_padding = output_padding
414 | self.groups = groups
415 |
416 | self.conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride,
417 | padding=self.padding, output_padding=output_padding, groups=self.groups)
418 |
419 | nn.init.normal_(self.conv.weight.data, std=0.05)
420 | nn.init.constant_(self.conv.bias, 0.)
421 |
422 | def forward(self, inputs):
423 | out = self.conv(inputs)
424 |
425 | return out
426 |
427 |
428 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch
429 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55
430 | class ComplexBatchNorm(torch.nn.Module):
431 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
432 | track_running_stats=True, complex_axis=1):
433 | super(ComplexBatchNorm, self).__init__()
434 | self.num_features = num_features // 2
435 | self.eps = eps
436 | self.momentum = momentum
437 | self.affine = affine
438 | self.track_running_stats = track_running_stats
439 |
440 | self.complex_axis = complex_axis
441 |
442 | if self.affine:
443 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features))
444 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features))
445 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features))
446 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features))
447 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features))
448 | else:
449 | self.register_parameter('Wrr', None)
450 | self.register_parameter('Wri', None)
451 | self.register_parameter('Wii', None)
452 | self.register_parameter('Br', None)
453 | self.register_parameter('Bi', None)
454 |
455 | if self.track_running_stats:
456 | self.register_buffer('RMr', torch.zeros(self.num_features))
457 | self.register_buffer('RMi', torch.zeros(self.num_features))
458 | self.register_buffer('RVrr', torch.ones(self.num_features))
459 | self.register_buffer('RVri', torch.zeros(self.num_features))
460 | self.register_buffer('RVii', torch.ones(self.num_features))
461 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
462 | else:
463 | self.register_parameter('RMr', None)
464 | self.register_parameter('RMi', None)
465 | self.register_parameter('RVrr', None)
466 | self.register_parameter('RVri', None)
467 | self.register_parameter('RVii', None)
468 | self.register_parameter('num_batches_tracked', None)
469 | self.reset_parameters()
470 |
471 | def reset_running_stats(self):
472 | if self.track_running_stats:
473 | self.RMr.zero_()
474 | self.RMi.zero_()
475 | self.RVrr.fill_(1)
476 | self.RVri.zero_()
477 | self.RVii.fill_(1)
478 | self.num_batches_tracked.zero_()
479 |
480 | def reset_parameters(self):
481 | self.reset_running_stats()
482 | if self.affine:
483 | self.Br.data.zero_()
484 | self.Bi.data.zero_()
485 | self.Wrr.data.fill_(1)
486 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
487 | self.Wii.data.fill_(1)
488 |
489 | def _check_input_dim(self, xr, xi):
490 | assert (xr.shape == xi.shape)
491 | assert (xr.size(1) == self.num_features)
492 |
493 | def forward(self, inputs):
494 | # self._check_input_dim(xr, xi)
495 |
496 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis)
497 | exponential_average_factor = 0.0
498 |
499 | if self.training and self.track_running_stats:
500 | self.num_batches_tracked += 1
501 | if self.momentum is None: # use cumulative moving average
502 | exponential_average_factor = 1.0 / self.num_batches_tracked.item()
503 | else: # use exponential moving average
504 | exponential_average_factor = self.momentum
505 |
506 | #
507 | # NOTE: The precise meaning of the "training flag" is:
508 | # True: Normalize using batch statistics, update running statistics
509 | # if they are being collected.
510 | # False: Normalize using running statistics, ignore batch statistics.
511 | #
512 | training = self.training or not self.track_running_stats
513 | redux = [i for i in reversed(range(xr.dim())) if i != 1]
514 | vdim = [1] * xr.dim()
515 | vdim[1] = xr.size(1)
516 |
517 | #
518 | # Mean M Computation and Centering
519 | #
520 | # Includes running mean update if training and running.
521 | #
522 | if training:
523 | Mr, Mi = xr, xi
524 | for d in redux:
525 | Mr = Mr.mean(d, keepdim=True)
526 | Mi = Mi.mean(d, keepdim=True)
527 | if self.track_running_stats:
528 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
529 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
530 | else:
531 | Mr = self.RMr.view(vdim)
532 | Mi = self.RMi.view(vdim)
533 | xr, xi = xr - Mr, xi - Mi
534 |
535 | #
536 | # Variance Matrix V Computation
537 | #
538 | # Includes epsilon numerical stabilizer/Tikhonov regularizer.
539 | # Includes running variance update if training and running.
540 | #
541 | if training:
542 | Vrr = xr * xr
543 | Vri = xr * xi
544 | Vii = xi * xi
545 | for d in redux:
546 | Vrr = Vrr.mean(d, keepdim=True)
547 | Vri = Vri.mean(d, keepdim=True)
548 | Vii = Vii.mean(d, keepdim=True)
549 | if self.track_running_stats:
550 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
551 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
552 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
553 | else:
554 | Vrr = self.RVrr.view(vdim)
555 | Vri = self.RVri.view(vdim)
556 | Vii = self.RVii.view(vdim)
557 | Vrr = Vrr + self.eps
558 | Vri = Vri
559 | Vii = Vii + self.eps
560 |
561 | #
562 | # Matrix Inverse Square Root U = V^-0.5
563 | #
564 | # sqrt of a 2x2 matrix,
565 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
566 | tau = Vrr + Vii
567 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri)
568 | s = delta.sqrt()
569 | t = (tau + 2 * s).sqrt()
570 |
571 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
572 | rst = (s * t).reciprocal()
573 | Urr = (s + Vii) * rst
574 | Uii = (s + Vrr) * rst
575 | Uri = (- Vri) * rst
576 |
577 | #
578 | # Optionally left-multiply U by affine weights W to produce combined
579 | # weights Z, left-multiply the inputs by Z, then optionally bias them.
580 | #
581 | # y = Zx + B
582 | # y = WUx + B
583 | # y = [Wrr Wri][Urr Uri] [xr] + [Br]
584 | # [Wir Wii][Uir Uii] [xi] [Bi]
585 | #
586 | if self.affine:
587 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
588 | Zrr = (Wrr * Urr) + (Wri * Uri)
589 | Zri = (Wrr * Uri) + (Wri * Uii)
590 | Zir = (Wri * Urr) + (Wii * Uri)
591 | Zii = (Wri * Uri) + (Wii * Uii)
592 | else:
593 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
594 |
595 | yr = (Zrr * xr) + (Zri * xi)
596 | yi = (Zir * xr) + (Zii * xi)
597 |
598 | if self.affine:
599 | yr = yr + self.Br.view(vdim)
600 | yi = yi + self.Bi.view(vdim)
601 |
602 | outputs = torch.cat([yr, yi], self.complex_axis)
603 | return outputs
604 |
605 | def extra_repr(self):
606 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
607 | 'track_running_stats={track_running_stats}'.format(**self.__dict__)
608 |
609 |
610 | def complex_cat(inputs, axis):
611 | real, imag = [], []
612 | for idx, data in enumerate(inputs):
613 | r, i = torch.chunk(data, 2, axis)
614 | real.append(r)
615 | imag.append(i)
616 | real = torch.cat(real, axis)
617 | imag = torch.cat(imag, axis)
618 | outputs = torch.cat([real, imag], axis)
619 | return outputs
620 |
621 | ############################################################################
622 | # for FullSubNet #
623 | ############################################################################
624 | # Source: https://github.com/haoxiangsnr/FullSubNet
625 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/model/module/sequence_model.py
626 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/model/base_model.py
627 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/acoustics/feature.py
628 | def stft(y, n_fft=cfg.fft_len, hop_length=int(cfg.win_len*cfg.ola_ratio), win_length=cfg.win_len):
629 | """
630 | Args:
631 | y: [B, F, T]
632 | n_fft: num of FFT
633 | hop_length: hop length
634 | win_length: window length
635 |
636 | Returns:
637 | [B, F, T], **complex-valued** STFT coefficients
638 |
639 | """
640 | assert y.dim() == 2
641 | return torch.stft(
642 | y,
643 | n_fft,
644 | hop_length,
645 | win_length,
646 | window=torch.hann_window(win_length).to(y.device),
647 | return_complex=True
648 | )
649 |
650 |
651 | def istft(features, n_fft=cfg.fft_len, hop_length=int(cfg.win_len*cfg.ola_ratio), win_length=cfg.win_len, length=None, use_mag_phase=False):
652 | """
653 | Wrapper for the official torch.istft
654 |
655 | Args:
656 | features: [B, F, T, 2] (complex) or ([B, F, T], [B, F, T]) (mag and phase)
657 | n_fft:
658 | hop_length:
659 | win_length:
660 | device:
661 | length:
662 | use_mag_phase: use mag and phase as inputs of iSTFT
663 |
664 | Returns:
665 | [B, T]
666 | """
667 | if use_mag_phase:
668 | # (mag, phase) or [mag, phase]
669 | assert isinstance(features, tuple) or isinstance(features, list)
670 | mag, phase = features
671 | features = torch.stack([mag * torch.cos(phase), mag * torch.sin(phase)], dim=-1)
672 |
673 | return torch.istft(
674 | features,
675 | n_fft,
676 | hop_length,
677 | win_length,
678 | window=torch.hann_window(win_length).to(features.device),
679 | length=length
680 | )
681 |
682 |
683 | def mag_phase(complex_tensor):
684 | return torch.abs(complex_tensor), torch.angle(complex_tensor)
685 |
686 |
687 | def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor:
688 | """
689 |
690 | Args:
691 | noisy: [B, F, T], noisy complex-valued stft coefficients
692 | clean: [B, F, T], clean complex-valued stft coefficients
693 |
694 | Returns:
695 | [B, F, T, 2]
696 | """
697 | denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON
698 |
699 | mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator
700 | mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator
701 |
702 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1)
703 |
704 | return compress_cIRM(complex_ratio_mask, K=10, C=0.1)
705 |
706 |
707 | def compress_cIRM(mask, K=10, C=0.1):
708 | """
709 | Compress from (-inf, +inf) to [-K ~ K]
710 | """
711 | if torch.is_tensor(mask):
712 | mask = -100 * (mask <= -100) + mask * (mask > -100)
713 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask))
714 | else:
715 | mask = -100 * (mask <= -100) + mask * (mask > -100)
716 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask))
717 | return mask
718 |
719 |
720 | def decompress_cIRM(mask, K=10, limit=9.9):
721 | mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit)
722 | mask = -K * torch.log((K - mask) / (K + mask))
723 | return mask
724 |
725 |
726 | class SequenceModel(nn.Module):
727 | def __init__(
728 | self,
729 | input_size,
730 | output_size,
731 | hidden_size,
732 | num_layers,
733 | bidirectional,
734 | sequence_model="GRU",
735 | output_activate_function="Tanh"
736 | ):
737 | super().__init__()
738 | # Sequence layer
739 | if sequence_model == "LSTM":
740 | self.sequence_model = nn.LSTM(
741 | input_size=input_size,
742 | hidden_size=hidden_size,
743 | num_layers=num_layers,
744 | batch_first=True,
745 | bidirectional=bidirectional,
746 | dropout=0.8,
747 | )
748 | elif sequence_model == "GRU":
749 | self.sequence_model = nn.GRU(
750 | input_size=input_size,
751 | hidden_size=hidden_size,
752 | num_layers=num_layers,
753 | batch_first=True,
754 | bidirectional=bidirectional,
755 | dropout=0.8,
756 | )
757 | else:
758 | raise NotImplementedError(f"Not implemented {sequence_model}")
759 |
760 | # Fully connected layer
761 | if bidirectional:
762 | self.fc_output_layer = nn.Linear(hidden_size * 2, output_size)
763 | else:
764 | self.fc_output_layer = nn.Linear(hidden_size, output_size)
765 |
766 | # Activation function layer
767 | if output_activate_function:
768 | if output_activate_function == "Tanh":
769 | self.activate_function = nn.Tanh()
770 | elif output_activate_function == "ReLU":
771 | self.activate_function = nn.ReLU()
772 | elif output_activate_function == "ReLU6":
773 | self.activate_function = nn.ReLU6()
774 | else:
775 | raise NotImplementedError(f"Not implemented activation function {self.activate_function}")
776 |
777 | self.output_activate_function = output_activate_function
778 |
779 | def forward(self, x):
780 | """
781 | Args:
782 | x: [B, F, T]
783 | Returns:
784 | [B, F, T]
785 | """
786 | assert x.dim() == 3
787 | self.sequence_model.flatten_parameters()
788 |
789 | x = x.permute(0, 2, 1).contiguous() # [B, F, T] => [B, T, F]
790 | o, _ = self.sequence_model(x)
791 | o = self.fc_output_layer(o)
792 | if self.output_activate_function:
793 | o = self.activate_function(o)
794 | o = o.permute(0, 2, 1).contiguous() # [B, T, F] => [B, F, T]
795 | return o
796 |
797 |
798 | EPSILON = np.finfo(np.float32).eps
799 |
800 |
801 | class BaseModel(nn.Module):
802 | def __init__(self):
803 | super(BaseModel, self).__init__()
804 |
805 | @staticmethod
806 | def unfold(input, num_neighbor):
807 | """
808 | Along with the frequency dim, split overlapped sub band units from spectrogram.
809 |
810 | Args:
811 | input: [B, C, F, T]
812 | num_neighbor:
813 |
814 | Returns:
815 | [B, N, C, F_s, T], F, e.g. [2, 161, 1, 19, 200]
816 | """
817 | assert input.dim() == 4, f"The dim of input is {input.dim()}. It should be four dim."
818 | batch_size, num_channels, num_freqs, num_frames = input.size()
819 |
820 | if num_neighbor < 1:
821 | # No change for the input
822 | return input.permute(0, 2, 1, 3).reshape(batch_size, num_freqs, num_channels, 1, num_frames)
823 |
824 | output = input.reshape(batch_size * num_channels, 1, num_freqs, num_frames)
825 | sub_band_unit_size = num_neighbor * 2 + 1
826 |
827 | # Pad to the top and bottom
828 | output = F.pad(output, [0, 0, num_neighbor, num_neighbor], mode="reflect")
829 |
830 | output = F.unfold(output, (sub_band_unit_size, num_frames))
831 | assert output.shape[-1] == num_freqs, f"n_freqs != N (sub_band), {num_freqs} != {output.shape[-1]}"
832 |
833 | # Split the dim of the unfolded feature
834 | output = output.reshape(batch_size, num_channels, sub_band_unit_size, num_frames, num_freqs)
835 | output = output.permute(0, 4, 1, 2, 3).contiguous()
836 |
837 | return output
838 |
839 | @staticmethod
840 | def _reduce_complexity_separately(sub_band_input, full_band_output, device):
841 | """
842 |
843 | Args:
844 | sub_band_input: [60, 257, 1, 33, 200]
845 | full_band_output: [60, 257, 1, 3, 200]
846 | device:
847 |
848 | Notes:
849 | 1. 255 and 256 freq not able to be trained
850 | 2. batch size
851 |
852 | Returns:
853 | [60, 85, 1, 36, 200]
854 | """
855 | batch_size = full_band_output.shape[0]
856 | n_freqs = full_band_output.shape[1]
857 | sub_batch_size = batch_size // 3
858 | final_selected = []
859 |
860 | for idx in range(3):
861 | # [0, 60) => [0, 20)
862 | sub_batch_indices = torch.arange(idx * sub_batch_size, (idx + 1) * sub_batch_size, device=device)
863 | full_band_output_sub_batch = torch.index_select(full_band_output, dim=0, index=sub_batch_indices)
864 | sub_band_output_sub_batch = torch.index_select(sub_band_input, dim=0, index=sub_batch_indices)
865 |
866 | # Avoid to use padded value (first freq and last freq)
867 | # i = 0, (1, 256, 3) = [1, 4, ..., 253]
868 | # i = 1, (2, 256, 3) = [2, 5, ..., 254]
869 | # i = 2, (3, 256, 3) = [3, 6, ..., 255]
870 | freq_indices = torch.arange(idx + 1, n_freqs - 1, step=3, device=device)
871 | full_band_output_sub_batch = torch.index_select(full_band_output_sub_batch, dim=1, index=freq_indices)
872 | sub_band_output_sub_batch = torch.index_select(sub_band_output_sub_batch, dim=1, index=freq_indices)
873 |
874 | # ([30, 85, 1, 33 200], [30, 85, 1, 3, 200]) => [30, 85, 1, 36, 200]
875 |
876 | final_selected.append(torch.cat([sub_band_output_sub_batch, full_band_output_sub_batch], dim=-2))
877 |
878 | return torch.cat(final_selected, dim=0)
879 |
880 | @staticmethod
881 | def sband_forgetting_norm(input, train_sample_length):
882 | """
883 | Args:
884 | input:
885 | train_sample_length:
886 |
887 | Returns:
888 |
889 | """
890 | assert input.ndim == 3
891 | batch_size, n_freqs, n_frames = input.size()
892 |
893 | eps = 1e-10
894 | alpha = (train_sample_length - 1) / (train_sample_length + 1)
895 | mu = 0
896 | mu_list = []
897 |
898 | for idx in range(input.shape[-1]):
899 | if idx < train_sample_length:
900 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha]))
901 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1]
902 | else:
903 | mu = alpha * mu + (1 - alpha) * input[:, (n_freqs // 2 - 1), idx].reshape(batch_size, 1)
904 |
905 | mu_list.append(mu)
906 |
907 | # print("input", input[:, :, idx].min(), input[:, :, idx].max(), input[:, :, idx].mean())
908 | # print(f"alp {idx}: ", alp)
909 | # print(f"mu {idx}: {mu[128, 0]}")
910 |
911 | mu = torch.stack(mu_list, dim=-1) # [B, 1, T]
912 | input = input / (mu + eps)
913 | return input
914 |
915 | @staticmethod
916 | def forgetting_norm(input, sample_length_in_training):
917 | """
918 | Args:
919 | input: [B, F, T]
920 | sample_length_in_training:
921 |
922 | Returns:
923 |
924 | """
925 | assert input.ndim == 3
926 | batch_size, n_freqs, n_frames = input.size()
927 | eps = 1e-10
928 | mu = 0
929 | alpha = (sample_length_in_training - 1) / (sample_length_in_training + 1)
930 |
931 | mu_list = []
932 | for idx in range(input.shape[-1]):
933 | if idx < sample_length_in_training:
934 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha]))
935 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1]
936 | else:
937 | current_frame_mu = torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1]
938 | mu = alpha * mu + (1 - alpha) * current_frame_mu
939 |
940 | mu_list.append(mu)
941 |
942 | # print("input", input[:, :, idx].min(), input[:, :, idx].max(), input[:, :, idx].mean())
943 | # print(f"alp {idx}: ", alp)
944 | # print(f"mu {idx}: {mu[128, 0]}")
945 |
946 | mu = torch.stack(mu_list, dim=-1) # [B, 1, T]
947 | input = input / (mu + eps)
948 | return input
949 |
950 | @staticmethod
951 | def hybrid_norm(input, sample_length_in_training=192):
952 | """
953 | Args:
954 | input: [B, F, T]
955 | sample_length_in_training:
956 |
957 | Returns:
958 | [B, F, T]
959 | """
960 | assert input.ndim == 3
961 | device = input.device
962 | data_type = input.dtype
963 | batch_size, n_freqs, n_frames = input.size()
964 | eps = 1e-10
965 |
966 | mu = 0
967 | alpha = (sample_length_in_training - 1) / (sample_length_in_training + 1)
968 | mu_list = []
969 | for idx in range(input.shape[-1]):
970 | if idx < sample_length_in_training:
971 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha]))
972 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1]
973 | mu_list.append(mu)
974 | else:
975 | break
976 | initial_mu = torch.stack(mu_list, dim=-1) # [B, 1, T]
977 |
978 | step_sum = torch.sum(input, dim=1) # [B, T]
979 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
980 |
981 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device)
982 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
983 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
984 |
985 | cum_mean = cumulative_sum / entry_count # B, T
986 |
987 | cum_mean = cum_mean.reshape(batch_size, 1, n_frames) # [B, 1, T]
988 |
989 | # print(initial_mu[0, 0, :50])
990 | # print("-"*60)
991 | # print(cum_mean[0, 0, :50])
992 | cum_mean[:, :, :sample_length_in_training] = initial_mu
993 |
994 | return input / (cum_mean + eps)
995 |
996 | @staticmethod
997 | def offline_laplace_norm(input):
998 | """
999 |
1000 | Args:
1001 | input: [B, C, F, T]
1002 |
1003 | Returns:
1004 | [B, C, F, T]
1005 | """
1006 | # utterance-level mu
1007 | mu = torch.mean(input, dim=(1, 2, 3), keepdim=True)
1008 |
1009 | normed = input / (mu + 1e-5)
1010 |
1011 | return normed
1012 |
1013 | @staticmethod
1014 | def cumulative_laplace_norm(input):
1015 | """
1016 |
1017 | Args:
1018 | input: [B, C, F, T]
1019 |
1020 | Returns:
1021 |
1022 | """
1023 | batch_size, num_channels, num_freqs, num_frames = input.size()
1024 | input = input.reshape(batch_size * num_channels, num_freqs, num_frames)
1025 |
1026 | step_sum = torch.sum(input, dim=1) # [B * C, F, T] => [B, T]
1027 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
1028 |
1029 | entry_count = torch.arange(
1030 | num_freqs,
1031 | num_freqs * num_frames + 1,
1032 | num_freqs,
1033 | dtype=input.dtype,
1034 | device=input.device
1035 | )
1036 | entry_count = entry_count.reshape(1, num_frames) # [1, T]
1037 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
1038 |
1039 | cumulative_mean = cumulative_sum / entry_count # B, T
1040 | cumulative_mean = cumulative_mean.reshape(batch_size * num_channels, 1, num_frames)
1041 |
1042 | normed = input / (cumulative_mean + EPSILON)
1043 |
1044 | return normed.reshape(batch_size, num_channels, num_freqs, num_frames)
1045 |
1046 | @staticmethod
1047 | def offline_gaussian_norm(input):
1048 | """
1049 | Zero-Norm
1050 | Args:
1051 | input: [B, C, F, T]
1052 |
1053 | Returns:
1054 | [B, C, F, T]
1055 | """
1056 | mu = torch.mean(input, dim=(1, 2, 3), keepdim=True)
1057 | std = torch.std(input, dim=(1, 2, 3), keepdim=True)
1058 |
1059 | normed = (input - mu) / (std + 1e-5)
1060 |
1061 | return normed
1062 |
1063 | @staticmethod
1064 | def cumulative_layer_norm(input):
1065 | """
1066 | Online zero-norm
1067 |
1068 | Args:
1069 | input: [B, C, F, T]
1070 |
1071 | Returns:
1072 | [B, C, F, T]
1073 | """
1074 | batch_size, num_channels, num_freqs, num_frames = input.size()
1075 | input = input.reshape(batch_size * num_channels, num_freqs, num_frames)
1076 |
1077 | step_sum = torch.sum(input, dim=1) # [B * C, F, T] => [B, T]
1078 | step_pow_sum = torch.sum(torch.square(input), dim=1)
1079 |
1080 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
1081 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T]
1082 |
1083 | entry_count = torch.arange(
1084 | num_freqs,
1085 | num_freqs * num_frames + 1,
1086 | num_freqs,
1087 | dtype=input.dtype,
1088 | device=input.device
1089 | )
1090 | entry_count = entry_count.reshape(1, num_frames) # [1, T]
1091 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
1092 |
1093 | cumulative_mean = cumulative_sum / entry_count # [B, T]
1094 | cumulative_var = (
1095 | cumulative_pow_sum - 2 * cumulative_mean * cumulative_sum) / entry_count + cumulative_mean.pow(
1096 | 2) # [B, T]
1097 | cumulative_std = torch.sqrt(cumulative_var + EPSILON) # [B, T]
1098 |
1099 | cumulative_mean = cumulative_mean.reshape(batch_size * num_channels, 1, num_frames)
1100 | cumulative_std = cumulative_std.reshape(batch_size * num_channels, 1, num_frames)
1101 |
1102 | normed = (input - cumulative_mean) / cumulative_std
1103 |
1104 | return normed.reshape(batch_size, num_channels, num_freqs, num_frames)
1105 |
1106 | def norm_wrapper(self, norm_type: str):
1107 | if norm_type == "offline_laplace_norm":
1108 | norm = self.offline_laplace_norm
1109 | elif norm_type == "cumulative_laplace_norm":
1110 | norm = self.cumulative_laplace_norm
1111 | elif norm_type == "offline_gaussian_norm":
1112 | norm = self.offline_gaussian_norm
1113 | elif norm_type == "cumulative_layer_norm":
1114 | norm = self.cumulative_layer_norm
1115 | else:
1116 | raise NotImplementedError("You must set up a type of Norm. "
1117 | "e.g. offline_laplace_norm, cumulative_laplace_norm, forgetting_norm, etc.")
1118 | return norm
1119 |
1120 | def weight_init(self, m):
1121 | """
1122 | Usage:
1123 | model = Model()
1124 | model.apply(weight_init)
1125 | """
1126 | if isinstance(m, nn.Conv1d):
1127 | init.normal_(m.weight.data)
1128 | if m.bias is not None:
1129 | init.normal_(m.bias.data)
1130 | elif isinstance(m, nn.Conv2d):
1131 | init.xavier_normal_(m.weight.data)
1132 | if m.bias is not None:
1133 | init.normal_(m.bias.data)
1134 | elif isinstance(m, nn.Conv3d):
1135 | init.xavier_normal_(m.weight.data)
1136 | if m.bias is not None:
1137 | init.normal_(m.bias.data)
1138 | elif isinstance(m, nn.ConvTranspose1d):
1139 | init.normal_(m.weight.data)
1140 | if m.bias is not None:
1141 | init.normal_(m.bias.data)
1142 | elif isinstance(m, nn.ConvTranspose2d):
1143 | init.xavier_normal_(m.weight.data)
1144 | if m.bias is not None:
1145 | init.normal_(m.bias.data)
1146 | elif isinstance(m, nn.ConvTranspose3d):
1147 | init.xavier_normal_(m.weight.data)
1148 | if m.bias is not None:
1149 | init.normal_(m.bias.data)
1150 | elif isinstance(m, nn.BatchNorm1d):
1151 | init.normal_(m.weight.data, mean=1, std=0.02)
1152 | init.constant_(m.bias.data, 0)
1153 | elif isinstance(m, nn.BatchNorm2d):
1154 | init.normal_(m.weight.data, mean=1, std=0.02)
1155 | init.constant_(m.bias.data, 0)
1156 | elif isinstance(m, nn.BatchNorm3d):
1157 | init.normal_(m.weight.data, mean=1, std=0.02)
1158 | init.constant_(m.bias.data, 0)
1159 | elif isinstance(m, nn.Linear):
1160 | init.xavier_normal_(m.weight.data)
1161 | init.normal_(m.bias.data)
1162 | elif isinstance(m, nn.LSTM):
1163 | for param in m.parameters():
1164 | if len(param.shape) >= 2:
1165 | init.orthogonal_(param.data)
1166 | else:
1167 | init.normal_(param.data)
1168 | elif isinstance(m, nn.LSTMCell):
1169 | for param in m.parameters():
1170 | if len(param.shape) >= 2:
1171 | init.orthogonal_(param.data)
1172 | else:
1173 | init.normal_(param.data)
1174 | elif isinstance(m, nn.GRU):
1175 | for param in m.parameters():
1176 | if len(param.shape) >= 2:
1177 | init.orthogonal_(param.data)
1178 | else:
1179 | init.normal_(param.data)
1180 | elif isinstance(m, nn.GRUCell):
1181 | for param in m.parameters():
1182 | if len(param.shape) >= 2:
1183 | init.orthogonal_(param.data)
1184 | else:
1185 | init.normal_(param.data)
1186 |
1187 |
1188 | ############################################################################
1189 | # for data normalization #
1190 | ############################################################################
1191 | # get mu and sig
1192 | def get_mu_sig(data):
1193 | """Compute mean and standard deviation vector of input data
1194 |
1195 | Returns:
1196 | mu: mean vector (#dim by one)
1197 | sig: standard deviation vector (#dim by one)
1198 | """
1199 | # Initialize array.
1200 | data_num = len(data)
1201 | mu_utt = []
1202 | tmp_utt = []
1203 | for n in range(data_num):
1204 | dim = len(data[n])
1205 | mu_utt_tmp = np.zeros(dim)
1206 | mu_utt.append(mu_utt_tmp)
1207 |
1208 | tmp_utt_tmp = np.zeros(dim)
1209 | tmp_utt.append(tmp_utt_tmp)
1210 |
1211 | # Get mean.
1212 | for n in range(data_num):
1213 | mu_utt[n] = np.mean(data[n], 0)
1214 | mu = mu_utt
1215 |
1216 | # Get standard deviation.
1217 | for n in range(data_num):
1218 | tmp_utt[n] = np.mean(np.square(data[n] - mu[n]), 0)
1219 | sig = np.sqrt(tmp_utt)
1220 |
1221 | # Assign unit variance.
1222 | for n in range(len(sig)):
1223 | if sig[n] < 1e-5:
1224 | sig[n] = 1.0
1225 | return np.float16(mu), np.float16(sig)
1226 |
1227 |
1228 | def get_statistics_inp(inp):
1229 | """Get statistical parameter of input data.
1230 |
1231 | Args:
1232 | inp: input data
1233 |
1234 | Returns:
1235 | mu_inp: mean vector of input data
1236 | sig_inp: standard deviation vector of input data
1237 | """
1238 |
1239 | mu_inp, sig_inp = get_mu_sig(inp)
1240 |
1241 | return mu_inp, sig_inp
1242 |
1243 |
1244 | ############################################################################
1245 | # for plotting the samples #
1246 | ############################################################################
1247 | def hann_window(win_samp):
1248 | tmp = np.arange(1, win_samp + 1, 1.0, dtype=np.float64)
1249 | window = 0.5 - 0.5 * np.cos((2.0 * np.pi * tmp) / (win_samp + 1))
1250 | return np.float32(window)
1251 |
1252 |
1253 | def fig2np(fig):
1254 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
1255 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
1256 | return data
1257 |
1258 |
1259 | def plot_spectrogram_to_numpy(input_wav, fs, n_fft, n_overlap, mode, clim, label):
1260 | # cuda to cpu
1261 | input_wav = input_wav.cpu().detach().numpy()
1262 |
1263 | fig, ax = plt.subplots(figsize=(12, 3))
1264 |
1265 | if mode == 'phase':
1266 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), noverlap=n_overlap,
1267 | cmap='jet',
1268 | mode=mode)
1269 | else:
1270 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), noverlap=n_overlap,
1271 | cmap='jet')
1272 |
1273 | plt.xlabel('Time (s)')
1274 | plt.ylabel('Frequency (Hz)')
1275 | plt.tight_layout()
1276 | plt.clim(clim)
1277 |
1278 | if label is None:
1279 | fig.colorbar(cax)
1280 | else:
1281 | fig.colorbar(cax, label=label)
1282 |
1283 | fig.canvas.draw()
1284 | data = fig2np(fig)
1285 | plt.close()
1286 | return data
1287 |
1288 |
1289 | def plot_mask_to_numpy(mask, fs, n_fft, n_overlap, clim1, clim2, cmap):
1290 | frame_num = mask.shape[0]
1291 | shift_length = n_overlap
1292 | frame_length = n_fft
1293 | signal_length = frame_num * shift_length + frame_length
1294 |
1295 | xt = np.arange(0, np.floor(10 * signal_length / fs) / 10, step=0.5) / (signal_length / fs) * frame_num + 1e-8
1296 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)
1297 |
1298 | fig, ax = plt.subplots(figsize=(12, 3))
1299 | im = ax.imshow(np.transpose(mask), aspect='auto', origin='lower', interpolation='none', cmap=cmap)
1300 |
1301 | plt.xlabel('Time (s)')
1302 | plt.ylabel('Frequency (kHz)')
1303 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
1304 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
1305 | plt.tight_layout()
1306 | plt.colorbar(im, ax=ax)
1307 | im.set_clim(clim1, clim2)
1308 |
1309 | fig.canvas.draw()
1310 | data = fig2np(fig)
1311 | plt.close()
1312 | return data
1313 |
1314 |
1315 | def plot_error_to_numpy(estimated, target, fs, n_fft, n_overlap, mode, clim1, clim2, label):
1316 | fig, ax = plt.subplots(figsize=(12, 3))
1317 | if mode is None:
1318 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet')
1319 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet')
1320 | im = ax.imshow(10 * np.log10(pxx1) - 10 * np.log10(pxx2), aspect='auto', origin='lower', interpolation='none',
1321 | cmap='jet')
1322 | else:
1323 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet',
1324 | mode=mode)
1325 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet',
1326 | mode=mode)
1327 | im = ax.imshow(pxx1 - pxx2, aspect='auto', origin='lower', interpolation='none', cmap='jet')
1328 |
1329 | frame_num = pxx1.shape[1]
1330 | shift_length = n_overlap
1331 | frame_length = n_fft
1332 | signal_length = frame_num * shift_length + frame_length
1333 |
1334 | xt = np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5) / (signal_length / fs) * frame_num
1335 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1)
1336 |
1337 | plt.xlabel('Time (s)')
1338 | plt.ylabel('Frequency (kHz)')
1339 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5))
1340 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt))))
1341 | plt.tight_layout()
1342 | plt.colorbar(im, ax=ax, label=label)
1343 | im.set_clim(clim1, clim2)
1344 |
1345 | fig.canvas.draw()
1346 | data = fig2np(fig)
1347 | plt.close()
1348 | return data
1349 |
1350 |
1351 | ############################################################################
1352 | # for trainer.py #
1353 | ############################################################################
1354 | class Bar(object):
1355 | def __init__(self, dataloader):
1356 | if not hasattr(dataloader, 'dataset'):
1357 | raise ValueError('Attribute `dataset` not exists in dataloder.')
1358 | if not hasattr(dataloader, 'batch_size'):
1359 | raise ValueError('Attribute `batch_size` not exists in dataloder.')
1360 |
1361 | self.dataloader = dataloader
1362 | self.iterator = iter(dataloader)
1363 | self.dataset = dataloader.dataset
1364 | self.batch_size = dataloader.batch_size
1365 | self._idx = 0
1366 | self._batch_idx = 0
1367 | self._time = []
1368 | self._DISPLAY_LENGTH = 50
1369 |
1370 | def __len__(self):
1371 | return len(self.dataloader)
1372 |
1373 | def __iter__(self):
1374 | return self
1375 |
1376 | def __next__(self):
1377 | if len(self._time) < 2:
1378 | self._time.append(time.time())
1379 |
1380 | self._batch_idx += self.batch_size
1381 | if self._batch_idx > len(self.dataset):
1382 | self._batch_idx = len(self.dataset)
1383 |
1384 | try:
1385 | batch = next(self.iterator)
1386 | self._display()
1387 | except StopIteration:
1388 | raise StopIteration()
1389 |
1390 | self._idx += 1
1391 | if self._idx >= len(self.dataloader):
1392 | self._reset()
1393 |
1394 | return batch
1395 |
1396 | def _display(self):
1397 | if len(self._time) > 1:
1398 | t = (self._time[-1] - self._time[-2])
1399 | eta = t * (len(self.dataloader) - self._idx)
1400 | else:
1401 | eta = 0
1402 |
1403 | rate = self._idx / len(self.dataloader)
1404 | len_bar = int(rate * self._DISPLAY_LENGTH)
1405 | bar = ('=' * len_bar + '>').ljust(self._DISPLAY_LENGTH, '.')
1406 | idx = str(self._batch_idx).rjust(len(str(len(self.dataset))), ' ')
1407 |
1408 | tmpl = '\r{}/{}: [{}] - ETA {:.1f}s'.format(
1409 | idx,
1410 | len(self.dataset),
1411 | bar,
1412 | eta
1413 | )
1414 | print(tmpl, end='')
1415 | if self._batch_idx == len(self.dataset):
1416 | print()
1417 |
1418 | def _reset(self):
1419 | self._idx = 0
1420 | self._batch_idx = 0
1421 | self._time = []
1422 |
--------------------------------------------------------------------------------
/train_interface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import shutil
5 | import numpy as np
6 | import config as cfg
7 | from models import DCCRN, CRN, FullSubNet # you can import 'DCCRN' or 'CRN' or 'FullSubNet'
8 | from write_on_tensorboard import Writer
9 | from dataloader import create_dataloader
10 | from trainer import model_train, model_validate, \
11 | model_perceptual_train, model_perceptual_validate, \
12 | dccrn_direct_train, dccrn_direct_validate, \
13 | crn_direct_train, crn_direct_validate, \
14 | fullsubnet_train, fullsubnet_validate
15 |
16 |
17 | ###############################################################################
18 | # Helper function definition #
19 | ###############################################################################
20 | # Write training related parameters into the log file.
21 | def write_status_to_log_file(fp, total_parameters):
22 | fp.write('%d-%d-%d %d:%d:%d\n' %
23 | (time.localtime().tm_year, time.localtime().tm_mon,
24 | time.localtime().tm_mday, time.localtime().tm_hour,
25 | time.localtime().tm_min, time.localtime().tm_sec))
26 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' %
27 | (total_parameters,
28 | total_parameters / 1000000.0,
29 | total_parameters * 4.0 / 1000000.0))
30 |
31 |
32 | # Calculate the size of total network.
33 | def calculate_total_params(our_model):
34 | total_parameters = 0
35 | for variable in our_model.parameters():
36 | shape = variable.size()
37 | variable_parameters = 1
38 | for dim in shape:
39 | variable_parameters *= dim
40 | total_parameters += variable_parameters
41 |
42 | return total_parameters
43 |
44 |
45 | ###############################################################################
46 | # Parameter Initialization and Setting for model training #
47 | ###############################################################################
48 | # Set device
49 | DEVICE = torch.device(cfg.DEVICE)
50 |
51 | # Set model
52 | if cfg.model == 'DCCRN':
53 | model = DCCRN().to(DEVICE)
54 | elif cfg.model == 'CRN':
55 | model = CRN().to(DEVICE)
56 | elif cfg.model == 'FullSubNet':
57 | model = FullSubNet().to(DEVICE)
58 | # Set optimizer and learning rate
59 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
60 | total_params = calculate_total_params(model)
61 |
62 | # Set trainer and estimator
63 | if cfg.perceptual is not False:
64 | trainer = model_perceptual_train
65 | estimator = model_perceptual_validate
66 | elif cfg.model == 'FullSubNet':
67 | trainer = fullsubnet_train
68 | estimator = fullsubnet_validate
69 | elif cfg.masking_mode == 'Direct(None make)' and cfg.model == 'DCCRN':
70 | trainer = dccrn_direct_train
71 | estimator = dccrn_direct_validate
72 | elif cfg.masking_mode == 'Direct(None make)' and cfg.model == 'CRN':
73 | trainer = crn_direct_train
74 | estimator = crn_direct_validate
75 | else:
76 | trainer = model_train
77 | estimator = model_validate
78 |
79 | ###############################################################################
80 | # Confirm model information #
81 | ###############################################################################
82 | print('%d-%d-%d %d:%d:%d\n' %
83 | (time.localtime().tm_year, time.localtime().tm_mon,
84 | time.localtime().tm_mday, time.localtime().tm_hour,
85 | time.localtime().tm_min, time.localtime().tm_sec))
86 | print('total params : %d (%.2f M, %.2f MBytes)\n' %
87 | (total_params,
88 | total_params / 1000000.0,
89 | total_params * 4.0 / 1000000.0))
90 |
91 | ###############################################################################
92 | # Create Dataloader #
93 | ###############################################################################
94 | train_loader = create_dataloader(mode='train')
95 | validation_loader = create_dataloader(mode='valid')
96 |
97 | ###############################################################################
98 | # Set a log file to store progress. #
99 | # Set a hps file to store hyper-parameters information. #
100 | ###############################################################################
101 | if cfg.chkpt_model is not None: # Load the checkpoint
102 | print('Resuming from checkpoint: %s' % cfg.chkpt_path)
103 |
104 | # Set a log file to store progress.
105 | dir_to_save = cfg.job_dir + cfg.chkpt_model
106 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model
107 |
108 | checkpoint = torch.load(cfg.chkpt_path)
109 | model.load_state_dict(checkpoint['model'])
110 | optimizer.load_state_dict(checkpoint['optimizer'])
111 | epoch_start_idx = checkpoint['epoch'] + 1
112 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy'))
113 | # if the loaded length is shorter than I expected, extend the length
114 | if len(mse_vali_total) < cfg.max_epochs:
115 | plus = cfg.max_epochs - len(mse_vali_total)
116 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0)
117 | else: # First learning
118 | print('Starting new training run...')
119 |
120 | # make the file directory to save the models
121 | if not os.path.exists(cfg.job_dir):
122 | os.mkdir(cfg.job_dir)
123 | if not os.path.exists(cfg.logs_dir):
124 | os.mkdir(cfg.logs_dir)
125 |
126 | epoch_start_idx = 1
127 | mse_vali_total = np.zeros(cfg.max_epochs)
128 |
129 | # Set a log file to store progress.
130 | dir_to_save = cfg.job_dir + cfg.expr_num + '_%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) + \
131 | '_%s' % cfg.model + '_%s' % cfg.loss
132 | dir_to_logs = cfg.logs_dir + cfg.expr_num + '_%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \
133 | + '_%s' % cfg.model + '_%s' % cfg.loss
134 |
135 | # make the file directory
136 | if not os.path.exists(dir_to_save):
137 | os.mkdir(dir_to_save)
138 | os.mkdir(dir_to_logs)
139 |
140 | # logging
141 | log_fname = str(dir_to_save + '/log.txt')
142 | if not os.path.exists(log_fname):
143 | fp = open(log_fname, 'w')
144 | write_status_to_log_file(fp, total_params)
145 | else:
146 | fp = open(log_fname, 'a')
147 |
148 | ###############################################################################
149 | ###############################################################################
150 | # Main program start !! #
151 | ###############################################################################
152 | ###############################################################################
153 | # Writer initialize
154 | writer = Writer(dir_to_logs)
155 |
156 | ###############################################################################
157 | # Train #
158 | ###############################################################################
159 | if cfg.perceptual is not False: # train with perceptual loss function
160 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1):
161 | start_time = time.time()
162 | # Training
163 | train_loss, train_main_loss, train_perceptual_loss = trainer(model, optimizer, train_loader, DEVICE)
164 |
165 | # save checkpoint file to resume training
166 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch))
167 | torch.save({
168 | 'model': model.state_dict(),
169 | 'optimizer': optimizer.state_dict(),
170 | 'epoch': epoch
171 | }, save_path)
172 |
173 | # Validation
174 | vali_loss, validation_main_loss, validation_perceptual_loss, vali_pesq, vali_stoi = \
175 | estimator(model, validation_loader, writer, dir_to_save, epoch, DEVICE)
176 | # write the loss on tensorboard
177 | writer.log_loss(train_loss, vali_loss, epoch)
178 | writer.log_score(vali_pesq, vali_stoi, epoch)
179 | writer.log_sub_loss(train_main_loss, train_perceptual_loss,
180 | validation_main_loss, validation_perceptual_loss, epoch)
181 |
182 | print('Epoch [{}] | T {:.6f} | V {:.6} '
183 | .format(epoch, train_loss, vali_loss))
184 | print(' | T {:.6f} {:.6f} | V {:.6} {:.6f} takes {:.2f} seconds\n'
185 | .format(epoch, train_main_loss, train_perceptual_loss, validation_main_loss, validation_perceptual_loss,
186 | time.time() - start_time))
187 | print(' | V PESQ: {:.6f} | STOI: {:.6f} '.format(vali_pesq, vali_stoi))
188 | # log file save
189 | fp.write('Epoch [{}] | T {:.6f} | V {:.6}\n'
190 | .format(epoch, train_loss, vali_loss))
191 | fp.write(' | T {:.6f} {:.6f} | V {:.6} {:.6f} takes {:.2f} seconds\n'
192 | .format(epoch, train_main_loss, train_perceptual_loss,
193 | validation_main_loss, validation_perceptual_loss, time.time() - start_time))
194 | fp.write(' | V PESQ: {:.6f} | STOI: {:.6f} \n'.format(vali_pesq, vali_stoi))
195 |
196 | mse_vali_total[epoch - 1] = vali_loss
197 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total)
198 | else:
199 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1):
200 | start_time = time.time()
201 | # Training
202 | train_loss = trainer(model, optimizer, train_loader, DEVICE)
203 |
204 | # save checkpoint file to resume training
205 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch))
206 | torch.save({
207 | 'model': model.state_dict(),
208 | 'optimizer': optimizer.state_dict(),
209 | 'epoch': epoch
210 | }, save_path)
211 |
212 | # Validation
213 | vali_loss, vali_pesq, vali_stoi = \
214 | estimator(model, validation_loader, writer, dir_to_save, epoch, DEVICE)
215 | # write the loss on tensorboard
216 | writer.log_loss(train_loss, vali_loss, epoch)
217 | writer.log_score(vali_pesq, vali_stoi, epoch)
218 |
219 | print('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n'
220 | .format(epoch, train_loss, vali_loss, time.time() - start_time))
221 | print(' | V PESQ: {:.6f} | STOI: {:.6f} '.format(vali_pesq, vali_stoi))
222 | # log file save
223 | fp.write('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n'
224 | .format(epoch, train_loss, vali_loss, time.time() - start_time))
225 | fp.write(' | V PESQ: {:.6f} | STOI: {:.6f} \n'.format(vali_pesq, vali_stoi))
226 |
227 | mse_vali_total[epoch - 1] = vali_loss
228 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total)
229 |
230 | fp.close()
231 | print('Training has been finished.')
232 |
233 | # Copy optimum model that has minimum MSE.
234 | print('Save optimum models...')
235 | min_index = np.argmin(mse_vali_total)
236 | print('Minimum validation loss is at ' + str(min_index + 1) + '.')
237 | src_file = str(dir_to_save + '/' + ('chkpt_%d.pt' % (min_index + 1)))
238 | tgt_file = str(dir_to_save + '/chkpt_opt.pt')
239 | shutil.copy(src_file, tgt_file)
240 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Where the model is actually trained and validated
3 | """
4 |
5 | import torch
6 | import numpy as np
7 | import tools_for_model as tools
8 | from tools_for_estimate import cal_pesq, cal_stoi
9 |
10 |
11 | #######################################################################
12 | # For train #
13 | #######################################################################
14 | # T-F masking
15 | def model_train(model, optimizer, train_loader, DEVICE):
16 | # initialization
17 | train_loss = 0
18 | batch_num = 0
19 |
20 | # arr = []
21 | # train
22 | model.train()
23 | for inputs, targets in tools.Bar(train_loader):
24 | batch_num += 1
25 |
26 | # to cuda
27 | inputs = inputs.float().to(DEVICE)
28 | targets = targets.float().to(DEVICE)
29 |
30 | _, _, outputs = model(inputs, targets)
31 | loss = model.loss(outputs, targets)
32 | # # if you want to check the scale of the loss
33 | # print('loss: {:.4}'.format(loss))
34 |
35 | optimizer.zero_grad()
36 | loss.backward()
37 | optimizer.step()
38 |
39 | train_loss += loss
40 | train_loss /= batch_num
41 |
42 | return train_loss
43 |
44 |
45 | def model_perceptual_train(model, optimizer, train_loader, DEVICE):
46 | # initialization
47 | train_loss = 0
48 | train_main_loss = 0
49 | train_perceptual_loss = 0
50 | batch_num = 0
51 |
52 | # train
53 | model.train()
54 | for inputs, targets in tools.Bar(train_loader):
55 | batch_num += 1
56 |
57 | # to cuda
58 | inputs = inputs.float().to(DEVICE)
59 | targets = targets.float().to(DEVICE)
60 |
61 | real_spec, img_spec, outputs = model(inputs)
62 | main_loss = model.loss(outputs, targets)
63 | perceptual_loss = model.loss(outputs, targets, real_spec, img_spec, perceptual=True)
64 |
65 | # the constraint ratio
66 | r1 = 1
67 | r2 = 1
68 | r3 = r1 + r2
69 | loss = (r1 * main_loss + r2 * perceptual_loss) / r3
70 |
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 |
75 | train_loss += loss
76 | train_main_loss += r1 * main_loss
77 | train_perceptual_loss += r2 * perceptual_loss
78 | train_loss /= batch_num
79 | train_main_loss /= batch_num
80 | train_perceptual_loss /= batch_num
81 |
82 | return train_loss, train_main_loss, train_perceptual_loss
83 |
84 |
85 | def fullsubnet_train(model, optimizer, train_loader, DEVICE):
86 | # initialization
87 | train_loss = 0
88 | batch_num = 0
89 |
90 | # arr = []
91 | # train
92 | model.train()
93 | for inputs, targets in tools.Bar(train_loader):
94 | batch_num += 1
95 |
96 | # to cuda
97 | inputs = inputs.float().to(DEVICE)
98 | targets = targets.float().to(DEVICE)
99 |
100 | noisy_complex = tools.stft(inputs)
101 | clean_complex = tools.stft(targets)
102 |
103 | noisy_mag, _ = tools.mag_phase(noisy_complex)
104 | cIRM = tools.build_complex_ideal_ratio_mask(noisy_complex, clean_complex)
105 |
106 | cRM = model(noisy_mag)
107 | loss = model.loss(cIRM, cRM)
108 | # # if you want to check the scale of the loss
109 | # print('loss: {:.4}'.format(loss))
110 |
111 | optimizer.zero_grad()
112 | loss.backward()
113 | optimizer.step()
114 |
115 | train_loss += loss
116 | train_loss /= batch_num
117 |
118 | return train_loss
119 |
120 |
121 | # Spectral mapping
122 | def dccrn_direct_train(model, optimizer, train_loader, DEVICE):
123 | # initialization
124 | train_loss = 0
125 | batch_num = 0
126 |
127 | # train
128 | model.train()
129 | for inputs, targets in tools.Bar(train_loader):
130 | batch_num += 1
131 |
132 | # to cuda
133 | inputs = inputs.float().to(DEVICE)
134 | targets = targets.float().to(DEVICE)
135 |
136 | output_real, target_real, output_imag, target_imag, _ = model(inputs, targets)
137 | real_loss = model.loss(output_real, target_real)
138 | imag_loss = model.loss(output_imag, target_imag)
139 | loss = (real_loss + imag_loss) / 2
140 |
141 | # # if you want to check the scale of the loss
142 | # print('loss: {:.4}'.format(loss))
143 |
144 | optimizer.zero_grad()
145 | loss.backward()
146 | optimizer.step()
147 |
148 | train_loss += loss
149 | train_loss /= batch_num
150 |
151 | return train_loss
152 |
153 |
154 | def crn_direct_train(model, optimizer, train_loader, DEVICE):
155 | # initialization
156 | train_loss = 0
157 | batch_num = 0
158 |
159 | # train
160 | model.train()
161 | for inputs, targets in tools.Bar(train_loader):
162 | batch_num += 1
163 |
164 | # to cuda
165 | inputs = inputs.float().to(DEVICE)
166 | targets = targets.float().to(DEVICE)
167 |
168 | output_mag, target_mag, _ = model(inputs, targets)
169 | loss = model.loss(output_mag, target_mag)
170 |
171 | # # if you want to check the scale of the loss
172 | # print('loss: {:.4}'.format(loss))
173 |
174 | optimizer.zero_grad()
175 | loss.backward()
176 | optimizer.step()
177 |
178 | train_loss += loss
179 | train_loss /= batch_num
180 |
181 | return train_loss
182 |
183 |
184 | #######################################################################
185 | # For validation #
186 | #######################################################################
187 | # T-F masking
188 | def model_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE):
189 | # initialization
190 | validation_loss = 0
191 | batch_num = 0
192 |
193 | avg_pesq = 0
194 | avg_stoi = 0
195 |
196 | # for record the score each samples
197 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')
198 |
199 | model.eval()
200 | with torch.no_grad():
201 | for inputs, targets in tools.Bar(validation_loader):
202 | batch_num += 1
203 |
204 | # to cuda
205 | inputs = inputs.float().to(DEVICE)
206 | targets = targets.float().to(DEVICE)
207 |
208 | _, _, outputs = model(inputs, targets)
209 | loss = model.loss(outputs, targets)
210 |
211 | validation_loss += loss
212 |
213 | # estimate the output speech with pesq and stoi
214 | estimated_wavs = outputs.cpu().detach().numpy()
215 | clean_wavs = targets.cpu().detach().numpy()
216 |
217 | pesq = cal_pesq(estimated_wavs, clean_wavs)
218 | stoi = cal_stoi(estimated_wavs, clean_wavs)
219 |
220 | # pesq: 0.1 better / stoi: 0.01 better
221 | for i in range(len(pesq)):
222 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))
223 |
224 | # reshape for sum
225 | pesq = np.reshape(pesq, (1, -1))
226 | stoi = np.reshape(stoi, (1, -1))
227 |
228 | avg_pesq += sum(pesq[0]) / len(inputs)
229 | avg_stoi += sum(stoi[0]) / len(inputs)
230 |
231 | # save the samples to tensorboard
232 | if epoch % 10 == 0:
233 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch)
234 |
235 | validation_loss /= batch_num
236 | avg_pesq /= batch_num
237 | avg_stoi /= batch_num
238 |
239 | return validation_loss, avg_pesq, avg_stoi
240 |
241 |
242 | def model_perceptual_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE):
243 | # initialization
244 | validation_loss = 0
245 | validation_main_loss = 0
246 | validation_perceptual_loss = 0
247 | batch_num = 0
248 |
249 | avg_pesq = 0
250 | avg_stoi = 0
251 |
252 | # for record the score each samples
253 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')
254 |
255 | model.eval()
256 | with torch.no_grad():
257 | for inputs, targets in tools.Bar(validation_loader):
258 | batch_num += 1
259 |
260 | # to cuda
261 | inputs = inputs.float().to(DEVICE)
262 | targets = targets.float().to(DEVICE)
263 |
264 | real_spec, img_spec, outputs = model(inputs)
265 | main_loss = model.loss(outputs, targets)
266 | perceptual_loss = model.loss(outputs, targets, real_spec, img_spec, perceptual=True)
267 |
268 | # the constraint ratio
269 | r1 = 1
270 | r2 = 1
271 | r3 = r1 + r2
272 | loss = (r1 * main_loss + r2 * perceptual_loss) / r3
273 |
274 | validation_loss += loss
275 | validation_main_loss += r1 * main_loss
276 | validation_perceptual_loss += r2 * perceptual_loss
277 |
278 | # estimate the output speech with pesq and stoi
279 | estimated_wavs = outputs.cpu().detach().numpy()
280 | clean_wavs = targets.cpu().detach().numpy()
281 |
282 | pesq = cal_pesq(estimated_wavs, clean_wavs)
283 | stoi = cal_stoi(estimated_wavs, clean_wavs)
284 |
285 | # pesq: 0.1 better / stoi: 0.01 better
286 | for i in range(len(pesq)):
287 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))
288 |
289 | # reshape for sum
290 | pesq = np.reshape(pesq, (1, -1))
291 | stoi = np.reshape(stoi, (1, -1))
292 |
293 | avg_pesq += sum(pesq[0]) / len(inputs)
294 | avg_stoi += sum(stoi[0]) / len(inputs)
295 |
296 | # save the samples to tensorboard
297 | if epoch % 10 == 0:
298 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch)
299 |
300 | validation_loss /= batch_num
301 | validation_main_loss /= batch_num
302 | validation_perceptual_loss /= batch_num
303 | avg_pesq /= batch_num
304 | avg_stoi /= batch_num
305 |
306 | return validation_loss, validation_main_loss, validation_perceptual_loss, avg_pesq, avg_stoi
307 |
308 |
309 | def fullsubnet_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE):
310 | # initialization
311 | validation_loss = 0
312 | batch_num = 0
313 |
314 | avg_pesq = 0
315 | avg_stoi = 0
316 |
317 | # for record the score each samples
318 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')
319 |
320 | model.eval()
321 | with torch.no_grad():
322 | for inputs, targets in tools.Bar(validation_loader):
323 | batch_num += 1
324 |
325 | # to cuda
326 | inputs = inputs.float().to(DEVICE)
327 | targets = targets.float().to(DEVICE)
328 |
329 | noisy_complex = tools.stft(inputs)
330 | clean_complex = tools.stft(targets)
331 |
332 | noisy_mag, _ = tools.mag_phase(noisy_complex)
333 | cIRM = tools.build_complex_ideal_ratio_mask(noisy_complex, clean_complex)
334 |
335 | cRM = model(noisy_mag)
336 | loss = model.loss(cIRM, cRM)
337 |
338 | validation_loss += loss
339 |
340 | # estimate the output speech with pesq and stoi
341 | cRM = tools.decompress_cIRM(cRM)
342 | enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[..., 1] * noisy_complex.imag
343 | enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[..., 0] * noisy_complex.imag
344 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
345 | enhanced_outputs = tools.istft(enhanced_complex, length=inputs.size(-1))
346 |
347 | estimated_wavs = enhanced_outputs.cpu().detach().numpy()
348 | clean_wavs = targets.cpu().detach().numpy()
349 |
350 | pesq = cal_pesq(estimated_wavs, clean_wavs)
351 | stoi = cal_stoi(estimated_wavs, clean_wavs)
352 |
353 | # pesq: 0.1 better / stoi: 0.01 better
354 | for i in range(len(pesq)):
355 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))
356 |
357 | # reshape for sum
358 | pesq = np.reshape(pesq, (1, -1))
359 | stoi = np.reshape(stoi, (1, -1))
360 |
361 | avg_pesq += sum(pesq[0]) / len(inputs)
362 | avg_stoi += sum(stoi[0]) / len(inputs)
363 |
364 | # save the samples to tensorboard
365 | if epoch % 10 == 0:
366 | writer.log_wav(inputs[0], targets[0], enhanced_outputs[0], epoch)
367 |
368 | validation_loss /= batch_num
369 | avg_pesq /= batch_num
370 | avg_stoi /= batch_num
371 |
372 | return validation_loss, avg_pesq, avg_stoi
373 |
374 |
375 | # Spectral mapping
376 | def dccrn_direct_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE):
377 | # initialization
378 | validation_loss = 0
379 | batch_num = 0
380 |
381 | avg_pesq = 0
382 | avg_stoi = 0
383 |
384 | # for record the score each samples
385 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')
386 |
387 | model.eval()
388 | with torch.no_grad():
389 | for inputs, targets in tools.Bar(validation_loader):
390 | batch_num += 1
391 |
392 | # to cuda
393 | inputs = inputs.float().to(DEVICE)
394 | targets = targets.float().to(DEVICE)
395 |
396 | output_real, target_real, output_imag, target_imag, outputs = model(inputs, targets)
397 | real_loss = model.loss(output_real, target_real)
398 | imag_loss = model.loss(output_imag, target_imag)
399 | loss = (real_loss + imag_loss) / 2
400 |
401 | validation_loss += loss
402 |
403 | # estimate the output speech with pesq and stoi
404 | estimated_wavs = outputs.cpu().detach().numpy()
405 | clean_wavs = targets.cpu().detach().numpy()
406 |
407 | pesq = cal_pesq(estimated_wavs, clean_wavs)
408 | stoi = cal_stoi(estimated_wavs, clean_wavs)
409 |
410 | # pesq: 0.1 better / stoi: 0.01 better
411 | for i in range(len(pesq)):
412 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))
413 |
414 | # reshape for sum
415 | pesq = np.reshape(pesq, (1, -1))
416 | stoi = np.reshape(stoi, (1, -1))
417 |
418 | avg_pesq += sum(pesq[0]) / len(inputs)
419 | avg_stoi += sum(stoi[0]) / len(inputs)
420 |
421 | # save the samples to tensorboard
422 | if epoch % 10 == 0:
423 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch)
424 |
425 | validation_loss /= batch_num
426 | avg_pesq /= batch_num
427 | avg_stoi /= batch_num
428 |
429 | return validation_loss, avg_pesq, avg_stoi
430 |
431 |
432 | def crn_direct_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE):
433 | # initialization
434 | validation_loss = 0
435 | batch_num = 0
436 |
437 | avg_pesq = 0
438 | avg_stoi = 0
439 |
440 | # for record the score each samples
441 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a')
442 |
443 | model.eval()
444 | with torch.no_grad():
445 | for inputs, targets in tools.Bar(validation_loader):
446 | batch_num += 1
447 |
448 | # to cuda
449 | inputs = inputs.float().to(DEVICE)
450 | targets = targets.float().to(DEVICE)
451 |
452 | output_mag, target_mag, outputs = model(inputs, targets)
453 | loss = model.loss(output_mag, target_mag)
454 |
455 | validation_loss += loss
456 |
457 | # estimate the output speech with pesq and stoi
458 | estimated_wavs = outputs.cpu().detach().numpy()
459 | clean_wavs = targets.cpu().detach().numpy()
460 |
461 | pesq = cal_pesq(estimated_wavs, clean_wavs)
462 | stoi = cal_stoi(estimated_wavs, clean_wavs)
463 |
464 | # pesq: 0.1 better / stoi: 0.01 better
465 | for i in range(len(pesq)):
466 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i]))
467 |
468 | # reshape for sum
469 | pesq = np.reshape(pesq, (1, -1))
470 | stoi = np.reshape(stoi, (1, -1))
471 |
472 | avg_pesq += sum(pesq[0]) / len(inputs)
473 | avg_stoi += sum(stoi[0]) / len(inputs)
474 |
475 | # save the samples to tensorboard
476 | if epoch % 10 == 0:
477 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch)
478 |
479 | validation_loss /= batch_num
480 | avg_pesq /= batch_num
481 | avg_stoi /= batch_num
482 |
483 | return validation_loss, avg_pesq, avg_stoi
484 |
--------------------------------------------------------------------------------
/write_on_tensorboard.py:
--------------------------------------------------------------------------------
1 | """
2 | For observing the results using tensorboard
3 |
4 | 1. wav
5 | 2. spectrogram
6 | 3. loss
7 | """
8 | from tensorboardX import SummaryWriter
9 | import matplotlib
10 | import config as cfg
11 |
12 |
13 | class Writer(SummaryWriter):
14 | def __init__(self, logdir):
15 | super(Writer, self).__init__(logdir)
16 | # mask real/ imag
17 | cmap_custom = {
18 | 'red': ((0.0, 0.0, 0.0),
19 | (1 / 63, 0.0, 0.0),
20 | (2 / 63, 0.0, 0.0),
21 | (3 / 63, 0.0, 0.0),
22 | (4 / 63, 0.0, 0.0),
23 | (5 / 63, 0.0, 0.0),
24 | (6 / 63, 0.0, 0.0),
25 | (7 / 63, 0.0, 0.0),
26 | (8 / 63, 0.0, 0.0),
27 | (9 / 63, 0.0, 0.0),
28 | (10 / 63, 0.0, 0.0),
29 | (11 / 63, 0.0, 0.0),
30 | (12 / 63, 0.0, 0.0),
31 | (13 / 63, 0.0, 0.0),
32 | (14 / 63, 0.0, 0.0),
33 | (15 / 63, 0.0, 0.0),
34 | (16 / 63, 0.0, 0.0),
35 | (17 / 63, 0.0, 0.0),
36 | (18 / 63, 0.0, 0.0),
37 | (19 / 63, 0.0, 0.0),
38 | (20 / 63, 0.0, 0.0),
39 | (21 / 63, 0.0, 0.0),
40 | (22 / 63, 0.0, 0.0),
41 | (23 / 63, 0.0, 0.0),
42 | (24 / 63, 0.5625, 0.5625),
43 | (25 / 63, 0.6250, 0.6250),
44 | (26 / 63, 0.6875, 0.6875),
45 | (27 / 63, 0.7500, 0.7500),
46 | (28 / 63, 0.8125, 0.8125),
47 | (29 / 63, 0.8750, 0.8750),
48 | (30 / 63, 0.9375, 0.9375),
49 | (31 / 63, 1.0, 1.0),
50 | (32 / 63, 1.0, 1.0),
51 | (33 / 63, 1.0, 1.0),
52 | (34 / 63, 1.0, 1.0),
53 | (35 / 63, 1.0, 1.0),
54 | (36 / 63, 1.0, 1.0),
55 | (37 / 63, 1.0, 1.0),
56 | (38 / 63, 1.0, 1.0),
57 | (39 / 63, 1.0, 1.0),
58 | (40 / 63, 1.0, 1.0),
59 | (41 / 63, 1.0, 1.0),
60 | (42 / 63, 1.0, 1.0),
61 | (43 / 63, 1.0, 1.0),
62 | (44 / 63, 1.0, 1.0),
63 | (45 / 63, 1.0, 1.0),
64 | (46 / 63, 1.0, 1.0),
65 | (47 / 63, 1.0, 1.0),
66 | (48 / 63, 1.0, 1.0),
67 | (49 / 63, 1.0, 1.0),
68 | (50 / 63, 1.0, 1.0),
69 | (51 / 63, 1.0, 1.0),
70 | (52 / 63, 1.0, 1.0),
71 | (53 / 63, 1.0, 1.0),
72 | (54 / 63, 1.0, 1.0),
73 | (55 / 63, 1.0, 1.0),
74 | (56 / 63, 0.9375, 0.9375),
75 | (57 / 63, 0.8750, 0.8750),
76 | (58 / 63, 0.8125, 0.8125),
77 | (59 / 63, 0.7500, 0.7500),
78 | (60 / 63, 0.6875, 0.6875),
79 | (61 / 63, 0.6250, 0.6250),
80 | (62 / 63, 0.5625, 0.5625),
81 | (63 / 63, 0.5000, 0.5000)),
82 | 'green': ((0.0, 0.0, 0.0),
83 | (1 / 63, 0.0, 0.0),
84 | (2 / 63, 0.0, 0.0),
85 | (3 / 63, 0.0, 0.0),
86 | (4 / 63, 0.0, 0.0),
87 | (5 / 63, 0.0, 0.0),
88 | (6 / 63, 0.0, 0.0),
89 | (7 / 63, 0.0, 0.0),
90 | (8 / 63, 0.0625, 0.0625),
91 | (9 / 63, 0.1250, 0.1250),
92 | (10 / 63, 0.1875, 0.1875),
93 | (11 / 63, 0.2500, 0.2500),
94 | (12 / 63, 0.3125, 0.3125),
95 | (13 / 63, 0.3750, 0.3750),
96 | (14 / 63, 0.4375, 0.4375),
97 | (15 / 63, 0.5000, 0.5000),
98 | (16 / 63, 0.5625, 0.5625),
99 | (17 / 63, 0.6250, 0.6250),
100 | (18 / 63, 0.6875, 0.6875),
101 | (19 / 63, 0.7500, 0.7500),
102 | (20 / 63, 0.8125, 0.8125),
103 | (21 / 63, 0.8750, 0.8750),
104 | (22 / 63, 0.9375, 0.9375),
105 | (23 / 63, 1.0, 1.0),
106 | (24 / 63, 1.0, 1.0),
107 | (25 / 63, 1.0, 1.0),
108 | (26 / 63, 1.0, 1.0),
109 | (27 / 63, 1.0, 1.0),
110 | (28 / 63, 1.0, 1.0),
111 | (29 / 63, 1.0, 1.0),
112 | (30 / 63, 1.0, 1.0),
113 | (31 / 63, 1.0, 1.0),
114 | (32 / 63, 1.0, 1.0),
115 | (33 / 63, 1.0, 1.0),
116 | (34 / 63, 1.0, 1.0),
117 | (35 / 63, 1.0, 1.0),
118 | (36 / 63, 1.0, 1.0),
119 | (37 / 63, 1.0, 1.0),
120 | (38 / 63, 1.0, 1.0),
121 | (39 / 63, 1.0, 1.0),
122 | (40 / 63, 0.9375, 0.9375),
123 | (41 / 63, 0.8750, 0.8750),
124 | (42 / 63, 0.8125, 0.8125),
125 | (43 / 63, 0.7500, 0.7500),
126 | (44 / 63, 0.6875, 0.6875),
127 | (45 / 63, 0.6250, 0.6250),
128 | (46 / 63, 0.5625, 0.5625),
129 | (47 / 63, 0.5000, 0.5000),
130 | (48 / 63, 0.4375, 0.4375),
131 | (49 / 63, 0.3750, 0.3750),
132 | (50 / 63, 0.3125, 0.3125),
133 | (51 / 63, 0.2500, 0.2500),
134 | (52 / 63, 0.1875, 0.1875),
135 | (53 / 63, 0.1250, 0.1250),
136 | (54 / 63, 0.0625, 0.0625),
137 | (55 / 63, 0.0, 0.0),
138 | (56 / 63, 0.0, 0.0),
139 | (57 / 63, 0.0, 0.0),
140 | (58 / 63, 0.0, 0.0),
141 | (59 / 63, 0.0, 0.0),
142 | (60 / 63, 0.0, 0.0),
143 | (61 / 63, 0.0, 0.0),
144 | (62 / 63, 0.0, 0.0),
145 | (63 / 63, 0.0, 0.0)),
146 | 'blue': ((0.0, 0.5625, 0.5625),
147 | (1 / 63, 0.6250, 0.6250),
148 | (2 / 63, 0.6875, 0.6875),
149 | (3 / 63, 0.7500, 0.7500),
150 | (4 / 63, 0.8125, 0.8125),
151 | (5 / 63, 0.8750, 0.8750),
152 | (6 / 63, 0.9375, 0.9375),
153 | (7 / 63, 1.0, 1.0),
154 | (8 / 63, 1.0, 1.0),
155 | (9 / 63, 1.0, 1.0),
156 | (10 / 63, 1.0, 1.0),
157 | (11 / 63, 1.0, 1.0),
158 | (12 / 63, 1.0, 1.0),
159 | (13 / 63, 1.0, 1.0),
160 | (14 / 63, 1.0, 1.0),
161 | (15 / 63, 1.0, 1.0),
162 | (16 / 63, 1.0, 1.0),
163 | (17 / 63, 1.0, 1.0),
164 | (18 / 63, 1.0, 1.0),
165 | (19 / 63, 1.0, 1.0),
166 | (20 / 63, 1.0, 1.0),
167 | (21 / 63, 1.0, 1.0),
168 | (22 / 63, 1.0, 1.0),
169 | (23 / 63, 1.0, 1.0),
170 | (24 / 63, 1.0, 1.0),
171 | (25 / 63, 1.0, 1.0),
172 | (26 / 63, 1.0, 1.0),
173 | (27 / 63, 1.0, 1.0),
174 | (28 / 63, 1.0, 1.0),
175 | (29 / 63, 1.0, 1.0),
176 | (30 / 63, 1.0, 1.0),
177 | (31 / 63, 1.0, 1.0),
178 | (32 / 63, 0.9375, 0.9375),
179 | (33 / 63, 0.8750, 0.8750),
180 | (34 / 63, 0.8125, 0.8125),
181 | (35 / 63, 0.7500, 0.7500),
182 | (36 / 63, 0.6875, 0.6875),
183 | (37 / 63, 0.6250, 0.6250),
184 | (38 / 63, 0.5625, 0.5625),
185 | (39 / 63, 0.0, 0.0),
186 | (40 / 63, 0.0, 0.0),
187 | (41 / 63, 0.0, 0.0),
188 | (42 / 63, 0.0, 0.0),
189 | (43 / 63, 0.0, 0.0),
190 | (44 / 63, 0.0, 0.0),
191 | (45 / 63, 0.0, 0.0),
192 | (46 / 63, 0.0, 0.0),
193 | (47 / 63, 0.0, 0.0),
194 | (48 / 63, 0.0, 0.0),
195 | (49 / 63, 0.0, 0.0),
196 | (50 / 63, 0.0, 0.0),
197 | (51 / 63, 0.0, 0.0),
198 | (52 / 63, 0.0, 0.0),
199 | (53 / 63, 0.0, 0.0),
200 | (54 / 63, 0.0, 0.0),
201 | (55 / 63, 0.0, 0.0),
202 | (56 / 63, 0.0, 0.0),
203 | (57 / 63, 0.0, 0.0),
204 | (58 / 63, 0.0, 0.0),
205 | (59 / 63, 0.0, 0.0),
206 | (60 / 63, 0.0, 0.0),
207 | (61 / 63, 0.0, 0.0),
208 | (62 / 63, 0.0, 0.0),
209 | (63 / 63, 0.0, 0.0))
210 | }
211 |
212 | # mask magnitude
213 | cmap_custom2 = {
214 | 'red': ((0.0, 1.0, 1.0),
215 | (1 / 32, 1.0, 1.0),
216 | (2 / 32, 1.0, 1.0),
217 | (3 / 32, 1.0, 1.0),
218 | (4 / 32, 1.0, 1.0),
219 | (5 / 32, 1.0, 1.0),
220 | (6 / 32, 1.0, 1.0),
221 | (7 / 32, 1.0, 1.0),
222 | (8 / 32, 1.0, 1.0),
223 | (9 / 32, 1.0, 1.0),
224 | (10 / 32, 1.0, 1.0),
225 | (11 / 32, 1.0, 1.0),
226 | (12 / 32, 1.0, 1.0),
227 | (13 / 32, 1.0, 1.0),
228 | (14 / 32, 1.0, 1.0),
229 | (15 / 32, 1.0, 1.0),
230 | (16 / 32, 1.0, 1.0),
231 | (17 / 32, 1.0, 1.0),
232 | (18 / 32, 1.0, 1.0),
233 | (19 / 32, 1.0, 1.0),
234 | (20 / 32, 1.0, 1.0),
235 | (21 / 32, 1.0, 1.0),
236 | (22 / 32, 1.0, 1.0),
237 | (23 / 32, 1.0, 1.0),
238 | (24 / 32, 1.0, 1.0),
239 | (25 / 32, 0.9375, 0.9375),
240 | (26 / 32, 0.8750, 0.8750),
241 | (27 / 32, 0.8125, 0.8125),
242 | (28 / 32, 0.7500, 0.7500),
243 | (29 / 32, 0.6875, 0.6875),
244 | (30 / 32, 0.6250, 0.6250),
245 | (31 / 32, 0.5625, 0.5625),
246 | (32 / 32, 0.5000, 0.5000)),
247 | 'green': ((0.0, 1.0, 1.0),
248 | (1 / 32, 1.0, 1.0),
249 | (2 / 32, 1.0, 1.0),
250 | (3 / 32, 1.0, 1.0),
251 | (4 / 32, 1.0, 1.0),
252 | (5 / 32, 1.0, 1.0),
253 | (6 / 32, 1.0, 1.0),
254 | (7 / 32, 1.0, 1.0),
255 | (8 / 32, 1.0, 1.0),
256 | (9 / 32, 0.9375, 0.9375),
257 | (10 / 32, 0.8750, 0.8750),
258 | (11 / 32, 0.8125, 0.8125),
259 | (12 / 32, 0.7500, 0.7500),
260 | (13 / 32, 0.6875, 0.6875),
261 | (14 / 32, 0.6250, 0.6250),
262 | (15 / 32, 0.5625, 0.5625),
263 | (16 / 32, 0.5000, 0.5000),
264 | (17 / 32, 0.4375, 0.4375),
265 | (18 / 32, 0.3750, 0.3750),
266 | (19 / 32, 0.3125, 0.3125),
267 | (20 / 32, 0.2500, 0.2500),
268 | (21 / 32, 0.1875, 0.1875),
269 | (22 / 32, 0.1250, 0.1250),
270 | (23 / 32, 0.0625, 0.0625),
271 | (24 / 32, 0.0, 0.0),
272 | (25 / 32, 0.0, 0.0),
273 | (26 / 32, 0.0, 0.0),
274 | (27 / 32, 0.0, 0.0),
275 | (28 / 32, 0.0, 0.0),
276 | (29 / 32, 0.0, 0.0),
277 | (30 / 32, 0.0, 0.0),
278 | (31 / 32, 0.0, 0.0),
279 | (32 / 32, 0.0, 0.0)),
280 | 'blue': ((0.0, 1.0, 1.0),
281 | (1 / 32, 0.9375, 0.9375),
282 | (2 / 32, 0.8750, 0.8750),
283 | (3 / 32, 0.8125, 0.8125),
284 | (4 / 32, 0.7500, 0.7500),
285 | (5 / 32, 0.6875, 0.6875),
286 | (6 / 32, 0.6250, 0.6250),
287 | (7 / 32, 0.5625, 0.5625),
288 | (8 / 32, 0.0, 0.0),
289 | (9 / 32, 0.0, 0.0),
290 | (10 / 32, 0.0, 0.0),
291 | (11 / 32, 0.0, 0.0),
292 | (12 / 32, 0.0, 0.0),
293 | (13 / 32, 0.0, 0.0),
294 | (14 / 32, 0.0, 0.0),
295 | (15 / 32, 0.0, 0.0),
296 | (16 / 32, 0.0, 0.0),
297 | (17 / 32, 0.0, 0.0),
298 | (18 / 32, 0.0, 0.0),
299 | (19 / 32, 0.0, 0.0),
300 | (20 / 32, 0.0, 0.0),
301 | (21 / 32, 0.0, 0.0),
302 | (22 / 32, 0.0, 0.0),
303 | (23 / 32, 0.0, 0.0),
304 | (24 / 32, 0.0, 0.0),
305 | (25 / 32, 0.0, 0.0),
306 | (26 / 32, 0.0, 0.0),
307 | (27 / 32, 0.0, 0.0),
308 | (28 / 32, 0.0, 0.0),
309 | (29 / 32, 0.0, 0.0),
310 | (30 / 32, 0.0, 0.0),
311 | (31 / 32, 0.0, 0.0),
312 | (32 / 32, 0.0, 0.0))
313 | }
314 |
315 | self.cmap_custom = matplotlib.colors.LinearSegmentedColormap('testCmap', segmentdata=cmap_custom, N=256)
316 | self.cmap_custom2 = matplotlib.colors.LinearSegmentedColormap('testCmap2', segmentdata=cmap_custom2, N=256)
317 |
318 | def log_loss(self, train_loss, vali_loss, step):
319 | self.add_scalar('train_loss', train_loss, step)
320 | self.add_scalar('vali_loss', vali_loss, step)
321 |
322 | def log_sub_loss(self, train_main_loss, train_sub_loss, vali_main_loss, vali_sub_loss, step):
323 | self.add_scalar('train_main_loss', train_main_loss, step)
324 | self.add_scalar('train_sub_loss', train_sub_loss, step)
325 | self.add_scalar('vali_main_loss', vali_main_loss, step)
326 | self.add_scalar('vali_sub_loss', vali_sub_loss, step)
327 |
328 | def log_score(self, vali_pesq, vali_stoi, step):
329 | self.add_scalar('vali_pesq', vali_pesq, step)
330 | self.add_scalar('vali_stoi', vali_stoi, step)
331 |
332 | def log_wav(self, mixed_wav, clean_wav, est_wav, step):
333 | #
334 | self.add_audio('mixed_wav', mixed_wav, step, cfg.fs)
335 | self.add_audio('clean_target_wav', clean_wav, step, cfg.fs)
336 | self.add_audio('estimated_wav', est_wav, step, cfg.fs)
337 |
338 | def log_spectrogram(self, mixed_wav, clean_wav, noise_wav, est_wav, step):
339 | #
340 | self.add_image('data/mixed_spectrogram',
341 | plot_spectrogram_to_numpy(mixed_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
342 | None, [-150, -40], 'dB'), step,
343 | dataformats='HWC')
344 | self.add_image('data/clean_spectrogram',
345 | plot_spectrogram_to_numpy(clean_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
346 | None, [-150, -40], 'dB'), step,
347 | dataformats='HWC')
348 | self.add_image('data/noise_spectrogram',
349 | plot_spectrogram_to_numpy(noise_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
350 | None, [-150, -40], 'dB'), step,
351 | dataformats='HWC')
352 | self.add_image('data/clean_unwrap_phase',
353 | plot_spectrogram_to_numpy(clean_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
354 | 'phase', [-500, 500], None), step,
355 | dataformats='HWC')
356 |
357 | #
358 | self.add_image('result/estimated_spectrogram',
359 | plot_spectrogram_to_numpy(est_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
360 | None, [-150, -40], 'dB'), step,
361 | dataformats='HWC')
362 | self.add_image('result/estimated_unwrap_phase',
363 | plot_spectrogram_to_numpy(est_wav, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
364 | 'phase', [-500, 500], None), step,
365 | dataformats='HWC')
366 | self.add_image('result/estimated_magnitude-clean_magnitude',
367 | plot_spectrogram_to_numpy(est_wav - clean_wav, cfg.fs, cfg.win_len,
368 | int(cfg.ola_ratio), None,
369 | [-80, 80], 'dB'), step, dataformats='HWC')
370 | self.add_image('result/estimated_unwrap_phase-clean_unwrap_phase',
371 | plot_spectrogram_to_numpy(est_wav - clean_wav, cfg.fs, cfg.win_len,
372 | int(cfg.ola_ratio), 'phase',
373 | [-500, 500], None), step, dataformats='HWC')
374 |
375 | def log_mask_spectrogram(self, est_mask_real, est_mask_imag, step):
376 | #
377 | self.add_image('result/estimated_mask_magnitude',
378 | plot_mask_to_numpy(np.sqrt(est_mask_real ** 2 + est_mask_imag ** 2), cfg.fs, cfg.win_len,
379 | int(cfg.ola_ratio), 0, 2,
380 | cmap=self.cmap_custom2), step, dataformats='HWC')
381 | self.add_image('result/estimated_mask_real',
382 | plot_mask_to_numpy(est_mask_real, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
383 | -2, 2, cmap=self.cmap_custom), step, dataformats='HWC')
384 | self.add_image('result/estimated_mask_imag',
385 | plot_mask_to_numpy(est_mask_imag, cfg.fs, cfg.win_len, int(cfg.ola_ratio),
386 | -2, 2, cmap=self.cmap_custom), step, dataformats='HWC')
387 |
--------------------------------------------------------------------------------