├── LICENSE ├── PESQ.so ├── README.md ├── SE_tutorials.ipynb ├── composite.m ├── config.py ├── dataloader.py ├── estimation └── check_object_metrics.py ├── generate_noisy_data.py ├── models.py ├── tools_for_estimate.py ├── tools_for_loss.py ├── tools_for_model.py ├── train_interface.py ├── trainer.py └── write_on_tensorboard.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Seo-Rim Hwang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PESQ.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seorim0/DNN-based-Speech-Enhancement-in-the-frequency-domain/ed54e8c0eaea1f063c4db8e7a475ea3eb6e2f836/PESQ.so -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNN-based Speech Enhancement in the frequency domain 2 | You can do DNN-based speech enhancement(SE) in the frequency domain using various method through this repository. 3 | First, you have to make noisy data by mixing clean speech and noise. The dataset is used for deep learning training. 4 | And, you can adjust the type of the network and configuration in various ways, as shown below. 5 | The results of the network can be evaluated through various objective metrics (PESQ, STOI, CSIG, CBAK, COVL). 6 | 7 | 8 | 9 | 10 |
11 | You can change 12 |
    13 |
  1. 14 | Networks 15 |
  2. 16 |
  3. 17 | Learning methods 18 |
  4. 19 |
  5. Loss functions
  6. 20 |
21 |
22 |
23 | 24 | 25 | ## Requirements 26 | > This repository is tested on Ubuntu 20.04, and 27 | * Python 3.7 28 | * Cuda 11.1 29 | * CuDNN 8.0.5 30 | * Pytorch 1.9.0 31 | 32 | 33 | ## Getting Started 34 | 1. Install the necessary libraries 35 | 2. Make a dataset for train and validation 36 | ```sh 37 | # The shape of the dataset 38 | [data_num, 2 (inputs and targets), sampling_frequency * data_length] 39 | 40 | # For example, if you want to use 1,000 3-second data sets with a sampling frequency of 16k, the shape is, 41 | [1000, 2, 48000] 42 | ``` 43 | 4. Set [dataloader.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/dataloader.py) 44 | ```sh 45 | self.input_path = "DATASET_FILE_PATH" 46 | ``` 47 | 5. Set [config.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/config.py) 48 | ```sh 49 | # If you need to adjust any settings, simply change this file. 50 | # When you run this project for the first time, you need to set the path where the model and logs will be saved. 51 | ``` 52 | 6. Run [train_interface.py](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/train_interface.py) 53 | 54 | 55 | ## Tutorials 56 | ['SE_tutorials.ipynb'](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/SE_tutorials.ipynb) was made for tutorial. 57 | You can simply train the CRN with the colab file without any preparation . 58 | 59 | 60 | 61 | ## Networks 62 | > You can find a list that you can adjust in various ways at config.py, and they are: 63 | * Real network 64 | - convolutional recurrent network (CRN) 65 | it is a real version of DCCRN 66 | - FullSubNet [[1]](https://arxiv.org/abs/2010.15508) 67 | * Complex network 68 | - deep complex convolutional recurrent network (DCCRN) [[2]](https://arxiv.org/abs/2008.00264) 69 | 70 | 71 | ## Learning Methods 72 | * T-F masking 73 | * Spectral mapping 74 | 75 | 76 | ## Loss Functions 77 | * MSE 78 | * SDR 79 | * SI-SNR 80 | * SI-SDR 81 | 82 | > and you can join the loss functions with perceptual loss. 83 | * LMS 84 | * PMSQE 85 | 86 | 87 | ## Tensorboard 88 | > As shown below, you can check whether the network is being trained well in real time through ['write_on_tensorboard.py'](https://github.com/seorim0/Speech_enhancement_for_you/blob/main/write_on_tensorboard.py). 89 | 90 | ![tensor](https://user-images.githubusercontent.com/55497506/131444707-4459a979-8652-46f4-82f1-0c640cfff685.png) 91 | * loss 92 | * pesq, stoi 93 | * spectrogram 94 | 95 | 96 | ## Reference 97 | **FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement** 98 | Xiang Hao, Xiangdong Su, Radu Horaud, Xiaofei Li 99 | [[arXiv]](https://arxiv.org/abs/2010.15508) [[code]](https://github.com/haoxiangsnr/FullSubNet) 100 | **DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement** 101 | Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang, Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie 102 | [[arXiv]](https://arxiv.org/abs/2008.00264) [[code]](https://github.com/huyanxin/DeepComplexCRN) 103 | **Other tools** 104 | https://github.com/usimarit/semetrics 105 | https://ecs.utdallas.edu/loizou/speech/software.htm 106 | 107 | -------------------------------------------------------------------------------- /composite.m: -------------------------------------------------------------------------------- 1 | function [Csig,Cbak,Covl,segSNR]= composite(cleanFile, enhancedFile); 2 | 3 | % --------- composite objective measure ---------------------- 4 | % 5 | % Center for Robust Speech Systems 6 | % University of Texas-Dallas 7 | % Copyright (c) 2006 8 | % All Rights Reserved. 9 | % 10 | % Description: 11 | % 12 | % This function implements the composite objective measure 13 | % proposed in [1]. It returns three values: The predicted rating of 14 | % overall quality (Covl), the rating of speech distortion (Csig) and 15 | % the rating of background distortion (Cbak). The ratings are based on the 1-5 MOS scale. 16 | % In addition, it returns the values of the SNRseg, log-likelihood ratio (LLR), PESQ 17 | % and weighted spectral slope (WSS) objective measures. 18 | % 19 | % References: 20 | % [1] Hu, Y. and Loizou, P. (2006). �Evaluation of objective measures for speech enhancement, 21 | % Proceedings of INTERSPEECH-2006, Philadelphia, PA, September 2006. 22 | % 23 | % 24 | % Authors: 25 | % Philipos C. Loizou and Yi Hu 26 | % Bryan L. Pellom and John H. L. Hansen (for the implementation of 27 | % the WSS, LLR and SnrSeg measures) 28 | % 29 | %---------------------------------------------------------- 30 | 31 | if nargin<2 32 | fprintf('Usage: [Csig,Cbak,Covl]=composite(cleanfile.wav,enhanced.wav)\n'); 33 | fprintf('where ''Csig'' is the predicted rating of speech distortion\n'); 34 | fprintf(' ''Cbak'' is the predicted rating of background distortion\n'); 35 | fprintf(' ''Covl'' is the predicted rating of overall quality.\n\n'); 36 | return; 37 | end 38 | 39 | 40 | alpha= 0.95; 41 | 42 | [data1, Srate1]= audioread(cleanFile); 43 | [data2, Srate2]= audioread(enhancedFile); 44 | info1 = audioinfo(cleanFile); 45 | info2 = audioinfo(enhancedFile); 46 | Nbits1 = info1.BitsPerSample; 47 | Nbits2 = info2.BitsPerSample; 48 | if ( Srate1~= Srate2) | ( Nbits1~= Nbits2) 49 | error( 'The two files do not match!\n'); 50 | end 51 | 52 | len= min( length( data1), length( data2)); 53 | data1= data1( 1: len)+eps; 54 | data2= data2( 1: len)+eps; 55 | 56 | 57 | % -- compute the WSS measure --- 58 | % 59 | wss_dist_vec= wss( data1, data2,Srate1); 60 | wss_dist_vec= sort( wss_dist_vec); 61 | wss_dist= mean( wss_dist_vec( 1: round( length( wss_dist_vec)*alpha))); 62 | 63 | % --- compute the LLR measure --------- 64 | % 65 | LLR_dist= llr( data1, data2,Srate1); 66 | LLRs= sort(LLR_dist); 67 | LLR_len= round( length(LLR_dist)* alpha); 68 | llr_mean= mean( LLRs( 1: LLR_len)); 69 | 70 | % --- compute the SNRseg ---------------- 71 | % 72 | [snr_dist, segsnr_dist]= snr( data1, data2,Srate1); 73 | snr_mean= snr_dist; 74 | segSNR= mean( segsnr_dist); 75 | 76 | 77 | % -- compute the pesq ---- 78 | %[pesq_mos]= pesq(Srate1,cleanFile, enhancedFile); 79 | pesq_mos = 0; 80 | 81 | 82 | % --- now compute the composite measures ------------------ 83 | % 84 | Csig = 3.093 - 1.029*llr_mean + 0.603*pesq_mos-0.009*wss_dist; 85 | Csig = max(1, Csig); Csig = min(5, Csig); %% adding for fitting range 1 to 5 86 | Cbak = 1.634 + 0.478 *pesq_mos - 0.007*wss_dist + 0.063*segSNR; 87 | Cbak = max(1, Cbak); Cbak = min(5, Cbak); %% adding for fitting range 1 to 5 88 | Covl = 1.594 + 0.805*pesq_mos - 0.512*llr_mean - 0.007*wss_dist; 89 | Covl = max(1, Covl); Covl = min(5, Covl); %% adding for fitting range 1 to 5 90 | 91 | %fprintf('\n LLR=%f SNRseg=%f WSS=%f PESQ=%f\n',llr_mean,segSNR,wss_dist,pesq_mos); 92 | 93 | return; 94 | 95 | % ---------------------------------------------------------------------- 96 | % 97 | % Weighted Spectral Slope (WSS) Objective Speech Quality Measure 98 | % 99 | % Center for Robust Speech Systems 100 | % University of Texas-Dallas 101 | % Copyright (c) 1998-2006 102 | % All Rights Reserved. 103 | % 104 | % Description: 105 | % 106 | % This function implements the Weighted Spectral Slope (WSS) 107 | % distance measure originally proposed in [1]. The algorithm 108 | % works by first decomposing the speech signal into a set of 109 | % frequency bands (this is done for both the test and reference 110 | % frame). The intensities within each critical band are 111 | % measured. Then, a weighted distances between the measured 112 | % slopes of the log-critical band spectra are computed. 113 | % This measure is also described in Section 2.2.9 (pages 56-58) 114 | % of [2]. 115 | % 116 | % Whereas Klatt's original measure used 36 critical-band 117 | % filters to estimate the smoothed short-time spectrum, this 118 | % implementation considers a bank of 25 filters spanning 119 | % the 4 kHz bandwidth. 120 | % 121 | % Input/Output: 122 | % 123 | % The input is a reference 8kHz sampled speech, and processed 124 | % speech (could be noisy or enhanced). 125 | % 126 | % The function returns the numerical distance between each 127 | % frame of the two input files (one distance per frame). 128 | % 129 | % References: 130 | % 131 | % [1] D. H. Klatt, "Prediction of Perceived Phonetic Distance 132 | % from Critical-Band Spectra: A First Step", Proc. IEEE 133 | % ICASSP'82, Volume 2, pp. 1278-1281, May, 1982. 134 | % 135 | % [2] S. R. Quackenbush, T. P. Barnwell, and M. A. Clements, 136 | % Objective Measures of Speech Quality. Prentice Hall 137 | % Advanced Reference Series, Englewood Cliffs, NJ, 1988, 138 | % ISBN: 0-13-629056-6. 139 | % 140 | % Authors: 141 | % 142 | % Bryan L. Pellom and John H. L. Hansen 143 | % 144 | % 145 | % Last Modified: 146 | % 147 | % July 22, 1998 148 | % September 12, 2006 by Philipos Loizou 149 | % ---------------------------------------------------------------------- 150 | 151 | function distortion = wss(clean_speech, processed_speech,sample_rate) 152 | 153 | 154 | % ---------------------------------------------------------------------- 155 | % Check the length of the clean and processed speech. Must be the same. 156 | % ---------------------------------------------------------------------- 157 | 158 | clean_length = length(clean_speech); 159 | processed_length = length(processed_speech); 160 | 161 | if (clean_length ~= processed_length) 162 | disp('Error: Files musthave same length.'); 163 | return 164 | end 165 | 166 | 167 | 168 | % ---------------------------------------------------------------------- 169 | % Global Variables 170 | % ---------------------------------------------------------------------- 171 | 172 | % sample_rate = 8000; % default sample rate 173 | % winlength = 240; % window length in samples 174 | % skiprate = 60; % window skip in samples 175 | winlength = round(30*sample_rate/1000); %240; % window length in samples 176 | skiprate = floor(winlength/4); % window skip in samples 177 | max_freq = sample_rate/2; % maximum bandwidth 178 | num_crit = 25; % number of critical bands 179 | 180 | USE_FFT_SPECTRUM = 1; % defaults to 10th order LP spectrum 181 | %n_fft = 512; % FFT size 182 | n_fft = 2^nextpow2(2*winlength); 183 | n_fftby2 = n_fft/2; % FFT size/2 184 | Kmax = 20; % value suggested by Klatt, pg 1280 185 | Klocmax = 1; % value suggested by Klatt, pg 1280 186 | 187 | % ---------------------------------------------------------------------- 188 | % Critical Band Filter Definitions (Center Frequency and Bandwidths in Hz) 189 | % ---------------------------------------------------------------------- 190 | 191 | cent_freq(1) = 50.0000; bandwidth(1) = 70.0000; 192 | cent_freq(2) = 120.000; bandwidth(2) = 70.0000; 193 | cent_freq(3) = 190.000; bandwidth(3) = 70.0000; 194 | cent_freq(4) = 260.000; bandwidth(4) = 70.0000; 195 | cent_freq(5) = 330.000; bandwidth(5) = 70.0000; 196 | cent_freq(6) = 400.000; bandwidth(6) = 70.0000; 197 | cent_freq(7) = 470.000; bandwidth(7) = 70.0000; 198 | cent_freq(8) = 540.000; bandwidth(8) = 77.3724; 199 | cent_freq(9) = 617.372; bandwidth(9) = 86.0056; 200 | cent_freq(10) = 703.378; bandwidth(10) = 95.3398; 201 | cent_freq(11) = 798.717; bandwidth(11) = 105.411; 202 | cent_freq(12) = 904.128; bandwidth(12) = 116.256; 203 | cent_freq(13) = 1020.38; bandwidth(13) = 127.914; 204 | cent_freq(14) = 1148.30; bandwidth(14) = 140.423; 205 | cent_freq(15) = 1288.72; bandwidth(15) = 153.823; 206 | cent_freq(16) = 1442.54; bandwidth(16) = 168.154; 207 | cent_freq(17) = 1610.70; bandwidth(17) = 183.457; 208 | cent_freq(18) = 1794.16; bandwidth(18) = 199.776; 209 | cent_freq(19) = 1993.93; bandwidth(19) = 217.153; 210 | cent_freq(20) = 2211.08; bandwidth(20) = 235.631; 211 | cent_freq(21) = 2446.71; bandwidth(21) = 255.255; 212 | cent_freq(22) = 2701.97; bandwidth(22) = 276.072; 213 | cent_freq(23) = 2978.04; bandwidth(23) = 298.126; 214 | cent_freq(24) = 3276.17; bandwidth(24) = 321.465; 215 | cent_freq(25) = 3597.63; bandwidth(25) = 346.136; 216 | 217 | bw_min = bandwidth (1); % minimum critical bandwidth 218 | 219 | % ---------------------------------------------------------------------- 220 | % Set up the critical band filters. Note here that Gaussianly shaped 221 | % filters are used. Also, the sum of the filter weights are equivalent 222 | % for each critical band filter. Filter less than -30 dB and set to 223 | % zero. 224 | % ---------------------------------------------------------------------- 225 | 226 | min_factor = exp (-30.0 / (2.0 * 2.303)); % -30 dB point of filter 227 | 228 | for i = 1:num_crit 229 | f0 = (cent_freq (i) / max_freq) * (n_fftby2); 230 | all_f0(i) = floor(f0); 231 | bw = (bandwidth (i) / max_freq) * (n_fftby2); 232 | norm_factor = log(bw_min) - log(bandwidth(i)); 233 | j = 0:1:n_fftby2-1; 234 | crit_filter(i,:) = exp (-11 *(((j - floor(f0)) ./bw).^2) + norm_factor); 235 | crit_filter(i,:) = crit_filter(i,:).*(crit_filter(i,:) > min_factor); 236 | end 237 | 238 | % ---------------------------------------------------------------------- 239 | % For each frame of input speech, calculate the Weighted Spectral 240 | % Slope Measure 241 | % ---------------------------------------------------------------------- 242 | 243 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 244 | start = 1; % starting sample 245 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 246 | 247 | for frame_count = 1:num_frames 248 | 249 | % ---------------------------------------------------------- 250 | % (1) Get the Frames for the test and reference speech. 251 | % Multiply by Hanning Window. 252 | % ---------------------------------------------------------- 253 | 254 | clean_frame = clean_speech(start:start+winlength-1); 255 | processed_frame = processed_speech(start:start+winlength-1); 256 | clean_frame = clean_frame.*window; 257 | processed_frame = processed_frame.*window; 258 | 259 | % ---------------------------------------------------------- 260 | % (2) Compute the Power Spectrum of Clean and Processed 261 | % ---------------------------------------------------------- 262 | 263 | if (USE_FFT_SPECTRUM) 264 | clean_spec = (abs(fft(clean_frame,n_fft)).^2); 265 | processed_spec = (abs(fft(processed_frame,n_fft)).^2); 266 | else 267 | a_vec = zeros(1,n_fft); 268 | a_vec(1:11) = lpc(clean_frame,10); 269 | clean_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)'; 270 | 271 | a_vec = zeros(1,n_fft); 272 | a_vec(1:11) = lpc(processed_frame,10); 273 | processed_spec = 1.0/(abs(fft(a_vec,n_fft)).^2)'; 274 | end 275 | 276 | % ---------------------------------------------------------- 277 | % (3) Compute Filterbank Output Energies (in dB scale) 278 | % ---------------------------------------------------------- 279 | 280 | for i = 1:num_crit 281 | clean_energy(i) = sum(clean_spec(1:n_fftby2) ... 282 | .*crit_filter(i,:)'); 283 | processed_energy(i) = sum(processed_spec(1:n_fftby2) ... 284 | .*crit_filter(i,:)'); 285 | end 286 | clean_energy = 10*log10(max(clean_energy,1E-10)); 287 | processed_energy = 10*log10(max(processed_energy,1E-10)); 288 | 289 | % ---------------------------------------------------------- 290 | % (4) Compute Spectral Slope (dB[i+1]-dB[i]) 291 | % ---------------------------------------------------------- 292 | 293 | clean_slope = clean_energy(2:num_crit) - ... 294 | clean_energy(1:num_crit-1); 295 | processed_slope = processed_energy(2:num_crit) - ... 296 | processed_energy(1:num_crit-1); 297 | 298 | % ---------------------------------------------------------- 299 | % (5) Find the nearest peak locations in the spectra to 300 | % each critical band. If the slope is negative, we 301 | % search to the left. If positive, we search to the 302 | % right. 303 | % ---------------------------------------------------------- 304 | 305 | for i = 1:num_crit-1 306 | 307 | % find the peaks in the clean speech signal 308 | 309 | if (clean_slope(i)>0) % search to the right 310 | n = i; 311 | while ((n 0)) 312 | n = n+1; 313 | end 314 | clean_loc_peak(i) = clean_energy(n-1); 315 | else % search to the left 316 | n = i; 317 | while ((n>0) & (clean_slope(n) <= 0)) 318 | n = n-1; 319 | end 320 | clean_loc_peak(i) = clean_energy(n+1); 321 | end 322 | 323 | % find the peaks in the processed speech signal 324 | 325 | if (processed_slope(i)>0) % search to the right 326 | n = i; 327 | while ((n 0)) 328 | n = n+1; 329 | end 330 | processed_loc_peak(i) = processed_energy(n-1); 331 | else % search to the left 332 | n = i; 333 | while ((n>0) & (processed_slope(n) <= 0)) 334 | n = n-1; 335 | end 336 | processed_loc_peak(i) = processed_energy(n+1); 337 | end 338 | 339 | end 340 | 341 | % ---------------------------------------------------------- 342 | % (6) Compute the WSS Measure for this frame. This 343 | % includes determination of the weighting function. 344 | % ---------------------------------------------------------- 345 | 346 | dBMax_clean = max(clean_energy); 347 | dBMax_processed = max(processed_energy); 348 | 349 | % The weights are calculated by averaging individual 350 | % weighting factors from the clean and processed frame. 351 | % These weights W_clean and W_processed should range 352 | % from 0 to 1 and place more emphasis on spectral 353 | % peaks and less emphasis on slope differences in spectral 354 | % valleys. This procedure is described on page 1280 of 355 | % Klatt's 1982 ICASSP paper. 356 | 357 | Wmax_clean = Kmax ./ (Kmax + dBMax_clean - ... 358 | clean_energy(1:num_crit-1)); 359 | Wlocmax_clean = Klocmax ./ ( Klocmax + clean_loc_peak - ... 360 | clean_energy(1:num_crit-1)); 361 | W_clean = Wmax_clean .* Wlocmax_clean; 362 | 363 | Wmax_processed = Kmax ./ (Kmax + dBMax_processed - ... 364 | processed_energy(1:num_crit-1)); 365 | Wlocmax_processed = Klocmax ./ ( Klocmax + processed_loc_peak - ... 366 | processed_energy(1:num_crit-1)); 367 | W_processed = Wmax_processed .* Wlocmax_processed; 368 | 369 | W = (W_clean + W_processed)./2.0; 370 | 371 | distortion(frame_count) = sum(W.*(clean_slope(1:num_crit-1) - ... 372 | processed_slope(1:num_crit-1)).^2); 373 | 374 | % this normalization is not part of Klatt's paper, but helps 375 | % to normalize the measure. Here we scale the measure by the 376 | % sum of the weights. 377 | 378 | distortion(frame_count) = distortion(frame_count)/sum(W); 379 | 380 | start = start + skiprate; 381 | 382 | end 383 | 384 | %----------------------------------------------- 385 | function distortion = llr(clean_speech, processed_speech,sample_rate) 386 | 387 | 388 | % ---------------------------------------------------------------------- 389 | % Check the length of the clean and processed speech. Must be the same. 390 | % ---------------------------------------------------------------------- 391 | 392 | clean_length = length(clean_speech); 393 | processed_length = length(processed_speech); 394 | 395 | if (clean_length ~= processed_length) 396 | disp('Error: Both Speech Files must be same length.'); 397 | return 398 | end 399 | 400 | % ---------------------------------------------------------------------- 401 | % Global Variables 402 | % ---------------------------------------------------------------------- 403 | 404 | % sample_rate = 8000; % default sample rate 405 | % winlength = 240; % window length in samples 406 | % skiprate = 60; % window skip in samples 407 | % P = 10; % LPC Analysis Order 408 | winlength = round(30*sample_rate/1000); % window length in samples 409 | skiprate = floor(winlength/4); % window skip in samples 410 | if sample_rate<10000 411 | P = 10; % LPC Analysis Order 412 | else 413 | P=16; % this could vary depending on sampling frequency. 414 | end 415 | 416 | % ---------------------------------------------------------------------- 417 | % For each frame of input speech, calculate the Log Likelihood Ratio 418 | % ---------------------------------------------------------------------- 419 | 420 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 421 | start = 1; % starting sample 422 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 423 | 424 | for frame_count = 1:num_frames 425 | 426 | % ---------------------------------------------------------- 427 | % (1) Get the Frames for the test and reference speech. 428 | % Multiply by Hanning Window. 429 | % ---------------------------------------------------------- 430 | 431 | clean_frame = clean_speech(start:start+winlength-1); 432 | processed_frame = processed_speech(start:start+winlength-1); 433 | clean_frame = clean_frame.*window; 434 | processed_frame = processed_frame.*window; 435 | 436 | % ---------------------------------------------------------- 437 | % (2) Get the autocorrelation lags and LPC parameters used 438 | % to compute the LLR measure. 439 | % ---------------------------------------------------------- 440 | 441 | [R_clean, Ref_clean, A_clean] = ... 442 | lpcoeff(clean_frame, P); 443 | [R_processed, Ref_processed, A_processed] = ... 444 | lpcoeff(processed_frame, P); 445 | 446 | % ---------------------------------------------------------- 447 | % (3) Compute the LLR measure 448 | % ---------------------------------------------------------- 449 | 450 | numerator = A_processed*toeplitz(R_clean)*A_processed'; 451 | denominator = A_clean*toeplitz(R_clean)*A_clean'; 452 | distortion(frame_count) = log(numerator/denominator); 453 | start = start + skiprate; 454 | 455 | end 456 | 457 | %--------------------------------------------- 458 | function [acorr, refcoeff, lpparams] = lpcoeff(speech_frame, model_order) 459 | 460 | % ---------------------------------------------------------- 461 | % (1) Compute Autocorrelation Lags 462 | % ---------------------------------------------------------- 463 | 464 | winlength = max(size(speech_frame)); 465 | for k=1:model_order+1 466 | R(k) = sum(speech_frame(1:winlength-k+1) ... 467 | .*speech_frame(k:winlength)); 468 | end 469 | 470 | % ---------------------------------------------------------- 471 | % (2) Levinson-Durbin 472 | % ---------------------------------------------------------- 473 | 474 | a = ones(1,model_order); 475 | E(1)=R(1); 476 | for i=1:model_order 477 | a_past(1:i-1) = a(1:i-1); 478 | sum_term = sum(a_past(1:i-1).*R(i:-1:2)); 479 | rcoeff(i)=(R(i+1) - sum_term) / E(i); 480 | a(i)=rcoeff(i); 481 | a(1:i-1) = a_past(1:i-1) - rcoeff(i).*a_past(i-1:-1:1); 482 | E(i+1)=(1-rcoeff(i)*rcoeff(i))*E(i); 483 | end 484 | 485 | acorr = R; 486 | refcoeff = rcoeff; 487 | lpparams = [1 -a]; 488 | 489 | 490 | % ---------------------------------------------------------------------- 491 | 492 | function [overall_snr, segmental_snr] = snr(clean_speech, processed_speech,sample_rate) 493 | 494 | % ---------------------------------------------------------------------- 495 | % Check the length of the clean and processed speech. Must be the same. 496 | % ---------------------------------------------------------------------- 497 | 498 | clean_length = length(clean_speech); 499 | processed_length = length(processed_speech); 500 | 501 | if (clean_length ~= processed_length) 502 | disp('Error: Both Speech Files must be same length.'); 503 | return 504 | end 505 | 506 | % ---------------------------------------------------------------------- 507 | % Scale both clean speech and processed speech to have same dynamic 508 | % range. Also remove DC component from each signal 509 | % ---------------------------------------------------------------------- 510 | 511 | %clean_speech = clean_speech - mean(clean_speech); 512 | %processed_speech = processed_speech - mean(processed_speech); 513 | 514 | %processed_speech = processed_speech.*(max(abs(clean_speech))/ max(abs(processed_speech))); 515 | 516 | overall_snr = 10* log10( sum(clean_speech.^2)/sum((clean_speech-processed_speech).^2)); 517 | 518 | % ---------------------------------------------------------------------- 519 | % Global Variables 520 | % ---------------------------------------------------------------------- 521 | 522 | % sample_rate = 8000; % default sample rate 523 | % winlength = 240; % window length in samples 524 | % skiprate = 60; % window skip in samples 525 | winlength = round(30*sample_rate/1000); %240; % window length in samples 526 | skiprate = floor(winlength/4); % window skip in samples 527 | MIN_SNR = -10; % minimum SNR in dB 528 | MAX_SNR = 35; % maximum SNR in dB 529 | 530 | % ---------------------------------------------------------------------- 531 | % For each frame of input speech, calculate the Segmental SNR 532 | % ---------------------------------------------------------------------- 533 | 534 | num_frames = clean_length/skiprate-(winlength/skiprate); % number of frames 535 | start = 1; % starting sample 536 | window = 0.5*(1 - cos(2*pi*(1:winlength)'/(winlength+1))); 537 | 538 | for frame_count = 1: num_frames 539 | 540 | % ---------------------------------------------------------- 541 | % (1) Get the Frames for the test and reference speech. 542 | % Multiply by Hanning Window. 543 | % ---------------------------------------------------------- 544 | 545 | clean_frame = clean_speech(start:start+winlength-1); 546 | processed_frame = processed_speech(start:start+winlength-1); 547 | clean_frame = clean_frame.*window; 548 | processed_frame = processed_frame.*window; 549 | 550 | % ---------------------------------------------------------- 551 | % (2) Compute the Segmental SNR 552 | % ---------------------------------------------------------- 553 | 554 | signal_energy = sum(clean_frame.^2); 555 | noise_energy = sum((clean_frame-processed_frame).^2); 556 | segmental_snr(frame_count) = 10*log10(signal_energy/(noise_energy+eps)+eps); 557 | segmental_snr(frame_count) = max(segmental_snr(frame_count),MIN_SNR); 558 | segmental_snr(frame_count) = min(segmental_snr(frame_count),MAX_SNR); 559 | 560 | start = start + skiprate; 561 | 562 | end 563 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for train_interface 3 | 4 | You can check the essential information, 5 | and if you want to change model structure or training method, 6 | you have to change this file. 7 | """ 8 | ####################################################################### 9 | # path # 10 | ####################################################################### 11 | job_dir = './models/' 12 | logs_dir = './logs/' 13 | chkpt_model = None # 'FILE PATH (if you have pretrained model..)' 14 | chkpt = str("EPOCH") 15 | if chkpt_model is not None: 16 | chkpt_path = job_dir + chkpt_model + '/chkpt_' + chkpt + '.pt' 17 | 18 | ####################################################################### 19 | # possible setting # 20 | ####################################################################### 21 | # the list you can do 22 | model_list = ['DCCRN', 'CRN', 'FullSubNet'] 23 | loss_list = ['MSE', 'SDR', 'SI-SNR', 'SI-SDR'] 24 | perceptual_list = [False, 'LMS', 'PMSQE'] 25 | lstm_type = ['real', 'complex'] 26 | main_net = ['LSTM', 'GRU'] 27 | mask_type = ['Direct(None make)', 'E', 'C', 'R'] 28 | 29 | # experiment number setting 30 | expr_num = 'EXPERIMENT_NUMBER' 31 | DEVICE = 'cuda' # if you want to run the code with 'cpu', change 'cpu' 32 | ####################################################################### 33 | # current setting # 34 | ####################################################################### 35 | model = model_list[0] 36 | loss = loss_list[1] 37 | perceptual = perceptual_list[0] 38 | lstm = lstm_type[1] 39 | sequence_model = main_net[0] 40 | 41 | masking_mode = mask_type[1] 42 | skip_type = True # False, if you want to remove 'skip connection' 43 | 44 | # hyper-parameters 45 | max_epochs = 100 46 | learning_rate = 0.001 47 | batch = 10 48 | 49 | # kernel size 50 | dccrn_kernel_num = [32, 64, 128, 256, 256, 256] 51 | ####################################################################### 52 | # model information # 53 | ####################################################################### 54 | fs = 16000 55 | win_len = 400 56 | win_inc = 100 57 | ola_ratio = 0.75 58 | fft_len = 512 59 | sam_sec = fft_len / fs 60 | frm_samp = fs * (fft_len / fs) 61 | window = 'hanning' 62 | 63 | # for DCCRN 64 | rnn_layers = 2 65 | rnn_units = 256 66 | 67 | # for CRN 68 | rnn_input_size = 512 69 | 70 | # for FullSubNet 71 | sb_num_neighbors = 15 72 | fb_num_neighbors = 0 73 | num_freqs = fft_len // 2 + 1 74 | look_ahead = 2 75 | fb_output_activate_function = "ReLU" 76 | sb_output_activate_function = None 77 | fb_model_hidden_size = 512 78 | sb_model_hidden_size = 384 79 | weight_init = False 80 | norm_type = "offline_laplace_norm" 81 | num_groups_in_drop_band = 2 82 | ####################################################################### 83 | # setting error check # 84 | ####################################################################### 85 | # if the setting is wrong, print error message 86 | assert not (masking_mode == 'Direct(None make)' and perceptual is not False), \ 87 | "This setting is not created " 88 | assert not (model == 'FullSubNet' and perceptual is not False), \ 89 | "This setting is not created " 90 | 91 | ####################################################################### 92 | # print setting # 93 | ####################################################################### 94 | print('-------------------- C O N F I G ----------------------') 95 | print('--------------------------------------------------------------') 96 | print('MODEL INFO : {}'.format(model)) 97 | print('LOSS INFO : {}, perceptual : {}'.format(loss, perceptual)) 98 | if model != 'FullSubNet': 99 | print('LSTM : {}'.format(lstm)) 100 | print('SKIP : {}'.format(skip_type)) 101 | print('MASKING INFO : {}'.format(masking_mode)) 102 | else: 103 | print('Main network : {}'.format(sequence_model)) 104 | print('\nBATCH : {}'.format(batch)) 105 | print('LEARNING RATE : {}'.format(learning_rate)) 106 | print('--------------------------------------------------------------') 107 | print('--------------------------------------------------------------\n') 108 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset, DataLoader 4 | import config as cfg 5 | 6 | # # If you don't set the data type to object when saving the data... 7 | # np_load_old = np.load 8 | # np.load = lambda *a, **k: np_load_old(*a, allow_pickle=True, **k) 9 | 10 | 11 | def create_dataloader(mode, type=0, snr=0): 12 | if mode == 'train': 13 | return DataLoader( 14 | dataset=Wave_Dataset(mode, type, snr), 15 | batch_size=cfg.batch, 16 | shuffle=True, 17 | num_workers=0, 18 | pin_memory=True, 19 | drop_last=True, 20 | sampler=None 21 | ) 22 | elif mode == 'valid': 23 | return DataLoader( 24 | dataset=Wave_Dataset(mode, type, snr), 25 | batch_size=cfg.batch, shuffle=False, num_workers=0 26 | ) 27 | elif mode == 'test': 28 | return DataLoader( 29 | dataset=Wave_Dataset(mode, type, snr), 30 | batch_size=cfg.batch, shuffle=False, num_workers=0 31 | ) 32 | 33 | 34 | class Wave_Dataset(Dataset): 35 | def __init__(self, mode, type, snr): 36 | # load data 37 | if mode == 'train': 38 | self.mode = 'train' 39 | print('') 40 | print('Load the data...') 41 | self.input_path = "DATASET_FILE_PATH" 42 | self.input = np.load(self.input_path) 43 | elif mode == 'valid': 44 | self.mode = 'valid' 45 | print('') 46 | print('Load the data...') 47 | self.input_path = "DATASET_FILE_PATH" 48 | self.input = np.load(self.input_path) 49 | # # if you want to use a part of the dataset 50 | # self.input = self.input[:500] 51 | elif mode == 'test': 52 | self.mode = 'test' 53 | print('') 54 | print('Load the data...') 55 | self.input_path = "DATASET_FILE_PATH" 56 | 57 | self.input = np.load(self.input_path) 58 | self.input = self.input[type][snr] 59 | 60 | def __len__(self): 61 | return len(self.input) 62 | 63 | def __getitem__(self, idx): 64 | inputs = self.input[idx][0] 65 | targets = self.input[idx][1] 66 | 67 | # transform to torch from numpy 68 | inputs = torch.from_numpy(inputs) 69 | targets = torch.from_numpy(targets) 70 | 71 | return inputs, targets 72 | -------------------------------------------------------------------------------- /estimation/check_object_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | for checking speech quality with some metrics. 3 | 4 | 1. PESQ 5 | 2. STOI 6 | 3. CSIG, CBAK, COVL 7 | """ 8 | import os 9 | from tools_for_estimate import cal_pesq, cal_stoi, composite 10 | from pathlib import Path 11 | 12 | # number of files we want to check 13 | flie_num = 1 14 | 15 | target_wav = ['.wav'] 16 | estimated_wav = ['.wav'] 17 | 18 | file_directory = '/' 19 | 20 | if flie_num == 1: 21 | pesq = cal_pesq(estimated_wav, target_wav) 22 | stoi = cal_stoi(estimated_wav, target_wav) 23 | CSIG, CBAK, CVOL, _ = composite(target_wav[0], estimated_wav[0]) 24 | 25 | print('{} is ...'.format(estimated_wav[0])) 26 | print('PESQ {:.4} | STOI {:.4} | CSIG {:.4} | CBAK {:.4} | CVOL {:.4}' 27 | .format(pesq, stoi, CSIG, CBAK, CVOL)) 28 | else: 29 | # the list of files in file directory 30 | if os.path.isdir(file_directory) is False: 31 | print("[Error] There is no directory '%s'." % file_directory) 32 | exit() 33 | else: 34 | print("Scanning a directory %s " % file_directory) 35 | 36 | # pick target wav from the directory 37 | target_addr = [] 38 | for path, dir, files in os.walk(file_directory): 39 | for file in files: 40 | if file in 'target': 41 | filepath = Path(path) / file 42 | target_addr.append(filepath) 43 | 44 | for addr in target_addr: 45 | estimated_addr = str(addr).replace('target', 'estimated') 46 | 47 | pesq = cal_pesq([estimated_addr], [addr]) 48 | stoi = cal_stoi([estimated_addr], [addr]) 49 | CSIG, CBAK, CVOL, _ = composite(addr, estimated_addr) 50 | 51 | print('{} is ...'.format(estimated_addr)) 52 | print('PESQ {:.4} | STOI {:.4} | CSIG {:.4} | CBAK {:.4} | CVOL {:.4}' 53 | .format(pesq, stoi, CSIG, CBAK, CVOL)) 54 | -------------------------------------------------------------------------------- /generate_noisy_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate noisy data with various noise files 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | import scipy.io.wavfile as wav 8 | import librosa 9 | from pathlib import Path 10 | import soundfile 11 | 12 | ####################################################################### 13 | # data info setting # 14 | ####################################################################### 15 | # USE THIS, OR SYS.ARGVS 16 | # mode = 'train' # train / validation / test 17 | # snr_set = [0, 5] 18 | # fs = 16000 19 | 20 | ####################################################################### 21 | # main # 22 | ####################################################################### 23 | def scan_directory(dir_name): 24 | """Scan directory and save address of clean/noisy wav data. 25 | Args: 26 | dir_name: directroy name to scan 27 | Returns: 28 | addr: all address list of clean/noisy wave data in subdirectory 29 | """ 30 | if os.path.isdir(dir_name) is False: 31 | print("[Error] There is no directory '%s'." % dir_name) 32 | exit() 33 | else: 34 | print("Scanning a directory %s " % dir_name) 35 | 36 | addr = [] 37 | for subdir, dirs, files in os.walk(dir_name): 38 | for file in files: 39 | if file.endswith(".wav"): 40 | filepath = Path(subdir) / file 41 | addr.append(filepath) 42 | return addr 43 | 44 | 45 | # Generate noisy data given speech, noise, and target SNR. 46 | def generate_noisy_wav(wav_speech, wav_noise, snr): 47 | # Obtain the length of speech and noise components. 48 | len_speech = len(wav_speech) 49 | len_noise = len(wav_noise) 50 | 51 | # Select noise segment randomly to have same length with speech signal. 52 | st = np.random.randint(0, len_noise - len_speech) 53 | ed = st + len_speech 54 | wav_noise = wav_noise[st:ed] 55 | 56 | # Compute the power of speech and noise after removing DC bias. 57 | dc_speech = np.mean(wav_speech) 58 | dc_noise = np.mean(wav_noise) 59 | pow_speech = np.mean(np.power(wav_speech - dc_speech, 2.0)) 60 | pow_noise = np.mean(np.power(wav_noise - dc_noise, 2.0)) 61 | 62 | # Compute the scale factor of noise component depending on the target SNR. 63 | alpha = np.sqrt(10.0 ** (float(-snr) / 10.0) * pow_speech / (pow_noise + 1e-6)) 64 | noisy_wav = (wav_speech + alpha * wav_noise) * 32768 65 | noisy_wav = noisy_wav.astype(np.int16) 66 | 67 | return noisy_wav 68 | 69 | 70 | def main(): 71 | argvs = sys.argv[1:] 72 | if len(argvs) != 3: 73 | print('Error: Invalid input arguments') 74 | print('\t Usage: python generate_noisy_data.py [mode] [snr] [fs]') 75 | print("\t\t [mode]: 'train', 'validation'") 76 | print("\t\t [snr]: '0', '0, 5', ...'") 77 | print("\t\t [fs]: '16000', ...") 78 | exit() 79 | mode = argvs[0] 80 | snr_set = argvs[1].split(',') 81 | fs = int(argvs[2]) 82 | 83 | # Set speech and noise directory. 84 | speech_dir = Path("./") 85 | 86 | # Make a speech file list. 87 | speech_mode_clean_dir = speech_dir / mode / 'clean' 88 | speech_mode_noisy_dir = speech_dir / mode / 'noisy' 89 | list_speech_files = scan_directory(speech_mode_clean_dir) 90 | 91 | # Make directories of the mode and noisy data. 92 | if os.path.isdir(speech_mode_clean_dir) is False: 93 | os.system('mkdir ' + str(speech_mode_clean_dir)) 94 | 95 | if os.path.isdir(speech_mode_noisy_dir) is False: 96 | os.system('mkdir ' + str(speech_mode_noisy_dir)) 97 | 98 | # Define a log file name. 99 | log_file_name = Path("./log_generate_data_" + mode + ".txt") 100 | f = open(log_file_name, 'w') 101 | 102 | if mode == 'train': 103 | # Make a noise file list 104 | noise_subset_dir = speech_dir / 'train' / 'noise' 105 | list_noise_files = scan_directory(noise_subset_dir) 106 | for snr_in_db in snr_set: 107 | for addr_speech in list_speech_files: 108 | # Load speech waveform and its sampling frequency. 109 | wav_speech, read_fs = soundfile.read(addr_speech) 110 | if read_fs != fs: 111 | wav_speech = librosa.resample(wav_speech, read_fs, fs) 112 | 113 | # Select a noise component randomly, and read it. 114 | nidx = np.random.randint(0, len(list_noise_files)) 115 | addr_noise = list_noise_files[nidx] 116 | wav_noise, read_fs = soundfile.read(addr_noise) 117 | if wav_noise.ndim > 1: 118 | wav_noise = wav_noise.mean(axis=1) 119 | if read_fs != fs: 120 | wav_noise = librosa.resample(wav_noise, read_fs, fs) 121 | 122 | # Generate noisy speech by mixing speech and noise components. 123 | wav_noisy = generate_noisy_wav(wav_speech, wav_noise, int(snr_in_db)) 124 | noisy_name = Path(addr_speech).name[:-4] +'_' + Path(addr_noise).name[:-4] + '_' + str( 125 | int(snr_in_db)) + '.wav' 126 | addr_noisy = speech_mode_noisy_dir / noisy_name 127 | wav.write(addr_noisy, fs, wav_noisy) 128 | 129 | # Display progress. 130 | print('%s > %s' % (addr_speech, addr_noisy)) 131 | f.write('%s\t%s\t%s\t%d dB\n' % (addr_noisy, addr_speech, addr_noise, int(snr_in_db))) 132 | 133 | elif mode == 'validation': 134 | # Make a noise file list for validation. 135 | noise_subset_dir = speech_dir / 'train' / 'noise' 136 | list_noise_files = scan_directory(noise_subset_dir) 137 | 138 | for addr_speech in list_speech_files: 139 | # Load speech waveform and its sampling frequency. 140 | wav_speech, read_fs = soundfile.read(addr_speech) 141 | if read_fs != fs: 142 | wav_speech = librosa.resample(wav_speech, read_fs, fs) 143 | 144 | # Select a noise component randomly, and read it. 145 | nidx = np.random.randint(0, len(list_noise_files)) 146 | addr_noise = list_noise_files[nidx] 147 | wav_noise, read_fs = soundfile.read(addr_noise) 148 | if wav_noise.ndim > 1: 149 | wav_noise = wav_noise.mean(axis=1) 150 | if read_fs != fs: 151 | wav_noise = librosa.resample(wav_noise, read_fs, fs) 152 | 153 | # Select an SNR randomly. 154 | ridx_snr = np.random.randint(0, len(snr_set)) 155 | snr_in_db = int(snr_set[ridx_snr]) 156 | 157 | # Generate noisy speech by mixing speech and noise components. 158 | wav_noisy = generate_noisy_wav(wav_speech, wav_noise, snr_in_db) 159 | 160 | # Write the generated noisy speech into a file. 161 | noisy_name = Path(addr_speech).name[:-4] + '_' + Path(addr_noise).name[:-4] + '_' + str( 162 | snr_in_db) + '.wav' 163 | addr_noisy = speech_mode_noisy_dir / noisy_name 164 | wav.write(addr_noisy, fs, wav_noisy) 165 | 166 | # Display progress. 167 | print('%s > %s' % (addr_speech, addr_noisy)) 168 | f.write('%s\t%s\t%s\t%d dB\n' % (addr_noisy, addr_speech, addr_noise, snr_in_db)) 169 | f.close() 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tools_for_model import ConvSTFT, ConviSTFT, \ 5 | ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm, \ 6 | RealConv2d, RealConvTranspose2d, \ 7 | BaseModel, SequenceModel 8 | import config as cfg 9 | from tools_for_loss import sdr, si_sdr, si_snr, get_array_lms_loss, get_array_pmsqe_loss 10 | 11 | 12 | ####################################################################### 13 | # complex network # 14 | ####################################################################### 15 | class DCCRN(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | rnn_layers=cfg.rnn_layers, 20 | rnn_units=cfg.rnn_units, 21 | win_len=cfg.win_len, 22 | win_inc=cfg.win_inc, 23 | fft_len=cfg.fft_len, 24 | win_type=cfg.window, 25 | masking_mode=cfg.masking_mode, 26 | use_cbn=False, 27 | kernel_size=5 28 | ): 29 | ''' 30 | rnn_layers: the number of lstm layers in the crn, 31 | rnn_units: for clstm, rnn_units = real+imag 32 | ''' 33 | 34 | super(DCCRN, self).__init__() 35 | 36 | # for fft 37 | self.win_len = win_len 38 | self.win_inc = win_inc 39 | self.fft_len = fft_len 40 | self.win_type = win_type 41 | 42 | input_dim = win_len 43 | output_dim = win_len 44 | 45 | self.rnn_units = rnn_units 46 | self.input_dim = input_dim 47 | self.output_dim = output_dim 48 | self.hidden_layers = rnn_layers 49 | self.kernel_size = kernel_size 50 | kernel_num = cfg.dccrn_kernel_num 51 | self.kernel_num = [2] + kernel_num 52 | self.masking_mode = masking_mode 53 | 54 | # bidirectional=True 55 | bidirectional = False 56 | fac = 2 if bidirectional else 1 57 | 58 | fix = True 59 | self.fix = fix 60 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 61 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 62 | 63 | self.encoder = nn.ModuleList() 64 | self.decoder = nn.ModuleList() 65 | for idx in range(len(self.kernel_num) - 1): 66 | self.encoder.append( 67 | nn.Sequential( 68 | # nn.ConstantPad2d([0, 0, 0, 0], 0), 69 | ComplexConv2d( 70 | self.kernel_num[idx], 71 | self.kernel_num[idx + 1], 72 | kernel_size=(self.kernel_size, 2), 73 | stride=(2, 1), 74 | padding=(2, 1) 75 | ), 76 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm( 77 | self.kernel_num[idx + 1]), 78 | nn.PReLU() 79 | ) 80 | ) 81 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num))) 82 | 83 | if cfg.lstm == 'complex': 84 | rnns = [] 85 | for idx in range(rnn_layers): 86 | rnns.append( 87 | NavieComplexLSTM( 88 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units, 89 | hidden_size=self.rnn_units, 90 | bidirectional=bidirectional, 91 | batch_first=False, 92 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None, 93 | ) 94 | ) 95 | self.enhance = nn.Sequential(*rnns) 96 | else: 97 | self.enhance = nn.LSTM( 98 | input_size=hidden_dim * self.kernel_num[-1], 99 | hidden_size=self.rnn_units, 100 | num_layers=2, 101 | dropout=0.0, 102 | bidirectional=bidirectional, 103 | batch_first=False 104 | ) 105 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1]) 106 | 107 | if cfg.skip_type: 108 | for idx in range(len(self.kernel_num) - 1, 0, -1): 109 | if idx != 1: 110 | self.decoder.append( 111 | nn.Sequential( 112 | ComplexConvTranspose2d( 113 | self.kernel_num[idx] * 2, 114 | self.kernel_num[idx - 1], 115 | kernel_size=(self.kernel_size, 2), 116 | stride=(2, 1), 117 | padding=(2, 0), 118 | output_padding=(1, 0) 119 | ), 120 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm( 121 | self.kernel_num[idx - 1]), 122 | nn.PReLU() 123 | ) 124 | ) 125 | else: 126 | self.decoder.append( 127 | nn.Sequential( 128 | ComplexConvTranspose2d( 129 | self.kernel_num[idx] * 2, 130 | self.kernel_num[idx - 1], 131 | kernel_size=(self.kernel_size, 2), 132 | stride=(2, 1), 133 | padding=(2, 0), 134 | output_padding=(1, 0) 135 | ), 136 | ) 137 | ) 138 | else: # you can erase the skip connection 139 | for idx in range(len(self.kernel_num) - 1, 0, -1): 140 | if idx != 1: 141 | self.decoder.append( 142 | nn.Sequential( 143 | ComplexConvTranspose2d( 144 | self.kernel_num[idx], 145 | self.kernel_num[idx - 1], 146 | kernel_size=(self.kernel_size, 2), 147 | stride=(2, 1), 148 | padding=(2, 0), 149 | output_padding=(1, 0) 150 | ), 151 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm( 152 | self.kernel_num[idx - 1]), 153 | # nn.ELU() 154 | nn.PReLU() 155 | ) 156 | ) 157 | else: 158 | self.decoder.append( 159 | nn.Sequential( 160 | ComplexConvTranspose2d( 161 | self.kernel_num[idx], 162 | self.kernel_num[idx - 1], 163 | kernel_size=(self.kernel_size, 2), 164 | stride=(2, 1), 165 | padding=(2, 0), 166 | output_padding=(1, 0) 167 | ), 168 | ) 169 | ) 170 | self.flatten_parameters() 171 | 172 | def flatten_parameters(self): 173 | if isinstance(self.enhance, nn.LSTM): 174 | self.enhance.flatten_parameters() 175 | 176 | def forward(self, inputs, targets=0): 177 | specs = self.stft(inputs) 178 | real = specs[:, :self.fft_len // 2 + 1] 179 | imag = specs[:, self.fft_len // 2 + 1:] 180 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8) 181 | 182 | spec_phase = torch.atan2(imag, real) 183 | cspecs = torch.stack([real, imag], 1) 184 | cspecs = cspecs[:, :, 1:] 185 | ''' 186 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 187 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 188 | normed_cspecs = (cspecs-means)/(std+1e-8) 189 | out = normed_cspecs 190 | ''' 191 | 192 | out = cspecs 193 | encoder_out = [] 194 | 195 | for idx, layer in enumerate(self.encoder): 196 | out = layer(out) 197 | # print('encoder', out.size()) 198 | encoder_out.append(out) 199 | 200 | batch_size, channels, dims, lengths = out.size() 201 | out = out.permute(3, 0, 1, 2) 202 | if cfg.lstm == 'complex': 203 | r_rnn_in = out[:, :, :channels // 2] 204 | i_rnn_in = out[:, :, channels // 2:] 205 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims]) 206 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims]) 207 | 208 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in]) 209 | 210 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims]) 211 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims]) 212 | out = torch.cat([r_rnn_in, i_rnn_in], 2) 213 | else: 214 | # to [L, B, C, D] 215 | out = torch.reshape(out, [lengths, batch_size, channels * dims]) 216 | out, _ = self.enhance(out) 217 | out = self.tranform(out) 218 | out = torch.reshape(out, [lengths, batch_size, channels, dims]) 219 | 220 | out = out.permute(1, 2, 3, 0) 221 | 222 | if cfg.skip_type: # use skip connection 223 | for idx in range(len(self.decoder)): 224 | out = complex_cat([out, encoder_out[-1 - idx]], 1) 225 | out = self.decoder[idx](out) 226 | out = out[..., 1:] # 227 | else: 228 | for idx in range(len(self.decoder)): 229 | out = self.decoder[idx](out) 230 | out = out[..., 1:] 231 | 232 | if self.masking_mode == 'Direct(None make)': 233 | # for loss calculation 234 | target_specs = self.stft(targets) 235 | target_real = target_specs[:, :self.fft_len // 2 + 1] 236 | target_imag = target_specs[:, self.fft_len // 2 + 1:] 237 | 238 | # spectral mapping 239 | out_real = out[:, 0] 240 | out_imag = out[:, 1] 241 | out_real = F.pad(out_real, [0, 0, 1, 0]) 242 | out_imag = F.pad(out_imag, [0, 0, 1, 0]) 243 | 244 | out_spec = torch.cat([out_real, out_imag], 1) 245 | 246 | out_wav = self.istft(out_spec) 247 | out_wav = torch.squeeze(out_wav, 1) 248 | out_wav = torch.clamp_(out_wav, -1, 1) 249 | 250 | return out_real, target_real, out_imag, target_imag, out_wav 251 | else: 252 | # print('decoder', out.size()) 253 | mask_real = out[:, 0] 254 | mask_imag = out[:, 1] 255 | mask_real = F.pad(mask_real, [0, 0, 1, 0]) 256 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) 257 | 258 | if self.masking_mode == 'E': 259 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5 260 | real_phase = mask_real / (mask_mags + 1e-8) 261 | imag_phase = mask_imag / (mask_mags + 1e-8) 262 | mask_phase = torch.atan2( 263 | imag_phase, 264 | real_phase 265 | ) 266 | 267 | # mask_mags = torch.clamp_(mask_mags,0,100) 268 | mask_mags = torch.tanh(mask_mags) 269 | est_mags = mask_mags * spec_mags 270 | est_phase = spec_phase + mask_phase 271 | out_real = est_mags * torch.cos(est_phase) 272 | out_imag = est_mags * torch.sin(est_phase) 273 | elif self.masking_mode == 'C': 274 | out_real, out_imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real 275 | elif self.masking_mode == 'R': 276 | out_real, out_imag = real * mask_real, imag * mask_imag 277 | 278 | out_spec = torch.cat([out_real, out_imag], 1) 279 | 280 | out_wav = self.istft(out_spec) 281 | out_wav = torch.squeeze(out_wav, 1) 282 | out_wav = torch.clamp_(out_wav, -1, 1) 283 | 284 | return out_real, out_imag, out_wav 285 | 286 | def get_params(self, weight_decay=0.0): 287 | # add L2 penalty 288 | weights, biases = [], [] 289 | for name, param in self.named_parameters(): 290 | if 'bias' in name: 291 | biases += [param] 292 | else: 293 | weights += [param] 294 | params = [{ 295 | 'params': weights, 296 | 'weight_decay': weight_decay, 297 | }, { 298 | 'params': biases, 299 | 'weight_decay': 0.0, 300 | }] 301 | return params 302 | 303 | def loss(self, estimated, target, real_spec=0, img_spec=0, perceptual=False): 304 | if perceptual: 305 | if cfg.perceptual == 'LMS': 306 | clean_specs = self.stft(target) 307 | clean_real = clean_specs[:, :self.fft_len // 2 + 1] 308 | clean_imag = clean_specs[:, self.fft_len // 2 + 1:] 309 | clean_mags = torch.sqrt(clean_real ** 2 + clean_imag ** 2 + 1e-7) 310 | 311 | est_clean_mags = torch.sqrt(real_spec ** 2 + img_spec ** 2 + 1e-7) 312 | return get_array_lms_loss(clean_mags, est_clean_mags) 313 | elif cfg.perceptual == 'PMSQE': 314 | return get_array_pmsqe_loss(target, estimated) 315 | else: 316 | if cfg.loss == 'MSE': 317 | return F.mse_loss(estimated, target, reduction='mean') 318 | elif cfg.loss == 'SDR': 319 | return -sdr(target, estimated) 320 | elif cfg.loss == 'SI-SNR': 321 | return -(si_snr(estimated, target)) 322 | elif cfg.loss == 'SI-SDR': 323 | return -(si_sdr(target, estimated)) 324 | 325 | 326 | ####################################################################### 327 | # real network # 328 | ####################################################################### 329 | class CRN(nn.Module): 330 | def __init__( 331 | self, 332 | rnn_layers=cfg.rnn_layers, 333 | rnn_input_size=cfg.rnn_input_size, 334 | rnn_units=cfg.rnn_units, 335 | win_len=cfg.win_len, 336 | win_inc=cfg.win_inc, 337 | fft_len=cfg.fft_len, 338 | win_type=cfg.window, 339 | masking_mode=cfg.masking_mode, 340 | kernel_size=5 341 | ): 342 | ''' 343 | rnn_layers: the number of lstm layers in the crn 344 | ''' 345 | 346 | super(CRN, self).__init__() 347 | 348 | # for fft 349 | self.win_len = win_len 350 | self.win_inc = win_inc 351 | self.fft_len = fft_len 352 | self.win_type = win_type 353 | 354 | input_dim = win_len 355 | output_dim = win_len 356 | 357 | self.rnn_input_size = rnn_input_size 358 | self.rnn_units = rnn_units//2 359 | self.input_dim = input_dim 360 | self.output_dim = output_dim 361 | self.hidden_layers = rnn_layers 362 | self.kernel_size = kernel_size 363 | kernel_num = cfg.dccrn_kernel_num 364 | self.kernel_num = [2] + kernel_num 365 | self.masking_mode = masking_mode 366 | 367 | # bidirectional=True 368 | bidirectional = False 369 | 370 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'real') 371 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex') 372 | 373 | self.encoder = nn.ModuleList() 374 | self.decoder = nn.ModuleList() 375 | for idx in range(len(self.kernel_num) - 1): 376 | self.encoder.append( 377 | nn.Sequential( 378 | RealConv2d( 379 | self.kernel_num[idx] // 2, 380 | self.kernel_num[idx + 1] // 2, 381 | kernel_size=(self.kernel_size, 2), 382 | stride=(2, 1), 383 | padding=(2, 1) 384 | ), 385 | nn.BatchNorm2d(self.kernel_num[idx + 1] // 2), 386 | nn.PReLU() 387 | ) 388 | ) 389 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num))) 390 | 391 | self.enhance = nn.LSTM( 392 | input_size=self.rnn_input_size, 393 | hidden_size=self.rnn_units, 394 | dropout=0.0, 395 | bidirectional=bidirectional, 396 | batch_first=False 397 | ) 398 | self.tranform = nn.Linear(self.rnn_units, self.rnn_input_size) 399 | 400 | if cfg.skip_type: 401 | for idx in range(len(self.kernel_num) - 1, 0, -1): 402 | if idx != 1: 403 | self.decoder.append( 404 | nn.Sequential( 405 | RealConvTranspose2d( 406 | self.kernel_num[idx], 407 | self.kernel_num[idx - 1] // 2, 408 | kernel_size=(self.kernel_size, 2), 409 | stride=(2, 1), 410 | padding=(2, 0), 411 | output_padding=(1, 0) 412 | ), 413 | nn.BatchNorm2d(self.kernel_num[idx - 1] // 2), 414 | nn.PReLU() 415 | ) 416 | ) 417 | else: 418 | self.decoder.append( 419 | nn.Sequential( 420 | RealConvTranspose2d( 421 | self.kernel_num[idx], 422 | self.kernel_num[idx - 1] // 2, 423 | kernel_size=(self.kernel_size, 2), 424 | stride=(2, 1), 425 | padding=(2, 0), 426 | output_padding=(1, 0) 427 | ), 428 | ) 429 | ) 430 | else: 431 | for idx in range(len(self.kernel_num) - 1, 0, -1): 432 | if idx != 1: 433 | self.decoder.append( 434 | nn.Sequential( 435 | nn.ConvTranspose2d( 436 | self.kernel_num[idx], 437 | self.kernel_num[idx - 1], 438 | kernel_size=(self.kernel_size, 2), 439 | stride=(2, 1), 440 | padding=(2, 0), 441 | output_padding=(1, 0) 442 | ), 443 | nn.BatchNorm2d(self.kernel_num[idx - 1]), 444 | # nn.ELU() 445 | nn.PReLU() 446 | ) 447 | ) 448 | else: 449 | self.decoder.append( 450 | nn.Sequential( 451 | nn.ConvTranspose2d( 452 | self.kernel_num[idx], 453 | self.kernel_num[idx - 1], 454 | kernel_size=(self.kernel_size, 2), 455 | stride=(2, 1), 456 | padding=(2, 0), 457 | output_padding=(1, 0) 458 | ), 459 | ) 460 | ) 461 | self.flatten_parameters() 462 | 463 | def flatten_parameters(self): 464 | if isinstance(self.enhance, nn.LSTM): 465 | self.enhance.flatten_parameters() 466 | 467 | def forward(self, inputs, targets=0): 468 | mags, phase = self.stft(inputs) 469 | 470 | out = mags 471 | out = out.unsqueeze(1) 472 | out = out[:, :, 1:] 473 | encoder_out = [] 474 | 475 | for idx, layer in enumerate(self.encoder): 476 | out = layer(out) 477 | # print('encoder', out.size()) 478 | encoder_out.append(out) 479 | 480 | batch_size, channels, dims, lengths = out.size() 481 | out = out.permute(3, 0, 1, 2) 482 | 483 | rnn_in = torch.reshape(out, [lengths, batch_size, channels * dims]) 484 | out, _ = self.enhance(rnn_in) 485 | out = self.tranform(out) 486 | out = torch.reshape(out, [lengths, batch_size, channels, dims]) 487 | 488 | out = out.permute(1, 2, 3, 0) 489 | 490 | if cfg.skip_type: # use skip connection 491 | for idx in range(len(self.decoder)): 492 | out = torch.cat([out, encoder_out[-1 - idx]], 1) 493 | out = self.decoder[idx](out) 494 | out = out[..., 1:] # 495 | else: 496 | for idx in range(len(self.decoder)): 497 | out = self.decoder[idx](out) 498 | out = out[..., 1:] 499 | 500 | # mask_mags = F.pad(out, [0, 0, 1, 0]) 501 | out = out.squeeze(1) 502 | out = F.pad(out, [0, 0, 1, 0]) 503 | 504 | # for loss calculation 505 | target_mags, _ = self.stft(targets) 506 | 507 | if self.masking_mode == 'Direct(None make)': # spectral mapping 508 | out_real = out * torch.cos(phase) 509 | out_imag = out * torch.sin(phase) 510 | 511 | out_spec = torch.cat([out_real, out_imag], 1) 512 | 513 | out_wav = self.istft(out_spec) 514 | out_wav = torch.squeeze(out_wav, 1) 515 | out_wav = torch.clamp_(out_wav, -1, 1) 516 | 517 | return out, target_mags, out_wav 518 | else: # T-F masking 519 | # mask_mags = torch.clamp_(mask_mags,0,100) 520 | # out = F.pad(out, [0, 0, 1, 0]) 521 | mask_mags = torch.tanh(out) 522 | est_mags = mask_mags * mags 523 | out_real = est_mags * torch.cos(phase) 524 | out_imag = est_mags * torch.sin(phase) 525 | 526 | out_spec = torch.cat([out_real, out_imag], 1) 527 | 528 | out_wav = self.istft(out_spec) 529 | out_wav = torch.squeeze(out_wav, 1) 530 | out_wav = torch.clamp_(out_wav, -1, 1) 531 | 532 | return est_mags, target_mags, out_wav 533 | 534 | def get_params(self, weight_decay=0.0): 535 | # add L2 penalty 536 | weights, biases = [], [] 537 | for name, param in self.named_parameters(): 538 | if 'bias' in name: 539 | biases += [param] 540 | else: 541 | weights += [param] 542 | params = [{ 543 | 'params': weights, 544 | 'weight_decay': weight_decay, 545 | }, { 546 | 'params': biases, 547 | 'weight_decay': 0.0, 548 | }] 549 | return params 550 | 551 | def loss(self, estimated, target, out_mags=0, target_mags=0, perceptual=False): 552 | if perceptual: 553 | if cfg.perceptual == 'LMS': 554 | return get_array_lms_loss(target_mags, out_mags) 555 | elif cfg.perceptual == 'PMSQE': 556 | return get_array_pmsqe_loss(target, estimated) 557 | else: 558 | if cfg.loss == 'MSE': 559 | return F.mse_loss(estimated, target, reduction='mean') 560 | elif cfg.loss == 'SDR': 561 | return -sdr(target, estimated) 562 | elif cfg.loss == 'SI-SNR': 563 | return -(si_snr(estimated, target)) 564 | elif cfg.loss == 'SI-SDR': 565 | return -(si_sdr(target, estimated)) 566 | 567 | 568 | class FullSubNet(BaseModel): 569 | def __init__(self, 570 | sb_num_neighbors=cfg.sb_num_neighbors, 571 | fb_num_neighbors=cfg.fb_num_neighbors, 572 | num_freqs=cfg.num_freqs, 573 | look_ahead=cfg.look_ahead, 574 | sequence_model=cfg.sequence_model, 575 | fb_output_activate_function=cfg.fb_output_activate_function, 576 | sb_output_activate_function=cfg.sb_output_activate_function, 577 | fb_model_hidden_size=cfg.fb_model_hidden_size, 578 | sb_model_hidden_size=cfg.sb_model_hidden_size, 579 | weight_init=cfg.weight_init, 580 | norm_type=cfg.norm_type, 581 | ): 582 | """ 583 | FullSubNet model (cIRM mask) 584 | 585 | Args: 586 | num_freqs: Frequency dim of the input 587 | look_ahead: Number of use of the future frames 588 | fb_num_neighbors: How much neighbor frequencies at each side from fullband model's output 589 | sb_num_neighbors: How much neighbor frequencies at each side from noisy spectrogram 590 | sequence_model: Chose one sequence model as the basic model e.g., GRU, LSTM 591 | fb_output_activate_function: fullband model's activation function 592 | sb_output_activate_function: subband model's activation function 593 | norm_type: type of normalization, see more details in "BaseModel" class 594 | """ 595 | super().__init__() 596 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM." 597 | 598 | self.fb_model = SequenceModel( 599 | input_size=num_freqs, 600 | output_size=num_freqs, 601 | hidden_size=fb_model_hidden_size, 602 | num_layers=2, 603 | bidirectional=False, 604 | sequence_model=sequence_model, 605 | output_activate_function=fb_output_activate_function 606 | ) 607 | 608 | self.sb_model = SequenceModel( 609 | input_size=(sb_num_neighbors * 2 + 1) + (fb_num_neighbors * 2 + 1), 610 | output_size=2, 611 | hidden_size=sb_model_hidden_size, 612 | num_layers=2, 613 | bidirectional=False, 614 | sequence_model=sequence_model, 615 | output_activate_function=sb_output_activate_function 616 | ) 617 | 618 | self.sb_num_neighbors = sb_num_neighbors 619 | self.fb_num_neighbors = fb_num_neighbors 620 | self.look_ahead = look_ahead 621 | self.norm = self.norm_wrapper(norm_type) 622 | 623 | if weight_init: 624 | self.apply(self.weight_init) 625 | 626 | def forward(self, noisy_mag): 627 | """ 628 | Args: 629 | noisy_mag: noisy magnitude spectrogram 630 | 631 | Returns: 632 | The real part and imag part of the enhanced spectrogram 633 | 634 | Shapes: 635 | noisy_mag: [B, 1, F, T] 636 | return: [B, 2, F, T] 637 | """ 638 | if not noisy_mag.dim() == 4: 639 | noisy_mag = noisy_mag.unsqueeze(1) 640 | noisy_mag = F.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead 641 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size() 642 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs." 643 | 644 | # Fullband model 645 | fb_input = self.norm(noisy_mag).reshape(batch_size, num_channels * num_freqs, num_frames) 646 | fb_output = self.fb_model(fb_input).reshape(batch_size, 1, num_freqs, num_frames) 647 | 648 | # Unfold fullband model's output, [B, N=F, C, F_f, T]. N is the number of sub-band units 649 | fb_output_unfolded = self.unfold(fb_output, num_neighbor=self.fb_num_neighbors) 650 | fb_output_unfolded = fb_output_unfolded.reshape(batch_size, num_freqs, self.fb_num_neighbors * 2 + 1, num_frames) 651 | 652 | # Unfold noisy spectrogram, [B, N=F, C, F_s, T] 653 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors) 654 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, num_frames) 655 | 656 | # Concatenation, [B, F, (F_s + F_f), T] 657 | sb_input = torch.cat([noisy_mag_unfolded, fb_output_unfolded], dim=2) 658 | sb_input = self.norm(sb_input) 659 | 660 | sb_input = sb_input.reshape( 661 | batch_size * num_freqs, 662 | (self.sb_num_neighbors * 2 + 1) + (self.fb_num_neighbors * 2 + 1), 663 | num_frames 664 | ) 665 | 666 | # [B * F, (F_s + F_f), T] => [B * F, 2, T] => [B, F, 2, T] 667 | sb_mask = self.sb_model(sb_input) 668 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous() 669 | 670 | output = sb_mask[:, :, :, self.look_ahead:] 671 | output = output.permute(0, 2, 3, 1) 672 | return output 673 | 674 | def loss(self, estimated, target): 675 | if cfg.loss == 'MSE': 676 | return F.mse_loss(estimated, target, reduction='mean') 677 | elif cfg.loss == 'SDR': 678 | return -sdr(target, estimated) 679 | elif cfg.loss == 'SI-SNR': 680 | return -(si_snr(estimated, target)) 681 | elif cfg.loss == 'SI-SDR': 682 | return -(si_sdr(target, estimated)) 683 | 684 | -------------------------------------------------------------------------------- /tools_for_estimate.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | from pesq import pesq 4 | import numpy as np 5 | import ctypes 6 | import logging 7 | import oct2py 8 | from scipy.io import wavfile 9 | from pystoi import stoi 10 | import config as cfg 11 | 12 | 13 | ############################################################################ 14 | # MOS # 15 | ############################################################################ 16 | # Reference 17 | # https://github.com/usimarit/semetrics # https://ecs.utdallas.edu/loizou/speech/software.htm 18 | logging.basicConfig(level=logging.ERROR) 19 | oc = oct2py.Oct2Py(logger=logging.getLogger()) 20 | 21 | COMPOSITE = os.path.join(os.path.abspath(os.path.dirname(__file__)), "composite.m") 22 | 23 | 24 | def composite(clean: str, enhanced: str): 25 | pesq_score = pesq_mos(clean, enhanced) 26 | csig, cbak, covl, ssnr = oc.feval(COMPOSITE, clean, enhanced, nout=4) 27 | csig += 0.603 * pesq_score 28 | cbak += 0.478 * pesq_score 29 | covl += 0.805 * pesq_score 30 | return csig, cbak, covl, ssnr 31 | 32 | 33 | ############################################################################ 34 | # PESQ # 35 | ############################################################################ 36 | # Reference 37 | # https://github.com/usimarit/semetrics 38 | # https://ecs.utdallas.edu/loizou/speech/software.htm 39 | 40 | def pesq_mos(clean: str, enhanced: str): 41 | sr1, clean_wav = wavfile.read(clean) 42 | sr2, enhanced_wav = wavfile.read(enhanced) 43 | assert sr1 == sr2 44 | mode = "nb" if sr1 < 16000 else "wb" 45 | return pesq(sr1, clean_wav, enhanced_wav, mode) 46 | 47 | 48 | ############################################################################### 49 | # PESQ (another ref) # 50 | ############################################################################### 51 | pesq_dll = ctypes.CDLL('./PESQ.so') 52 | pesq_dll.pesq.restype = ctypes.c_double 53 | 54 | 55 | # interface to PESQ evaluation, taking in two filenames as input 56 | def run_pesq_filenames(clean, to_eval): 57 | pesq_regex = re.compile("\(MOS-LQO\): = ([0-9]+\.[0-9]+)") 58 | 59 | pesq_out = os.popen("./PESQ" + cfg.fs + "wb " + clean + " " + to_eval).read() 60 | regex_result = pesq_regex.search(pesq_out) 61 | 62 | if (regex_result is None): 63 | return 0.0 64 | else: 65 | return float(regex_result.group(1)) 66 | 67 | 68 | def run_pesq_waveforms(dirty_wav, clean_wav): 69 | clean_wav = clean_wav.astype(np.double) 70 | dirty_wav = dirty_wav.astype(np.double) 71 | # return pesq(clean_wav, dirty_wav, fs=8000) 72 | return pesq_dll.pesq(ctypes.c_void_p(clean_wav.ctypes.data), 73 | ctypes.c_void_p(dirty_wav.ctypes.data), 74 | len(clean_wav), 75 | len(dirty_wav)) 76 | 77 | 78 | # interface to PESQ evaluation, taking in two waveforms as input 79 | def cal_pesq(dirty_wavs, clean_wavs): 80 | scores = [] 81 | for i in range(len(dirty_wavs)): 82 | pesq = run_pesq_waveforms(dirty_wavs[i], clean_wavs[i]) 83 | scores.append(pesq) 84 | return scores 85 | 86 | 87 | ############################################################################### 88 | # STOI # 89 | ############################################################################### 90 | def cal_stoi(estimated_speechs, clean_speechs): 91 | stoi_scores = [] 92 | for i in range(len(estimated_speechs)): 93 | stoi_score = stoi(clean_speechs[i], estimated_speechs[i], cfg.fs, extended=False) 94 | stoi_scores.append(stoi_score) 95 | return stoi_scores 96 | 97 | 98 | ############################################################################### 99 | # SNR # 100 | ############################################################################### 101 | def cal_snr(s1, s2, eps=1e-8): 102 | signal = s2 103 | mean_signal = np.mean(signal) 104 | signal_diff = signal - mean_signal 105 | var_signal = np.sum(np.mean(signal_diff ** 2)) # # variance of orignal data 106 | 107 | noisy_signal = s1 108 | noise = noisy_signal - signal 109 | mean_noise = np.mean(noise) 110 | noise_diff = noise - mean_noise 111 | var_noise = np.sum(np.mean(noise_diff ** 2)) # # variance of noise 112 | 113 | if var_noise == 0: 114 | snr_score = 100 # # clean 115 | else: 116 | snr_score = (np.log10(var_signal/var_noise + eps))*10 117 | return snr_score 118 | 119 | 120 | def cal_snr_array(estimated_speechs, clean_speechs): 121 | snr_score = [] 122 | for i in range(len(estimated_speechs)): 123 | snr = cal_snr(estimated_speechs[i], clean_speechs[i]) 124 | snr_score.append(snr) 125 | return snr_score 126 | -------------------------------------------------------------------------------- /tools_for_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | import config as cfg 5 | from asteroid.losses import SingleSrcPMSQE, PITLossWrapper 6 | from asteroid_filterbanks import STFTFB, Encoder, transforms 7 | 8 | ############################################################################ 9 | # for model structure & loss function # 10 | ############################################################################ 11 | def remove_dc(data): 12 | mean = torch.mean(data, -1, keepdim=True) 13 | data = data - mean 14 | return data 15 | 16 | 17 | def l2_norm(s1, s2): 18 | norm = torch.sum(s1 * s2, -1, keepdim=True) 19 | return norm 20 | 21 | 22 | def sdr_linear(s1, s2, eps=1e-8): 23 | sn = l2_norm(s1, s1) 24 | sn_m_shn = l2_norm(s1 - s2, s1 - s2) 25 | sdr_loss = sn**2 / (sn_m_shn**2 + eps) 26 | return torch.mean(sdr_loss) 27 | 28 | 29 | def sdr(s1, s2, eps=1e-8): 30 | sn = l2_norm(s1, s1) 31 | sn_m_shn = l2_norm(s1 - s2, s1 - s2) 32 | sdr_loss = 10 * torch.log10(sn**2 / (sn_m_shn**2 + eps)) 33 | return torch.mean(sdr_loss) 34 | 35 | 36 | def si_snr(s1, s2, eps=1e-8): 37 | s1_s2_norm = l2_norm(s1, s2) 38 | s2_s2_norm = l2_norm(s2, s2) 39 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 40 | e_nosie = s1 - s_target 41 | target_norm = l2_norm(s_target, s_target) 42 | noise_norm = l2_norm(e_nosie, e_nosie) 43 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) 44 | return torch.mean(snr) 45 | 46 | 47 | def si_sdr(reference, estimation, eps=1e-8): 48 | """ 49 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) 50 | Args: 51 | reference: numpy.ndarray, [..., T] 52 | estimation: numpy.ndarray, [..., T] 53 | Returns: 54 | SI-SDR 55 | [1] SDR– Half- Baked or Well Done? 56 | http://www.merl.com/publications/docs/TR2019-013.pdf 57 | >>> np.random.seed(0) 58 | >>> reference = np.random.randn(100) 59 | >>> si_sdr(reference, reference) 60 | inf 61 | >>> si_sdr(reference, reference * 2) 62 | inf 63 | >>> si_sdr(reference, np.flip(reference)) 64 | -25.127672346460717 65 | >>> si_sdr(reference, reference + np.flip(reference)) 66 | 0.481070445785553 67 | >>> si_sdr(reference, reference + 0.5) 68 | 6.3704606032577304 69 | >>> si_sdr(reference, reference * 2 + 1) 70 | 6.3704606032577304 71 | >>> si_sdr([1., 0], [0., 0]) # never predict only zeros 72 | nan 73 | >>> si_sdr([reference, reference], [reference * 2 + 1, reference * 1 + 0.5]) 74 | array([6.3704606, 6.3704606]) 75 | :param reference: 76 | :param estimation: 77 | :param eps: 78 | """ 79 | 80 | reference_energy = torch.sum(reference ** 2, axis=-1, keepdims=True) 81 | 82 | # This is $\alpha$ after Equation (3) in [1]. 83 | optimal_scaling = torch.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy + eps 84 | 85 | # This is $e_{\text{target}}$ in Equation (4) in [1]. 86 | projection = optimal_scaling * reference 87 | 88 | # This is $e_{\text{res}}$ in Equation (4) in [1]. 89 | noise = estimation - projection 90 | 91 | ratio = torch.sum(projection ** 2, axis=-1) / torch.sum(noise ** 2, axis=-1) + eps 92 | 93 | ratio = torch.mean(ratio) 94 | return 10 * torch.log10(ratio + eps) 95 | 96 | 97 | ############################################################################ 98 | # for LMS loss function # 99 | ############################################################################ 100 | # MFCC (Mel Frequency Cepstral Coefficients) 101 | 102 | # based on a combination of this article: 103 | # http://practicalcryptography.com/miscellaneous/machine-learning/... 104 | # guide-mel-frequency-cepstral-coefficients-mfccs/ 105 | # and some of this code: 106 | # http://stackoverflow.com/questions/5835568/... 107 | # how-to-get-mfcc-from-an-fft-on-a-signal 108 | # Set device 109 | DEVICE = torch.device(cfg.DEVICE) 110 | 111 | FFT_SIZE = cfg.fft_len 112 | 113 | # multi-scale MFCC distance 114 | if cfg.perceptual == 'LMS': 115 | MEL_SCALES = [16, 32, 64] 116 | elif cfg.perceptual == 'PAM': 117 | MEL_SCALES = [32, 64] 118 | 119 | 120 | class rmse(torch.nn.Module): 121 | def __init__(self): 122 | super(rmse, self).__init__() 123 | 124 | def forward(self, y_true, y_pred): 125 | mse = torch.mean((y_pred - y_true) ** 2, axis=-1) 126 | rmse = torch.sqrt(mse + 1e-7) 127 | 128 | return torch.mean(rmse) 129 | 130 | 131 | # conversions between Mel scale and regular frequency scale 132 | def freqToMel(freq): 133 | return 1127.01048 * math.log(1 + freq / 700.0) 134 | 135 | 136 | def melToFreq(mel): 137 | return 700 * (math.exp(mel / 1127.01048) - 1) 138 | 139 | # generate Mel filter bank 140 | def melFilterBank(numCoeffs, fftSize=None): 141 | minHz = 0 142 | maxHz = cfg.fs / 2 # max Hz by Nyquist theorem 143 | if (fftSize is None): 144 | numFFTBins = cfg.win_len 145 | else: 146 | numFFTBins = int(fftSize / 2) + 1 147 | 148 | maxMel = freqToMel(maxHz) 149 | minMel = freqToMel(minHz) 150 | 151 | # we need (numCoeffs + 2) points to create (numCoeffs) filterbanks 152 | melRange = np.array(range(numCoeffs + 2)) 153 | melRange = melRange.astype(np.float32) 154 | 155 | # create (numCoeffs + 2) points evenly spaced between minMel and maxMel 156 | melCenterFilters = melRange * (maxMel - minMel) / (numCoeffs + 1) + minMel 157 | 158 | for i in range(numCoeffs + 2): 159 | # mel domain => frequency domain 160 | melCenterFilters[i] = melToFreq(melCenterFilters[i]) 161 | 162 | # frequency domain => FFT bins 163 | melCenterFilters[i] = math.floor(numFFTBins * melCenterFilters[i] / maxHz) 164 | 165 | # create matrix of filters (one row is one filter) 166 | filterMat = np.zeros((numCoeffs, numFFTBins)) 167 | 168 | # generate triangular filters (in frequency domain) 169 | for i in range(1, numCoeffs + 1): 170 | filter = np.zeros(numFFTBins) 171 | 172 | startRange = int(melCenterFilters[i - 1]) 173 | midRange = int(melCenterFilters[i]) 174 | endRange = int(melCenterFilters[i + 1]) 175 | 176 | for j in range(startRange, midRange): 177 | filter[j] = (float(j) - startRange) / (midRange - startRange) 178 | for j in range(midRange, endRange): 179 | filter[j] = 1 - ((float(j) - midRange) / (endRange - midRange)) 180 | 181 | filterMat[i - 1] = filter 182 | 183 | # return filterbank as matrix 184 | return filterMat 185 | 186 | 187 | # Finally: a perceptual loss function (based on Mel scale) 188 | 189 | # given a (symbolic Theano) array of size M x WINDOW_SIZE 190 | # this returns an array M x N where each window has been replaced 191 | # by some perceptual transform (in this case, MFCC coeffs) 192 | def perceptual_transform(x): 193 | # precompute Mel filterbank: [FFT_SIZE x NUM_MFCC_COEFFS] 194 | MEL_FILTERBANKS = [] 195 | for scale in MEL_SCALES: 196 | filterbank_npy = melFilterBank(scale, FFT_SIZE).transpose() 197 | torch_filterbank_npy = torch.from_numpy(filterbank_npy).type(torch.FloatTensor) 198 | MEL_FILTERBANKS.append(torch_filterbank_npy.to(DEVICE)) 199 | 200 | transforms = [] 201 | # powerSpectrum = torch_dft_mag(x, DFT_REAL, DFT_IMAG)**2 202 | 203 | powerSpectrum = x.view(-1, FFT_SIZE // 2 + 1) 204 | powerSpectrum = 1.0 / FFT_SIZE * powerSpectrum 205 | 206 | for filterbank in MEL_FILTERBANKS: 207 | filteredSpectrum = torch.mm(powerSpectrum, filterbank) 208 | filteredSpectrum = torch.log(filteredSpectrum + 1e-7) 209 | transforms.append(filteredSpectrum) 210 | 211 | return transforms 212 | 213 | 214 | # perceptual loss function 215 | class perceptual_distance(torch.nn.Module): 216 | 217 | def __init__(self): 218 | super(perceptual_distance, self).__init__() 219 | 220 | def forward(self, y_true, y_pred): 221 | rmse_loss = rmse() 222 | # y_true = torch.reshape(y_true, (-1, WINDOW_SIZE)) 223 | # y_pred = torch.reshape(y_pred, (-1, WINDOW_SIZE)) 224 | 225 | pvec_true = perceptual_transform(y_true) 226 | pvec_pred = perceptual_transform(y_pred) 227 | 228 | distances = [] 229 | for i in range(0, len(pvec_true)): 230 | error = rmse_loss(pvec_pred[i], pvec_true[i]) 231 | error = error.unsqueeze(dim=-1) 232 | distances.append(error) 233 | distances = torch.cat(distances, axis=-1) 234 | 235 | loss = torch.mean(distances, axis=-1) 236 | return torch.mean(loss) 237 | 238 | 239 | get_mel_loss = perceptual_distance() 240 | 241 | 242 | def get_array_lms_loss(clean_array, est_array): 243 | array_mel_loss = 0 244 | for i in range(len(clean_array)): 245 | mel_loss = get_mel_loss(clean_array[i], est_array[i]) 246 | array_mel_loss += mel_loss 247 | 248 | avg_mel_loss = array_mel_loss / len(clean_array) 249 | return avg_mel_loss 250 | 251 | 252 | ############################################################################ 253 | # for pmsqe loss function # 254 | ############################################################################ 255 | pmsqe_stft = Encoder(STFTFB(kernel_size=512, n_filters=512, stride=256)).to(DEVICE) 256 | pmsqe_loss = PITLossWrapper(SingleSrcPMSQE(), pit_from='pw_pt').to(DEVICE) 257 | 258 | 259 | def get_array_pmsqe_loss(clean_array, est_array): 260 | if clean_array.dim() == 2: 261 | clean_wav = torch.unsqueeze(clean_array, 1) 262 | est_wav = torch.unsqueeze(est_array, 1) 263 | N, C, H = clean_wav.size() 264 | clean_wav = clean_wav.contiguous().view(N, -1, cfg.fs) 265 | est_wav = est_wav.contiguous().view(N, -1, cfg.fs) 266 | 267 | clean_spec = transforms.mag(pmsqe_stft(clean_wav)) 268 | est_spec = transforms.mag(pmsqe_stft(est_wav)) 269 | return pmsqe_loss(est_spec, clean_spec) 270 | -------------------------------------------------------------------------------- /tools_for_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | from scipy.signal import get_window 8 | import matplotlib.pylab as plt 9 | import config as cfg 10 | 11 | 12 | ############################################################################ 13 | # for convolutional STFT # 14 | ############################################################################ 15 | # this is from conv_stft https://github.com/huyanxin/DeepComplexCRN 16 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 17 | if win_type == 'None' or win_type is None: 18 | window = np.ones(win_len) 19 | else: 20 | window = get_window(win_type, win_len, fftbins=True) # **0.5 21 | 22 | N = fft_len 23 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 24 | real_kernel = np.real(fourier_basis) 25 | imag_kernel = np.imag(fourier_basis) 26 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 27 | 28 | if invers: 29 | kernel = np.linalg.pinv(kernel).T 30 | 31 | kernel = kernel * window 32 | kernel = kernel[:, None, :] 33 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) 34 | 35 | 36 | class ConvSTFT(nn.Module): 37 | 38 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 39 | super(ConvSTFT, self).__init__() 40 | 41 | if fft_len == None: 42 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 43 | else: 44 | self.fft_len = fft_len 45 | 46 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 47 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 48 | self.register_buffer('weight', kernel) 49 | self.feature_type = feature_type 50 | self.stride = win_inc 51 | self.win_len = win_len 52 | self.dim = self.fft_len 53 | 54 | def forward(self, inputs): 55 | if inputs.dim() == 2: 56 | inputs = torch.unsqueeze(inputs, 1) 57 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride]) 58 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 59 | 60 | if self.feature_type == 'complex': 61 | return outputs 62 | else: 63 | dim = self.dim // 2 + 1 64 | real = outputs[:, :dim, :] 65 | imag = outputs[:, dim:, :] 66 | mags = torch.sqrt(real ** 2 + imag ** 2) 67 | phase = torch.atan2(imag, real) 68 | return mags, phase 69 | 70 | 71 | class ConviSTFT(nn.Module): 72 | 73 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 74 | super(ConviSTFT, self).__init__() 75 | if fft_len == None: 76 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 77 | else: 78 | self.fft_len = fft_len 79 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 80 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 81 | self.register_buffer('weight', kernel) 82 | self.feature_type = feature_type 83 | self.win_type = win_type 84 | self.win_len = win_len 85 | self.stride = win_inc 86 | self.dim = self.fft_len 87 | self.register_buffer('window', window) 88 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) 89 | 90 | def forward(self, inputs, phase=None): 91 | """ 92 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 93 | phase: [B, N//2+1, T] (if not none) 94 | """ 95 | 96 | if phase is not None: 97 | real = inputs * torch.cos(phase) 98 | imag = inputs * torch.sin(phase) 99 | inputs = torch.cat([real, imag], 1) 100 | 101 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 102 | 103 | # this is from torch-stft: https://github.com/pseeth/torch-stft 104 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2 105 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 106 | 107 | outputs = outputs / (coff + 1e-8) 108 | 109 | # # outputs = torch.where(coff == 0, outputs, outputs/coff) 110 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)] 111 | 112 | return outputs 113 | 114 | 115 | ############################################################################ 116 | # for complex rnn # 117 | ############################################################################ 118 | def get_casual_padding1d(): 119 | pass 120 | 121 | 122 | def get_casual_padding2d(): 123 | pass 124 | 125 | 126 | class cPReLU(nn.Module): 127 | 128 | def __init__(self, complex_axis=1): 129 | super(cPReLU, self).__init__() 130 | self.r_prelu = nn.PReLU() 131 | self.i_prelu = nn.PReLU() 132 | self.complex_axis = complex_axis 133 | 134 | def forward(self, inputs): 135 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 136 | real = self.r_prelu(real) 137 | imag = self.i_prelu(imag) 138 | return torch.cat([real, imag], self.complex_axis) 139 | 140 | 141 | class NavieComplexLSTM(nn.Module): 142 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False): 143 | super(NavieComplexLSTM, self).__init__() 144 | 145 | self.input_dim = input_size // 2 146 | self.rnn_units = hidden_size // 2 147 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 148 | batch_first=False) 149 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 150 | batch_first=False) 151 | if bidirectional: 152 | bidirectional = 2 153 | else: 154 | bidirectional = 1 155 | if projection_dim is not None: 156 | self.projection_dim = projection_dim // 2 157 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 158 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 159 | else: 160 | self.projection_dim = None 161 | 162 | def forward(self, inputs): 163 | if isinstance(inputs, list): 164 | real, imag = inputs 165 | elif isinstance(inputs, torch.Tensor): 166 | real, imag = torch.chunk(inputs, -1) 167 | r2r_out = self.real_lstm(real)[0] 168 | r2i_out = self.imag_lstm(real)[0] 169 | i2r_out = self.real_lstm(imag)[0] 170 | i2i_out = self.imag_lstm(imag)[0] 171 | real_out = r2r_out - i2i_out 172 | imag_out = i2r_out + r2i_out 173 | if self.projection_dim is not None: 174 | real_out = self.r_trans(real_out) 175 | imag_out = self.i_trans(imag_out) 176 | # print(real_out.shape,imag_out.shape) 177 | return [real_out, imag_out] 178 | 179 | def flatten_parameters(self): 180 | self.imag_lstm.flatten_parameters() 181 | self.real_lstm.flatten_parameters() 182 | 183 | 184 | def complex_cat(inputs, axis): 185 | real, imag = [], [] 186 | for idx, data in enumerate(inputs): 187 | r, i = torch.chunk(data, 2, axis) 188 | real.append(r) 189 | imag.append(i) 190 | real = torch.cat(real, axis) 191 | imag = torch.cat(imag, axis) 192 | outputs = torch.cat([real, imag], axis) 193 | return outputs 194 | 195 | 196 | ############################################################################ 197 | # for convolutional layer # 198 | ############################################################################ 199 | class ComplexConv2d(nn.Module): 200 | 201 | def __init__( 202 | self, 203 | in_channels, 204 | out_channels, 205 | kernel_size=(1, 1), 206 | stride=(1, 1), 207 | padding=(0, 0), 208 | dilation=1, 209 | groups=1, 210 | causal=True, 211 | complex_axis=1, 212 | ): 213 | ''' 214 | in_channels: real+imag 215 | out_channels: real+imag 216 | kernel_size : input [B,C,D,T] kernel size in [D,T] 217 | padding : input [B,C,D,T] padding in [D,T] 218 | causal: if causal, will padding time dimension's left side, 219 | otherwise both 220 | 221 | ''' 222 | super(ComplexConv2d, self).__init__() 223 | self.in_channels = in_channels // 2 224 | self.out_channels = out_channels // 2 225 | self.kernel_size = kernel_size 226 | self.stride = stride 227 | self.padding = padding 228 | self.causal = causal 229 | self.groups = groups 230 | self.dilation = dilation 231 | self.complex_axis = complex_axis 232 | 233 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 234 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 235 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 236 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 237 | 238 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 239 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 240 | nn.init.constant_(self.real_conv.bias, 0.) 241 | nn.init.constant_(self.imag_conv.bias, 0.) 242 | 243 | def forward(self, inputs): 244 | if self.padding[1] != 0 and self.causal: 245 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) # # [width left, width right, height left, height right] 246 | else: 247 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 248 | 249 | if self.complex_axis == 0: 250 | real = self.real_conv(inputs) 251 | imag = self.imag_conv(inputs) 252 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 253 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 254 | 255 | else: 256 | if isinstance(inputs, torch.Tensor): 257 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 258 | 259 | real2real = self.real_conv(real, ) 260 | imag2imag = self.imag_conv(imag, ) 261 | 262 | real2imag = self.imag_conv(real) 263 | imag2real = self.real_conv(imag) 264 | 265 | real = real2real - imag2imag 266 | imag = real2imag + imag2real 267 | out = torch.cat([real, imag], self.complex_axis) 268 | 269 | return out 270 | 271 | 272 | class ComplexConvTranspose2d(nn.Module): 273 | 274 | def __init__( 275 | self, 276 | in_channels, 277 | out_channels, 278 | kernel_size=(1, 1), 279 | stride=(1, 1), 280 | padding=(0, 0), 281 | output_padding=(0, 0), 282 | causal=False, 283 | complex_axis=1, 284 | groups=1 285 | ): 286 | ''' 287 | in_channels: real+imag 288 | out_channels: real+imag 289 | ''' 290 | super(ComplexConvTranspose2d, self).__init__() 291 | self.in_channels = in_channels // 2 292 | self.out_channels = out_channels // 2 293 | self.kernel_size = kernel_size 294 | self.stride = stride 295 | self.padding = padding 296 | self.output_padding = output_padding 297 | self.groups = groups 298 | 299 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 300 | padding=self.padding, output_padding=output_padding, groups=self.groups) 301 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 302 | padding=self.padding, output_padding=output_padding, groups=self.groups) 303 | 304 | self.complex_axis = complex_axis 305 | 306 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 307 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 308 | nn.init.constant_(self.real_conv.bias, 0.) 309 | nn.init.constant_(self.imag_conv.bias, 0.) 310 | 311 | def forward(self, inputs): 312 | 313 | if isinstance(inputs, torch.Tensor): 314 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 315 | elif isinstance(inputs, tuple) or isinstance(inputs, list): 316 | real = inputs[0] 317 | imag = inputs[1] 318 | if self.complex_axis == 0: 319 | real = self.real_conv(inputs) 320 | imag = self.imag_conv(inputs) 321 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 322 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 323 | 324 | else: 325 | if isinstance(inputs, torch.Tensor): 326 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 327 | 328 | real2real = self.real_conv(real, ) 329 | imag2imag = self.imag_conv(imag, ) 330 | 331 | real2imag = self.imag_conv(real) 332 | imag2real = self.real_conv(imag) 333 | 334 | real = real2real - imag2imag 335 | imag = real2imag + imag2real 336 | out = torch.cat([real, imag], self.complex_axis) 337 | 338 | return out 339 | 340 | 341 | class RealConv2d(nn.Module): 342 | 343 | def __init__( 344 | self, 345 | in_channels, 346 | out_channels, 347 | kernel_size=(1, 1), 348 | stride=(1, 1), 349 | padding=(0, 0), 350 | dilation=1, 351 | groups=1, 352 | causal=True, 353 | complex_axis=1, 354 | ): 355 | ''' 356 | in_channels: real+imag 357 | out_channels: real+imag 358 | kernel_size : input [B,C,D,T] kernel size in [D,T] 359 | padding : input [B,C,D,T] padding in [D,T] 360 | causal: if causal, will padding time dimension's left side, 361 | otherwise both 362 | 363 | ''' 364 | super(RealConv2d, self).__init__() 365 | self.in_channels = in_channels 366 | self.out_channels = out_channels 367 | self.kernel_size = kernel_size 368 | self.stride = stride 369 | self.padding = padding 370 | self.causal = causal 371 | self.groups = groups 372 | self.dilation = dilation 373 | 374 | self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 375 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 376 | 377 | nn.init.normal_(self.conv.weight.data, std=0.05) 378 | nn.init.constant_(self.conv.bias, 0.) 379 | 380 | def forward(self, inputs): 381 | if self.padding[1] != 0 and self.causal: 382 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) ## [width left, width right, height left, height right] 383 | else: 384 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 385 | 386 | out = self.conv(inputs) 387 | 388 | return out 389 | 390 | 391 | class RealConvTranspose2d(nn.Module): 392 | 393 | def __init__( 394 | self, 395 | in_channels, 396 | out_channels, 397 | kernel_size=(1, 1), 398 | stride=(1, 1), 399 | padding=(0, 0), 400 | output_padding=(0, 0), 401 | groups=1 402 | ): 403 | ''' 404 | in_channels: real+imag 405 | out_channels: real+imag 406 | ''' 407 | super(RealConvTranspose2d, self).__init__() 408 | self.in_channels = in_channels 409 | self.out_channels = out_channels 410 | self.kernel_size = kernel_size 411 | self.stride = stride 412 | self.padding = padding 413 | self.output_padding = output_padding 414 | self.groups = groups 415 | 416 | self.conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 417 | padding=self.padding, output_padding=output_padding, groups=self.groups) 418 | 419 | nn.init.normal_(self.conv.weight.data, std=0.05) 420 | nn.init.constant_(self.conv.bias, 0.) 421 | 422 | def forward(self, inputs): 423 | out = self.conv(inputs) 424 | 425 | return out 426 | 427 | 428 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch 429 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55 430 | class ComplexBatchNorm(torch.nn.Module): 431 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 432 | track_running_stats=True, complex_axis=1): 433 | super(ComplexBatchNorm, self).__init__() 434 | self.num_features = num_features // 2 435 | self.eps = eps 436 | self.momentum = momentum 437 | self.affine = affine 438 | self.track_running_stats = track_running_stats 439 | 440 | self.complex_axis = complex_axis 441 | 442 | if self.affine: 443 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) 444 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) 445 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) 446 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) 447 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) 448 | else: 449 | self.register_parameter('Wrr', None) 450 | self.register_parameter('Wri', None) 451 | self.register_parameter('Wii', None) 452 | self.register_parameter('Br', None) 453 | self.register_parameter('Bi', None) 454 | 455 | if self.track_running_stats: 456 | self.register_buffer('RMr', torch.zeros(self.num_features)) 457 | self.register_buffer('RMi', torch.zeros(self.num_features)) 458 | self.register_buffer('RVrr', torch.ones(self.num_features)) 459 | self.register_buffer('RVri', torch.zeros(self.num_features)) 460 | self.register_buffer('RVii', torch.ones(self.num_features)) 461 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 462 | else: 463 | self.register_parameter('RMr', None) 464 | self.register_parameter('RMi', None) 465 | self.register_parameter('RVrr', None) 466 | self.register_parameter('RVri', None) 467 | self.register_parameter('RVii', None) 468 | self.register_parameter('num_batches_tracked', None) 469 | self.reset_parameters() 470 | 471 | def reset_running_stats(self): 472 | if self.track_running_stats: 473 | self.RMr.zero_() 474 | self.RMi.zero_() 475 | self.RVrr.fill_(1) 476 | self.RVri.zero_() 477 | self.RVii.fill_(1) 478 | self.num_batches_tracked.zero_() 479 | 480 | def reset_parameters(self): 481 | self.reset_running_stats() 482 | if self.affine: 483 | self.Br.data.zero_() 484 | self.Bi.data.zero_() 485 | self.Wrr.data.fill_(1) 486 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite 487 | self.Wii.data.fill_(1) 488 | 489 | def _check_input_dim(self, xr, xi): 490 | assert (xr.shape == xi.shape) 491 | assert (xr.size(1) == self.num_features) 492 | 493 | def forward(self, inputs): 494 | # self._check_input_dim(xr, xi) 495 | 496 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) 497 | exponential_average_factor = 0.0 498 | 499 | if self.training and self.track_running_stats: 500 | self.num_batches_tracked += 1 501 | if self.momentum is None: # use cumulative moving average 502 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 503 | else: # use exponential moving average 504 | exponential_average_factor = self.momentum 505 | 506 | # 507 | # NOTE: The precise meaning of the "training flag" is: 508 | # True: Normalize using batch statistics, update running statistics 509 | # if they are being collected. 510 | # False: Normalize using running statistics, ignore batch statistics. 511 | # 512 | training = self.training or not self.track_running_stats 513 | redux = [i for i in reversed(range(xr.dim())) if i != 1] 514 | vdim = [1] * xr.dim() 515 | vdim[1] = xr.size(1) 516 | 517 | # 518 | # Mean M Computation and Centering 519 | # 520 | # Includes running mean update if training and running. 521 | # 522 | if training: 523 | Mr, Mi = xr, xi 524 | for d in redux: 525 | Mr = Mr.mean(d, keepdim=True) 526 | Mi = Mi.mean(d, keepdim=True) 527 | if self.track_running_stats: 528 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) 529 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) 530 | else: 531 | Mr = self.RMr.view(vdim) 532 | Mi = self.RMi.view(vdim) 533 | xr, xi = xr - Mr, xi - Mi 534 | 535 | # 536 | # Variance Matrix V Computation 537 | # 538 | # Includes epsilon numerical stabilizer/Tikhonov regularizer. 539 | # Includes running variance update if training and running. 540 | # 541 | if training: 542 | Vrr = xr * xr 543 | Vri = xr * xi 544 | Vii = xi * xi 545 | for d in redux: 546 | Vrr = Vrr.mean(d, keepdim=True) 547 | Vri = Vri.mean(d, keepdim=True) 548 | Vii = Vii.mean(d, keepdim=True) 549 | if self.track_running_stats: 550 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) 551 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) 552 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) 553 | else: 554 | Vrr = self.RVrr.view(vdim) 555 | Vri = self.RVri.view(vdim) 556 | Vii = self.RVii.view(vdim) 557 | Vrr = Vrr + self.eps 558 | Vri = Vri 559 | Vii = Vii + self.eps 560 | 561 | # 562 | # Matrix Inverse Square Root U = V^-0.5 563 | # 564 | # sqrt of a 2x2 matrix, 565 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 566 | tau = Vrr + Vii 567 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) 568 | s = delta.sqrt() 569 | t = (tau + 2 * s).sqrt() 570 | 571 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html 572 | rst = (s * t).reciprocal() 573 | Urr = (s + Vii) * rst 574 | Uii = (s + Vrr) * rst 575 | Uri = (- Vri) * rst 576 | 577 | # 578 | # Optionally left-multiply U by affine weights W to produce combined 579 | # weights Z, left-multiply the inputs by Z, then optionally bias them. 580 | # 581 | # y = Zx + B 582 | # y = WUx + B 583 | # y = [Wrr Wri][Urr Uri] [xr] + [Br] 584 | # [Wir Wii][Uir Uii] [xi] [Bi] 585 | # 586 | if self.affine: 587 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) 588 | Zrr = (Wrr * Urr) + (Wri * Uri) 589 | Zri = (Wrr * Uri) + (Wri * Uii) 590 | Zir = (Wri * Urr) + (Wii * Uri) 591 | Zii = (Wri * Uri) + (Wii * Uii) 592 | else: 593 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii 594 | 595 | yr = (Zrr * xr) + (Zri * xi) 596 | yi = (Zir * xr) + (Zii * xi) 597 | 598 | if self.affine: 599 | yr = yr + self.Br.view(vdim) 600 | yi = yi + self.Bi.view(vdim) 601 | 602 | outputs = torch.cat([yr, yi], self.complex_axis) 603 | return outputs 604 | 605 | def extra_repr(self): 606 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 607 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 608 | 609 | 610 | def complex_cat(inputs, axis): 611 | real, imag = [], [] 612 | for idx, data in enumerate(inputs): 613 | r, i = torch.chunk(data, 2, axis) 614 | real.append(r) 615 | imag.append(i) 616 | real = torch.cat(real, axis) 617 | imag = torch.cat(imag, axis) 618 | outputs = torch.cat([real, imag], axis) 619 | return outputs 620 | 621 | ############################################################################ 622 | # for FullSubNet # 623 | ############################################################################ 624 | # Source: https://github.com/haoxiangsnr/FullSubNet 625 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/model/module/sequence_model.py 626 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/model/base_model.py 627 | # from https://github.com/haoxiangsnr/FullSubNet/blob/main/audio_zen/acoustics/feature.py 628 | def stft(y, n_fft=cfg.fft_len, hop_length=int(cfg.win_len*cfg.ola_ratio), win_length=cfg.win_len): 629 | """ 630 | Args: 631 | y: [B, F, T] 632 | n_fft: num of FFT 633 | hop_length: hop length 634 | win_length: window length 635 | 636 | Returns: 637 | [B, F, T], **complex-valued** STFT coefficients 638 | 639 | """ 640 | assert y.dim() == 2 641 | return torch.stft( 642 | y, 643 | n_fft, 644 | hop_length, 645 | win_length, 646 | window=torch.hann_window(win_length).to(y.device), 647 | return_complex=True 648 | ) 649 | 650 | 651 | def istft(features, n_fft=cfg.fft_len, hop_length=int(cfg.win_len*cfg.ola_ratio), win_length=cfg.win_len, length=None, use_mag_phase=False): 652 | """ 653 | Wrapper for the official torch.istft 654 | 655 | Args: 656 | features: [B, F, T, 2] (complex) or ([B, F, T], [B, F, T]) (mag and phase) 657 | n_fft: 658 | hop_length: 659 | win_length: 660 | device: 661 | length: 662 | use_mag_phase: use mag and phase as inputs of iSTFT 663 | 664 | Returns: 665 | [B, T] 666 | """ 667 | if use_mag_phase: 668 | # (mag, phase) or [mag, phase] 669 | assert isinstance(features, tuple) or isinstance(features, list) 670 | mag, phase = features 671 | features = torch.stack([mag * torch.cos(phase), mag * torch.sin(phase)], dim=-1) 672 | 673 | return torch.istft( 674 | features, 675 | n_fft, 676 | hop_length, 677 | win_length, 678 | window=torch.hann_window(win_length).to(features.device), 679 | length=length 680 | ) 681 | 682 | 683 | def mag_phase(complex_tensor): 684 | return torch.abs(complex_tensor), torch.angle(complex_tensor) 685 | 686 | 687 | def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor: 688 | """ 689 | 690 | Args: 691 | noisy: [B, F, T], noisy complex-valued stft coefficients 692 | clean: [B, F, T], clean complex-valued stft coefficients 693 | 694 | Returns: 695 | [B, F, T, 2] 696 | """ 697 | denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON 698 | 699 | mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator 700 | mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator 701 | 702 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1) 703 | 704 | return compress_cIRM(complex_ratio_mask, K=10, C=0.1) 705 | 706 | 707 | def compress_cIRM(mask, K=10, C=0.1): 708 | """ 709 | Compress from (-inf, +inf) to [-K ~ K] 710 | """ 711 | if torch.is_tensor(mask): 712 | mask = -100 * (mask <= -100) + mask * (mask > -100) 713 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask)) 714 | else: 715 | mask = -100 * (mask <= -100) + mask * (mask > -100) 716 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask)) 717 | return mask 718 | 719 | 720 | def decompress_cIRM(mask, K=10, limit=9.9): 721 | mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit) 722 | mask = -K * torch.log((K - mask) / (K + mask)) 723 | return mask 724 | 725 | 726 | class SequenceModel(nn.Module): 727 | def __init__( 728 | self, 729 | input_size, 730 | output_size, 731 | hidden_size, 732 | num_layers, 733 | bidirectional, 734 | sequence_model="GRU", 735 | output_activate_function="Tanh" 736 | ): 737 | super().__init__() 738 | # Sequence layer 739 | if sequence_model == "LSTM": 740 | self.sequence_model = nn.LSTM( 741 | input_size=input_size, 742 | hidden_size=hidden_size, 743 | num_layers=num_layers, 744 | batch_first=True, 745 | bidirectional=bidirectional, 746 | dropout=0.8, 747 | ) 748 | elif sequence_model == "GRU": 749 | self.sequence_model = nn.GRU( 750 | input_size=input_size, 751 | hidden_size=hidden_size, 752 | num_layers=num_layers, 753 | batch_first=True, 754 | bidirectional=bidirectional, 755 | dropout=0.8, 756 | ) 757 | else: 758 | raise NotImplementedError(f"Not implemented {sequence_model}") 759 | 760 | # Fully connected layer 761 | if bidirectional: 762 | self.fc_output_layer = nn.Linear(hidden_size * 2, output_size) 763 | else: 764 | self.fc_output_layer = nn.Linear(hidden_size, output_size) 765 | 766 | # Activation function layer 767 | if output_activate_function: 768 | if output_activate_function == "Tanh": 769 | self.activate_function = nn.Tanh() 770 | elif output_activate_function == "ReLU": 771 | self.activate_function = nn.ReLU() 772 | elif output_activate_function == "ReLU6": 773 | self.activate_function = nn.ReLU6() 774 | else: 775 | raise NotImplementedError(f"Not implemented activation function {self.activate_function}") 776 | 777 | self.output_activate_function = output_activate_function 778 | 779 | def forward(self, x): 780 | """ 781 | Args: 782 | x: [B, F, T] 783 | Returns: 784 | [B, F, T] 785 | """ 786 | assert x.dim() == 3 787 | self.sequence_model.flatten_parameters() 788 | 789 | x = x.permute(0, 2, 1).contiguous() # [B, F, T] => [B, T, F] 790 | o, _ = self.sequence_model(x) 791 | o = self.fc_output_layer(o) 792 | if self.output_activate_function: 793 | o = self.activate_function(o) 794 | o = o.permute(0, 2, 1).contiguous() # [B, T, F] => [B, F, T] 795 | return o 796 | 797 | 798 | EPSILON = np.finfo(np.float32).eps 799 | 800 | 801 | class BaseModel(nn.Module): 802 | def __init__(self): 803 | super(BaseModel, self).__init__() 804 | 805 | @staticmethod 806 | def unfold(input, num_neighbor): 807 | """ 808 | Along with the frequency dim, split overlapped sub band units from spectrogram. 809 | 810 | Args: 811 | input: [B, C, F, T] 812 | num_neighbor: 813 | 814 | Returns: 815 | [B, N, C, F_s, T], F, e.g. [2, 161, 1, 19, 200] 816 | """ 817 | assert input.dim() == 4, f"The dim of input is {input.dim()}. It should be four dim." 818 | batch_size, num_channels, num_freqs, num_frames = input.size() 819 | 820 | if num_neighbor < 1: 821 | # No change for the input 822 | return input.permute(0, 2, 1, 3).reshape(batch_size, num_freqs, num_channels, 1, num_frames) 823 | 824 | output = input.reshape(batch_size * num_channels, 1, num_freqs, num_frames) 825 | sub_band_unit_size = num_neighbor * 2 + 1 826 | 827 | # Pad to the top and bottom 828 | output = F.pad(output, [0, 0, num_neighbor, num_neighbor], mode="reflect") 829 | 830 | output = F.unfold(output, (sub_band_unit_size, num_frames)) 831 | assert output.shape[-1] == num_freqs, f"n_freqs != N (sub_band), {num_freqs} != {output.shape[-1]}" 832 | 833 | # Split the dim of the unfolded feature 834 | output = output.reshape(batch_size, num_channels, sub_band_unit_size, num_frames, num_freqs) 835 | output = output.permute(0, 4, 1, 2, 3).contiguous() 836 | 837 | return output 838 | 839 | @staticmethod 840 | def _reduce_complexity_separately(sub_band_input, full_band_output, device): 841 | """ 842 | 843 | Args: 844 | sub_band_input: [60, 257, 1, 33, 200] 845 | full_band_output: [60, 257, 1, 3, 200] 846 | device: 847 | 848 | Notes: 849 | 1. 255 and 256 freq not able to be trained 850 | 2. batch size 851 | 852 | Returns: 853 | [60, 85, 1, 36, 200] 854 | """ 855 | batch_size = full_band_output.shape[0] 856 | n_freqs = full_band_output.shape[1] 857 | sub_batch_size = batch_size // 3 858 | final_selected = [] 859 | 860 | for idx in range(3): 861 | # [0, 60) => [0, 20) 862 | sub_batch_indices = torch.arange(idx * sub_batch_size, (idx + 1) * sub_batch_size, device=device) 863 | full_band_output_sub_batch = torch.index_select(full_band_output, dim=0, index=sub_batch_indices) 864 | sub_band_output_sub_batch = torch.index_select(sub_band_input, dim=0, index=sub_batch_indices) 865 | 866 | # Avoid to use padded value (first freq and last freq) 867 | # i = 0, (1, 256, 3) = [1, 4, ..., 253] 868 | # i = 1, (2, 256, 3) = [2, 5, ..., 254] 869 | # i = 2, (3, 256, 3) = [3, 6, ..., 255] 870 | freq_indices = torch.arange(idx + 1, n_freqs - 1, step=3, device=device) 871 | full_band_output_sub_batch = torch.index_select(full_band_output_sub_batch, dim=1, index=freq_indices) 872 | sub_band_output_sub_batch = torch.index_select(sub_band_output_sub_batch, dim=1, index=freq_indices) 873 | 874 | # ([30, 85, 1, 33 200], [30, 85, 1, 3, 200]) => [30, 85, 1, 36, 200] 875 | 876 | final_selected.append(torch.cat([sub_band_output_sub_batch, full_band_output_sub_batch], dim=-2)) 877 | 878 | return torch.cat(final_selected, dim=0) 879 | 880 | @staticmethod 881 | def sband_forgetting_norm(input, train_sample_length): 882 | """ 883 | Args: 884 | input: 885 | train_sample_length: 886 | 887 | Returns: 888 | 889 | """ 890 | assert input.ndim == 3 891 | batch_size, n_freqs, n_frames = input.size() 892 | 893 | eps = 1e-10 894 | alpha = (train_sample_length - 1) / (train_sample_length + 1) 895 | mu = 0 896 | mu_list = [] 897 | 898 | for idx in range(input.shape[-1]): 899 | if idx < train_sample_length: 900 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha])) 901 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1] 902 | else: 903 | mu = alpha * mu + (1 - alpha) * input[:, (n_freqs // 2 - 1), idx].reshape(batch_size, 1) 904 | 905 | mu_list.append(mu) 906 | 907 | # print("input", input[:, :, idx].min(), input[:, :, idx].max(), input[:, :, idx].mean()) 908 | # print(f"alp {idx}: ", alp) 909 | # print(f"mu {idx}: {mu[128, 0]}") 910 | 911 | mu = torch.stack(mu_list, dim=-1) # [B, 1, T] 912 | input = input / (mu + eps) 913 | return input 914 | 915 | @staticmethod 916 | def forgetting_norm(input, sample_length_in_training): 917 | """ 918 | Args: 919 | input: [B, F, T] 920 | sample_length_in_training: 921 | 922 | Returns: 923 | 924 | """ 925 | assert input.ndim == 3 926 | batch_size, n_freqs, n_frames = input.size() 927 | eps = 1e-10 928 | mu = 0 929 | alpha = (sample_length_in_training - 1) / (sample_length_in_training + 1) 930 | 931 | mu_list = [] 932 | for idx in range(input.shape[-1]): 933 | if idx < sample_length_in_training: 934 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha])) 935 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1] 936 | else: 937 | current_frame_mu = torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1] 938 | mu = alpha * mu + (1 - alpha) * current_frame_mu 939 | 940 | mu_list.append(mu) 941 | 942 | # print("input", input[:, :, idx].min(), input[:, :, idx].max(), input[:, :, idx].mean()) 943 | # print(f"alp {idx}: ", alp) 944 | # print(f"mu {idx}: {mu[128, 0]}") 945 | 946 | mu = torch.stack(mu_list, dim=-1) # [B, 1, T] 947 | input = input / (mu + eps) 948 | return input 949 | 950 | @staticmethod 951 | def hybrid_norm(input, sample_length_in_training=192): 952 | """ 953 | Args: 954 | input: [B, F, T] 955 | sample_length_in_training: 956 | 957 | Returns: 958 | [B, F, T] 959 | """ 960 | assert input.ndim == 3 961 | device = input.device 962 | data_type = input.dtype 963 | batch_size, n_freqs, n_frames = input.size() 964 | eps = 1e-10 965 | 966 | mu = 0 967 | alpha = (sample_length_in_training - 1) / (sample_length_in_training + 1) 968 | mu_list = [] 969 | for idx in range(input.shape[-1]): 970 | if idx < sample_length_in_training: 971 | alp = torch.min(torch.tensor([(idx - 1) / (idx + 1), alpha])) 972 | mu = alp * mu + (1 - alp) * torch.mean(input[:, :, idx], dim=1).reshape(batch_size, 1) # [B, 1] 973 | mu_list.append(mu) 974 | else: 975 | break 976 | initial_mu = torch.stack(mu_list, dim=-1) # [B, 1, T] 977 | 978 | step_sum = torch.sum(input, dim=1) # [B, T] 979 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 980 | 981 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device) 982 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 983 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 984 | 985 | cum_mean = cumulative_sum / entry_count # B, T 986 | 987 | cum_mean = cum_mean.reshape(batch_size, 1, n_frames) # [B, 1, T] 988 | 989 | # print(initial_mu[0, 0, :50]) 990 | # print("-"*60) 991 | # print(cum_mean[0, 0, :50]) 992 | cum_mean[:, :, :sample_length_in_training] = initial_mu 993 | 994 | return input / (cum_mean + eps) 995 | 996 | @staticmethod 997 | def offline_laplace_norm(input): 998 | """ 999 | 1000 | Args: 1001 | input: [B, C, F, T] 1002 | 1003 | Returns: 1004 | [B, C, F, T] 1005 | """ 1006 | # utterance-level mu 1007 | mu = torch.mean(input, dim=(1, 2, 3), keepdim=True) 1008 | 1009 | normed = input / (mu + 1e-5) 1010 | 1011 | return normed 1012 | 1013 | @staticmethod 1014 | def cumulative_laplace_norm(input): 1015 | """ 1016 | 1017 | Args: 1018 | input: [B, C, F, T] 1019 | 1020 | Returns: 1021 | 1022 | """ 1023 | batch_size, num_channels, num_freqs, num_frames = input.size() 1024 | input = input.reshape(batch_size * num_channels, num_freqs, num_frames) 1025 | 1026 | step_sum = torch.sum(input, dim=1) # [B * C, F, T] => [B, T] 1027 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 1028 | 1029 | entry_count = torch.arange( 1030 | num_freqs, 1031 | num_freqs * num_frames + 1, 1032 | num_freqs, 1033 | dtype=input.dtype, 1034 | device=input.device 1035 | ) 1036 | entry_count = entry_count.reshape(1, num_frames) # [1, T] 1037 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 1038 | 1039 | cumulative_mean = cumulative_sum / entry_count # B, T 1040 | cumulative_mean = cumulative_mean.reshape(batch_size * num_channels, 1, num_frames) 1041 | 1042 | normed = input / (cumulative_mean + EPSILON) 1043 | 1044 | return normed.reshape(batch_size, num_channels, num_freqs, num_frames) 1045 | 1046 | @staticmethod 1047 | def offline_gaussian_norm(input): 1048 | """ 1049 | Zero-Norm 1050 | Args: 1051 | input: [B, C, F, T] 1052 | 1053 | Returns: 1054 | [B, C, F, T] 1055 | """ 1056 | mu = torch.mean(input, dim=(1, 2, 3), keepdim=True) 1057 | std = torch.std(input, dim=(1, 2, 3), keepdim=True) 1058 | 1059 | normed = (input - mu) / (std + 1e-5) 1060 | 1061 | return normed 1062 | 1063 | @staticmethod 1064 | def cumulative_layer_norm(input): 1065 | """ 1066 | Online zero-norm 1067 | 1068 | Args: 1069 | input: [B, C, F, T] 1070 | 1071 | Returns: 1072 | [B, C, F, T] 1073 | """ 1074 | batch_size, num_channels, num_freqs, num_frames = input.size() 1075 | input = input.reshape(batch_size * num_channels, num_freqs, num_frames) 1076 | 1077 | step_sum = torch.sum(input, dim=1) # [B * C, F, T] => [B, T] 1078 | step_pow_sum = torch.sum(torch.square(input), dim=1) 1079 | 1080 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 1081 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T] 1082 | 1083 | entry_count = torch.arange( 1084 | num_freqs, 1085 | num_freqs * num_frames + 1, 1086 | num_freqs, 1087 | dtype=input.dtype, 1088 | device=input.device 1089 | ) 1090 | entry_count = entry_count.reshape(1, num_frames) # [1, T] 1091 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 1092 | 1093 | cumulative_mean = cumulative_sum / entry_count # [B, T] 1094 | cumulative_var = ( 1095 | cumulative_pow_sum - 2 * cumulative_mean * cumulative_sum) / entry_count + cumulative_mean.pow( 1096 | 2) # [B, T] 1097 | cumulative_std = torch.sqrt(cumulative_var + EPSILON) # [B, T] 1098 | 1099 | cumulative_mean = cumulative_mean.reshape(batch_size * num_channels, 1, num_frames) 1100 | cumulative_std = cumulative_std.reshape(batch_size * num_channels, 1, num_frames) 1101 | 1102 | normed = (input - cumulative_mean) / cumulative_std 1103 | 1104 | return normed.reshape(batch_size, num_channels, num_freqs, num_frames) 1105 | 1106 | def norm_wrapper(self, norm_type: str): 1107 | if norm_type == "offline_laplace_norm": 1108 | norm = self.offline_laplace_norm 1109 | elif norm_type == "cumulative_laplace_norm": 1110 | norm = self.cumulative_laplace_norm 1111 | elif norm_type == "offline_gaussian_norm": 1112 | norm = self.offline_gaussian_norm 1113 | elif norm_type == "cumulative_layer_norm": 1114 | norm = self.cumulative_layer_norm 1115 | else: 1116 | raise NotImplementedError("You must set up a type of Norm. " 1117 | "e.g. offline_laplace_norm, cumulative_laplace_norm, forgetting_norm, etc.") 1118 | return norm 1119 | 1120 | def weight_init(self, m): 1121 | """ 1122 | Usage: 1123 | model = Model() 1124 | model.apply(weight_init) 1125 | """ 1126 | if isinstance(m, nn.Conv1d): 1127 | init.normal_(m.weight.data) 1128 | if m.bias is not None: 1129 | init.normal_(m.bias.data) 1130 | elif isinstance(m, nn.Conv2d): 1131 | init.xavier_normal_(m.weight.data) 1132 | if m.bias is not None: 1133 | init.normal_(m.bias.data) 1134 | elif isinstance(m, nn.Conv3d): 1135 | init.xavier_normal_(m.weight.data) 1136 | if m.bias is not None: 1137 | init.normal_(m.bias.data) 1138 | elif isinstance(m, nn.ConvTranspose1d): 1139 | init.normal_(m.weight.data) 1140 | if m.bias is not None: 1141 | init.normal_(m.bias.data) 1142 | elif isinstance(m, nn.ConvTranspose2d): 1143 | init.xavier_normal_(m.weight.data) 1144 | if m.bias is not None: 1145 | init.normal_(m.bias.data) 1146 | elif isinstance(m, nn.ConvTranspose3d): 1147 | init.xavier_normal_(m.weight.data) 1148 | if m.bias is not None: 1149 | init.normal_(m.bias.data) 1150 | elif isinstance(m, nn.BatchNorm1d): 1151 | init.normal_(m.weight.data, mean=1, std=0.02) 1152 | init.constant_(m.bias.data, 0) 1153 | elif isinstance(m, nn.BatchNorm2d): 1154 | init.normal_(m.weight.data, mean=1, std=0.02) 1155 | init.constant_(m.bias.data, 0) 1156 | elif isinstance(m, nn.BatchNorm3d): 1157 | init.normal_(m.weight.data, mean=1, std=0.02) 1158 | init.constant_(m.bias.data, 0) 1159 | elif isinstance(m, nn.Linear): 1160 | init.xavier_normal_(m.weight.data) 1161 | init.normal_(m.bias.data) 1162 | elif isinstance(m, nn.LSTM): 1163 | for param in m.parameters(): 1164 | if len(param.shape) >= 2: 1165 | init.orthogonal_(param.data) 1166 | else: 1167 | init.normal_(param.data) 1168 | elif isinstance(m, nn.LSTMCell): 1169 | for param in m.parameters(): 1170 | if len(param.shape) >= 2: 1171 | init.orthogonal_(param.data) 1172 | else: 1173 | init.normal_(param.data) 1174 | elif isinstance(m, nn.GRU): 1175 | for param in m.parameters(): 1176 | if len(param.shape) >= 2: 1177 | init.orthogonal_(param.data) 1178 | else: 1179 | init.normal_(param.data) 1180 | elif isinstance(m, nn.GRUCell): 1181 | for param in m.parameters(): 1182 | if len(param.shape) >= 2: 1183 | init.orthogonal_(param.data) 1184 | else: 1185 | init.normal_(param.data) 1186 | 1187 | 1188 | ############################################################################ 1189 | # for data normalization # 1190 | ############################################################################ 1191 | # get mu and sig 1192 | def get_mu_sig(data): 1193 | """Compute mean and standard deviation vector of input data 1194 | 1195 | Returns: 1196 | mu: mean vector (#dim by one) 1197 | sig: standard deviation vector (#dim by one) 1198 | """ 1199 | # Initialize array. 1200 | data_num = len(data) 1201 | mu_utt = [] 1202 | tmp_utt = [] 1203 | for n in range(data_num): 1204 | dim = len(data[n]) 1205 | mu_utt_tmp = np.zeros(dim) 1206 | mu_utt.append(mu_utt_tmp) 1207 | 1208 | tmp_utt_tmp = np.zeros(dim) 1209 | tmp_utt.append(tmp_utt_tmp) 1210 | 1211 | # Get mean. 1212 | for n in range(data_num): 1213 | mu_utt[n] = np.mean(data[n], 0) 1214 | mu = mu_utt 1215 | 1216 | # Get standard deviation. 1217 | for n in range(data_num): 1218 | tmp_utt[n] = np.mean(np.square(data[n] - mu[n]), 0) 1219 | sig = np.sqrt(tmp_utt) 1220 | 1221 | # Assign unit variance. 1222 | for n in range(len(sig)): 1223 | if sig[n] < 1e-5: 1224 | sig[n] = 1.0 1225 | return np.float16(mu), np.float16(sig) 1226 | 1227 | 1228 | def get_statistics_inp(inp): 1229 | """Get statistical parameter of input data. 1230 | 1231 | Args: 1232 | inp: input data 1233 | 1234 | Returns: 1235 | mu_inp: mean vector of input data 1236 | sig_inp: standard deviation vector of input data 1237 | """ 1238 | 1239 | mu_inp, sig_inp = get_mu_sig(inp) 1240 | 1241 | return mu_inp, sig_inp 1242 | 1243 | 1244 | ############################################################################ 1245 | # for plotting the samples # 1246 | ############################################################################ 1247 | def hann_window(win_samp): 1248 | tmp = np.arange(1, win_samp + 1, 1.0, dtype=np.float64) 1249 | window = 0.5 - 0.5 * np.cos((2.0 * np.pi * tmp) / (win_samp + 1)) 1250 | return np.float32(window) 1251 | 1252 | 1253 | def fig2np(fig): 1254 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 1255 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 1256 | return data 1257 | 1258 | 1259 | def plot_spectrogram_to_numpy(input_wav, fs, n_fft, n_overlap, mode, clim, label): 1260 | # cuda to cpu 1261 | input_wav = input_wav.cpu().detach().numpy() 1262 | 1263 | fig, ax = plt.subplots(figsize=(12, 3)) 1264 | 1265 | if mode == 'phase': 1266 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), noverlap=n_overlap, 1267 | cmap='jet', 1268 | mode=mode) 1269 | else: 1270 | pxx, freq, t, cax = plt.specgram(input_wav, NFFT=int(n_fft), Fs=int(fs), noverlap=n_overlap, 1271 | cmap='jet') 1272 | 1273 | plt.xlabel('Time (s)') 1274 | plt.ylabel('Frequency (Hz)') 1275 | plt.tight_layout() 1276 | plt.clim(clim) 1277 | 1278 | if label is None: 1279 | fig.colorbar(cax) 1280 | else: 1281 | fig.colorbar(cax, label=label) 1282 | 1283 | fig.canvas.draw() 1284 | data = fig2np(fig) 1285 | plt.close() 1286 | return data 1287 | 1288 | 1289 | def plot_mask_to_numpy(mask, fs, n_fft, n_overlap, clim1, clim2, cmap): 1290 | frame_num = mask.shape[0] 1291 | shift_length = n_overlap 1292 | frame_length = n_fft 1293 | signal_length = frame_num * shift_length + frame_length 1294 | 1295 | xt = np.arange(0, np.floor(10 * signal_length / fs) / 10, step=0.5) / (signal_length / fs) * frame_num + 1e-8 1296 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1) 1297 | 1298 | fig, ax = plt.subplots(figsize=(12, 3)) 1299 | im = ax.imshow(np.transpose(mask), aspect='auto', origin='lower', interpolation='none', cmap=cmap) 1300 | 1301 | plt.xlabel('Time (s)') 1302 | plt.ylabel('Frequency (kHz)') 1303 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5)) 1304 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt)))) 1305 | plt.tight_layout() 1306 | plt.colorbar(im, ax=ax) 1307 | im.set_clim(clim1, clim2) 1308 | 1309 | fig.canvas.draw() 1310 | data = fig2np(fig) 1311 | plt.close() 1312 | return data 1313 | 1314 | 1315 | def plot_error_to_numpy(estimated, target, fs, n_fft, n_overlap, mode, clim1, clim2, label): 1316 | fig, ax = plt.subplots(figsize=(12, 3)) 1317 | if mode is None: 1318 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet') 1319 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet') 1320 | im = ax.imshow(10 * np.log10(pxx1) - 10 * np.log10(pxx2), aspect='auto', origin='lower', interpolation='none', 1321 | cmap='jet') 1322 | else: 1323 | pxx1, freq, t, cax = plt.specgram(estimated, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet', 1324 | mode=mode) 1325 | pxx2, freq, t, cax = plt.specgram(target, NFFT=n_fft, Fs=int(fs), noverlap=n_overlap, cmap='jet', 1326 | mode=mode) 1327 | im = ax.imshow(pxx1 - pxx2, aspect='auto', origin='lower', interpolation='none', cmap='jet') 1328 | 1329 | frame_num = pxx1.shape[1] 1330 | shift_length = n_overlap 1331 | frame_length = n_fft 1332 | signal_length = frame_num * shift_length + frame_length 1333 | 1334 | xt = np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5) / (signal_length / fs) * frame_num 1335 | yt = (n_fft / 2) / (fs / 1000 / 2) * np.arange(0, (fs / 1000 / 2) + 1) 1336 | 1337 | plt.xlabel('Time (s)') 1338 | plt.ylabel('Frequency (kHz)') 1339 | plt.xticks(xt, np.arange(0, np.floor(10 * (signal_length / fs)) / 10, step=0.5)) 1340 | plt.yticks(yt, np.int16(np.linspace(0, int((fs / 1000) / 2), len(yt)))) 1341 | plt.tight_layout() 1342 | plt.colorbar(im, ax=ax, label=label) 1343 | im.set_clim(clim1, clim2) 1344 | 1345 | fig.canvas.draw() 1346 | data = fig2np(fig) 1347 | plt.close() 1348 | return data 1349 | 1350 | 1351 | ############################################################################ 1352 | # for trainer.py # 1353 | ############################################################################ 1354 | class Bar(object): 1355 | def __init__(self, dataloader): 1356 | if not hasattr(dataloader, 'dataset'): 1357 | raise ValueError('Attribute `dataset` not exists in dataloder.') 1358 | if not hasattr(dataloader, 'batch_size'): 1359 | raise ValueError('Attribute `batch_size` not exists in dataloder.') 1360 | 1361 | self.dataloader = dataloader 1362 | self.iterator = iter(dataloader) 1363 | self.dataset = dataloader.dataset 1364 | self.batch_size = dataloader.batch_size 1365 | self._idx = 0 1366 | self._batch_idx = 0 1367 | self._time = [] 1368 | self._DISPLAY_LENGTH = 50 1369 | 1370 | def __len__(self): 1371 | return len(self.dataloader) 1372 | 1373 | def __iter__(self): 1374 | return self 1375 | 1376 | def __next__(self): 1377 | if len(self._time) < 2: 1378 | self._time.append(time.time()) 1379 | 1380 | self._batch_idx += self.batch_size 1381 | if self._batch_idx > len(self.dataset): 1382 | self._batch_idx = len(self.dataset) 1383 | 1384 | try: 1385 | batch = next(self.iterator) 1386 | self._display() 1387 | except StopIteration: 1388 | raise StopIteration() 1389 | 1390 | self._idx += 1 1391 | if self._idx >= len(self.dataloader): 1392 | self._reset() 1393 | 1394 | return batch 1395 | 1396 | def _display(self): 1397 | if len(self._time) > 1: 1398 | t = (self._time[-1] - self._time[-2]) 1399 | eta = t * (len(self.dataloader) - self._idx) 1400 | else: 1401 | eta = 0 1402 | 1403 | rate = self._idx / len(self.dataloader) 1404 | len_bar = int(rate * self._DISPLAY_LENGTH) 1405 | bar = ('=' * len_bar + '>').ljust(self._DISPLAY_LENGTH, '.') 1406 | idx = str(self._batch_idx).rjust(len(str(len(self.dataset))), ' ') 1407 | 1408 | tmpl = '\r{}/{}: [{}] - ETA {:.1f}s'.format( 1409 | idx, 1410 | len(self.dataset), 1411 | bar, 1412 | eta 1413 | ) 1414 | print(tmpl, end='') 1415 | if self._batch_idx == len(self.dataset): 1416 | print() 1417 | 1418 | def _reset(self): 1419 | self._idx = 0 1420 | self._batch_idx = 0 1421 | self._time = [] 1422 | -------------------------------------------------------------------------------- /train_interface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import shutil 5 | import numpy as np 6 | import config as cfg 7 | from models import DCCRN, CRN, FullSubNet # you can import 'DCCRN' or 'CRN' or 'FullSubNet' 8 | from write_on_tensorboard import Writer 9 | from dataloader import create_dataloader 10 | from trainer import model_train, model_validate, \ 11 | model_perceptual_train, model_perceptual_validate, \ 12 | dccrn_direct_train, dccrn_direct_validate, \ 13 | crn_direct_train, crn_direct_validate, \ 14 | fullsubnet_train, fullsubnet_validate 15 | 16 | 17 | ############################################################################### 18 | # Helper function definition # 19 | ############################################################################### 20 | # Write training related parameters into the log file. 21 | def write_status_to_log_file(fp, total_parameters): 22 | fp.write('%d-%d-%d %d:%d:%d\n' % 23 | (time.localtime().tm_year, time.localtime().tm_mon, 24 | time.localtime().tm_mday, time.localtime().tm_hour, 25 | time.localtime().tm_min, time.localtime().tm_sec)) 26 | fp.write('total params : %d (%.2f M, %.2f MBytes)\n' % 27 | (total_parameters, 28 | total_parameters / 1000000.0, 29 | total_parameters * 4.0 / 1000000.0)) 30 | 31 | 32 | # Calculate the size of total network. 33 | def calculate_total_params(our_model): 34 | total_parameters = 0 35 | for variable in our_model.parameters(): 36 | shape = variable.size() 37 | variable_parameters = 1 38 | for dim in shape: 39 | variable_parameters *= dim 40 | total_parameters += variable_parameters 41 | 42 | return total_parameters 43 | 44 | 45 | ############################################################################### 46 | # Parameter Initialization and Setting for model training # 47 | ############################################################################### 48 | # Set device 49 | DEVICE = torch.device(cfg.DEVICE) 50 | 51 | # Set model 52 | if cfg.model == 'DCCRN': 53 | model = DCCRN().to(DEVICE) 54 | elif cfg.model == 'CRN': 55 | model = CRN().to(DEVICE) 56 | elif cfg.model == 'FullSubNet': 57 | model = FullSubNet().to(DEVICE) 58 | # Set optimizer and learning rate 59 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate) 60 | total_params = calculate_total_params(model) 61 | 62 | # Set trainer and estimator 63 | if cfg.perceptual is not False: 64 | trainer = model_perceptual_train 65 | estimator = model_perceptual_validate 66 | elif cfg.model == 'FullSubNet': 67 | trainer = fullsubnet_train 68 | estimator = fullsubnet_validate 69 | elif cfg.masking_mode == 'Direct(None make)' and cfg.model == 'DCCRN': 70 | trainer = dccrn_direct_train 71 | estimator = dccrn_direct_validate 72 | elif cfg.masking_mode == 'Direct(None make)' and cfg.model == 'CRN': 73 | trainer = crn_direct_train 74 | estimator = crn_direct_validate 75 | else: 76 | trainer = model_train 77 | estimator = model_validate 78 | 79 | ############################################################################### 80 | # Confirm model information # 81 | ############################################################################### 82 | print('%d-%d-%d %d:%d:%d\n' % 83 | (time.localtime().tm_year, time.localtime().tm_mon, 84 | time.localtime().tm_mday, time.localtime().tm_hour, 85 | time.localtime().tm_min, time.localtime().tm_sec)) 86 | print('total params : %d (%.2f M, %.2f MBytes)\n' % 87 | (total_params, 88 | total_params / 1000000.0, 89 | total_params * 4.0 / 1000000.0)) 90 | 91 | ############################################################################### 92 | # Create Dataloader # 93 | ############################################################################### 94 | train_loader = create_dataloader(mode='train') 95 | validation_loader = create_dataloader(mode='valid') 96 | 97 | ############################################################################### 98 | # Set a log file to store progress. # 99 | # Set a hps file to store hyper-parameters information. # 100 | ############################################################################### 101 | if cfg.chkpt_model is not None: # Load the checkpoint 102 | print('Resuming from checkpoint: %s' % cfg.chkpt_path) 103 | 104 | # Set a log file to store progress. 105 | dir_to_save = cfg.job_dir + cfg.chkpt_model 106 | dir_to_logs = cfg.logs_dir + cfg.chkpt_model 107 | 108 | checkpoint = torch.load(cfg.chkpt_path) 109 | model.load_state_dict(checkpoint['model']) 110 | optimizer.load_state_dict(checkpoint['optimizer']) 111 | epoch_start_idx = checkpoint['epoch'] + 1 112 | mse_vali_total = np.load(str(dir_to_save + '/mse_vali_total.npy')) 113 | # if the loaded length is shorter than I expected, extend the length 114 | if len(mse_vali_total) < cfg.max_epochs: 115 | plus = cfg.max_epochs - len(mse_vali_total) 116 | mse_vali_total = np.concatenate((mse_vali_total, np.zeros(plus)), 0) 117 | else: # First learning 118 | print('Starting new training run...') 119 | 120 | # make the file directory to save the models 121 | if not os.path.exists(cfg.job_dir): 122 | os.mkdir(cfg.job_dir) 123 | if not os.path.exists(cfg.logs_dir): 124 | os.mkdir(cfg.logs_dir) 125 | 126 | epoch_start_idx = 1 127 | mse_vali_total = np.zeros(cfg.max_epochs) 128 | 129 | # Set a log file to store progress. 130 | dir_to_save = cfg.job_dir + cfg.expr_num + '_%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) + \ 131 | '_%s' % cfg.model + '_%s' % cfg.loss 132 | dir_to_logs = cfg.logs_dir + cfg.expr_num + '_%d.%d' % (time.localtime().tm_mon, time.localtime().tm_mday) \ 133 | + '_%s' % cfg.model + '_%s' % cfg.loss 134 | 135 | # make the file directory 136 | if not os.path.exists(dir_to_save): 137 | os.mkdir(dir_to_save) 138 | os.mkdir(dir_to_logs) 139 | 140 | # logging 141 | log_fname = str(dir_to_save + '/log.txt') 142 | if not os.path.exists(log_fname): 143 | fp = open(log_fname, 'w') 144 | write_status_to_log_file(fp, total_params) 145 | else: 146 | fp = open(log_fname, 'a') 147 | 148 | ############################################################################### 149 | ############################################################################### 150 | # Main program start !! # 151 | ############################################################################### 152 | ############################################################################### 153 | # Writer initialize 154 | writer = Writer(dir_to_logs) 155 | 156 | ############################################################################### 157 | # Train # 158 | ############################################################################### 159 | if cfg.perceptual is not False: # train with perceptual loss function 160 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1): 161 | start_time = time.time() 162 | # Training 163 | train_loss, train_main_loss, train_perceptual_loss = trainer(model, optimizer, train_loader, DEVICE) 164 | 165 | # save checkpoint file to resume training 166 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch)) 167 | torch.save({ 168 | 'model': model.state_dict(), 169 | 'optimizer': optimizer.state_dict(), 170 | 'epoch': epoch 171 | }, save_path) 172 | 173 | # Validation 174 | vali_loss, validation_main_loss, validation_perceptual_loss, vali_pesq, vali_stoi = \ 175 | estimator(model, validation_loader, writer, dir_to_save, epoch, DEVICE) 176 | # write the loss on tensorboard 177 | writer.log_loss(train_loss, vali_loss, epoch) 178 | writer.log_score(vali_pesq, vali_stoi, epoch) 179 | writer.log_sub_loss(train_main_loss, train_perceptual_loss, 180 | validation_main_loss, validation_perceptual_loss, epoch) 181 | 182 | print('Epoch [{}] | T {:.6f} | V {:.6} ' 183 | .format(epoch, train_loss, vali_loss)) 184 | print(' | T {:.6f} {:.6f} | V {:.6} {:.6f} takes {:.2f} seconds\n' 185 | .format(epoch, train_main_loss, train_perceptual_loss, validation_main_loss, validation_perceptual_loss, 186 | time.time() - start_time)) 187 | print(' | V PESQ: {:.6f} | STOI: {:.6f} '.format(vali_pesq, vali_stoi)) 188 | # log file save 189 | fp.write('Epoch [{}] | T {:.6f} | V {:.6}\n' 190 | .format(epoch, train_loss, vali_loss)) 191 | fp.write(' | T {:.6f} {:.6f} | V {:.6} {:.6f} takes {:.2f} seconds\n' 192 | .format(epoch, train_main_loss, train_perceptual_loss, 193 | validation_main_loss, validation_perceptual_loss, time.time() - start_time)) 194 | fp.write(' | V PESQ: {:.6f} | STOI: {:.6f} \n'.format(vali_pesq, vali_stoi)) 195 | 196 | mse_vali_total[epoch - 1] = vali_loss 197 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total) 198 | else: 199 | for epoch in range(epoch_start_idx, cfg.max_epochs + 1): 200 | start_time = time.time() 201 | # Training 202 | train_loss = trainer(model, optimizer, train_loader, DEVICE) 203 | 204 | # save checkpoint file to resume training 205 | save_path = str(dir_to_save + '/' + ('chkpt_%d.pt' % epoch)) 206 | torch.save({ 207 | 'model': model.state_dict(), 208 | 'optimizer': optimizer.state_dict(), 209 | 'epoch': epoch 210 | }, save_path) 211 | 212 | # Validation 213 | vali_loss, vali_pesq, vali_stoi = \ 214 | estimator(model, validation_loader, writer, dir_to_save, epoch, DEVICE) 215 | # write the loss on tensorboard 216 | writer.log_loss(train_loss, vali_loss, epoch) 217 | writer.log_score(vali_pesq, vali_stoi, epoch) 218 | 219 | print('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n' 220 | .format(epoch, train_loss, vali_loss, time.time() - start_time)) 221 | print(' | V PESQ: {:.6f} | STOI: {:.6f} '.format(vali_pesq, vali_stoi)) 222 | # log file save 223 | fp.write('Epoch [{}] | T {:.6f} | V {:.6} takes {:.2f} seconds\n' 224 | .format(epoch, train_loss, vali_loss, time.time() - start_time)) 225 | fp.write(' | V PESQ: {:.6f} | STOI: {:.6f} \n'.format(vali_pesq, vali_stoi)) 226 | 227 | mse_vali_total[epoch - 1] = vali_loss 228 | np.save(str(dir_to_save + '/mse_vali_total.npy'), mse_vali_total) 229 | 230 | fp.close() 231 | print('Training has been finished.') 232 | 233 | # Copy optimum model that has minimum MSE. 234 | print('Save optimum models...') 235 | min_index = np.argmin(mse_vali_total) 236 | print('Minimum validation loss is at ' + str(min_index + 1) + '.') 237 | src_file = str(dir_to_save + '/' + ('chkpt_%d.pt' % (min_index + 1))) 238 | tgt_file = str(dir_to_save + '/chkpt_opt.pt') 239 | shutil.copy(src_file, tgt_file) 240 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Where the model is actually trained and validated 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import tools_for_model as tools 8 | from tools_for_estimate import cal_pesq, cal_stoi 9 | 10 | 11 | ####################################################################### 12 | # For train # 13 | ####################################################################### 14 | # T-F masking 15 | def model_train(model, optimizer, train_loader, DEVICE): 16 | # initialization 17 | train_loss = 0 18 | batch_num = 0 19 | 20 | # arr = [] 21 | # train 22 | model.train() 23 | for inputs, targets in tools.Bar(train_loader): 24 | batch_num += 1 25 | 26 | # to cuda 27 | inputs = inputs.float().to(DEVICE) 28 | targets = targets.float().to(DEVICE) 29 | 30 | _, _, outputs = model(inputs, targets) 31 | loss = model.loss(outputs, targets) 32 | # # if you want to check the scale of the loss 33 | # print('loss: {:.4}'.format(loss)) 34 | 35 | optimizer.zero_grad() 36 | loss.backward() 37 | optimizer.step() 38 | 39 | train_loss += loss 40 | train_loss /= batch_num 41 | 42 | return train_loss 43 | 44 | 45 | def model_perceptual_train(model, optimizer, train_loader, DEVICE): 46 | # initialization 47 | train_loss = 0 48 | train_main_loss = 0 49 | train_perceptual_loss = 0 50 | batch_num = 0 51 | 52 | # train 53 | model.train() 54 | for inputs, targets in tools.Bar(train_loader): 55 | batch_num += 1 56 | 57 | # to cuda 58 | inputs = inputs.float().to(DEVICE) 59 | targets = targets.float().to(DEVICE) 60 | 61 | real_spec, img_spec, outputs = model(inputs) 62 | main_loss = model.loss(outputs, targets) 63 | perceptual_loss = model.loss(outputs, targets, real_spec, img_spec, perceptual=True) 64 | 65 | # the constraint ratio 66 | r1 = 1 67 | r2 = 1 68 | r3 = r1 + r2 69 | loss = (r1 * main_loss + r2 * perceptual_loss) / r3 70 | 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | train_loss += loss 76 | train_main_loss += r1 * main_loss 77 | train_perceptual_loss += r2 * perceptual_loss 78 | train_loss /= batch_num 79 | train_main_loss /= batch_num 80 | train_perceptual_loss /= batch_num 81 | 82 | return train_loss, train_main_loss, train_perceptual_loss 83 | 84 | 85 | def fullsubnet_train(model, optimizer, train_loader, DEVICE): 86 | # initialization 87 | train_loss = 0 88 | batch_num = 0 89 | 90 | # arr = [] 91 | # train 92 | model.train() 93 | for inputs, targets in tools.Bar(train_loader): 94 | batch_num += 1 95 | 96 | # to cuda 97 | inputs = inputs.float().to(DEVICE) 98 | targets = targets.float().to(DEVICE) 99 | 100 | noisy_complex = tools.stft(inputs) 101 | clean_complex = tools.stft(targets) 102 | 103 | noisy_mag, _ = tools.mag_phase(noisy_complex) 104 | cIRM = tools.build_complex_ideal_ratio_mask(noisy_complex, clean_complex) 105 | 106 | cRM = model(noisy_mag) 107 | loss = model.loss(cIRM, cRM) 108 | # # if you want to check the scale of the loss 109 | # print('loss: {:.4}'.format(loss)) 110 | 111 | optimizer.zero_grad() 112 | loss.backward() 113 | optimizer.step() 114 | 115 | train_loss += loss 116 | train_loss /= batch_num 117 | 118 | return train_loss 119 | 120 | 121 | # Spectral mapping 122 | def dccrn_direct_train(model, optimizer, train_loader, DEVICE): 123 | # initialization 124 | train_loss = 0 125 | batch_num = 0 126 | 127 | # train 128 | model.train() 129 | for inputs, targets in tools.Bar(train_loader): 130 | batch_num += 1 131 | 132 | # to cuda 133 | inputs = inputs.float().to(DEVICE) 134 | targets = targets.float().to(DEVICE) 135 | 136 | output_real, target_real, output_imag, target_imag, _ = model(inputs, targets) 137 | real_loss = model.loss(output_real, target_real) 138 | imag_loss = model.loss(output_imag, target_imag) 139 | loss = (real_loss + imag_loss) / 2 140 | 141 | # # if you want to check the scale of the loss 142 | # print('loss: {:.4}'.format(loss)) 143 | 144 | optimizer.zero_grad() 145 | loss.backward() 146 | optimizer.step() 147 | 148 | train_loss += loss 149 | train_loss /= batch_num 150 | 151 | return train_loss 152 | 153 | 154 | def crn_direct_train(model, optimizer, train_loader, DEVICE): 155 | # initialization 156 | train_loss = 0 157 | batch_num = 0 158 | 159 | # train 160 | model.train() 161 | for inputs, targets in tools.Bar(train_loader): 162 | batch_num += 1 163 | 164 | # to cuda 165 | inputs = inputs.float().to(DEVICE) 166 | targets = targets.float().to(DEVICE) 167 | 168 | output_mag, target_mag, _ = model(inputs, targets) 169 | loss = model.loss(output_mag, target_mag) 170 | 171 | # # if you want to check the scale of the loss 172 | # print('loss: {:.4}'.format(loss)) 173 | 174 | optimizer.zero_grad() 175 | loss.backward() 176 | optimizer.step() 177 | 178 | train_loss += loss 179 | train_loss /= batch_num 180 | 181 | return train_loss 182 | 183 | 184 | ####################################################################### 185 | # For validation # 186 | ####################################################################### 187 | # T-F masking 188 | def model_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE): 189 | # initialization 190 | validation_loss = 0 191 | batch_num = 0 192 | 193 | avg_pesq = 0 194 | avg_stoi = 0 195 | 196 | # for record the score each samples 197 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a') 198 | 199 | model.eval() 200 | with torch.no_grad(): 201 | for inputs, targets in tools.Bar(validation_loader): 202 | batch_num += 1 203 | 204 | # to cuda 205 | inputs = inputs.float().to(DEVICE) 206 | targets = targets.float().to(DEVICE) 207 | 208 | _, _, outputs = model(inputs, targets) 209 | loss = model.loss(outputs, targets) 210 | 211 | validation_loss += loss 212 | 213 | # estimate the output speech with pesq and stoi 214 | estimated_wavs = outputs.cpu().detach().numpy() 215 | clean_wavs = targets.cpu().detach().numpy() 216 | 217 | pesq = cal_pesq(estimated_wavs, clean_wavs) 218 | stoi = cal_stoi(estimated_wavs, clean_wavs) 219 | 220 | # pesq: 0.1 better / stoi: 0.01 better 221 | for i in range(len(pesq)): 222 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i])) 223 | 224 | # reshape for sum 225 | pesq = np.reshape(pesq, (1, -1)) 226 | stoi = np.reshape(stoi, (1, -1)) 227 | 228 | avg_pesq += sum(pesq[0]) / len(inputs) 229 | avg_stoi += sum(stoi[0]) / len(inputs) 230 | 231 | # save the samples to tensorboard 232 | if epoch % 10 == 0: 233 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch) 234 | 235 | validation_loss /= batch_num 236 | avg_pesq /= batch_num 237 | avg_stoi /= batch_num 238 | 239 | return validation_loss, avg_pesq, avg_stoi 240 | 241 | 242 | def model_perceptual_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE): 243 | # initialization 244 | validation_loss = 0 245 | validation_main_loss = 0 246 | validation_perceptual_loss = 0 247 | batch_num = 0 248 | 249 | avg_pesq = 0 250 | avg_stoi = 0 251 | 252 | # for record the score each samples 253 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a') 254 | 255 | model.eval() 256 | with torch.no_grad(): 257 | for inputs, targets in tools.Bar(validation_loader): 258 | batch_num += 1 259 | 260 | # to cuda 261 | inputs = inputs.float().to(DEVICE) 262 | targets = targets.float().to(DEVICE) 263 | 264 | real_spec, img_spec, outputs = model(inputs) 265 | main_loss = model.loss(outputs, targets) 266 | perceptual_loss = model.loss(outputs, targets, real_spec, img_spec, perceptual=True) 267 | 268 | # the constraint ratio 269 | r1 = 1 270 | r2 = 1 271 | r3 = r1 + r2 272 | loss = (r1 * main_loss + r2 * perceptual_loss) / r3 273 | 274 | validation_loss += loss 275 | validation_main_loss += r1 * main_loss 276 | validation_perceptual_loss += r2 * perceptual_loss 277 | 278 | # estimate the output speech with pesq and stoi 279 | estimated_wavs = outputs.cpu().detach().numpy() 280 | clean_wavs = targets.cpu().detach().numpy() 281 | 282 | pesq = cal_pesq(estimated_wavs, clean_wavs) 283 | stoi = cal_stoi(estimated_wavs, clean_wavs) 284 | 285 | # pesq: 0.1 better / stoi: 0.01 better 286 | for i in range(len(pesq)): 287 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i])) 288 | 289 | # reshape for sum 290 | pesq = np.reshape(pesq, (1, -1)) 291 | stoi = np.reshape(stoi, (1, -1)) 292 | 293 | avg_pesq += sum(pesq[0]) / len(inputs) 294 | avg_stoi += sum(stoi[0]) / len(inputs) 295 | 296 | # save the samples to tensorboard 297 | if epoch % 10 == 0: 298 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch) 299 | 300 | validation_loss /= batch_num 301 | validation_main_loss /= batch_num 302 | validation_perceptual_loss /= batch_num 303 | avg_pesq /= batch_num 304 | avg_stoi /= batch_num 305 | 306 | return validation_loss, validation_main_loss, validation_perceptual_loss, avg_pesq, avg_stoi 307 | 308 | 309 | def fullsubnet_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE): 310 | # initialization 311 | validation_loss = 0 312 | batch_num = 0 313 | 314 | avg_pesq = 0 315 | avg_stoi = 0 316 | 317 | # for record the score each samples 318 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a') 319 | 320 | model.eval() 321 | with torch.no_grad(): 322 | for inputs, targets in tools.Bar(validation_loader): 323 | batch_num += 1 324 | 325 | # to cuda 326 | inputs = inputs.float().to(DEVICE) 327 | targets = targets.float().to(DEVICE) 328 | 329 | noisy_complex = tools.stft(inputs) 330 | clean_complex = tools.stft(targets) 331 | 332 | noisy_mag, _ = tools.mag_phase(noisy_complex) 333 | cIRM = tools.build_complex_ideal_ratio_mask(noisy_complex, clean_complex) 334 | 335 | cRM = model(noisy_mag) 336 | loss = model.loss(cIRM, cRM) 337 | 338 | validation_loss += loss 339 | 340 | # estimate the output speech with pesq and stoi 341 | cRM = tools.decompress_cIRM(cRM) 342 | enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[..., 1] * noisy_complex.imag 343 | enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[..., 0] * noisy_complex.imag 344 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1) 345 | enhanced_outputs = tools.istft(enhanced_complex, length=inputs.size(-1)) 346 | 347 | estimated_wavs = enhanced_outputs.cpu().detach().numpy() 348 | clean_wavs = targets.cpu().detach().numpy() 349 | 350 | pesq = cal_pesq(estimated_wavs, clean_wavs) 351 | stoi = cal_stoi(estimated_wavs, clean_wavs) 352 | 353 | # pesq: 0.1 better / stoi: 0.01 better 354 | for i in range(len(pesq)): 355 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i])) 356 | 357 | # reshape for sum 358 | pesq = np.reshape(pesq, (1, -1)) 359 | stoi = np.reshape(stoi, (1, -1)) 360 | 361 | avg_pesq += sum(pesq[0]) / len(inputs) 362 | avg_stoi += sum(stoi[0]) / len(inputs) 363 | 364 | # save the samples to tensorboard 365 | if epoch % 10 == 0: 366 | writer.log_wav(inputs[0], targets[0], enhanced_outputs[0], epoch) 367 | 368 | validation_loss /= batch_num 369 | avg_pesq /= batch_num 370 | avg_stoi /= batch_num 371 | 372 | return validation_loss, avg_pesq, avg_stoi 373 | 374 | 375 | # Spectral mapping 376 | def dccrn_direct_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE): 377 | # initialization 378 | validation_loss = 0 379 | batch_num = 0 380 | 381 | avg_pesq = 0 382 | avg_stoi = 0 383 | 384 | # for record the score each samples 385 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a') 386 | 387 | model.eval() 388 | with torch.no_grad(): 389 | for inputs, targets in tools.Bar(validation_loader): 390 | batch_num += 1 391 | 392 | # to cuda 393 | inputs = inputs.float().to(DEVICE) 394 | targets = targets.float().to(DEVICE) 395 | 396 | output_real, target_real, output_imag, target_imag, outputs = model(inputs, targets) 397 | real_loss = model.loss(output_real, target_real) 398 | imag_loss = model.loss(output_imag, target_imag) 399 | loss = (real_loss + imag_loss) / 2 400 | 401 | validation_loss += loss 402 | 403 | # estimate the output speech with pesq and stoi 404 | estimated_wavs = outputs.cpu().detach().numpy() 405 | clean_wavs = targets.cpu().detach().numpy() 406 | 407 | pesq = cal_pesq(estimated_wavs, clean_wavs) 408 | stoi = cal_stoi(estimated_wavs, clean_wavs) 409 | 410 | # pesq: 0.1 better / stoi: 0.01 better 411 | for i in range(len(pesq)): 412 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i])) 413 | 414 | # reshape for sum 415 | pesq = np.reshape(pesq, (1, -1)) 416 | stoi = np.reshape(stoi, (1, -1)) 417 | 418 | avg_pesq += sum(pesq[0]) / len(inputs) 419 | avg_stoi += sum(stoi[0]) / len(inputs) 420 | 421 | # save the samples to tensorboard 422 | if epoch % 10 == 0: 423 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch) 424 | 425 | validation_loss /= batch_num 426 | avg_pesq /= batch_num 427 | avg_stoi /= batch_num 428 | 429 | return validation_loss, avg_pesq, avg_stoi 430 | 431 | 432 | def crn_direct_validate(model, validation_loader, writer, dir_to_save, epoch, DEVICE): 433 | # initialization 434 | validation_loss = 0 435 | batch_num = 0 436 | 437 | avg_pesq = 0 438 | avg_stoi = 0 439 | 440 | # for record the score each samples 441 | f_score = open(dir_to_save + '/Epoch_' + '%d_SCORES' % epoch, 'a') 442 | 443 | model.eval() 444 | with torch.no_grad(): 445 | for inputs, targets in tools.Bar(validation_loader): 446 | batch_num += 1 447 | 448 | # to cuda 449 | inputs = inputs.float().to(DEVICE) 450 | targets = targets.float().to(DEVICE) 451 | 452 | output_mag, target_mag, outputs = model(inputs, targets) 453 | loss = model.loss(output_mag, target_mag) 454 | 455 | validation_loss += loss 456 | 457 | # estimate the output speech with pesq and stoi 458 | estimated_wavs = outputs.cpu().detach().numpy() 459 | clean_wavs = targets.cpu().detach().numpy() 460 | 461 | pesq = cal_pesq(estimated_wavs, clean_wavs) 462 | stoi = cal_stoi(estimated_wavs, clean_wavs) 463 | 464 | # pesq: 0.1 better / stoi: 0.01 better 465 | for i in range(len(pesq)): 466 | f_score.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq[i], stoi[i])) 467 | 468 | # reshape for sum 469 | pesq = np.reshape(pesq, (1, -1)) 470 | stoi = np.reshape(stoi, (1, -1)) 471 | 472 | avg_pesq += sum(pesq[0]) / len(inputs) 473 | avg_stoi += sum(stoi[0]) / len(inputs) 474 | 475 | # save the samples to tensorboard 476 | if epoch % 10 == 0: 477 | writer.log_wav(inputs[0], targets[0], outputs[0], epoch) 478 | 479 | validation_loss /= batch_num 480 | avg_pesq /= batch_num 481 | avg_stoi /= batch_num 482 | 483 | return validation_loss, avg_pesq, avg_stoi 484 | -------------------------------------------------------------------------------- /write_on_tensorboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | For observing the results using tensorboard 3 | 4 | 1. wav 5 | 2. spectrogram 6 | 3. loss 7 | """ 8 | from tensorboardX import SummaryWriter 9 | import matplotlib 10 | import config as cfg 11 | 12 | 13 | class Writer(SummaryWriter): 14 | def __init__(self, logdir): 15 | super(Writer, self).__init__(logdir) 16 | # mask real/ imag 17 | cmap_custom = { 18 | 'red': ((0.0, 0.0, 0.0), 19 | (1 / 63, 0.0, 0.0), 20 | (2 / 63, 0.0, 0.0), 21 | (3 / 63, 0.0, 0.0), 22 | (4 / 63, 0.0, 0.0), 23 | (5 / 63, 0.0, 0.0), 24 | (6 / 63, 0.0, 0.0), 25 | (7 / 63, 0.0, 0.0), 26 | (8 / 63, 0.0, 0.0), 27 | (9 / 63, 0.0, 0.0), 28 | (10 / 63, 0.0, 0.0), 29 | (11 / 63, 0.0, 0.0), 30 | (12 / 63, 0.0, 0.0), 31 | (13 / 63, 0.0, 0.0), 32 | (14 / 63, 0.0, 0.0), 33 | (15 / 63, 0.0, 0.0), 34 | (16 / 63, 0.0, 0.0), 35 | (17 / 63, 0.0, 0.0), 36 | (18 / 63, 0.0, 0.0), 37 | (19 / 63, 0.0, 0.0), 38 | (20 / 63, 0.0, 0.0), 39 | (21 / 63, 0.0, 0.0), 40 | (22 / 63, 0.0, 0.0), 41 | (23 / 63, 0.0, 0.0), 42 | (24 / 63, 0.5625, 0.5625), 43 | (25 / 63, 0.6250, 0.6250), 44 | (26 / 63, 0.6875, 0.6875), 45 | (27 / 63, 0.7500, 0.7500), 46 | (28 / 63, 0.8125, 0.8125), 47 | (29 / 63, 0.8750, 0.8750), 48 | (30 / 63, 0.9375, 0.9375), 49 | (31 / 63, 1.0, 1.0), 50 | (32 / 63, 1.0, 1.0), 51 | (33 / 63, 1.0, 1.0), 52 | (34 / 63, 1.0, 1.0), 53 | (35 / 63, 1.0, 1.0), 54 | (36 / 63, 1.0, 1.0), 55 | (37 / 63, 1.0, 1.0), 56 | (38 / 63, 1.0, 1.0), 57 | (39 / 63, 1.0, 1.0), 58 | (40 / 63, 1.0, 1.0), 59 | (41 / 63, 1.0, 1.0), 60 | (42 / 63, 1.0, 1.0), 61 | (43 / 63, 1.0, 1.0), 62 | (44 / 63, 1.0, 1.0), 63 | (45 / 63, 1.0, 1.0), 64 | (46 / 63, 1.0, 1.0), 65 | (47 / 63, 1.0, 1.0), 66 | (48 / 63, 1.0, 1.0), 67 | (49 / 63, 1.0, 1.0), 68 | (50 / 63, 1.0, 1.0), 69 | (51 / 63, 1.0, 1.0), 70 | (52 / 63, 1.0, 1.0), 71 | (53 / 63, 1.0, 1.0), 72 | (54 / 63, 1.0, 1.0), 73 | (55 / 63, 1.0, 1.0), 74 | (56 / 63, 0.9375, 0.9375), 75 | (57 / 63, 0.8750, 0.8750), 76 | (58 / 63, 0.8125, 0.8125), 77 | (59 / 63, 0.7500, 0.7500), 78 | (60 / 63, 0.6875, 0.6875), 79 | (61 / 63, 0.6250, 0.6250), 80 | (62 / 63, 0.5625, 0.5625), 81 | (63 / 63, 0.5000, 0.5000)), 82 | 'green': ((0.0, 0.0, 0.0), 83 | (1 / 63, 0.0, 0.0), 84 | (2 / 63, 0.0, 0.0), 85 | (3 / 63, 0.0, 0.0), 86 | (4 / 63, 0.0, 0.0), 87 | (5 / 63, 0.0, 0.0), 88 | (6 / 63, 0.0, 0.0), 89 | (7 / 63, 0.0, 0.0), 90 | (8 / 63, 0.0625, 0.0625), 91 | (9 / 63, 0.1250, 0.1250), 92 | (10 / 63, 0.1875, 0.1875), 93 | (11 / 63, 0.2500, 0.2500), 94 | (12 / 63, 0.3125, 0.3125), 95 | (13 / 63, 0.3750, 0.3750), 96 | (14 / 63, 0.4375, 0.4375), 97 | (15 / 63, 0.5000, 0.5000), 98 | (16 / 63, 0.5625, 0.5625), 99 | (17 / 63, 0.6250, 0.6250), 100 | (18 / 63, 0.6875, 0.6875), 101 | (19 / 63, 0.7500, 0.7500), 102 | (20 / 63, 0.8125, 0.8125), 103 | (21 / 63, 0.8750, 0.8750), 104 | (22 / 63, 0.9375, 0.9375), 105 | (23 / 63, 1.0, 1.0), 106 | (24 / 63, 1.0, 1.0), 107 | (25 / 63, 1.0, 1.0), 108 | (26 / 63, 1.0, 1.0), 109 | (27 / 63, 1.0, 1.0), 110 | (28 / 63, 1.0, 1.0), 111 | (29 / 63, 1.0, 1.0), 112 | (30 / 63, 1.0, 1.0), 113 | (31 / 63, 1.0, 1.0), 114 | (32 / 63, 1.0, 1.0), 115 | (33 / 63, 1.0, 1.0), 116 | (34 / 63, 1.0, 1.0), 117 | (35 / 63, 1.0, 1.0), 118 | (36 / 63, 1.0, 1.0), 119 | (37 / 63, 1.0, 1.0), 120 | (38 / 63, 1.0, 1.0), 121 | (39 / 63, 1.0, 1.0), 122 | (40 / 63, 0.9375, 0.9375), 123 | (41 / 63, 0.8750, 0.8750), 124 | (42 / 63, 0.8125, 0.8125), 125 | (43 / 63, 0.7500, 0.7500), 126 | (44 / 63, 0.6875, 0.6875), 127 | (45 / 63, 0.6250, 0.6250), 128 | (46 / 63, 0.5625, 0.5625), 129 | (47 / 63, 0.5000, 0.5000), 130 | (48 / 63, 0.4375, 0.4375), 131 | (49 / 63, 0.3750, 0.3750), 132 | (50 / 63, 0.3125, 0.3125), 133 | (51 / 63, 0.2500, 0.2500), 134 | (52 / 63, 0.1875, 0.1875), 135 | (53 / 63, 0.1250, 0.1250), 136 | (54 / 63, 0.0625, 0.0625), 137 | (55 / 63, 0.0, 0.0), 138 | (56 / 63, 0.0, 0.0), 139 | (57 / 63, 0.0, 0.0), 140 | (58 / 63, 0.0, 0.0), 141 | (59 / 63, 0.0, 0.0), 142 | (60 / 63, 0.0, 0.0), 143 | (61 / 63, 0.0, 0.0), 144 | (62 / 63, 0.0, 0.0), 145 | (63 / 63, 0.0, 0.0)), 146 | 'blue': ((0.0, 0.5625, 0.5625), 147 | (1 / 63, 0.6250, 0.6250), 148 | (2 / 63, 0.6875, 0.6875), 149 | (3 / 63, 0.7500, 0.7500), 150 | (4 / 63, 0.8125, 0.8125), 151 | (5 / 63, 0.8750, 0.8750), 152 | (6 / 63, 0.9375, 0.9375), 153 | (7 / 63, 1.0, 1.0), 154 | (8 / 63, 1.0, 1.0), 155 | (9 / 63, 1.0, 1.0), 156 | (10 / 63, 1.0, 1.0), 157 | (11 / 63, 1.0, 1.0), 158 | (12 / 63, 1.0, 1.0), 159 | (13 / 63, 1.0, 1.0), 160 | (14 / 63, 1.0, 1.0), 161 | (15 / 63, 1.0, 1.0), 162 | (16 / 63, 1.0, 1.0), 163 | (17 / 63, 1.0, 1.0), 164 | (18 / 63, 1.0, 1.0), 165 | (19 / 63, 1.0, 1.0), 166 | (20 / 63, 1.0, 1.0), 167 | (21 / 63, 1.0, 1.0), 168 | (22 / 63, 1.0, 1.0), 169 | (23 / 63, 1.0, 1.0), 170 | (24 / 63, 1.0, 1.0), 171 | (25 / 63, 1.0, 1.0), 172 | (26 / 63, 1.0, 1.0), 173 | (27 / 63, 1.0, 1.0), 174 | (28 / 63, 1.0, 1.0), 175 | (29 / 63, 1.0, 1.0), 176 | (30 / 63, 1.0, 1.0), 177 | (31 / 63, 1.0, 1.0), 178 | (32 / 63, 0.9375, 0.9375), 179 | (33 / 63, 0.8750, 0.8750), 180 | (34 / 63, 0.8125, 0.8125), 181 | (35 / 63, 0.7500, 0.7500), 182 | (36 / 63, 0.6875, 0.6875), 183 | (37 / 63, 0.6250, 0.6250), 184 | (38 / 63, 0.5625, 0.5625), 185 | (39 / 63, 0.0, 0.0), 186 | (40 / 63, 0.0, 0.0), 187 | (41 / 63, 0.0, 0.0), 188 | (42 / 63, 0.0, 0.0), 189 | (43 / 63, 0.0, 0.0), 190 | (44 / 63, 0.0, 0.0), 191 | (45 / 63, 0.0, 0.0), 192 | (46 / 63, 0.0, 0.0), 193 | (47 / 63, 0.0, 0.0), 194 | (48 / 63, 0.0, 0.0), 195 | (49 / 63, 0.0, 0.0), 196 | (50 / 63, 0.0, 0.0), 197 | (51 / 63, 0.0, 0.0), 198 | (52 / 63, 0.0, 0.0), 199 | (53 / 63, 0.0, 0.0), 200 | (54 / 63, 0.0, 0.0), 201 | (55 / 63, 0.0, 0.0), 202 | (56 / 63, 0.0, 0.0), 203 | (57 / 63, 0.0, 0.0), 204 | (58 / 63, 0.0, 0.0), 205 | (59 / 63, 0.0, 0.0), 206 | (60 / 63, 0.0, 0.0), 207 | (61 / 63, 0.0, 0.0), 208 | (62 / 63, 0.0, 0.0), 209 | (63 / 63, 0.0, 0.0)) 210 | } 211 | 212 | # mask magnitude 213 | cmap_custom2 = { 214 | 'red': ((0.0, 1.0, 1.0), 215 | (1 / 32, 1.0, 1.0), 216 | (2 / 32, 1.0, 1.0), 217 | (3 / 32, 1.0, 1.0), 218 | (4 / 32, 1.0, 1.0), 219 | (5 / 32, 1.0, 1.0), 220 | (6 / 32, 1.0, 1.0), 221 | (7 / 32, 1.0, 1.0), 222 | (8 / 32, 1.0, 1.0), 223 | (9 / 32, 1.0, 1.0), 224 | (10 / 32, 1.0, 1.0), 225 | (11 / 32, 1.0, 1.0), 226 | (12 / 32, 1.0, 1.0), 227 | (13 / 32, 1.0, 1.0), 228 | (14 / 32, 1.0, 1.0), 229 | (15 / 32, 1.0, 1.0), 230 | (16 / 32, 1.0, 1.0), 231 | (17 / 32, 1.0, 1.0), 232 | (18 / 32, 1.0, 1.0), 233 | (19 / 32, 1.0, 1.0), 234 | (20 / 32, 1.0, 1.0), 235 | (21 / 32, 1.0, 1.0), 236 | (22 / 32, 1.0, 1.0), 237 | (23 / 32, 1.0, 1.0), 238 | (24 / 32, 1.0, 1.0), 239 | (25 / 32, 0.9375, 0.9375), 240 | (26 / 32, 0.8750, 0.8750), 241 | (27 / 32, 0.8125, 0.8125), 242 | (28 / 32, 0.7500, 0.7500), 243 | (29 / 32, 0.6875, 0.6875), 244 | (30 / 32, 0.6250, 0.6250), 245 | (31 / 32, 0.5625, 0.5625), 246 | (32 / 32, 0.5000, 0.5000)), 247 | 'green': ((0.0, 1.0, 1.0), 248 | (1 / 32, 1.0, 1.0), 249 | (2 / 32, 1.0, 1.0), 250 | (3 / 32, 1.0, 1.0), 251 | (4 / 32, 1.0, 1.0), 252 | (5 / 32, 1.0, 1.0), 253 | (6 / 32, 1.0, 1.0), 254 | (7 / 32, 1.0, 1.0), 255 | (8 / 32, 1.0, 1.0), 256 | (9 / 32, 0.9375, 0.9375), 257 | (10 / 32, 0.8750, 0.8750), 258 | (11 / 32, 0.8125, 0.8125), 259 | (12 / 32, 0.7500, 0.7500), 260 | (13 / 32, 0.6875, 0.6875), 261 | (14 / 32, 0.6250, 0.6250), 262 | (15 / 32, 0.5625, 0.5625), 263 | (16 / 32, 0.5000, 0.5000), 264 | (17 / 32, 0.4375, 0.4375), 265 | (18 / 32, 0.3750, 0.3750), 266 | (19 / 32, 0.3125, 0.3125), 267 | (20 / 32, 0.2500, 0.2500), 268 | (21 / 32, 0.1875, 0.1875), 269 | (22 / 32, 0.1250, 0.1250), 270 | (23 / 32, 0.0625, 0.0625), 271 | (24 / 32, 0.0, 0.0), 272 | (25 / 32, 0.0, 0.0), 273 | (26 / 32, 0.0, 0.0), 274 | (27 / 32, 0.0, 0.0), 275 | (28 / 32, 0.0, 0.0), 276 | (29 / 32, 0.0, 0.0), 277 | (30 / 32, 0.0, 0.0), 278 | (31 / 32, 0.0, 0.0), 279 | (32 / 32, 0.0, 0.0)), 280 | 'blue': ((0.0, 1.0, 1.0), 281 | (1 / 32, 0.9375, 0.9375), 282 | (2 / 32, 0.8750, 0.8750), 283 | (3 / 32, 0.8125, 0.8125), 284 | (4 / 32, 0.7500, 0.7500), 285 | (5 / 32, 0.6875, 0.6875), 286 | (6 / 32, 0.6250, 0.6250), 287 | (7 / 32, 0.5625, 0.5625), 288 | (8 / 32, 0.0, 0.0), 289 | (9 / 32, 0.0, 0.0), 290 | (10 / 32, 0.0, 0.0), 291 | (11 / 32, 0.0, 0.0), 292 | (12 / 32, 0.0, 0.0), 293 | (13 / 32, 0.0, 0.0), 294 | (14 / 32, 0.0, 0.0), 295 | (15 / 32, 0.0, 0.0), 296 | (16 / 32, 0.0, 0.0), 297 | (17 / 32, 0.0, 0.0), 298 | (18 / 32, 0.0, 0.0), 299 | (19 / 32, 0.0, 0.0), 300 | (20 / 32, 0.0, 0.0), 301 | (21 / 32, 0.0, 0.0), 302 | (22 / 32, 0.0, 0.0), 303 | (23 / 32, 0.0, 0.0), 304 | (24 / 32, 0.0, 0.0), 305 | (25 / 32, 0.0, 0.0), 306 | (26 / 32, 0.0, 0.0), 307 | (27 / 32, 0.0, 0.0), 308 | (28 / 32, 0.0, 0.0), 309 | (29 / 32, 0.0, 0.0), 310 | (30 / 32, 0.0, 0.0), 311 | (31 / 32, 0.0, 0.0), 312 | (32 / 32, 0.0, 0.0)) 313 | } 314 | 315 | self.cmap_custom = matplotlib.colors.LinearSegmentedColormap('testCmap', segmentdata=cmap_custom, N=256) 316 | self.cmap_custom2 = matplotlib.colors.LinearSegmentedColormap('testCmap2', segmentdata=cmap_custom2, N=256) 317 | 318 | def log_loss(self, train_loss, vali_loss, step): 319 | self.add_scalar('train_loss', train_loss, step) 320 | self.add_scalar('vali_loss', vali_loss, step) 321 | 322 | def log_sub_loss(self, train_main_loss, train_sub_loss, vali_main_loss, vali_sub_loss, step): 323 | self.add_scalar('train_main_loss', train_main_loss, step) 324 | self.add_scalar('train_sub_loss', train_sub_loss, step) 325 | self.add_scalar('vali_main_loss', vali_main_loss, step) 326 | self.add_scalar('vali_sub_loss', vali_sub_loss, step) 327 | 328 | def log_score(self, vali_pesq, vali_stoi, step): 329 | self.add_scalar('vali_pesq', vali_pesq, step) 330 | self.add_scalar('vali_stoi', vali_stoi, step) 331 | 332 | def log_wav(self, mixed_wav, clean_wav, est_wav, step): 333 | #