├── download_sedata.sh ├── exp └── unet16.json ├── models ├── layers │ ├── complexnn.py │ └── istft.py └── unet.py ├── se_dataset.py ├── train.py └── utils.py /download_sedata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # DOWNLOAD THE DATASETS 4 | mkdir -p dataset_zip 5 | pushd dataset_zip 6 | # TRAINING DATASET 7 | if [[ $1 == all ]]; then 8 | echo "[INFO] Downloading train dataset" 9 | if [ ! -f clean_trainset_28spk_wav.zip ]; then 10 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_trainset_28spk_wav.zip 11 | fi 12 | if [ ! -f noisy_trainset_28spk_wav.zip ]; then 13 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/noisy_trainset_28spk_wav.zip 14 | fi 15 | else 16 | echo "[INFO] Pass downloading train dataset, use 'all' argument for down all" 17 | fi 18 | # VALIDATION DATASET 19 | echo "[INFO] Downloading valid dataset" 20 | if [ ! -f clean_testset_wav.zip ]; then 21 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/clean_testset_wav.zip 22 | fi 23 | if [ ! -f noisy_testset_wav.zip ]; then 24 | wget https://datashare.is.ed.ac.uk/bitstream/handle/10283/2791/noisy_testset_wav.zip 25 | fi 26 | popd 27 | 28 | ## INFLATE DATA 29 | mkdir -p dataset_tmp 30 | pushd dataset_tmp 31 | if [[ $1 == all ]]; then 32 | echo "[INFO] Unzip train dataset" 33 | unzip -q -j ../dataset_zip/clean_trainset_28spk_wav.zip -d trainset_clean 34 | unzip -q -j ../dataset_zip/noisy_trainset_28spk_wav.zip -d trainset_noisy 35 | fi 36 | echo "[INFO] Unzip valid dataset" 37 | unzip -q -j ../dataset_zip/clean_testset_wav.zip -d valset_clean 38 | unzip -q -j ../dataset_zip/noisy_testset_wav.zip -d valset_noisy 39 | popd 40 | 41 | ## RESAMPLE 42 | if [[ $1 == all ]]; then 43 | declare -a arr=("trainset_clean" "trainset_noisy" "valset_clean" "valset_noisy") 44 | else 45 | declare -a arr=("valset_clean" "valset_noisy") 46 | fi 47 | echo "[INFO] Resampling datasets: ${arr[*]}" 48 | mkdir -p dataset 49 | pushd dataset_tmp 50 | for d in */; do 51 | mkdir -p "../dataset/$d" 52 | pushd "$d" 53 | for f in *.wav; do 54 | sox "$f" "../../dataset/$d$f" rate -v -I 16000 55 | done 56 | popd 57 | done 58 | popd 59 | 60 | # REMOVE TMP DATA 61 | rm -r dataset_tmp 62 | -------------------------------------------------------------------------------- /exp/unet16.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed" : 2017, 3 | "save_path" : "/deep/group/awni/speech_models/test", 4 | 5 | "data" : { 6 | "train_set" : "/deep/group/speech/datasets/LibriSpeech/train-toy.json", 7 | "dev_set" : "/deep/group/speech/datasets/LibriSpeech/dev-toy.json" 8 | }, 9 | 10 | "optimizer" : { 11 | "batch_size" : 8, 12 | "epochs" : 1000, 13 | "learning_rate" : 1e-3, 14 | "momentum" : 0.0 15 | }, 16 | 17 | "model" : { 18 | "leaky_slope" : 0.1, 19 | "ratio_mask" : "BDT", 20 | "encoders" : [ 21 | [1, 32, [7, 5], [2, 2], [3, 2]], 22 | [32, 32, [7, 5], [2, 1], [3, 2]], 23 | [32, 64, [7, 5], [2, 2], [3, 2]], 24 | [64, 64, [5, 3], [2, 1], [2, 1]], 25 | [64, 64, [5, 3], [2, 2], [2, 1]], 26 | [64, 64, [5, 3], [2, 1], [2, 1]], 27 | [64, 64, [5, 3], [2, 2], [2, 1]], 28 | [64, 64, [5, 3], [2, 1], [2, 1]] 29 | ], 30 | "decoders" : [ 31 | [64, 64, [5, 3], [2, 1], [2, 1]], 32 | [128, 64, [5, 3], [2, 2], [2, 1]], 33 | [128, 64, [5, 3], [2, 1], [2, 1]], 34 | [128, 64, [5, 3], [2, 2], [2, 1]], 35 | [128, 64, [5, 3], [2, 1], [2, 1]], 36 | [128, 32, [7, 5], [2, 2], [3, 2]], 37 | [64, 32, [7, 5], [2, 1], [3, 2]], 38 | [64, 1, [7, 5], [2, 2], [3, 2]] 39 | ], 40 | "__coder_keys" : [ 41 | "in_channels", "out_channels", "kernel_size", "stride", "padding" 42 | ] 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /models/layers/complexnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | # Utility functions for initialization 7 | def complex_rayleigh_init(Wr, Wi, fanin=None, gain=1): 8 | if not fanin: 9 | fanin = 1 10 | for p in W1.shape[1:]: fanin *= p 11 | scale = float(gain) / float(fanin) 12 | theta = torch.empty_like(Wr).uniform_(-math.pi/2, +math.pi/2) 13 | rho = np.random.rayleigh(scale, tuple(Wr.shape)) 14 | rho = torch.tensor(rho).to(Wr) 15 | Wr.data.copy_(rho * theta.cos()) 16 | Wi.data.copy_(rho * theta.sin()) 17 | 18 | # Layers 19 | class ComplexConvWrapper(nn.Module): 20 | def __init__(self, conv_module, *args, **kwargs): 21 | super(ComplexConvWrapper, self).__init__() 22 | self.conv_re = conv_module(*args, **kwargs) 23 | self.conv_im = conv_module(*args, **kwargs) 24 | 25 | def reset_parameters(self): 26 | fanin = self.conv_re.in_channels // self.conv_re.groups 27 | for s in self.conv_re.kernel_size: fanin *= s 28 | complex_rayleigh_init(self.conv_re.weight, self.conv_im.weight, fanin) 29 | if self.conv_re.bias is not None: 30 | self.conv_re.bias.data.zero_() 31 | self.conv_im.bias.data.zero_() 32 | 33 | def forward(self, xr, xi): 34 | real = self.conv_re(xr) - self.conv_im(xi) 35 | imag = self.conv_re(xi) + self.conv_im(xr) 36 | return real, imag 37 | 38 | # Real-valued network module for complex input 39 | class RealConvWrapper(nn.Module): 40 | def __init__(self, conv_module, *args, **kwargs): 41 | super(ComplexConvWrapper,self).__init__() 42 | self.conv_re = conv_module(*args, **kwargs) 43 | 44 | def forward(self, xr, xi): 45 | real = self.conv_re(xr) 46 | imag = self.conv_re(xi) 47 | return real, imag 48 | 49 | class CLeakyReLU(nn.LeakyReLU): 50 | def forward(self, xr, xi): 51 | return F.leaky_relu(xr, self.negative_slope, self.inplace),\ 52 | F.leaky_relu(xi, self.negative_slope, self.inplace) 53 | 54 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch 55 | class ComplexBatchNorm(torch.nn.Module): 56 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 57 | track_running_stats=True): 58 | super(ComplexBatchNorm, self).__init__() 59 | self.num_features = num_features 60 | self.eps = eps 61 | self.momentum = momentum 62 | self.affine = affine 63 | self.track_running_stats = track_running_stats 64 | if self.affine: 65 | self.Wrr = torch.nn.Parameter(torch.Tensor(num_features)) 66 | self.Wri = torch.nn.Parameter(torch.Tensor(num_features)) 67 | self.Wii = torch.nn.Parameter(torch.Tensor(num_features)) 68 | self.Br = torch.nn.Parameter(torch.Tensor(num_features)) 69 | self.Bi = torch.nn.Parameter(torch.Tensor(num_features)) 70 | else: 71 | self.register_parameter('Wrr', None) 72 | self.register_parameter('Wri', None) 73 | self.register_parameter('Wii', None) 74 | self.register_parameter('Br', None) 75 | self.register_parameter('Bi', None) 76 | if self.track_running_stats: 77 | self.register_buffer('RMr', torch.zeros(num_features)) 78 | self.register_buffer('RMi', torch.zeros(num_features)) 79 | self.register_buffer('RVrr', torch.ones (num_features)) 80 | self.register_buffer('RVri', torch.zeros(num_features)) 81 | self.register_buffer('RVii', torch.ones (num_features)) 82 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 83 | else: 84 | self.register_parameter('RMr', None) 85 | self.register_parameter('RMi', None) 86 | self.register_parameter('RVrr', None) 87 | self.register_parameter('RVri', None) 88 | self.register_parameter('RVii', None) 89 | self.register_parameter('num_batches_tracked', None) 90 | self.reset_parameters() 91 | 92 | def reset_running_stats(self): 93 | if self.track_running_stats: 94 | self.RMr.zero_() 95 | self.RMi.zero_() 96 | self.RVrr.fill_(1) 97 | self.RVri.zero_() 98 | self.RVii.fill_(1) 99 | self.num_batches_tracked.zero_() 100 | 101 | def reset_parameters(self): 102 | self.reset_running_stats() 103 | if self.affine: 104 | self.Br.data.zero_() 105 | self.Bi.data.zero_() 106 | self.Wrr.data.fill_(1) 107 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite 108 | self.Wii.data.fill_(1) 109 | 110 | def _check_input_dim(self, xr, xi): 111 | assert(xr.shape == xi.shape) 112 | assert(xr.size(1) == self.num_features) 113 | 114 | def forward(self, xr, xi): 115 | self._check_input_dim(xr, xi) 116 | 117 | exponential_average_factor = 0.0 118 | 119 | if self.training and self.track_running_stats: 120 | self.num_batches_tracked += 1 121 | if self.momentum is None: # use cumulative moving average 122 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 123 | else: # use exponential moving average 124 | exponential_average_factor = self.momentum 125 | 126 | # 127 | # NOTE: The precise meaning of the "training flag" is: 128 | # True: Normalize using batch statistics, update running statistics 129 | # if they are being collected. 130 | # False: Normalize using running statistics, ignore batch statistics. 131 | # 132 | training = self.training or not self.track_running_stats 133 | redux = [i for i in reversed(range(xr.dim())) if i!=1] 134 | vdim = [1] * xr.dim() 135 | vdim[1] = xr.size(1) 136 | 137 | # 138 | # Mean M Computation and Centering 139 | # 140 | # Includes running mean update if training and running. 141 | # 142 | if training: 143 | Mr, Mi = xr, xi 144 | for d in redux: 145 | Mr = Mr.mean(d, keepdim=True) 146 | Mi = Mi.mean(d, keepdim=True) 147 | if self.track_running_stats: 148 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) 149 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) 150 | else: 151 | Mr = self.RMr.view(vdim) 152 | Mi = self.RMi.view(vdim) 153 | xr, xi = xr-Mr, xi-Mi 154 | 155 | # 156 | # Variance Matrix V Computation 157 | # 158 | # Includes epsilon numerical stabilizer/Tikhonov regularizer. 159 | # Includes running variance update if training and running. 160 | # 161 | if training: 162 | Vrr = xr * xr 163 | Vri = xr * xi 164 | Vii = xi * xi 165 | for d in redux: 166 | Vrr = Vrr.mean(d, keepdim=True) 167 | Vri = Vri.mean(d, keepdim=True) 168 | Vii = Vii.mean(d, keepdim=True) 169 | if self.track_running_stats: 170 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) 171 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) 172 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) 173 | else: 174 | Vrr = self.RVrr.view(vdim) 175 | Vri = self.RVri.view(vdim) 176 | Vii = self.RVii.view(vdim) 177 | Vrr = Vrr + self.eps 178 | Vri = Vri 179 | Vii = Vii + self.eps 180 | 181 | # 182 | # Matrix Inverse Square Root U = V^-0.5 183 | # 184 | # sqrt of a 2x2 matrix, 185 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 186 | tau = Vrr + Vii 187 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) 188 | s = delta.sqrt() 189 | t = (tau + 2*s).sqrt() 190 | 191 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html 192 | rst = (s * t).reciprocal() 193 | Urr = (s + Vii) * rst 194 | Uii = (s + Vrr) * rst 195 | Uri = ( - Vri) * rst 196 | 197 | # 198 | # Optionally left-multiply U by affine weights W to produce combined 199 | # weights Z, left-multiply the inputs by Z, then optionally bias them. 200 | # 201 | # y = Zx + B 202 | # y = WUx + B 203 | # y = [Wrr Wri][Urr Uri] [xr] + [Br] 204 | # [Wir Wii][Uir Uii] [xi] [Bi] 205 | # 206 | if self.affine: 207 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) 208 | Zrr = (Wrr * Urr) + (Wri * Uri) 209 | Zri = (Wrr * Uri) + (Wri * Uii) 210 | Zir = (Wri * Urr) + (Wii * Uri) 211 | Zii = (Wri * Uri) + (Wii * Uii) 212 | else: 213 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii 214 | 215 | yr = (Zrr * xr) + (Zri * xi) 216 | yi = (Zir * xr) + (Zii * xi) 217 | 218 | if self.affine: 219 | yr = yr + self.Br.view(vdim) 220 | yi = yi + self.Bi.view(vdim) 221 | 222 | return yr, yi 223 | 224 | def extra_repr(self): 225 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 226 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 227 | -------------------------------------------------------------------------------- /models/layers/istft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy.signal 6 | import librosa 7 | 8 | class ISTFT(torch.nn.Module): 9 | def __init__(self, filter_length=1024, hop_length=512, window='hanning', center=True): 10 | super(ISTFT, self).__init__() 11 | 12 | self.filter_length = filter_length 13 | self.hop_length = hop_length 14 | self.center = center 15 | 16 | win_cof = scipy.signal.get_window(window, filter_length) 17 | self.inv_win = self.inverse_stft_window(win_cof, hop_length) 18 | 19 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 20 | cutoff = int((self.filter_length / 2 + 1)) 21 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 22 | np.imag(fourier_basis[:cutoff, :])]) 23 | inverse_basis = torch.FloatTensor(self.inv_win * \ 24 | np.linalg.pinv(fourier_basis).T[:, None, :]) 25 | 26 | self.register_buffer('inverse_basis', inverse_basis.float()) 27 | 28 | # Use equation 8 from Griffin, Lim. 29 | # Paper: "Signal Estimation from Modified Short-Time Fourier Transform" 30 | # Reference implementation: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/signal/spectral_ops.py 31 | # librosa use equation 6 from paper: https://github.com/librosa/librosa/blob/0dcd53f462db124ed3f54edf2334f28738d2ecc6/librosa/core/spectrum.py#L302-L311 32 | def inverse_stft_window(self, window, hop_length): 33 | window_length = len(window) 34 | denom = window ** 2 35 | overlaps = -(-window_length // hop_length) # Ceiling division. 36 | denom = np.pad(denom, (0, overlaps * hop_length - window_length), 'constant') 37 | denom = np.reshape(denom, (overlaps, hop_length)).sum(0) 38 | denom = np.tile(denom, (overlaps, 1)).reshape(overlaps * hop_length) 39 | return window / denom[:window_length] 40 | 41 | def forward(self, real_part, imag_part, length=None): 42 | if (real_part.dim() == 2): 43 | real_part = real_part.unsqueeze(0) 44 | imag_part = imag_part.unsqueeze(0) 45 | 46 | recombined = torch.cat([real_part, imag_part], dim=1) 47 | 48 | inverse_transform = F.conv_transpose1d(recombined, 49 | self.inverse_basis, 50 | stride=self.hop_length, 51 | padding=0) 52 | 53 | padded = int(self.filter_length // 2) 54 | if length is None: 55 | if self.center: 56 | inverse_transform = inverse_transform[:, :, padded:-padded] 57 | else: 58 | if self.center: 59 | inverse_transform = inverse_transform[:, :, padded:] 60 | inverse_transform = inverse_transform[:, :, :length] 61 | 62 | return inverse_transform 63 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import models.layers.complexnn as dcnn 6 | 7 | 8 | # NOTE: Use Complex Ops for DCUnet when implemented 9 | # Reference: 10 | # > Progress: https://github.com/pytorch/pytorch/issues/755 11 | def pad2d_as(x1, x2): 12 | # Pad x1 to have same size with x2 13 | # inputs are NCHW 14 | diffH = x2.size()[2] - x1.size()[2] 15 | diffW = x2.size()[3] - x1.size()[3] 16 | 17 | return F.pad(x1, (0, diffW, 0, diffH)) # (L,R,T,B) 18 | 19 | def padded_cat(x1, x2, dim): 20 | # NOTE: Use torch.cat with pad instead when merged 21 | # > https://github.com/pytorch/pytorch/pull/11494 22 | x1 = pad2d_as(x1, x2) 23 | x1 = torch.cat([x1, x2], dim=dim) 24 | return x1 25 | 26 | class Encoder(nn.Module): 27 | def __init__(self, conv_cfg, leaky_slope): 28 | super(Encoder, self).__init__() 29 | self.conv = dcnn.ComplexConvWrapper(nn.Conv2d, *conv_cfg, bias=False) 30 | self.bn = dcnn.ComplexBatchNorm(conv_cfg[1]) 31 | self.act = dcnn.CLeakyReLU(leaky_slope, inplace=True) 32 | 33 | def forward(self, xr, xi): 34 | xr, xi = self.act(*self.bn(*self.conv(xr, xi))) 35 | return xr, xi 36 | 37 | class Decoder(nn.Module): 38 | def __init__(self, dconv_cfg, leaky_slope): 39 | super(Decoder, self).__init__() 40 | self.dconv = dcnn.ComplexConvWrapper(nn.ConvTranspose2d, *dconv_cfg, bias=False) 41 | self.bn = dcnn.ComplexBatchNorm(dconv_cfg[1]) 42 | self.act = dcnn.CLeakyReLU(leaky_slope, inplace=True) 43 | 44 | def forward(self, xr, xi, skip=None): 45 | if skip is not None: 46 | xr, xi = padded_cat(xr, skip[0], dim=1), padded_cat(xi, skip[1], dim=1) 47 | xr, xi = self.act(*self.bn(*self.dconv(xr, xi))) 48 | return xr, xi 49 | 50 | class Unet(nn.Module): 51 | def __init__(self, cfg): 52 | super(Unet, self).__init__() 53 | self.encoders = nn.ModuleList() 54 | for conv_cfg in cfg['encoders']: 55 | self.encoders.append(Encoder(conv_cfg, cfg['leaky_slope'])) 56 | 57 | self.decoders = nn.ModuleList() 58 | for dconv_cfg in cfg['decoders'][:-1]: 59 | self.decoders.append(Decoder(dconv_cfg, cfg['leaky_slope'])) 60 | 61 | # Last decoder doesn't use BN & LeakyReLU. Use bias. 62 | self.last_decoder = dcnn.ComplexConvWrapper(nn.ConvTranspose2d, 63 | *cfg['decoders'][-1], bias=True) 64 | 65 | self.ratio_mask_type = cfg['ratio_mask'] 66 | 67 | def get_ratio_mask(self, outr, outi): 68 | def inner_fn(r, i): 69 | if self.ratio_mask_type == 'BDSS': 70 | return torch.sigmoid(outr) * r, torch.sigmoid(outi) * i 71 | else: 72 | # Polar cordinate masks 73 | # x1.4 slower 74 | mag_mask = torch.sqrt(outr**2 + outi**2) 75 | # M_phase = O/|O| for O = g(X) 76 | # Same phase rotate(theta), for phase mask O/|O| and O. 77 | phase_rotate = torch.atan2(outi, outr) 78 | 79 | if self.ratio_mask_type == 'BDT': 80 | mag_mask = torch.tanh(mag_mask) 81 | # else then UBD(Unbounded) 82 | 83 | mag = mag_mask * torch.sqrt(r**2 + i**2) 84 | phase = phase_rotate + torch.atan2(i, r) 85 | 86 | # return real, imag 87 | return mag * torch.cos(phase), mag * torch.sin(phase) 88 | 89 | return inner_fn 90 | 91 | def forward(self, xr, xi): 92 | input_real, input_imag = xr, xi 93 | skips = list() 94 | 95 | for encoder in self.encoders: 96 | xr, xi = encoder(xr, xi) 97 | skips.append((xr, xi)) 98 | 99 | skip = skips.pop() 100 | skip = None # First decoder input x is same as skip, drop skip. 101 | for decoder in self.decoders: 102 | xr, xi = decoder(xr, xi, skip) 103 | skip = skips.pop() 104 | 105 | xr, xi = padded_cat(xr, skip[0], dim=1), padded_cat(xi, skip[1], dim=1) 106 | xr, xi = self.last_decoder(xr, xi) 107 | 108 | xr, xi = pad2d_as(xr, input_real), pad2d_as(xi, input_imag) 109 | ratio_mask_fn = self.get_ratio_mask(xr, xi) 110 | return ratio_mask_fn(input_real, input_imag) 111 | -------------------------------------------------------------------------------- /se_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import librosa 4 | import os, csv 5 | import torch 6 | from torch.utils import data 7 | 8 | # Reference 9 | # DATA LOADING - LOAD FILE LISTS 10 | def load_data_list(folder='./dataset', setname='train'): 11 | assert(setname in ['train', 'val']) 12 | 13 | dataset = {} 14 | foldername = folder + '/' + setname + 'set' 15 | 16 | print("Loading files...") 17 | dataset['innames'] = [] 18 | dataset['outnames'] = [] 19 | dataset['shortnames'] = [] 20 | 21 | filelist = os.listdir("%s_noisy"%(foldername)) 22 | filelist = [f for f in filelist if f.endswith(".wav")] 23 | for i in tqdm(filelist): 24 | dataset['innames'].append("%s_noisy/%s"%(foldername,i)) 25 | dataset['outnames'].append("%s_clean/%s"%(foldername,i)) 26 | dataset['shortnames'].append("%s"%(i)) 27 | 28 | return dataset 29 | 30 | # DATA LOADING - LOAD FILE DATA 31 | def load_data(dataset): 32 | 33 | dataset['inaudio'] = [None]*len(dataset['innames']) 34 | dataset['outaudio'] = [None]*len(dataset['outnames']) 35 | 36 | for id in tqdm(range(len(dataset['innames']))): 37 | 38 | if dataset['inaudio'][id] is None: 39 | inputData, sr = librosa.load(dataset['innames'][id], sr=None) 40 | outputData, sr = librosa.load(dataset['outnames'][id], sr=None) 41 | 42 | shape = np.shape(inputData) 43 | 44 | dataset['inaudio'][id] = np.float32(inputData) 45 | dataset['outaudio'][id] = np.float32(outputData) 46 | 47 | return dataset 48 | 49 | class AudioDataset(data.Dataset): 50 | """ 51 | Audio sample reader. 52 | """ 53 | 54 | def __init__(self, data_type): 55 | dataset = load_data_list(setname=data_type) 56 | self.dataset = load_data(dataset) 57 | 58 | self.file_names = dataset['innames'] 59 | 60 | def __getitem__(self, idx): 61 | mixed = torch.from_numpy(self.dataset['inaudio'][idx]).type(torch.FloatTensor) 62 | clean = torch.from_numpy(self.dataset['outaudio'][idx]).type(torch.FloatTensor) 63 | 64 | return mixed, clean 65 | 66 | def __len__(self): 67 | return len(self.file_names) 68 | 69 | def zero_pad_concat(self, inputs): 70 | max_t = max(inp.shape[0] for inp in inputs) 71 | shape = (len(inputs), max_t) 72 | input_mat = np.zeros(shape, dtype=np.float32) 73 | for e, inp in enumerate(inputs): 74 | input_mat[e, :inp.shape[0]] = inp 75 | return input_mat 76 | 77 | def collate(self, inputs): 78 | mixeds, cleans = zip(*inputs) 79 | seq_lens = torch.IntTensor([i.shape[0] for i in mixeds]) 80 | 81 | x = torch.FloatTensor(self.zero_pad_concat(mixeds)) 82 | y = torch.FloatTensor(self.zero_pad_concat(cleans)) 83 | 84 | batch = [x, y, seq_lens] 85 | return batch 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import ExponentialLR 10 | 11 | from scipy.io import wavfile 12 | import librosa 13 | from tqdm import tqdm 14 | 15 | import utils 16 | from models.unet import Unet 17 | from models.layers.istft import ISTFT 18 | from se_dataset import AudioDataset 19 | from torch.utils.data import DataLoader 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--model_dir', default='experiments/base_model', help="Directory containing params.json") 24 | parser.add_argument('--restore_file', default=None, help="Optional, name of the file in --model_dir containing weights to reload before training") # 'best' or 'train' 25 | parser.add_argument('--batch_size', default=32, type=int, help='train batch size') 26 | parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number') 27 | args = parser.parse_args() 28 | 29 | n_fft, hop_length = 400, 160 30 | window = torch.hann_window(n_fft).cuda() 31 | stft = lambda x: torch.stft(x, n_fft, hop_length, window=window) 32 | istft = ISTFT(n_fft, hop_length, window='hanning').cuda() 33 | 34 | def wSDRLoss(mixed, clean, clean_est, eps=2e-7): 35 | # Used on signal level(time-domain). Backprop-able istft should be used. 36 | # Batched audio inputs shape (N x T) required. 37 | bsum = lambda x: torch.sum(x, dim=1) # Batch preserving sum for convenience. 38 | def mSDRLoss(orig, est): 39 | # Modified SDR loss, / (||x|| * ||x`||) : L2 Norm. 40 | # Original SDR Loss: **2 / (== ||x`||**2) 41 | # > Maximize Correlation while producing minimum energy output. 42 | correlation = bsum(orig * est) 43 | energies = torch.norm(orig, p=2, dim=1) * torch.norm(est, p=2, dim=1) 44 | return -(correlation / (energies + eps)) 45 | 46 | noise = mixed - clean 47 | noise_est = mixed - clean_est 48 | 49 | a = bsum(clean**2) / (bsum(clean**2) + bsum(noise**2) + eps) 50 | wSDR = a * mSDRLoss(clean, clean_est) + (1 - a) * mSDRLoss(noise, noise_est) 51 | return torch.mean(wSDR) 52 | 53 | # TODO - loader clean speech tempo perturbed as input 54 | # TODO - loader clean speech volume pertubed as input 55 | # TODO - option for (tempo/volume/tempo+volume) 56 | # TODO - loader noise sound as second input 57 | # TODO - loader reverb effect as second input 58 | # TODO - option for (noise/reverb/noise+reverb) 59 | 60 | def main(): 61 | json_path = os.path.join(args.model_dir) 62 | params = utils.Params(json_path) 63 | 64 | net = Unet(params.model).cuda() 65 | # TODO - check exists 66 | #checkpoint = torch.load('./final.pth.tar') 67 | #net.load_state_dict(checkpoint) 68 | 69 | train_dataset = AudioDataset(data_type='train') 70 | test_dataset = AudioDataset(data_type='val') 71 | train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, 72 | collate_fn=train_dataset.collate, shuffle=True, num_workers=4) 73 | test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, 74 | collate_fn=test_dataset.collate, shuffle=False, num_workers=4) 75 | 76 | torch.set_printoptions(precision=10, profile="full") 77 | 78 | # Optimizer 79 | optimizer = optim.Adam(net.parameters(), lr=1e-3) 80 | # Learning rate scheduler 81 | scheduler = ExponentialLR(optimizer, 0.95) 82 | 83 | for epoch in range(args.num_epochs): 84 | train_bar = tqdm(train_data_loader) 85 | for input in train_bar: 86 | train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input) 87 | mixed = stft(train_mixed).unsqueeze(dim=1) 88 | real, imag = mixed[..., 0], mixed[..., 1] 89 | out_real, out_imag = net(real, imag) 90 | out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1) 91 | out_audio = istft(out_real, out_imag, train_mixed.size(1)) 92 | out_audio = torch.squeeze(out_audio, dim=1) 93 | for i, l in enumerate(seq_len): 94 | out_audio[i, l:] = 0 95 | librosa.output.write_wav('mixed.wav', train_mixed[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000) 96 | librosa.output.write_wav('clean.wav', train_clean[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000) 97 | librosa.output.write_wav('out.wav', out_audio[0].cpu().data.numpy()[:seq_len[0].cpu().data.numpy()], 16000) 98 | loss = wSDRLoss(train_mixed, train_clean, out_audio) 99 | print(epoch, loss) 100 | optimizer.zero_grad() 101 | loss.backward() 102 | 103 | optimizer.step() 104 | scheduler.step() 105 | torch.save(net.state_dict(), './final.pth.tar') 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import torch 7 | 8 | 9 | class Params(): 10 | """Class that loads hyperparameters from a json file. 11 | 12 | Example: 13 | ``` 14 | params = Params(json_path) 15 | print(params.learning_rate) 16 | params.learning_rate = 0.5 # change the value of learning_rate in params 17 | ``` 18 | """ 19 | 20 | def __init__(self, json_path): 21 | with open(json_path) as f: 22 | params = json.load(f) 23 | self.__dict__.update(params) 24 | 25 | def save(self, json_path): 26 | with open(json_path, 'w') as f: 27 | json.dump(self.__dict__, f, indent=4) 28 | 29 | def update(self, json_path): 30 | """Loads parameters from json file""" 31 | with open(json_path) as f: 32 | params = json.load(f) 33 | self.__dict__.update(params) 34 | 35 | @property 36 | def dict(self): 37 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 38 | return self.__dict__ 39 | 40 | 41 | class RunningAverage(): 42 | """A simple class that maintains the running average of a quantity 43 | 44 | Example: 45 | ``` 46 | loss_avg = RunningAverage() 47 | loss_avg.update(2) 48 | loss_avg.update(4) 49 | loss_avg() = 3 50 | ``` 51 | """ 52 | 53 | def __init__(self): 54 | self.steps = 0 55 | self.total = 0 56 | 57 | def update(self, val): 58 | self.total += val 59 | self.steps += 1 60 | 61 | def __call__(self): 62 | return self.total / float(self.steps) 63 | 64 | 65 | def set_logger(log_path): 66 | """Set the logger to log info in terminal and file `log_path`. 67 | 68 | In general, it is useful to have a logger so that every output to the terminal is saved 69 | in a permanent file. Here we save it to `model_dir/train.log`. 70 | 71 | Example: 72 | ``` 73 | logging.info("Starting training...") 74 | ``` 75 | 76 | Args: 77 | log_path: (string) where to log 78 | """ 79 | logger = logging.getLogger() 80 | logger.setLevel(logging.INFO) 81 | 82 | if not logger.handlers: 83 | # Logging to a file 84 | file_handler = logging.FileHandler(log_path) 85 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 86 | logger.addHandler(file_handler) 87 | 88 | # Logging to console 89 | stream_handler = logging.StreamHandler() 90 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 91 | logger.addHandler(stream_handler) 92 | 93 | 94 | def save_dict_to_json(d, json_path): 95 | """Saves dict of floats in json file 96 | 97 | Args: 98 | d: (dict) of float-castable values (np.float, int, float, etc.) 99 | json_path: (string) path to json file 100 | """ 101 | with open(json_path, 'w') as f: 102 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 103 | d = {k: float(v) for k, v in d.items()} 104 | json.dump(d, f, indent=4) 105 | 106 | 107 | def save_checkpoint(state, is_best, checkpoint): 108 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 109 | checkpoint + 'best.pth.tar' 110 | 111 | Args: 112 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 113 | is_best: (bool) True if it is the best model seen till now 114 | checkpoint: (string) folder where parameters are to be saved 115 | """ 116 | filepath = os.path.join(checkpoint, 'last.pth.tar') 117 | if not os.path.exists(checkpoint): 118 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) 119 | os.mkdir(checkpoint) 120 | else: 121 | print("Checkpoint Directory exists! ") 122 | torch.save(state, filepath) 123 | if is_best: 124 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 125 | 126 | 127 | def load_checkpoint(checkpoint, model, optimizer=None): 128 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 129 | optimizer assuming it is present in checkpoint. 130 | 131 | Args: 132 | checkpoint: (string) filename which needs to be loaded 133 | model: (torch.nn.Module) model for which the parameters are loaded 134 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 135 | """ 136 | if not os.path.exists(checkpoint): 137 | raise ("File doesn't exist {}".format(checkpoint)) 138 | checkpoint = torch.load(checkpoint) 139 | model.load_state_dict(checkpoint['state_dict']) 140 | 141 | if optimizer: 142 | optimizer.load_state_dict(checkpoint['optim_dict']) 143 | 144 | return checkpoint --------------------------------------------------------------------------------