├── .DS_Store ├── .gitignore ├── .idea ├── .gitignore ├── DeepComplexCRN.iml ├── deployment.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── webServers.xml ├── README.md ├── config.py ├── dataloader └── THCHS30.py ├── debug.py ├── main.py ├── models ├── DCCRN.py ├── complexnn.py ├── conv_stft.py └── loss.py └── utils ├── show.py └── synthesizer.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stdKonjac/DeepComplexCRN/fb029bf604d02d34702947a1b745e2595d34665a/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoint 2 | samples 3 | models/pretrained-models 4 | .DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/DeepComplexCRN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 19 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCCRN 2 | 3 | Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement 4 | 5 | __Authors__: Yanxin Hu, Yun Liu, Shubo Lv, Mengtao Xing, Shimin Zhang,Yihui Fu, Jian Wu, Bihong Zhang, Lei Xie 6 | 7 | Paper: https://arxiv.org/abs/2008.00264 8 | 9 | Official Sample: https://huyanxin.github.io/DeepComplexCRN/ 10 | 11 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 8 | 9 | 10 | class DefaultConfig(object): 11 | project_root = '/data1/zengziyun/Project/DeepComplexCRN' 12 | data_root = os.path.join('/data1/zengziyun/Project/Dataset') 13 | checkpoint_root = os.path.join(project_root, 'checkpoint') 14 | sample_root = os.path.join(project_root, 'samples') 15 | pretrained_models_root = os.path.join(project_root, 'models/pretrained-models') 16 | 17 | use_gpu = True if torch.cuda.is_available() else False 18 | device = torch.device('cuda' if use_gpu else 'cpu') 19 | num_workers = 4 20 | 21 | # train params 22 | batch_size = 16 23 | max_epoch = 40 24 | lr = 1e-3 25 | lr_decay = 0.1 26 | weight_decay = 1e-5 27 | 28 | verbose_inter = 20 29 | save_inter = 5 30 | 31 | def _parse(self, kwargs): 32 | """ 33 | update config params according to kwargs 34 | """ 35 | for k, v in kwargs.items(): 36 | if not hasattr(self, k): 37 | warnings.warn("Warning: opt does not have attribute %s" % k) 38 | setattr(self, k, v) 39 | 40 | opt.device = torch.device('cuda') if opt.use_gpu else torch.device('cpu') 41 | 42 | print('<===================current config===================>') 43 | for k, v in self.__class__.__dict__.items(): 44 | if not k.startswith('_'): 45 | print(k, '=', getattr(self, k)) 46 | print('<===================current config===================>') 47 | 48 | 49 | opt = DefaultConfig() 50 | -------------------------------------------------------------------------------- /dataloader/THCHS30.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import librosa 4 | import fnmatch 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from config import opt 10 | 11 | DATAPATH = os.path.join(opt.data_root, 'THCHS-30') 12 | 13 | 14 | class THCHS30(Dataset): 15 | def __init__(self, phase='train', sr=16000, dimension=72000): 16 | assert phase in ['train', 'test'], 'non-supported phase!' 17 | 18 | self.data_dir = None 19 | self.label_dir = None 20 | 21 | if phase == 'train': 22 | self.data_dir = os.path.join(DATAPATH, 'data_synthesized/train/noisy') 23 | self.label_dir = os.path.join(DATAPATH, 'data_synthesized/train/clean') 24 | elif phase == 'test': 25 | self.data_dir = os.path.join(DATAPATH, 'data_synthesized/test/noisy') 26 | self.label_dir = os.path.join(DATAPATH, 'data_synthesized/test/clean') 27 | 28 | self.sr = sr 29 | self.dim = dimension 30 | 31 | # use mapper in __getitem__ 32 | # ensure each data find corresponding label 33 | self.mapper = {} 34 | 35 | # get label 36 | self.label_path = [] 37 | for file in os.listdir(self.label_dir): 38 | if file.endswith('.wav'): 39 | self.mapper[file[:-4]] = len(self.label_path) 40 | self.label_path.append(os.path.join(self.label_dir, file)) 41 | 42 | # get data path 43 | self.data_path = [] 44 | for file in os.listdir(self.data_dir): 45 | if file.endswith('.wav'): 46 | self.data_path.append(os.path.join(self.data_dir, file)) 47 | 48 | assert len(self.data_path) == len(self.label_path), 'data or label is corrupted!' 49 | 50 | def __getitem__(self, item): 51 | data, _ = librosa.load(self.data_path[item], sr=self.sr) 52 | data_name = os.path.basename(self.data_path[item]) 53 | data_name = data_name[:data_name.rfind('_')] 54 | label, _ = librosa.load(self.label_path[self.mapper[data_name]], sr=self.sr) 55 | # 取 帧 56 | if len(data) > self.dim: 57 | max_audio_start = len(data) - self.dim 58 | audio_start = np.random.randint(0, max_audio_start) 59 | data = data[audio_start: audio_start + self.dim] 60 | label = label[audio_start:audio_start + self.dim] 61 | else: 62 | data = np.pad(data, (0, self.dim - len(data)), "constant") 63 | label = np.pad(label, (0, self.dim - len(label)), "constant") 64 | 65 | return data, label 66 | 67 | def __len__(self): 68 | return len(self.data_path) 69 | 70 | 71 | if __name__ == '__main__': 72 | ds = THCHS30(phase='train') 73 | min_dim = 1e8 74 | max_dim = 0 75 | for i in range(0, len(ds)): 76 | data, label = ds[i] 77 | min_dim = min(min_dim, len(data)) 78 | max_dim = max(max_dim, len(data)) 79 | print('min dim=', min_dim) 80 | print('max dim=', max_dim) 81 | print('mid dim=', int((min_dim + max_dim) / 2)) 82 | pass 83 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.DCCRN import dccrn 3 | 4 | if __name__ == '__main__': 5 | torch.manual_seed(10) 6 | torch.autograd.set_detect_anomaly(True) 7 | inputs = torch.randn([10, 16000 * 4]).clamp_(-1, 1) 8 | labels = torch.randn([10, 16000 * 4]).clamp_(-1, 1) 9 | 10 | print(inputs.shape) 11 | exit(0) 12 | 13 | # DCCRN-E 14 | # model = dccrn('E') 15 | # DCCRN-R 16 | # model = dccrn('R') 17 | # DCCRN-C 18 | # model = dccrn('C') 19 | # DCCRN-CL 20 | model = dccrn('CL') 21 | 22 | outputs = model(inputs)[1] 23 | loss = model.loss(outputs, labels, loss_mode='SI-SNR') 24 | print(loss) 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import soundfile as sf 4 | import time 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.optim import Adam 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | from torchnet.meter import AverageValueMeter 11 | 12 | from models.DCCRN import dccrn 13 | from models.loss import SISNRLoss 14 | 15 | from dataloader.THCHS30 import THCHS30 16 | from config import opt 17 | 18 | 19 | def train(mode='CL'): 20 | model = dccrn(mode) 21 | model.to(opt.device) 22 | 23 | train_data = THCHS30(phase='train') 24 | train_loader = DataLoader(train_data, 25 | batch_size=opt.batch_size, 26 | num_workers=opt.num_workers, 27 | shuffle=True) 28 | 29 | optimizer = Adam(model.parameters(), lr=opt.lr) 30 | scheduler = MultiStepLR(optimizer, 31 | milestones=[int(opt.max_epoch * 0.5), 32 | int(opt.max_epoch * 0.7), 33 | int(opt.max_epoch * 0.9)], 34 | gamma=opt.lr_decay) 35 | criterion = SISNRLoss() 36 | 37 | loss_meter = AverageValueMeter() 38 | 39 | for epoch in range(0, opt.max_epoch): 40 | loss_meter.reset() 41 | for i, (data, label) in enumerate(train_loader): 42 | data = data.to(opt.device) 43 | label = label.to(opt.device) 44 | 45 | spec, wav = model(data) 46 | 47 | optimizer.zero_grad() 48 | loss = criterion(wav, label) 49 | loss.backward() 50 | optimizer.step() 51 | 52 | loss_meter.add(loss.item()) 53 | 54 | if (i + 1) % opt.verbose_inter == 0: 55 | print('epoch', epoch + 1, 'batch', i + 1, 56 | 'SI-SNR', -loss_meter.value()[0]) 57 | if (epoch + 1) % opt.save_inter == 0: 58 | print('save model at epoch {0} ...'.format(epoch + 1)) 59 | save_path = os.path.join(opt.checkpoint_root, 60 | 'DCCRN_{0}_{1}.pth'.format(mode, epoch + 1)) 61 | torch.save(model.state_dict(), save_path) 62 | 63 | scheduler.step() 64 | 65 | save_path = os.path.join(opt.checkpoint_root, 66 | 'DCCRN_{0}.pth'.format(mode)) 67 | torch.save(model.state_dict(), save_path) 68 | 69 | 70 | # when denoising, use cpu 71 | def denoise(mode, speech_file, save_dir, pth=None): 72 | assert os.path.exists(speech_file), 'speech file does not exist!' 73 | 74 | assert speech_file.endswith('.wav'), 'non-supported speech format!' 75 | 76 | if not os.path.exists(save_dir): 77 | print('warning: save directory does not exist, it will be created automatically!') 78 | os.makedirs(save_dir) 79 | 80 | model = dccrn(mode) 81 | if pth is not None: 82 | model.load_state_dict(torch.load(pth), strict=True) 83 | 84 | noisy_wav, _ = librosa.load(speech_file, sr=16000) 85 | 86 | noisy_wav = torch.Tensor(noisy_wav).reshape(1, -1) 87 | 88 | torch.cuda.synchronize() 89 | start = time.time() 90 | 91 | _, denoised_wav = model(noisy_wav) 92 | 93 | torch.cuda.synchronize() 94 | end = time.time() 95 | 96 | print('process time {0}s on device {1}'.format(end - start, 'cpu')) 97 | 98 | speech_name = os.path.basename(speech_file)[:-4] 99 | 100 | noisy_path = os.path.join(save_dir, speech_name + '_' + 'noisy' + '.wav') 101 | denoised_path = os.path.join(save_dir, speech_name + '_' + 'denoised' + '.wav') 102 | 103 | noisy_wav = noisy_wav.data.numpy().flatten() 104 | denoised_wav = denoised_wav.data.numpy().flatten() 105 | 106 | sf.write(noisy_path, noisy_wav, 16000) 107 | sf.write(denoised_path, denoised_wav, 16000) 108 | 109 | 110 | if __name__ == '__main__': 111 | # train('E') 112 | 113 | test_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/test/noisy') 114 | test_speech = os.path.join(test_speech_base, 'D11_752_car.wav') 115 | 116 | save_dir = os.path.join(opt.sample_root, 'THCHS-30') 117 | pth = os.path.join(opt.checkpoint_root, 'DCCRN_E.pth') 118 | 119 | denoise('E', test_speech, save_dir, pth=pth) 120 | 121 | pass 122 | -------------------------------------------------------------------------------- /models/DCCRN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.show import show_params, show_model 4 | import torch.nn.functional as F 5 | from models.conv_stft import ConvSTFT, ConviSTFT 6 | 7 | from models.complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 8 | 9 | 10 | class DCCRN(nn.Module): 11 | 12 | def __init__( 13 | self, 14 | rnn_layers=2, 15 | rnn_units=128, 16 | win_len=400, 17 | win_inc=100, 18 | fft_len=512, 19 | win_type='hanning', 20 | masking_mode='E', 21 | use_clstm=False, 22 | use_cbn=False, 23 | kernel_size=5, 24 | kernel_num=[16, 32, 64, 128, 256, 256] 25 | ): 26 | ''' 27 | 28 | rnn_layers: the number of lstm layers in the crn, 29 | rnn_units: for clstm, rnn_units = real+imag 30 | 31 | ''' 32 | 33 | super(DCCRN, self).__init__() 34 | 35 | # for fft 36 | self.win_len = win_len 37 | self.win_inc = win_inc 38 | self.fft_len = fft_len 39 | self.win_type = win_type 40 | 41 | input_dim = win_len 42 | output_dim = win_len 43 | 44 | self.rnn_units = rnn_units 45 | self.input_dim = input_dim 46 | self.output_dim = output_dim 47 | self.hidden_layers = rnn_layers 48 | self.kernel_size = kernel_size 49 | # self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 50 | # self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 51 | self.kernel_num = [2] + kernel_num 52 | self.masking_mode = masking_mode 53 | self.use_clstm = use_clstm 54 | 55 | # bidirectional=True 56 | bidirectional = False 57 | fac = 2 if bidirectional else 1 58 | 59 | fix = True 60 | self.fix = fix 61 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 62 | self.istft = ConviSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 63 | 64 | self.encoder = nn.ModuleList() 65 | self.decoder = nn.ModuleList() 66 | for idx in range(len(self.kernel_num) - 1): 67 | self.encoder.append( 68 | nn.Sequential( 69 | # nn.ConstantPad2d([0, 0, 0, 0], 0), 70 | ComplexConv2d( 71 | self.kernel_num[idx], 72 | self.kernel_num[idx + 1], 73 | kernel_size=(self.kernel_size, 2), 74 | stride=(2, 1), 75 | padding=(2, 1) 76 | ), 77 | nn.BatchNorm2d(self.kernel_num[idx + 1]) if not use_cbn else ComplexBatchNorm( 78 | self.kernel_num[idx + 1]), 79 | nn.PReLU() 80 | ) 81 | ) 82 | hidden_dim = self.fft_len // (2 ** (len(self.kernel_num))) 83 | 84 | if self.use_clstm: 85 | rnns = [] 86 | for idx in range(rnn_layers): 87 | rnns.append( 88 | NavieComplexLSTM( 89 | input_size=hidden_dim * self.kernel_num[-1] if idx == 0 else self.rnn_units, 90 | hidden_size=self.rnn_units, 91 | bidirectional=bidirectional, 92 | batch_first=False, 93 | projection_dim=hidden_dim * self.kernel_num[-1] if idx == rnn_layers - 1 else None, 94 | ) 95 | ) 96 | self.enhance = nn.Sequential(*rnns) 97 | else: 98 | self.enhance = nn.LSTM( 99 | input_size=hidden_dim * self.kernel_num[-1], 100 | hidden_size=self.rnn_units, 101 | num_layers=2, 102 | dropout=0.0, 103 | bidirectional=bidirectional, 104 | batch_first=False 105 | ) 106 | self.tranform = nn.Linear(self.rnn_units * fac, hidden_dim * self.kernel_num[-1]) 107 | 108 | for idx in range(len(self.kernel_num) - 1, 0, -1): 109 | if idx != 1: 110 | self.decoder.append( 111 | nn.Sequential( 112 | ComplexConvTranspose2d( 113 | self.kernel_num[idx] * 2, 114 | self.kernel_num[idx - 1], 115 | kernel_size=(self.kernel_size, 2), 116 | stride=(2, 1), 117 | padding=(2, 0), 118 | output_padding=(1, 0) 119 | ), 120 | nn.BatchNorm2d(self.kernel_num[idx - 1]) if not use_cbn else ComplexBatchNorm( 121 | self.kernel_num[idx - 1]), 122 | # nn.ELU() 123 | nn.PReLU() 124 | ) 125 | ) 126 | else: 127 | self.decoder.append( 128 | nn.Sequential( 129 | ComplexConvTranspose2d( 130 | self.kernel_num[idx] * 2, 131 | self.kernel_num[idx - 1], 132 | kernel_size=(self.kernel_size, 2), 133 | stride=(2, 1), 134 | padding=(2, 0), 135 | output_padding=(1, 0) 136 | ), 137 | ) 138 | ) 139 | 140 | show_model(self) 141 | show_params(self) 142 | self.flatten_parameters() 143 | 144 | def flatten_parameters(self): 145 | if isinstance(self.enhance, nn.LSTM): 146 | self.enhance.flatten_parameters() 147 | 148 | def forward(self, inputs, lens=None): 149 | specs = self.stft(inputs) 150 | real = specs[:, :self.fft_len // 2 + 1] 151 | imag = specs[:, self.fft_len // 2 + 1:] 152 | spec_mags = torch.sqrt(real ** 2 + imag ** 2 + 1e-8) 153 | spec_mags = spec_mags 154 | spec_phase = torch.atan2(imag, real) 155 | spec_phase = spec_phase 156 | cspecs = torch.stack([real, imag], 1) 157 | cspecs = cspecs[:, :, 1:] 158 | ''' 159 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 160 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 161 | normed_cspecs = (cspecs-means)/(std+1e-8) 162 | out = normed_cspecs 163 | ''' 164 | 165 | out = cspecs 166 | encoder_out = [] 167 | 168 | for idx, layer in enumerate(self.encoder): 169 | out = layer(out) 170 | # print('encoder', out.size()) 171 | encoder_out.append(out) 172 | 173 | batch_size, channels, dims, lengths = out.size() 174 | out = out.permute(3, 0, 1, 2) 175 | if self.use_clstm: 176 | r_rnn_in = out[:, :, :channels // 2] 177 | i_rnn_in = out[:, :, channels // 2:] 178 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2 * dims]) 179 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2 * dims]) 180 | 181 | r_rnn_in, i_rnn_in = self.enhance([r_rnn_in, i_rnn_in]) 182 | 183 | r_rnn_in = torch.reshape(r_rnn_in, [lengths, batch_size, channels // 2, dims]) 184 | i_rnn_in = torch.reshape(i_rnn_in, [lengths, batch_size, channels // 2, dims]) 185 | out = torch.cat([r_rnn_in, i_rnn_in], 2) 186 | 187 | else: 188 | # to [L, B, C, D] 189 | out = torch.reshape(out, [lengths, batch_size, channels * dims]) 190 | out, _ = self.enhance(out) 191 | out = self.tranform(out) 192 | out = torch.reshape(out, [lengths, batch_size, channels, dims]) 193 | 194 | out = out.permute(1, 2, 3, 0) 195 | 196 | for idx in range(len(self.decoder)): 197 | out = complex_cat([out, encoder_out[-1 - idx]], 1) 198 | out = self.decoder[idx](out) 199 | out = out[..., 1:] 200 | # print('decoder', out.size()) 201 | mask_real = out[:, 0] 202 | mask_imag = out[:, 1] 203 | mask_real = F.pad(mask_real, [0, 0, 1, 0]) 204 | mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) 205 | 206 | if self.masking_mode == 'E': 207 | mask_mags = (mask_real ** 2 + mask_imag ** 2) ** 0.5 208 | real_phase = mask_real / (mask_mags + 1e-8) 209 | imag_phase = mask_imag / (mask_mags + 1e-8) 210 | mask_phase = torch.atan2( 211 | imag_phase, 212 | real_phase 213 | ) 214 | 215 | # mask_mags = torch.clamp_(mask_mags,0,100) 216 | mask_mags = torch.tanh(mask_mags) 217 | est_mags = mask_mags * spec_mags 218 | est_phase = spec_phase + mask_phase 219 | real = est_mags * torch.cos(est_phase) 220 | imag = est_mags * torch.sin(est_phase) 221 | elif self.masking_mode == 'C': 222 | real, imag = real * mask_real - imag * mask_imag, real * mask_imag + imag * mask_real 223 | elif self.masking_mode == 'R': 224 | real, imag = real * mask_real, imag * mask_imag 225 | 226 | out_spec = torch.cat([real, imag], 1) 227 | out_wav = self.istft(out_spec) 228 | 229 | out_wav = torch.squeeze(out_wav, 1) 230 | # out_wav = torch.tanh(out_wav) 231 | # add _ to be a in-place operation 232 | out_wav = torch.clamp_(out_wav, -1, 1) 233 | return out_spec, out_wav 234 | 235 | def get_params(self, weight_decay=0.0): 236 | # add L2 penalty 237 | weights, biases = [], [] 238 | for name, param in self.named_parameters(): 239 | if 'bias' in name: 240 | biases += [param] 241 | else: 242 | weights += [param] 243 | params = [{ 244 | 'params': weights, 245 | 'weight_decay': weight_decay, 246 | }, { 247 | 'params': biases, 248 | 'weight_decay': 0.0, 249 | }] 250 | return params 251 | 252 | 253 | def dccrn(mode='CL'): 254 | if mode == 'E': 255 | model = DCCRN(rnn_units=256, masking_mode='E') 256 | elif mode == 'R': 257 | model = DCCRN(rnn_units=256, masking_mode='R') 258 | elif mode == 'C': 259 | model = DCCRN(rnn_units=256, masking_mode='C') 260 | elif mode == 'CL': 261 | model = DCCRN(rnn_units=256, masking_mode='E', 262 | use_clstm=True, kernel_num=[32, 64, 128, 256, 256, 256]) 263 | else: 264 | raise Exception('non-supported mode!') 265 | return model 266 | -------------------------------------------------------------------------------- /models/complexnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def get_casual_padding1d(): 8 | pass 9 | 10 | 11 | def get_casual_padding2d(): 12 | pass 13 | 14 | 15 | class cPReLU(nn.Module): 16 | 17 | def __init__(self, complex_axis=1): 18 | super(cPReLU, self).__init__() 19 | self.r_prelu = nn.PReLU() 20 | self.i_prelu = nn.PReLU() 21 | self.complex_axis = complex_axis 22 | 23 | def forward(self, inputs): 24 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 25 | real = self.r_prelu(real) 26 | imag = self.i_prelu(imag) 27 | return torch.cat([real, imag], self.complex_axis) 28 | 29 | 30 | class NavieComplexLSTM(nn.Module): 31 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False): 32 | super(NavieComplexLSTM, self).__init__() 33 | 34 | self.input_dim = input_size // 2 35 | self.rnn_units = hidden_size // 2 36 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 37 | batch_first=False) 38 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, 39 | batch_first=False) 40 | if bidirectional: 41 | bidirectional = 2 42 | else: 43 | bidirectional = 1 44 | if projection_dim is not None: 45 | self.projection_dim = projection_dim // 2 46 | self.r_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 47 | self.i_trans = nn.Linear(self.rnn_units * bidirectional, self.projection_dim) 48 | else: 49 | self.projection_dim = None 50 | 51 | def forward(self, inputs): 52 | if isinstance(inputs, list): 53 | real, imag = inputs 54 | elif isinstance(inputs, torch.Tensor): 55 | real, imag = torch.chunk(inputs, -1) 56 | r2r_out = self.real_lstm(real)[0] 57 | r2i_out = self.imag_lstm(real)[0] 58 | i2r_out = self.real_lstm(imag)[0] 59 | i2i_out = self.imag_lstm(imag)[0] 60 | real_out = r2r_out - i2i_out 61 | imag_out = i2r_out + r2i_out 62 | if self.projection_dim is not None: 63 | real_out = self.r_trans(real_out) 64 | imag_out = self.i_trans(imag_out) 65 | # print(real_out.shape,imag_out.shape) 66 | return [real_out, imag_out] 67 | 68 | def flatten_parameters(self): 69 | self.imag_lstm.flatten_parameters() 70 | self.real_lstm.flatten_parameters() 71 | 72 | 73 | def complex_cat(inputs, axis): 74 | real, imag = [], [] 75 | for idx, data in enumerate(inputs): 76 | r, i = torch.chunk(data, 2, axis) 77 | real.append(r) 78 | imag.append(i) 79 | real = torch.cat(real, axis) 80 | imag = torch.cat(imag, axis) 81 | outputs = torch.cat([real, imag], axis) 82 | return outputs 83 | 84 | 85 | class ComplexConv2d(nn.Module): 86 | 87 | def __init__( 88 | self, 89 | in_channels, 90 | out_channels, 91 | kernel_size=(1, 1), 92 | stride=(1, 1), 93 | padding=(0, 0), 94 | dilation=1, 95 | groups=1, 96 | causal=True, 97 | complex_axis=1, 98 | ): 99 | ''' 100 | in_channels: real+imag 101 | out_channels: real+imag 102 | kernel_size : input [B,C,D,T] kernel size in [D,T] 103 | padding : input [B,C,D,T] padding in [D,T] 104 | causal: if causal, will padding time dimension's left side, 105 | otherwise both 106 | 107 | ''' 108 | super(ComplexConv2d, self).__init__() 109 | self.in_channels = in_channels // 2 110 | self.out_channels = out_channels // 2 111 | self.kernel_size = kernel_size 112 | self.stride = stride 113 | self.padding = padding 114 | self.causal = causal 115 | self.groups = groups 116 | self.dilation = dilation 117 | self.complex_axis = complex_axis 118 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 119 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 120 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, 121 | padding=[self.padding[0], 0], dilation=self.dilation, groups=self.groups) 122 | 123 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 124 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 125 | nn.init.constant_(self.real_conv.bias, 0.) 126 | nn.init.constant_(self.imag_conv.bias, 0.) 127 | 128 | def forward(self, inputs): 129 | if self.padding[1] != 0 and self.causal: 130 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) 131 | else: 132 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 133 | 134 | if self.complex_axis == 0: 135 | real = self.real_conv(inputs) 136 | imag = self.imag_conv(inputs) 137 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 138 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 139 | 140 | else: 141 | if isinstance(inputs, torch.Tensor): 142 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 143 | 144 | real2real = self.real_conv(real, ) 145 | imag2imag = self.imag_conv(imag, ) 146 | 147 | real2imag = self.imag_conv(real) 148 | imag2real = self.real_conv(imag) 149 | 150 | real = real2real - imag2imag 151 | imag = real2imag + imag2real 152 | out = torch.cat([real, imag], self.complex_axis) 153 | 154 | return out 155 | 156 | 157 | class ComplexConvTranspose2d(nn.Module): 158 | 159 | def __init__( 160 | self, 161 | in_channels, 162 | out_channels, 163 | kernel_size=(1, 1), 164 | stride=(1, 1), 165 | padding=(0, 0), 166 | output_padding=(0, 0), 167 | causal=False, 168 | complex_axis=1, 169 | groups=1 170 | ): 171 | ''' 172 | in_channels: real+imag 173 | out_channels: real+imag 174 | ''' 175 | super(ComplexConvTranspose2d, self).__init__() 176 | self.in_channels = in_channels // 2 177 | self.out_channels = out_channels // 2 178 | self.kernel_size = kernel_size 179 | self.stride = stride 180 | self.padding = padding 181 | self.output_padding = output_padding 182 | self.groups = groups 183 | 184 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 185 | padding=self.padding, output_padding=output_padding, groups=self.groups) 186 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size, self.stride, 187 | padding=self.padding, output_padding=output_padding, groups=self.groups) 188 | self.complex_axis = complex_axis 189 | 190 | nn.init.normal_(self.real_conv.weight, std=0.05) 191 | nn.init.normal_(self.imag_conv.weight, std=0.05) 192 | nn.init.constant_(self.real_conv.bias, 0.) 193 | nn.init.constant_(self.imag_conv.bias, 0.) 194 | 195 | def forward(self, inputs): 196 | 197 | if isinstance(inputs, torch.Tensor): 198 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 199 | elif isinstance(inputs, tuple) or isinstance(inputs, list): 200 | real = inputs[0] 201 | imag = inputs[1] 202 | if self.complex_axis == 0: 203 | real = self.real_conv(inputs) 204 | imag = self.imag_conv(inputs) 205 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 206 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 207 | 208 | else: 209 | if isinstance(inputs, torch.Tensor): 210 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 211 | 212 | real2real = self.real_conv(real, ) 213 | imag2imag = self.imag_conv(imag, ) 214 | 215 | real2imag = self.imag_conv(real) 216 | imag2real = self.real_conv(imag) 217 | 218 | real = real2real - imag2imag 219 | imag = real2imag + imag2real 220 | out = torch.cat([real, imag], self.complex_axis) 221 | 222 | return out 223 | 224 | 225 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch 226 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55 227 | 228 | class ComplexBatchNorm(torch.nn.Module): 229 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 230 | track_running_stats=True, complex_axis=1): 231 | super(ComplexBatchNorm, self).__init__() 232 | self.num_features = num_features // 2 233 | self.eps = eps 234 | self.momentum = momentum 235 | self.affine = affine 236 | self.track_running_stats = track_running_stats 237 | 238 | self.complex_axis = complex_axis 239 | 240 | if self.affine: 241 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) 242 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) 243 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) 244 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) 245 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) 246 | else: 247 | self.register_parameter('Wrr', None) 248 | self.register_parameter('Wri', None) 249 | self.register_parameter('Wii', None) 250 | self.register_parameter('Br', None) 251 | self.register_parameter('Bi', None) 252 | 253 | if self.track_running_stats: 254 | self.register_buffer('RMr', torch.zeros(self.num_features)) 255 | self.register_buffer('RMi', torch.zeros(self.num_features)) 256 | self.register_buffer('RVrr', torch.ones(self.num_features)) 257 | self.register_buffer('RVri', torch.zeros(self.num_features)) 258 | self.register_buffer('RVii', torch.ones(self.num_features)) 259 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 260 | else: 261 | self.register_parameter('RMr', None) 262 | self.register_parameter('RMi', None) 263 | self.register_parameter('RVrr', None) 264 | self.register_parameter('RVri', None) 265 | self.register_parameter('RVii', None) 266 | self.register_parameter('num_batches_tracked', None) 267 | self.reset_parameters() 268 | 269 | def reset_running_stats(self): 270 | if self.track_running_stats: 271 | self.RMr.zero_() 272 | self.RMi.zero_() 273 | self.RVrr.fill_(1) 274 | self.RVri.zero_() 275 | self.RVii.fill_(1) 276 | self.num_batches_tracked.zero_() 277 | 278 | def reset_parameters(self): 279 | self.reset_running_stats() 280 | if self.affine: 281 | self.Br.data.zero_() 282 | self.Bi.data.zero_() 283 | self.Wrr.data.fill_(1) 284 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite 285 | self.Wii.data.fill_(1) 286 | 287 | def _check_input_dim(self, xr, xi): 288 | assert (xr.shape == xi.shape) 289 | assert (xr.size(1) == self.num_features) 290 | 291 | def forward(self, inputs): 292 | # self._check_input_dim(xr, xi) 293 | 294 | xr, xi = torch.chunk(inputs, 2, axis=self.complex_axis) 295 | exponential_average_factor = 0.0 296 | 297 | if self.training and self.track_running_stats: 298 | self.num_batches_tracked += 1 299 | if self.momentum is None: # use cumulative moving average 300 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 301 | else: # use exponential moving average 302 | exponential_average_factor = self.momentum 303 | 304 | # 305 | # NOTE: The precise meaning of the "training flag" is: 306 | # True: Normalize using batch statistics, update running statistics 307 | # if they are being collected. 308 | # False: Normalize using running statistics, ignore batch statistics. 309 | # 310 | training = self.training or not self.track_running_stats 311 | redux = [i for i in reversed(range(xr.dim())) if i != 1] 312 | vdim = [1] * xr.dim() 313 | vdim[1] = xr.size(1) 314 | 315 | # 316 | # Mean M Computation and Centering 317 | # 318 | # Includes running mean update if training and running. 319 | # 320 | if training: 321 | Mr, Mi = xr, xi 322 | for d in redux: 323 | Mr = Mr.mean(d, keepdim=True) 324 | Mi = Mi.mean(d, keepdim=True) 325 | if self.track_running_stats: 326 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) 327 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) 328 | else: 329 | Mr = self.RMr.view(vdim) 330 | Mi = self.RMi.view(vdim) 331 | xr, xi = xr - Mr, xi - Mi 332 | 333 | # 334 | # Variance Matrix V Computation 335 | # 336 | # Includes epsilon numerical stabilizer/Tikhonov regularizer. 337 | # Includes running variance update if training and running. 338 | # 339 | if training: 340 | Vrr = xr * xr 341 | Vri = xr * xi 342 | Vii = xi * xi 343 | for d in redux: 344 | Vrr = Vrr.mean(d, keepdim=True) 345 | Vri = Vri.mean(d, keepdim=True) 346 | Vii = Vii.mean(d, keepdim=True) 347 | if self.track_running_stats: 348 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) 349 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) 350 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) 351 | else: 352 | Vrr = self.RVrr.view(vdim) 353 | Vri = self.RVri.view(vdim) 354 | Vii = self.RVii.view(vdim) 355 | Vrr = Vrr + self.eps 356 | Vri = Vri 357 | Vii = Vii + self.eps 358 | 359 | # 360 | # Matrix Inverse Square Root U = V^-0.5 361 | # 362 | # sqrt of a 2x2 matrix, 363 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 364 | tau = Vrr + Vii 365 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) 366 | s = delta.sqrt() 367 | t = (tau + 2 * s).sqrt() 368 | 369 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html 370 | rst = (s * t).reciprocal() 371 | Urr = (s + Vii) * rst 372 | Uii = (s + Vrr) * rst 373 | Uri = (- Vri) * rst 374 | 375 | # 376 | # Optionally left-multiply U by affine weights W to produce combined 377 | # weights Z, left-multiply the inputs by Z, then optionally bias them. 378 | # 379 | # y = Zx + B 380 | # y = WUx + B 381 | # y = [Wrr Wri][Urr Uri] [xr] + [Br] 382 | # [Wir Wii][Uir Uii] [xi] [Bi] 383 | # 384 | if self.affine: 385 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) 386 | Zrr = (Wrr * Urr) + (Wri * Uri) 387 | Zri = (Wrr * Uri) + (Wri * Uii) 388 | Zir = (Wri * Urr) + (Wii * Uri) 389 | Zii = (Wri * Uri) + (Wii * Uii) 390 | else: 391 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii 392 | 393 | yr = (Zrr * xr) + (Zri * xi) 394 | yi = (Zir * xr) + (Zii * xi) 395 | 396 | if self.affine: 397 | yr = yr + self.Br.view(vdim) 398 | yi = yi + self.Bi.view(vdim) 399 | 400 | outputs = torch.cat([yr, yi], self.complex_axis) 401 | return outputs 402 | 403 | def extra_repr(self): 404 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 405 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 406 | 407 | 408 | def complex_cat(inputs, axis): 409 | real, imag = [], [] 410 | for idx, data in enumerate(inputs): 411 | r, i = torch.chunk(data, 2, axis) 412 | real.append(r) 413 | imag.append(i) 414 | real = torch.cat(real, axis) 415 | imag = torch.cat(imag, axis) 416 | outputs = torch.cat([real, imag], axis) 417 | return outputs 418 | 419 | 420 | if __name__ == '__main__': 421 | import dc_crn7 422 | 423 | torch.manual_seed(20) 424 | onet1 = dc_crn7.ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) 425 | onet2 = dc_crn7.ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) 426 | inputs = torch.randn([1, 12, 12, 10]) 427 | # print(onet1.real_kernel[0,0,0,0]) 428 | nnet1 = ComplexConv2d(12, 12, kernel_size=(3, 2), padding=(2, 1), causal=True) 429 | # print(nnet1.real_conv.weight[0,0,0,0]) 430 | nnet2 = ComplexConvTranspose2d(12, 12, kernel_size=(3, 2), padding=(2, 1)) 431 | print(torch.mean(nnet1(inputs) - onet1(inputs))) 432 | -------------------------------------------------------------------------------- /models/conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | if win_type == 'None' or win_type is None: 10 | window = np.ones(win_len) 11 | else: 12 | window = get_window(win_type, win_len, fftbins=True) # **0.5 13 | 14 | N = fft_len 15 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 16 | real_kernel = np.real(fourier_basis) 17 | imag_kernel = np.imag(fourier_basis) 18 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 19 | 20 | if invers: 21 | kernel = np.linalg.pinv(kernel).T 22 | 23 | kernel = kernel * window 24 | kernel = kernel[:, None, :] 25 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) 26 | 27 | 28 | class ConvSTFT(nn.Module): 29 | 30 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 31 | super(ConvSTFT, self).__init__() 32 | 33 | if fft_len == None: 34 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 35 | else: 36 | self.fft_len = fft_len 37 | 38 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 39 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 40 | self.register_buffer('weight', kernel) 41 | self.feature_type = feature_type 42 | self.stride = win_inc 43 | self.win_len = win_len 44 | self.dim = self.fft_len 45 | 46 | def forward(self, inputs): 47 | if inputs.dim() == 2: 48 | inputs = torch.unsqueeze(inputs, 1) 49 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride]) 50 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 51 | 52 | if self.feature_type == 'complex': 53 | return outputs 54 | else: 55 | dim = self.dim // 2 + 1 56 | real = outputs[:, :dim, :] 57 | imag = outputs[:, dim:, :] 58 | mags = torch.sqrt(real ** 2 + imag ** 2) 59 | phase = torch.atan2(imag, real) 60 | return mags, phase 61 | 62 | 63 | class ConviSTFT(nn.Module): 64 | 65 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 66 | super(ConviSTFT, self).__init__() 67 | if fft_len == None: 68 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 69 | else: 70 | self.fft_len = fft_len 71 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 72 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 73 | self.register_buffer('weight', kernel) 74 | self.feature_type = feature_type 75 | self.win_type = win_type 76 | self.win_len = win_len 77 | self.stride = win_inc 78 | self.stride = win_inc 79 | self.dim = self.fft_len 80 | self.register_buffer('window', window) 81 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) 82 | 83 | def forward(self, inputs, phase=None): 84 | """ 85 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 86 | phase: [B, N//2+1, T] (if not none) 87 | """ 88 | 89 | if phase is not None: 90 | real = inputs * torch.cos(phase) 91 | imag = inputs * torch.sin(phase) 92 | inputs = torch.cat([real, imag], 1) 93 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 94 | 95 | # this is from torch-stft: https://github.com/pseeth/torch-stft 96 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2 97 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 98 | outputs = outputs / (coff + 1e-8) 99 | # outputs = torch.where(coff == 0, outputs, outputs/coff) 100 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)] 101 | 102 | return outputs 103 | 104 | 105 | def test_fft(): 106 | torch.manual_seed(20) 107 | win_len = 320 108 | win_inc = 160 109 | fft_len = 512 110 | inputs = torch.randn([1, 1, 16000 * 4]) 111 | fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') 112 | import librosa 113 | 114 | outputs1 = fft(inputs)[0] 115 | outputs1 = outputs1.numpy()[0] 116 | np_inputs = inputs.numpy().reshape([-1]) 117 | librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) 118 | print(np.mean((outputs1 - np.abs(librosa_stft)) ** 2)) 119 | 120 | 121 | def test_ifft1(): 122 | import soundfile as sf 123 | N = 400 124 | inc = 100 125 | fft_len = 512 126 | torch.manual_seed(N) 127 | data = np.random.randn(16000 * 8)[None, None, :] 128 | # data = sf.read('../ori.wav')[0] 129 | inputs = data.reshape([1, 1, -1]) 130 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 131 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 132 | inputs = torch.from_numpy(inputs.astype(np.float32)) 133 | outputs1 = fft(inputs) 134 | print(outputs1.shape) 135 | outputs2 = ifft(outputs1) 136 | sf.write('conv_stft.wav', outputs2.numpy()[0, 0, :], 16000) 137 | print('wav MSE', torch.mean(torch.abs(inputs[..., :outputs2.size(2)] - outputs2) ** 2)) 138 | 139 | 140 | def test_ifft2(): 141 | N = 400 142 | inc = 100 143 | fft_len = 512 144 | np.random.seed(20) 145 | torch.manual_seed(20) 146 | t = np.random.randn(16000 * 4) * 0.001 147 | t = np.clip(t, -1, 1) 148 | # input = torch.randn([1,16000*4]) 149 | input = torch.from_numpy(t[None, None, :].astype(np.float32)) 150 | 151 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 152 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 153 | 154 | out1 = fft(input) 155 | output = ifft(out1) 156 | print('random MSE', torch.mean(torch.abs(input - output) ** 2)) 157 | import soundfile as sf 158 | sf.write('zero.wav', output[0, 0].numpy(), 16000) 159 | 160 | 161 | if __name__ == '__main__': 162 | # test_fft() 163 | test_ifft1() 164 | # test_ifft2() 165 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def remove_dc(data): 7 | mean = torch.mean(data, -1, keepdim=True) 8 | data = data - mean 9 | return data 10 | 11 | 12 | def l2_norm(s1, s2): 13 | # norm = torch.sqrt(torch.sum(s1*s2, 1, keepdim=True)) 14 | # norm = torch.norm(s1*s2, 1, keepdim=True) 15 | 16 | norm = torch.sum(s1 * s2, -1, keepdim=True) 17 | return norm 18 | 19 | 20 | def si_snr(s1, s2, eps=1e-8): 21 | # s1 = remove_dc(s1) 22 | # s2 = remove_dc(s2) 23 | s1_s2_norm = l2_norm(s1, s2) 24 | s2_s2_norm = l2_norm(s2, s2) 25 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 26 | e_nosie = s1 - s_target 27 | target_norm = l2_norm(s_target, s_target) 28 | noise_norm = l2_norm(e_nosie, e_nosie) 29 | snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps) 30 | return torch.mean(snr) 31 | 32 | 33 | # The larger the SI-SNR, the better the model 34 | class SISNRLoss(nn.Module): 35 | def __init__(self, eps=1e-8): 36 | super().__init__() 37 | self.eps = eps 38 | 39 | def forward(self, x, y): 40 | # return -torch.mean(si_snr(inputs, labels)) 41 | return -(si_snr(x, y, eps=self.eps)) 42 | 43 | 44 | class MSELoss(nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, x, y): 49 | b, d, t = x.shape 50 | y[:, 0, :] = 0 51 | y[:, d // 2, :] = 0 52 | return F.mse_loss(x, y, reduction='mean') * d 53 | 54 | 55 | class MAELoss(nn.Module): 56 | def __init__(self, stft): 57 | super().__init__() 58 | self.stft = stft 59 | 60 | def forward(self, x, y): 61 | gth_spec, gth_phase = self.stft(y) 62 | b, d, t = x.shape 63 | return torch.mean(torch.abs(x - gth_spec)) * d 64 | -------------------------------------------------------------------------------- /utils/show.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -u 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | 11 | def show_params(nnet): 12 | print("=" * 40, "Model Parameters", "=" * 40) 13 | num_params = 0 14 | for module_name, m in nnet.named_modules(): 15 | if module_name == '': 16 | for name, params in m.named_parameters(): 17 | print(name, params.size()) 18 | i = 1 19 | for j in params.size(): 20 | i = i * j 21 | num_params += i 22 | print('[*] Parameter Size: {}'.format(num_params)) 23 | print("=" * 98) 24 | 25 | 26 | def show_model(nnet): 27 | print("=" * 40, "Model Structures", "=" * 40) 28 | for module_name, m in nnet.named_modules(): 29 | if module_name == '': 30 | print(m) 31 | print("=" * 98) 32 | -------------------------------------------------------------------------------- /utils/synthesizer.py: -------------------------------------------------------------------------------- 1 | from scipy.io import wavfile 2 | import numpy as np 3 | import soundfile as sf 4 | import librosa 5 | import random 6 | import os 7 | from config import opt 8 | 9 | 10 | # split origin noise file to gain better generalization performance 11 | def split_noise(noise_file, save_dir, prop=0.5): 12 | assert os.path.exists(noise_file), 'noise file does not exist!' 13 | 14 | assert noise_file.endswith('.wav'), 'non-supported noise format!' 15 | 16 | if not os.path.exists(save_dir): 17 | print('warning: save directory does not exist, it will be created automatically.') 18 | os.makedirs(save_dir) 19 | 20 | sample_rate, sig = wavfile.read(noise_file) 21 | 22 | train_len = sig.shape[0] * prop 23 | 24 | train_noise = sig[:int(train_len)] 25 | test_noise = sig[int(train_len):] 26 | 27 | # remove .wav 28 | noise_name = os.path.basename(noise_file)[:-4] 29 | 30 | train_noise_dir = os.path.join(save_dir, 'train') 31 | test_noise_dir = os.path.join(save_dir, 'test') 32 | 33 | if not os.path.exists(train_noise_dir): 34 | os.makedirs(train_noise_dir) 35 | if not os.path.exists(test_noise_dir): 36 | os.makedirs(test_noise_dir) 37 | 38 | train_noise_path = os.path.join(train_noise_dir, noise_name + '.wav') 39 | test_noise_path = os.path.join(test_noise_dir, noise_name + '.wav') 40 | 41 | sf.write(train_noise_path, train_noise, sample_rate) 42 | sf.write(test_noise_path, test_noise, sample_rate) 43 | 44 | 45 | def synthesize_noisy_speech(speech_file, noise_file, save_dir, snr=0): 46 | assert os.path.exists(speech_file), 'speech file does not exist!' 47 | assert os.path.exists(noise_file), 'noise file does not exist!' 48 | 49 | assert speech_file.endswith('.wav'), 'non-supported speech format!' 50 | assert noise_file.endswith('.wav'), 'non-supported noise format!' 51 | 52 | if not os.path.exists(save_dir): 53 | print('warning: save directory does not exist, it will be created automatically.') 54 | os.makedirs(save_dir) 55 | 56 | speech_name = os.path.basename(speech_file)[:-4] 57 | noise_name = os.path.basename(noise_file)[:-4] 58 | 59 | # 原始语音 60 | a, a_sr = librosa.load(speech_file, sr=16000) 61 | # 噪音 62 | b, b_sr = librosa.load(noise_file, sr=16000) 63 | # 随机取一段噪声,保证长度和纯净语音长度一致,保证不会越界 64 | start = random.randint(0, b.shape[0] - a.shape[0]) 65 | # 切片 66 | n_b = b[int(start):int(start) + a.shape[0]] 67 | 68 | # 平方求和 69 | sum_s = np.sum(a ** 2) 70 | sum_n = np.sum(n_b ** 2) 71 | # 信噪比为snr时的权重 72 | x = np.sqrt(sum_s / (sum_n * pow(10, snr))) 73 | 74 | noise = x * n_b 75 | noisy_speech = a + noise 76 | 77 | noisy_dir = os.path.join(save_dir, '{0}dB'.format(snr), 'noisy') 78 | clean_dir = os.path.join(save_dir, '{0}dB'.format(snr), 'clean') 79 | 80 | if not os.path.exists(noisy_dir): 81 | os.makedirs(noisy_dir) 82 | if not os.path.exists(clean_dir): 83 | os.makedirs(clean_dir) 84 | 85 | noisy_speech_path = os.path.join(noisy_dir, speech_name + '_' + noise_name + '.wav') 86 | clean_speech_path = os.path.join(clean_dir, speech_name + '.wav') 87 | 88 | sf.write(noisy_speech_path, noisy_speech, 16000) 89 | sf.write(clean_speech_path, a, 16000) 90 | 91 | 92 | # split noise for train and test 93 | def generate_noise_dataset(noise_base, save_dir): 94 | print('noise base directory: ', noise_base) 95 | print('output directory: ', save_dir) 96 | # find all noise file and split them with custom proportion 97 | for dir in os.listdir(noise_base): 98 | noise_dir = os.path.join(noise_base, dir) 99 | for file in os.listdir(noise_dir): 100 | if file.endswith('.wav'): 101 | noise_file = os.path.join(noise_dir, file) 102 | split_noise(noise_file, save_dir, prop=0.5) 103 | print('succesfully generated noise dataset!') 104 | 105 | 106 | def generate_noisy_dataset(speech_base, noise_base, save_dir): 107 | print('speech base directory: ', speech_base) 108 | print('output directory: ', save_dir) 109 | noise_files = [] 110 | for file in os.listdir(noise_base): 111 | if file.endswith('.wav'): 112 | noise_files.append(os.path.join(noise_base, file)) 113 | for file in os.listdir(speech_base): 114 | if file.endswith('.wav'): 115 | speech_file = os.path.join(speech_base, file) 116 | noise_file = random.choice(noise_files) 117 | synthesize_noisy_speech(speech_file, noise_file, save_dir=save_dir, snr=0) 118 | print('successfully generate noisy dataset!') 119 | 120 | 121 | if __name__ == '__main__': 122 | # origin speech data path 123 | noise_base = os.path.join(opt.data_root, 'THCHS-30', 'test-noise/noise') 124 | train_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_thchs30/train') 125 | test_speech_base = os.path.join(opt.data_root, 'THCHS-30', 'data_thchs30/test') 126 | 127 | # synthesized speech data path 128 | noise_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/noise') 129 | train_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/train') 130 | test_dir = os.path.join(opt.data_root, 'THCHS-30', 'data_synthesized/test') 131 | 132 | # split origin noise for train and test 133 | # generate_noise_dataset(noise_base=noise_base, save_dir=noise_dir) 134 | 135 | # generate train noisy speech 136 | generate_noisy_dataset(speech_base=train_speech_base, 137 | noise_base=os.path.join(noise_dir, 'train'), 138 | save_dir=train_dir) 139 | 140 | # generate test noisy speech 141 | generate_noisy_dataset(speech_base=test_speech_base, 142 | noise_base=os.path.join(noise_dir, 'test'), 143 | save_dir=test_dir) 144 | --------------------------------------------------------------------------------