├── README.md ├── ckpt_dir_stoi └── 5_448_stacked_dropout0_2.pt ├── clean.wav ├── inference.py ├── load_dataset.py ├── load_single_data.py ├── mixed.wav ├── models ├── attention.py └── layers │ └── istft.py ├── run_inference.sh ├── test.py ├── test_single.py ├── train.py ├── train_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Attention-SE.pytorch 2 | AN ATTENTION-BASED NEURAL NETWORK APPROACH FOR SINGLE CHANNEL SPEECH ENHANCEMENT 3 | 4 | ## Requirements 5 | 6 | * PYSTOI 7 | * Toolbox : [mpariente/pystoi](https://github.com/mpariente/pystoi) 8 | * PYPESQ 9 | * ToolBox : [vBaiCai/python-pesq](https://github.com/vBaiCai/python-pesq) 10 | * Other 11 | ```python3 12 | torch == 1.2.0 13 | numpy 14 | librosa 15 | ``` 16 | 17 | ## DataSet 18 | Noise dataset uses 'MUSAN' dataset 19 | We generate noisy mixed data from dataset 20 | 21 | |Set|Train|Valid|Test|Test2|Test3|Test4| 22 | |---|-----|-----|----|-----|-----|-----| 23 | |Noise|Musan|Musan|Musan|Musan|DEMAND|DEMAND| 24 | |SNR|0-20dB|0-20dB|0-20dB|-5-0dB|0-20dB|-5-0dB| 25 | 26 | #### Musan data set 27 | Speech and Noise recording 28 | - URL : http://www.openslr.org/17/ 29 | - Cites 30 | ```BibTeX 31 | @misc{musan2015, 32 | author = {David Snyder and Guoguo Chen and Daniel Povey}, 33 | title = {{MUSAN}: {A} {M}usic, {S}peech, and {N}oise {C}orpus}, 34 | year = {2015}, 35 | eprint = {1510.08484}, 36 | note = {arXiv:1510.08484v1} 37 | } 38 | ``` 39 | 40 | #### Data Loading Script 41 | Reference URL : 42 | - https://github.com/sweetcocoa/DeepComplexUNetPyTorch 43 | - https://github.com/jtkim-kaist/Speech-enhancement 44 | - https://github.com/chanil1218/DCUnet.pytorch 45 | 46 | - load_dataset.py 47 | ```Python3 48 | # load_dataset 49 | # Below is how to use data loader 50 | # data_type = ['val', 'train', 'test', 'test2', 'test3', 'test4'] 51 | 52 | import load_dataset from AudioDataset 53 | 54 | train_dataset = AudioDataset(data_type='train') 55 | train_data_loader = DataLoader(dataset=train_dataset, batch_size=4, collate_fn=train_dataset.collate, shuffle=True, num_workers=4) 56 | ``` 57 | 58 | - load_single_data.py 59 | ```Python3 60 | # load_single_data 61 | # Below is how to use data loader 62 | # data_type = ['val', 'train', 'test', 'test2', 'test3', 'test4'] 63 | 64 | import load_single_data from AudioDataset 65 | 66 | train_data = AudioDataset(data_type='train', data_name="p287_171") 67 | train_data_loader = DataLoader(dataset=train_data, collate_fn=train_data.collate, shuffle=True, num_workers=4) 68 | ``` 69 | 70 | 71 | 72 | ## Attention-based SE Model 73 | Reference : Xiang Hao 'AN ATTENTION-BASED NEURAL NETWORK APPROACH FOR SINGLE CHANNEL 74 | SPEECH ENHANCEMENT', 2019 75 | URL : http://lxie.nwpu-aslp.org/papers/2019ICASSP-XiangHao.pdf 76 | 77 | 78 | ## Train 79 | Arguments : 80 | - batch_size : Train batch size, default = 128 81 | - dropout_p : Attention model's dropout rate, default = 0 82 | - attn_use : Use Attention model, if it is False, Train with LSTM model. default = False 83 | - stacked_encoder : Use Stacked attention model, if it is False, Train with Extanded Attention model. default = False 84 | - hidden_size : Size of RNN. default = 0 85 | - num_epochs : Train epochs number. default = 100 86 | - learning_rate : Training Learning rate. default = 5e-4 87 | - ck_name : Name with save/load check point. default = 'SEckpt.pt' 88 | 89 | ```bash 90 | CUDA_VISIBLE_DEVICES=GPUNUMBERS \ 91 | python3 train.py --batch_size 128 \ 92 | --dropout_p 0.2\ 93 | --attn_use True \ 94 | --stacked_encoder True\ 95 | --attn_len 5\ 96 | --hidden_size 448\ 97 | --num_epochs 61\ 98 | --ck_name '5_448_stacked_dropout0_2.pt' 99 | 100 | # You can check other arguments from the source code. 101 | ``` 102 | 103 | ## Test 104 | Test print mean loss, PESQ, and STOI. 105 | Arguments : 106 | - batch_size : Train batch size, default = 128 107 | - dropout_p : Attention model's dropout rate, default = 0 108 | - attn_use : Use Attention model, if it is False, Train with LSTM model. default = False 109 | - stacked_encoder : Use Stacked attention model, if it is False, Train with Extanded Attention model. default = False 110 | - hidden_size : Size of RNN. default = 0 111 | - ck_name : Name with load check point. default = 'SEckpt.pt' 112 | - test_set : Name of data_type 113 | 114 | ```bash 115 | CUDA_VISIBLE_DEVICES=GPUNUMBERS \ 116 | python3 test.py --batch_size 128 \ 117 | --dropout_p 0.2\ 118 | --attn_use True \ 119 | --stacked_encoder True\ 120 | --attn_len 5\ 121 | --hidden_size 448\ 122 | --test_set 'test'\ 123 | --ck_name '5_448_stacked_dropout0_2.pt' 124 | 125 | # You can check other arguments from the source code. 126 | ``` 127 | 128 | ## Single file test 129 | Single file test return sample outputs with .wav files. 130 | - clean.wav : select clean voice data 131 | - mixed.wav : noisy voice data 132 | - out.wav : return output from model 133 | Arguments : 134 | - batch_size : Train batch size, default = 128 135 | - dropout_p : Attention model's dropout rate, default = 0 136 | - attn_use : Use Attention model, if it is False, Train with LSTM model. default = False 137 | - stacked_encoder : Use Stacked attention model, if it is False, Train with Extanded Attention model. default = False 138 | - hidden_size : Size of RNN. default = 0 139 | - ck_name : Name with load check point. default = 'SEckpt.pt' 140 | - test_set : Name of data_type 141 | - wav : Name of clean data 142 | 143 | ```bash 144 | CUDA_VISIBLE_DEVICES=GPUNUMBERS \ 145 | python3 test_single.py --batch_size 128 \ 146 | --dropout_p 0.2\ 147 | --attn_use True \ 148 | --stacked_encoder True\ 149 | --attn_len 5\ 150 | --hidden_size 448\ 151 | --test_set 'test'\ 152 | --wav 'p232_238' 153 | --ck_name '5_448_stacked_dropout0_2.pt' 154 | 155 | 156 | # You can check other arguments from the source code. 157 | ``` 158 | -------------------------------------------------------------------------------- /ckpt_dir_stoi/5_448_stacked_dropout0_2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chanil1218/Attention-SE.pytorch/5ded948c572bc81d6b0ebcc09c83edfe136a1a4e/ckpt_dir_stoi/5_448_stacked_dropout0_2.pt -------------------------------------------------------------------------------- /clean.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chanil1218/Attention-SE.pytorch/5ded948c572bc81d6b0ebcc09c83edfe136a1a4e/clean.wav -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import librosa 12 | 13 | from pystoi.stoi import stoi 14 | from pypesq import pesq 15 | 16 | from models.layers.istft import ISTFT 17 | from models.attention import AttentionModel 18 | import utils 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--dropout_p', default = 0, type=float, help='Attention model drop out rate') 22 | parser.add_argument('--stacked_encoder', default = False, type = bool) 23 | parser.add_argument('--attn_len', default = 0, type = int) 24 | parser.add_argument('--hidden_size', default = 112, type = int) 25 | parser.add_argument('--ckpt_path', default = 'ckpt_dir', help = 'ck path') 26 | parser.add_argument('--test_set', help = 'test_set') 27 | parser.add_argument('--attn_use', default = False, type=bool) 28 | parser.add_argument('--noisy_wav') 29 | parser.add_argument('--clean_wav') 30 | args = parser.parse_args() 31 | 32 | 33 | n_fft, hop_length = 512, 128 34 | window = torch.hann_window(n_fft) 35 | # STFT 36 | stft = lambda x: torch.stft(x, n_fft, hop_length, window=window) 37 | # ISTFT 38 | istft = ISTFT(n_fft, hop_length, window='hanning') 39 | 40 | def main(): 41 | #model select 42 | print('Model initializing\n') 43 | net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len)) 44 | 45 | #Check point load 46 | print('Trying Checkpoint Load\n') 47 | best_PESQ = 0. 48 | best_STOI = 0. 49 | ckpt_path = args.ckpt_path 50 | 51 | if os.path.exists(ckpt_path): 52 | ckpt = torch.load(ckpt_path) 53 | try: 54 | net.load_state_dict(ckpt['model']) 55 | net = net.module # uncover DataParallel 56 | best_STOI = ckpt['best_STOI'] 57 | 58 | print('checkpoint is loaded !') 59 | print('current best loss : %.4f' % best_STOI) 60 | except RuntimeError as e: 61 | print('wrong checkpoint\n') 62 | else: 63 | print('checkpoint not exist!') 64 | print('current best loss : %.4f' % best_STOI) 65 | 66 | #test phase 67 | net.eval() 68 | with torch.no_grad(): 69 | inputData, sr = librosa.load(args.noisy_wav, sr=None) 70 | outputData, sr = librosa.load(args.clean_wav, sr=None) 71 | inputData = np.float32(inputData) 72 | outputData = np.float32(outputData) 73 | mixed_audio = torch.from_numpy(inputData).type(torch.FloatTensor) 74 | clean_audio = torch.from_numpy(outputData).type(torch.FloatTensor) 75 | 76 | mixed = stft(mixed_audio) 77 | mixed = mixed.unsqueeze(0) 78 | mixed = mixed.transpose(1,2) 79 | cleaned = stft(clean_audio) 80 | cleaned = cleaned.unsqueeze(0) 81 | cleaned = cleaned.transpose(1,2) 82 | real, imag = mixed[..., 0], mixed[..., 1] 83 | clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1] 84 | mag = torch.sqrt(real**2 + imag**2) 85 | clean_mag = torch.sqrt(clean_real**2 + clean_imag**2) 86 | phase = torch.atan2(imag, real) 87 | 88 | logits_mag, logits_attn_weight = net(mag) 89 | logits_real = logits_mag * torch.cos(phase) 90 | logits_imag = logits_mag * torch.sin(phase) 91 | logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1) 92 | logits_real = logits_real.transpose(1,2) 93 | logits_imag = logits_imag.transpose(1,2) 94 | 95 | logits_audio = istft(logits_real, logits_imag, inputData.shape[0]) 96 | logits_audio = torch.squeeze(logits_audio, dim=1) 97 | 98 | print(logits_audio[0]) 99 | librosa.output.write_wav('./out.wav', logits_audio[0].cpu().data.numpy(), 16000) 100 | test_loss = F.mse_loss(logits_mag, clean_mag, True) 101 | test_PESQ = pesq(outputData, logits_audio[0].detach().cpu().numpy(), 16000) 102 | test_STOI = stoi(outputData, logits_audio[0].detach().cpu().numpy(), 16000, extended=False) 103 | 104 | print("Saved attention weight visualization to attention_viz.png") 105 | utils.plot_head_map(logits_attn_weight[0]) 106 | 107 | # FIXME - Issue with pcm_f32le. Require pcm_s16le 108 | print("Saved clean spectrogram visualization to spec_clean.png") 109 | clean_spect = utils.make_spectrogram_array(args.clean_wav) 110 | utils.save_spectrogram(clean_spect, 'clean') 111 | 112 | print("Saved noisy spectrogram visualization to spec_noisy.png") 113 | noisy_spect = utils.make_spectrogram_array(args.noisy_wav) 114 | utils.save_spectrogram(noisy_spect, 'noisy') 115 | 116 | print("Saved enhanced spectrogram visualization to spec_enhanced.png") 117 | enhanced_spect = utils.make_spectrogram_array('./out.wav') 118 | utils.save_spectrogram(enhanced_spect, 'enhanced') 119 | 120 | #test accuracy 121 | print('test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(test_loss, test_PESQ, test_STOI)) 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /load_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import ExponentialLR 12 | from scipy.io import wavfile 13 | import librosa 14 | from tqdm import tqdm 15 | from torch.utils.data import DataLoader 16 | import numpy as np 17 | from tqdm import tqdm 18 | import librosa 19 | import csv 20 | import torch 21 | from torch.utils import data 22 | import subprocess 23 | # if 'google_drive_downloader' not in sys.modules: 24 | # subprocess.call('pip install googledrivedownloader'.split()) 25 | # from google_drive_downloader import GoogleDriveDownloader as gdd 26 | 27 | 28 | # dataset_path = os.path.join(gdrive_root, 'dataset') 29 | dataset_path = "/home/under/ddw02141/Attention-SE.pytorch/dataset" 30 | if not os.path.exists(dataset_path): 31 | os.makedirs(dataset_path) 32 | 33 | train_tar_path = os.path.join(dataset_path, "train.tar.gz") 34 | valid_tar_path = os.path.join(dataset_path, "valid.tar.gz") 35 | test1_tar_path = os.path.join(dataset_path, "test1.tar.gz") 36 | test2_tar_path = os.path.join(dataset_path, "test2.tar.gz") 37 | test3_tar_path = os.path.join(dataset_path, "test3.tar.gz") 38 | test4_tar_path = os.path.join(dataset_path, "test4.tar.gz") 39 | test5_tar_path = os.path.join(dataset_path, "test5.tar.gz") 40 | 41 | # # Download Train set 42 | # if not os.path.exists(train_tar_path): 43 | 44 | # gdd.download_file_from_google_drive(file_id='1CULVCAq0T3wqZTPGIqPja6OtwjYJkGAy', 45 | # dest_path=train_tar_path, 46 | # unzip=False, showsize=True) 47 | # train_clean_path = os.path.join(dataset_path, "train_clean") 48 | # train_noisy_path = os.path.join(dataset_path, "train_noisy") 49 | # if not os.path.exists(train_clean_path) or not os.path.exists(train_noisy_path): 50 | # tar xvzf train.tar.gz 51 | 52 | # # Download Valid set 53 | # if not os.path.exists(valid_tar_path): 54 | 55 | # gdd.download_file_from_google_drive(file_id='1WE229Jt9WV2iZbxY7YjkYIfSZyCHz9Iq', 56 | # dest_path=valid_tar_path, 57 | # unzip=False, showsize=True) 58 | # valid_clean_path = os.path.join(dataset_path, "valid_clean") 59 | # valid_noisy_path = os.path.join(dataset_path, "valid_noisy") 60 | # if not os.path.exists(valid_clean_path) or not os.path.exists(valid_noisy_path): 61 | # !tar xvzf valid.tar.gz 62 | 63 | # # Download Test set 1 64 | # if not os.path.exists(test1_tar_path): 65 | 66 | # gdd.download_file_from_google_drive(file_id='16JIMca-JVXgQd7dltYdTducMek4pA58_', 67 | # dest_path=test_tar_path, 68 | # unzip=False, showsize=True) 69 | # # %cd /gdrive/My\ Drive/dataset 70 | # test1_clean_path = os.path.join(dataset_path, "test_clean") 71 | # test1_noisy_path = os.path.join(dataset_path, "test_noisy") 72 | # if not os.path.exists(test1_clean_path) or not os.path.exists(test1_noisy_path): 73 | # !tar xvzf /gdrive/My\ Drive/dataset/test1.tar.gz 74 | 75 | # # Download Test set 2 76 | # if not os.path.exists(test1_tar_path): 77 | 78 | # gdd.download_file_from_google_drive(file_id='10YSj0u_9ni_sVriiNNs97OrKKSeshvr7', 79 | # dest_path=test_tar_path, 80 | # unzip=False, showsize=True) 81 | # # %cd /gdrive/My\ Drive/dataset 82 | # test1_clean_path = os.path.join(dataset_path, "test_clean") 83 | # test1_noisy_path = os.path.join(dataset_path, "test_noisy") 84 | # if not os.path.exists(test1_clean_path) or not os.path.exists(test1_noisy_path): 85 | # !tar xvzf /gdrive/My\ Drive/dataset/test1.tar.gz 86 | 87 | # # Download Test set 3 88 | # if not os.path.exists(test1_tar_path): 89 | 90 | # gdd.download_file_from_google_drive(file_id='1J-UNPZ9SkJECih9SjQDMm6j51nus5Z4q', 91 | # dest_path=test_tar_path, 92 | # unzip=False, showsize=True) 93 | # # %cd /gdrive/My\ Drive/dataset 94 | # test1_clean_path = os.path.join(dataset_path, "test_clean") 95 | # test1_noisy_path = os.path.join(dataset_path, "test_noisy") 96 | # if not os.path.exists(test1_clean_path) or not os.path.exists(test1_noisy_path): 97 | # !tar xvzf /gdrive/My\ Drive/dataset/test1.tar.gz 98 | 99 | # # Download Test set 4 100 | # if not os.path.exists(test1_tar_path): 101 | 102 | # gdd.download_file_from_google_drive(file_id='1j3lrPTp-2gQudNfJgayK-zUp2x19rwPQ', 103 | # dest_path=test_tar_path, 104 | # unzip=False, showsize=True) 105 | # # %cd /gdrive/My\ Drive/dataset 106 | # test1_clean_path = os.path.join(dataset_path, "test_clean") 107 | # test1_noisy_path = os.path.join(dataset_path, "test_noisy") 108 | # if not os.path.exists(test1_clean_path) or not os.path.exists(test1_noisy_path): 109 | # !tar xvzf /gdrive/My\ Drive/dataset/test1.tar.gz 110 | 111 | # # Download Test set 5 112 | # if not os.path.exists(test1_tar_path): 113 | 114 | # gdd.download_file_from_google_drive(file_id='1HQXw26XOtg186QApNVx8DR85wKrFBjcW', 115 | # dest_path=test_tar_path, 116 | # unzip=False, showsize=True) 117 | # # %cd /gdrive/My\ Drive/dataset 118 | # test1_clean_path = os.path.join(dataset_path, "test_clean") 119 | # test1_noisy_path = os.path.join(dataset_path, "test_noisy") 120 | # if not os.path.exists(test1_clean_path) or not os.path.exists(test1_noisy_path): 121 | # !tar xvzf /gdrive/My\ Drive/dataset/test1.tar.gz 122 | 123 | 124 | # Data Loader Part 125 | 126 | # DATA LOADING - LOAD FILE LISTS 127 | def load_data_list(folder=dataset_path, setname='train'): 128 | assert(setname in ['train', 'val', 'test', 'test2', 'test3', 'test4']) 129 | 130 | dataset = {} 131 | 132 | if "test" in setname: 133 | clean_foldername = folder + '/testset' 134 | else: 135 | clean_foldername = folder + '/' + setname + "set" 136 | noisy_foldername = folder + '/' + setname + "set" 137 | 138 | 139 | print("Loading files...") 140 | dataset['innames'] = [] 141 | dataset['outnames'] = [] 142 | dataset['shortnames'] = [] 143 | 144 | noisy_filelist = os.listdir("%s_noisy"%(noisy_foldername)) 145 | noisy_filelist.sort() 146 | # filelist = [f for f in filelist if f.endswith(".wav")] 147 | for i in tqdm(noisy_filelist): 148 | dataset['innames'].append("%s_noisy/%s"%(noisy_foldername,i)) 149 | dataset['shortnames'].append("%s"%(i)) 150 | 151 | clean_filelist = os.listdir("%s_clean"%(clean_foldername)) 152 | clean_filelist.sort() 153 | for i in tqdm(clean_filelist): 154 | dataset['outnames'].append("%s_clean/%s"%(clean_foldername,i)) 155 | 156 | return dataset 157 | 158 | # DATA LOADING - LOAD FILE DATA 159 | def load_data(dataset): 160 | 161 | dataset['inaudio'] = [None]*len(dataset['innames']) 162 | dataset['outaudio'] = [None]*len(dataset['outnames']) 163 | 164 | for id in tqdm(range(len(dataset['innames']))): 165 | 166 | if dataset['inaudio'][id] is None: 167 | inputData, sr = librosa.load(dataset['innames'][id], sr=None) 168 | outputData, sr = librosa.load(dataset['outnames'][id], sr=None) 169 | 170 | shape = np.shape(inputData) 171 | 172 | dataset['inaudio'][id] = np.float32(inputData) 173 | dataset['outaudio'][id] = np.float32(outputData) 174 | 175 | return dataset 176 | 177 | class AudioDataset(data.Dataset): 178 | """ 179 | Audio sample reader. 180 | """ 181 | 182 | def __init__(self, data_type): 183 | dataset = load_data_list(setname=data_type) 184 | self.dataset = load_data(dataset) 185 | 186 | self.file_names = dataset['innames'] 187 | 188 | def __getitem__(self, idx): 189 | mixed = torch.from_numpy(self.dataset['inaudio'][idx]).type(torch.FloatTensor) 190 | clean = torch.from_numpy(self.dataset['outaudio'][idx]).type(torch.FloatTensor) 191 | 192 | return mixed, clean 193 | 194 | def __len__(self): 195 | return len(self.file_names) 196 | 197 | def zero_pad_concat(self, inputs): 198 | max_t = max(inp.shape[0] for inp in inputs) 199 | shape = (len(inputs), max_t) 200 | input_mat = np.zeros(shape, dtype=np.float32) 201 | for e, inp in enumerate(inputs): 202 | input_mat[e, :inp.shape[0]] = inp 203 | return input_mat 204 | 205 | def collate(self, inputs): 206 | mixeds, cleans = zip(*inputs) 207 | seq_lens = torch.IntTensor([i.shape[0] for i in mixeds]) 208 | 209 | x = torch.FloatTensor(self.zero_pad_concat(mixeds)) 210 | y = torch.FloatTensor(self.zero_pad_concat(cleans)) 211 | 212 | batch = [x, y, seq_lens] 213 | return batch 214 | 215 | # Below is how to use data loader 216 | 217 | # train_dataset = AudioDataset(data_type='train') 218 | # train_data_loader = DataLoader(dataset=train_dataset, batch_size=4, 219 | # collate_fn=train_dataset.collate, shuffle=True, num_workers=4) 220 | 221 | # # valid_dataset = AudioDataset(data_type='valid') 222 | # # valid_data_loader = DataLoader(dataset=valid_dataset, batch_size=4, 223 | # # collate_fn=valid_dataset.collate, shuffle=False, num_workers=4) 224 | # train_bar = tqdm(train_data_loader) 225 | 226 | # test_dataset = AudioDataset(data_type='test') 227 | # test_data_loader = DataLoader(dataset=test_dataset, batch_size=4, 228 | # collate_fn=test_dataset.collate, shuffle=True, num_workers=4) 229 | 230 | # test_dataset2 = AudioDataset(data_type='test2') 231 | # test_data_loader2 = DataLoader(dataset=test_dataset2, batch_size=4, 232 | # collate_fn=test_dataset2.collate, shuffle=True, num_workers=4) 233 | 234 | -------------------------------------------------------------------------------- /load_single_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import argparse 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import ExponentialLR 12 | from scipy.io import wavfile 13 | import librosa 14 | from tqdm import tqdm 15 | from torch.utils.data import DataLoader 16 | import numpy as np 17 | from tqdm import tqdm 18 | import librosa 19 | import csv 20 | import torch 21 | from torch.utils import data 22 | import subprocess 23 | # if 'google_drive_downloader' not in sys.modules: 24 | # subprocess.call('pip install googledrivedownloader'.split()) 25 | # from google_drive_downloader import GoogleDriveDownloader as gdd 26 | 27 | 28 | # dataset_path = os.path.join(gdrive_root, 'dataset') 29 | dataset_path = "/home/under/ddw02141/Attention-SE.pytorch/dataset" 30 | if not os.path.exists(dataset_path): 31 | os.makedirs(dataset_path) 32 | 33 | # Data Loader Part 34 | 35 | # DATA LOADING - LOAD FILE LISTS 36 | def load_data_list(folder=dataset_path, setname='train', data_name="Invalid"): 37 | assert(setname in ['train', 'val', 'test', 'test2', 'test3', 'test4']) 38 | 39 | dataset = {} 40 | 41 | if "test" in setname: 42 | clean_foldername = folder + '/testset' 43 | else: 44 | clean_foldername = folder + '/' + setname + "set" 45 | noisy_foldername = folder + '/' + setname + "set" 46 | 47 | 48 | print("Loading files...") 49 | print("data_name") 50 | print(data_name) 51 | dataset['innames'] = [] 52 | dataset['outnames'] = [] 53 | dataset['shortnames'] = [] 54 | 55 | noisy_file = "" 56 | clean_file = "" 57 | noisy_filelist = os.listdir("%s_noisy"%(noisy_foldername)) 58 | noisy_filelist.sort() 59 | for file in noisy_filelist: 60 | if data_name in file: 61 | noisy_file = file 62 | break 63 | # filelist = [f for f in filelist if f.endswith(".wav")] 64 | 65 | 66 | clean_filelist = os.listdir("%s_clean"%(clean_foldername)) 67 | clean_filelist.sort() 68 | for file in clean_filelist: 69 | if data_name in file: 70 | clean_file = file 71 | break 72 | if noisy_file=="" or clean_file=="": 73 | print("****************************************") 74 | print("File name with %s does not exist"%(data_name)) 75 | dataset['innames'].append("%s_noisy/%s"%(noisy_foldername,noisy_file)) 76 | dataset['shortnames'].append("%s"%(noisy_file)) 77 | dataset['outnames'].append("%s_clean/%s"%(clean_foldername,clean_file)) 78 | 79 | return dataset 80 | 81 | # DATA LOADING - LOAD FILE DATA 82 | def load_data(dataset): 83 | 84 | dataset['inaudio'] = [None]*len(dataset['innames']) 85 | dataset['outaudio'] = [None]*len(dataset['outnames']) 86 | 87 | for id in tqdm(range(len(dataset['innames']))): 88 | 89 | if dataset['inaudio'][id] is None: 90 | inputData, sr = librosa.load(dataset['innames'][id], sr=None) 91 | outputData, sr = librosa.load(dataset['outnames'][id], sr=None) 92 | 93 | shape = np.shape(inputData) 94 | 95 | dataset['inaudio'][id] = np.float32(inputData) 96 | dataset['outaudio'][id] = np.float32(outputData) 97 | 98 | return dataset 99 | 100 | class AudioDataset(data.Dataset): 101 | """ 102 | Audio sample reader. 103 | """ 104 | 105 | def __init__(self, data_type, data_name): 106 | dataset = load_data_list(setname=data_type, data_name = data_name) 107 | self.dataset = load_data(dataset) 108 | 109 | self.file_names = dataset['innames'] 110 | 111 | def __getitem__(self, idx): 112 | mixed = torch.from_numpy(self.dataset['inaudio'][idx]).type(torch.FloatTensor) 113 | clean = torch.from_numpy(self.dataset['outaudio'][idx]).type(torch.FloatTensor) 114 | 115 | return mixed, clean 116 | 117 | def __len__(self): 118 | return len(self.file_names) 119 | 120 | def zero_pad_concat(self, inputs): 121 | max_t = max(inp.shape[0] for inp in inputs) 122 | shape = (len(inputs), max_t) 123 | input_mat = np.zeros(shape, dtype=np.float32) 124 | for e, inp in enumerate(inputs): 125 | input_mat[e, :inp.shape[0]] = inp 126 | return input_mat 127 | 128 | def collate(self, inputs): 129 | mixeds, cleans = zip(*inputs) 130 | seq_lens = torch.IntTensor([i.shape[0] for i in mixeds]) 131 | 132 | x = torch.FloatTensor(self.zero_pad_concat(mixeds)) 133 | y = torch.FloatTensor(self.zero_pad_concat(cleans)) 134 | 135 | batch = [x, y, seq_lens] 136 | return batch 137 | 138 | # Below is how to use single data loader 139 | 140 | # train_data = AudioDataset(data_type='train', data_name="p287_171") 141 | # train_data_loader = DataLoader(dataset=train_data, 142 | # collate_fn=train_data.collate, shuffle=True, num_workers=4) 143 | 144 | # for train in train_data_loader: 145 | # print(train) 146 | 147 | 148 | -------------------------------------------------------------------------------- /mixed.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chanil1218/Attention-SE.pytorch/5ded948c572bc81d6b0ebcc09c83edfe136a1a4e/mixed.wav -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class AttentionModel(nn.Module): 8 | def __init__(self, input_size, hidden_size, stacked_encoder=False, use_attn=True, attn_len=0, dropout_p=0): 9 | # attn_len 0 for full attention(Causal Dynamic Attention), 10 | # Else x_t-w, ..., x_t only used(Causeal Local Attention). 11 | super(AttentionModel, self).__init__() 12 | self.stacked_encoder = stacked_encoder if use_attn else True 13 | self.use_attn = use_attn 14 | self.attn_len = attn_len 15 | 16 | # Encoder 17 | self.feat = nn.Linear(input_size, hidden_size) 18 | self.dropout = nn.Dropout(dropout_p) 19 | self.k_enc = nn.LSTM(hidden_size, hidden_size, batch_first=True) 20 | self.q_enc = nn.LSTM(hidden_size, hidden_size, batch_first=True) 21 | 22 | # Attention 23 | self.score = nn.Linear(hidden_size, hidden_size, bias=False) 24 | 25 | # Generator 26 | enhance_in = hidden_size * (2 if use_attn else 1) 27 | self.enhance = nn.Linear(enhance_in, hidden_size) 28 | self.mask = nn.Linear(hidden_size, input_size) 29 | 30 | def forward(self, x): 31 | # x dim (B, T, F) 32 | input_x = x 33 | 34 | # Encoder 35 | x = self.feat(x).tanh() 36 | # TODO - Not sure it is good place to dropout 37 | x = self.dropout(x) 38 | self.k_enc.flatten_parameters() 39 | self.q_enc.flatten_parameters() 40 | k, _ = self.k_enc(x) 41 | q, _ = self.q_enc(k if self.stacked_encoder else x) 42 | 43 | # Attention 44 | out = q 45 | attn_weights = None 46 | if self.use_attn: 47 | # attn_score dim (B x T x T'(k)) 48 | attn_score = torch.bmm(self.score(q), k.transpose(1, 2)) 49 | attn_max, _ = torch.max(attn_score, dim=-1, keepdim=True) # For numerial stability 50 | exp_score = torch.exp(attn_score - attn_max) 51 | 52 | # Causal contraints(score <= t) 53 | attn_weights = torch.tril(exp_score) 54 | if self.attn_len > 0: 55 | # Static constraints(t - w <= score) 56 | attn_weights = torch.triu(attn_weights, diagonal=-self.attn_len) 57 | weights_denom = torch.sum(attn_weights, dim=-1, keepdim=True) 58 | #attn_weights = attn_weights / (weights_denom + 1e-10) 59 | attn_weights = attn_weights / (weights_denom + 1e-30) 60 | 61 | c = torch.bmm(attn_weights, k) 62 | 63 | # concat query and context 64 | out = torch.cat((c, q), -1) 65 | 66 | # Generator 67 | out = self.enhance(out).tanh() 68 | out = self.mask(out).sigmoid() 69 | 70 | return input_x * out, attn_weights 71 | -------------------------------------------------------------------------------- /models/layers/istft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy.signal 6 | import librosa 7 | 8 | class ISTFT(torch.nn.Module): 9 | def __init__(self, filter_length=1024, hop_length=512, window='hanning', center=True): 10 | super(ISTFT, self).__init__() 11 | 12 | self.filter_length = filter_length 13 | self.hop_length = hop_length 14 | self.center = center 15 | 16 | win_cof = scipy.signal.get_window(window, filter_length) 17 | self.inv_win = self.inverse_stft_window(win_cof, hop_length) 18 | 19 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 20 | cutoff = int((self.filter_length / 2 + 1)) 21 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 22 | np.imag(fourier_basis[:cutoff, :])]) 23 | inverse_basis = torch.FloatTensor(self.inv_win * \ 24 | np.linalg.pinv(fourier_basis).T[:, None, :]) 25 | 26 | self.register_buffer('inverse_basis', inverse_basis.float()) 27 | 28 | # Use equation 8 from Griffin, Lim. 29 | # Paper: "Signal Estimation from Modified Short-Time Fourier Transform" 30 | # Reference implementation: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/signal/spectral_ops.py 31 | # librosa use equation 6 from paper: https://github.com/librosa/librosa/blob/0dcd53f462db124ed3f54edf2334f28738d2ecc6/librosa/core/spectrum.py#L302-L311 32 | def inverse_stft_window(self, window, hop_length): 33 | window_length = len(window) 34 | denom = window ** 2 35 | overlaps = -(-window_length // hop_length) # Ceiling division. 36 | denom = np.pad(denom, (0, overlaps * hop_length - window_length), 'constant') 37 | denom = np.reshape(denom, (overlaps, hop_length)).sum(0) 38 | denom = np.tile(denom, (overlaps, 1)).reshape(overlaps * hop_length) 39 | return window / denom[:window_length] 40 | 41 | def forward(self, real_part, imag_part, length=None): 42 | if (real_part.dim() == 2): 43 | real_part = real_part.unsqueeze(0) 44 | imag_part = imag_part.unsqueeze(0) 45 | 46 | recombined = torch.cat([real_part, imag_part], dim=1) 47 | 48 | inverse_transform = F.conv_transpose1d(recombined, 49 | self.inverse_basis, 50 | stride=self.hop_length, 51 | padding=0) 52 | 53 | padded = int(self.filter_length // 2) 54 | if length is None: 55 | if self.center: 56 | inverse_transform = inverse_transform[:, :, padded:-padded] 57 | else: 58 | if self.center: 59 | inverse_transform = inverse_transform[:, :, padded:] 60 | inverse_transform = inverse_transform[:, :, :length] 61 | 62 | return inverse_transform -------------------------------------------------------------------------------- /run_inference.sh: -------------------------------------------------------------------------------- 1 | python3 inference.py --clean_wav clean.wav --noisy_wav mixed.wav --ckpt_path ckpt_dir_stoi/5_448_stacked_dropout0_2.pt --stacked_encoder True --hidden_size 448 --attn_use True --attn_len 5 2 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torchvision 13 | from torchvision import transforms 14 | from torch.optim.lr_scheduler import ExponentialLR 15 | 16 | import tensorboardX 17 | from tensorboardX import SummaryWriter 18 | 19 | 20 | from scipy.io import wavfile 21 | import librosa 22 | 23 | import soundfile as sf 24 | from pystoi.stoi import stoi 25 | from pypesq import pesq 26 | 27 | from tqdm import tqdm 28 | from models.layers.istft import ISTFT 29 | import train_utils 30 | from load_dataset import AudioDataset 31 | from models.attention import AttentionModel 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--batch_size', default=128, type=int, help='train batch size') 35 | parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number') 36 | parser.add_argument('--dropout_p', default = 0, type=float, help='Attention model drop out rate') 37 | parser.add_argument('--learning_rate', default = 5e-4, type=float, help = 'Learning rate') 38 | parser.add_argument('--stacked_encoder', default = False, type = bool) 39 | parser.add_argument('--attn_len', default = 0, type = int) 40 | parser.add_argument('--hidden_size', default = 112, type = int) 41 | parser.add_argument('--ck_dir', default = 'ckpt_dir', help = 'ck path') 42 | parser.add_argument('--ck_name', help = 'ck file') 43 | parser.add_argument('--test_set', help = 'test_set') 44 | parser.add_argument('--attn_use', default = False, type=bool) 45 | args = parser.parse_args() 46 | 47 | 48 | n_fft, hop_length = 512, 128 49 | window = torch.hann_window(n_fft).cuda() 50 | # STFT 51 | stft = lambda x: torch.stft(x, n_fft, hop_length, window=window) 52 | # ISTFT 53 | istft = ISTFT(n_fft, hop_length, window='hanning').cuda() 54 | 55 | def main(): 56 | test_dataset = AudioDataset(data_type=args.test_set) 57 | test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate, shuffle=False, num_workers=4) 58 | 59 | 60 | #model select 61 | print('Model initializing\n') 62 | net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len)) 63 | #net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use) 64 | net = net.cuda() 65 | print(net) 66 | 67 | optimizer = optim.Adam(net.parameters(), lr=args.learning_rate) 68 | 69 | #Check point load 70 | 71 | print('Trying Checkpoint Load\n') 72 | ckpt_dir = args.ck_dir 73 | best_PESQ = 0. 74 | best_STOI = 0. 75 | ckpt_path = os.path.join(ckpt_dir, args.ck_name) 76 | 77 | 78 | if os.path.exists(ckpt_path): 79 | ckpt = torch.load(ckpt_path) 80 | try: 81 | net.load_state_dict(ckpt['model']) 82 | optimizer.load_state_dict(ckpt['optimizer']) 83 | best_STOI = ckpt['best_STOI'] 84 | 85 | print('checkpoint is loaded !') 86 | print('current best loss : %.4f' % best_loss) 87 | except RuntimeError as e: 88 | print('wrong checkpoint\n') 89 | else: 90 | print('checkpoint not exist!') 91 | print('current best loss : %.4f' % best_loss) 92 | 93 | #test phase 94 | n = 0 95 | avg_test_loss = 0 96 | 97 | 98 | net.eval() 99 | with torch.no_grad(): 100 | test_bar = tqdm(test_data_loader) 101 | for input in test_bar: 102 | test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input) 103 | mixed = stft(test_mixed) 104 | cleaned = stft(test_clean) 105 | mixed = mixed.transpose(1,2) 106 | cleaned = cleaned.transpose(1,2) 107 | real, imag = mixed[..., 0], mixed[..., 1] 108 | clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1] 109 | mag = torch.sqrt(real**2 + imag**2) 110 | clean_mag = torch.sqrt(clean_real**2 + clean_imag**2) 111 | phase = torch.atan2(imag, real) 112 | 113 | logits_mag, logits_attn_weight = net(mag) 114 | logits_real = logits_mag * torch.cos(phase) 115 | logits_imag = logits_mag * torch.sin(phase) 116 | logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1) 117 | logits_real = logits_real.transpose(1,2) 118 | logits_imag = logits_imag.transpose(1,2) 119 | 120 | logits_audio = istft(logits_real, logits_imag, test_mixed.size(1)) 121 | logits_audio = torch.squeeze(logits_audio, dim=1) 122 | for i, l in enumerate(seq_len): 123 | logits_audio[i, l:] = 0 124 | 125 | test_loss = 0 126 | test_PESQ = 0 127 | test_STOI = 0 128 | 129 | test_loss = F.mse_loss(logits_mag, clean_mag, True) 130 | 131 | 132 | 133 | for i in range(len(test_mixed)): 134 | 135 | 136 | librosa.output.write_wav('test_out.wav', logits_audio[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000) 137 | cur_PESQ = pesq(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000) 138 | cur_STOI = stoi(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000, extended=False) 139 | 140 | test_PESQ += cur_PESQ 141 | test_STOI += cur_STOI 142 | 143 | test_PESQ /= len(test_mixed) 144 | test_STOI /= len(test_mixed) 145 | avg_test_loss += test_loss 146 | n += 1 147 | 148 | #test accuracy 149 | #test_pesq = pesq('test_clean.wav', 'test_out.wav', 16000) 150 | #test_stoi = stoi('test_clean.wav', 'test_out.wav', 16000) 151 | 152 | avg_test_loss /= n 153 | #summary.add_scalar('Test Loss', avg_test_loss.item(), iteration) 154 | print('test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(avg_test_loss, test_PESQ, test_STOI)) 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /test_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torchvision 13 | from torchvision import transforms 14 | from torch.optim.lr_scheduler import ExponentialLR 15 | 16 | import tensorboardX 17 | from tensorboardX import SummaryWriter 18 | 19 | 20 | from scipy.io import wavfile 21 | import librosa 22 | 23 | import soundfile as sf 24 | from pystoi.stoi import stoi 25 | from pypesq import pesq 26 | 27 | from tqdm import tqdm 28 | from models.layers.istft import ISTFT 29 | import train_utils 30 | from load_single_data import AudioDataset 31 | from models.attention import AttentionModel 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--batch_size', default=128, type=int, help='train batch size') 35 | parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number') 36 | parser.add_argument('--dropout_p', default = 0, type=float, help='Attention model drop out rate') 37 | parser.add_argument('--learning_rate', default = 5e-4, type=float, help = 'Learning rate') 38 | parser.add_argument('--stacked_encoder', default = False, type = bool) 39 | parser.add_argument('--attn_len', default = 0, type = int) 40 | parser.add_argument('--hidden_size', default = 112, type = int) 41 | parser.add_argument('--ck_dir', default = 'ckpt_dir', help = 'ck path') 42 | parser.add_argument('--ck_name', help = 'ck file') 43 | parser.add_argument('--test_set', help = 'test_set') 44 | parser.add_argument('--attn_use', default = False, type=bool) 45 | parser.add_argument('--wav') 46 | args = parser.parse_args() 47 | 48 | 49 | n_fft, hop_length = 512, 128 50 | window = torch.hann_window(n_fft).cuda() 51 | # STFT 52 | stft = lambda x: torch.stft(x, n_fft, hop_length, window=window) 53 | # ISTFT 54 | istft = ISTFT(n_fft, hop_length, window='hanning').cuda() 55 | 56 | def main(): 57 | 58 | test_data_set = AudioDataset(data_type=args.test_set, data_name=args.wav) 59 | test_data_loader = DataLoader(dataset=test_data_set, collate_fn=test_data_set.collate, num_workers=4) 60 | 61 | 62 | 63 | #model select 64 | print('Model initializing\n') 65 | net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len)) 66 | #net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use) 67 | net = net.cuda() 68 | 69 | optimizer = optim.Adam(net.parameters(), lr=args.learning_rate) 70 | 71 | #Check point load 72 | 73 | print('Trying Checkpoint Load\n') 74 | ckpt_dir = args.ck_dir 75 | best_PESQ = 0. 76 | best_STOI = 0. 77 | ckpt_path = os.path.join(ckpt_dir, args.ck_name) 78 | 79 | 80 | if not os.path.exists('Test_wav_stoi'): 81 | os.makedirs('test_wav_stoi') 82 | 83 | test_dir = os.path.join('Test_wav_stoi', args.ck_name) 84 | if not os.path.exists(test_dir): 85 | os.makedirs(test_dir) 86 | 87 | test_dir = os.path.join(test_dir, args.test_set) 88 | if not os.path.exists(test_dir): 89 | os.makedirs(test_dir) 90 | 91 | if os.path.exists(ckpt_path): 92 | ckpt = torch.load(ckpt_path) 93 | try: 94 | net.load_state_dict(ckpt['model']) 95 | optimizer.load_state_dict(ckpt['optimizer']) 96 | best_STOI = ckpt['best_STOI'] 97 | 98 | print('checkpoint is loaded !') 99 | print('current best loss : %.4f' % best_STOI) 100 | except RuntimeError as e: 101 | print('wrong checkpoint\n') 102 | else: 103 | print('checkpoint not exist!') 104 | print('current best loss : %.4f' % best_STOI) 105 | 106 | #test phase 107 | n = 0 108 | avg_test_loss = 0 109 | 110 | 111 | net.eval() 112 | with torch.no_grad(): 113 | test_bar = tqdm(test_data_loader) 114 | for input in test_bar: 115 | test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input) 116 | mixed = stft(test_mixed) 117 | cleaned = stft(test_clean) 118 | mixed = mixed.transpose(1,2) 119 | cleaned = cleaned.transpose(1,2) 120 | real, imag = mixed[..., 0], mixed[..., 1] 121 | clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1] 122 | mag = torch.sqrt(real**2 + imag**2) 123 | clean_mag = torch.sqrt(clean_real**2 + clean_imag**2) 124 | phase = torch.atan2(imag, real) 125 | 126 | logits_mag, logits_attn_weight = net(mag) 127 | logits_real = logits_mag * torch.cos(phase) 128 | logits_imag = logits_mag * torch.sin(phase) 129 | logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1) 130 | logits_real = logits_real.transpose(1,2) 131 | logits_imag = logits_imag.transpose(1,2) 132 | 133 | logits_audio = istft(logits_real, logits_imag, test_mixed.size(1)) 134 | logits_audio = torch.squeeze(logits_audio, dim=1) 135 | for i, l in enumerate(seq_len): 136 | logits_audio[i, l:] = 0 137 | 138 | test_loss = 0 139 | test_PESQ = 0 140 | test_STOI = 0 141 | 142 | test_loss = F.mse_loss(logits_mag, clean_mag, True) 143 | 144 | mixed_wav = os.path.join(test_dir, 'mixed.wav') 145 | 146 | clean_wav = os.path.join(test_dir, 'clean.wav') 147 | 148 | out_wav = os.path.join(test_dir, 'out.wav') 149 | 150 | for i in range(len(test_mixed)): 151 | librosa.output.write_wav(mixed_wav, test_mixed[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000) 152 | librosa.output.write_wav(clean_wav, test_clean[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000) 153 | librosa.output.write_wav(out_wav, logits_audio[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000) 154 | test_PESQ += pesq(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000) 155 | test_STOI += stoi(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000, extended=False) 156 | 157 | test_PESQ /= len(test_mixed) 158 | test_STOI /= len(test_mixed) 159 | avg_test_loss += test_loss 160 | n += 1 161 | 162 | #test accuracy 163 | #test_pesq = pesq('test_clean.wav', 'test_out.wav', 16000) 164 | #test_stoi = stoi('test_clean.wav', 'test_out.wav', 16000) 165 | 166 | avg_test_loss /= n 167 | #summary.add_scalar('Test Loss', avg_test_loss.item(), iteration) 168 | print('test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(avg_test_loss, test_PESQ, test_STOI)) 169 | 170 | if __name__ == '__main__': 171 | main() 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torchvision 13 | from torchvision import transforms 14 | from torch.optim.lr_scheduler import ExponentialLR 15 | 16 | import tensorboardX 17 | from tensorboardX import SummaryWriter 18 | 19 | 20 | from scipy.io import wavfile 21 | import librosa 22 | 23 | import soundfile as sf 24 | from pystoi.stoi import stoi 25 | from pypesq import pesq 26 | 27 | from tqdm import tqdm 28 | from models.layers.istft import ISTFT 29 | import train_utils 30 | from load_dataset import AudioDataset 31 | from models.attention import AttentionModel 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--model_dir', default='experiment/SE_model.json', help="Directory containing params.json") 35 | parser.add_argument('--restore_file', default=None, help="Optional, name of the file in --model_dir containing weights to reload before training") # 'best' or 'train' 36 | parser.add_argument('--batch_size', default=128, type=int, help='train batch size') 37 | parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number') 38 | parser.add_argument('--dropout_p', default = 0, type=float, help='Attention model drop out rate') 39 | parser.add_argument('--learning_rate', default = 5e-4, type=float, help = 'Learning rate') 40 | parser.add_argument('--attn_use', default = False, type=bool) 41 | parser.add_argument('--stacked_encoder', default = False, type = bool) 42 | parser.add_argument('--attn_len', default = 0, type = int) 43 | parser.add_argument('--hidden_size', default = 112, type = int) 44 | parser.add_argument('--ck_name', default = 'SEckpt.pt') 45 | args = parser.parse_args() 46 | 47 | 48 | n_fft, hop_length = 512, 128 49 | window = torch.hann_window(n_fft).cuda() 50 | # STFT 51 | stft = lambda x: torch.stft(x, n_fft, hop_length, window=window) 52 | # ISTFT 53 | istft = ISTFT(n_fft, hop_length, window='hanning').cuda() 54 | 55 | def normalized(tensor): 56 | output = [[] for i in range(len(tensor))] 57 | 58 | for i in range(len(tensor)): 59 | nummer = tensor[i] - torch.min(tensor[i]) 60 | denomi = torch.max(tensor[i]) - torch.min(tensor[i]) 61 | 62 | output[i] = (nummer / (denomi + 1e-5)).tolist() 63 | 64 | 65 | return torch.tensor(output) 66 | 67 | 68 | def main(): 69 | #summary = SummaryWriter() 70 | #os.system('tensorboard --logdir=path_of_log_file') 71 | 72 | #set Hyper parameter 73 | json_path = os.path.join(args.model_dir) 74 | params = train_utils.Params(json_path) 75 | 76 | #data loader 77 | train_dataset = AudioDataset(data_type='train') 78 | train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, collate_fn=train_dataset.collate, shuffle=True, num_workers=4) 79 | test_dataset = AudioDataset(data_type='test') 80 | test_data_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, collate_fn=test_dataset.collate, shuffle=False, num_workers=4) 81 | #model select 82 | print('Model initializing\n') 83 | net = torch.nn.DataParallel(AttentionModel(257, hidden_size = args.hidden_size, dropout_p = args.dropout_p, use_attn = args.attn_use, stacked_encoder = args.stacked_encoder, attn_len = args.attn_len)) 84 | #net = AttentionModel(257, 112, dropout_p = args.dropout_p, use_attn = args.attn_use) 85 | net = net.cuda() 86 | print(net) 87 | 88 | optimizer = optim.Adam(net.parameters(), lr=args.learning_rate) 89 | 90 | scheduler = ExponentialLR(optimizer, 0.5) 91 | 92 | #check point load 93 | #Check point load 94 | 95 | print('Trying Checkpoint Load\n') 96 | ckpt_dir = 'ckpt_dir' 97 | if not os.path.exists(ckpt_dir): 98 | os.makedirs(ckpt_dir) 99 | 100 | best_PESQ = 0. 101 | best_STOI = 0. 102 | best_loss = 200000. 103 | ckpt_path = os.path.join(ckpt_dir, args.ck_name) 104 | if os.path.exists(ckpt_path): 105 | ckpt = torch.load(ckpt_path) 106 | try: 107 | net.load_state_dict(ckpt['model']) 108 | optimizer.load_state_dict(ckpt['optimizer']) 109 | best_loss = ckpt['best_loss'] 110 | 111 | print('checkpoint is loaded !') 112 | print('current best loss : %.4f' % best_loss) 113 | except RuntimeError as e: 114 | print('wrong checkpoint\n') 115 | else: 116 | print('checkpoint not exist!') 117 | print('current best loss : %.4f' % best_loss) 118 | 119 | print('Training Start!') 120 | #train 121 | iteration = 0 122 | train_losses = [] 123 | test_losses = [] 124 | for epoch in range(args.num_epochs): 125 | train_bar = tqdm(train_data_loader) 126 | # train_bar = train_data_loader 127 | n = 0 128 | avg_loss = 0 129 | net.train() 130 | for input in train_bar: 131 | iteration += 1 132 | #load data 133 | train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input) 134 | 135 | mixed = stft(train_mixed) 136 | cleaned = stft(train_clean) 137 | mixed = mixed.transpose(1,2) 138 | cleaned = cleaned.transpose(1,2) 139 | real, imag = mixed[..., 0], mixed[..., 1] 140 | clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1] 141 | mag = torch.sqrt(real**2 + imag**2) 142 | clean_mag = torch.sqrt(clean_real**2 + clean_imag**2) 143 | phase = torch.atan2(imag, real) 144 | 145 | 146 | #feed data 147 | out_mag, attn_weight = net(mag) 148 | out_real = out_mag * torch.cos(phase) 149 | out_imag = out_mag * torch.sin(phase) 150 | out_real, out_imag = torch.squeeze(out_real, 1), torch.squeeze(out_imag, 1) 151 | out_real = out_real.transpose(1,2) 152 | out_imag = out_imag.transpose(1,2) 153 | 154 | out_audio = istft(out_real, out_imag, train_mixed.size(1)) 155 | out_audio = torch.squeeze(out_audio, dim=1) 156 | for i, l in enumerate(seq_len): 157 | out_audio[i, l:] = 0 158 | 159 | loss = 0 160 | PESQ = 0 161 | STOI = 0 162 | 163 | loss = F.mse_loss(out_mag, clean_mag, True) 164 | if torch.any(torch.isnan(loss)): 165 | torch.save({'clean_mag': clean_mag, 'out_mag': out_mag, 'mag': mag}, 'nan_mag') 166 | raise('loss is NaN') 167 | avg_loss += loss 168 | n += 1 169 | #gradient optimizer 170 | optimizer.zero_grad() 171 | 172 | 173 | #backpropagate LOSS 174 | loss.backward() 175 | 176 | 177 | #update weight 178 | optimizer.step() 179 | 180 | #for i in range(len(train_mixed)): 181 | # PESQ += pesq(train_clean[i].cpu().data.numpy(), out_audio[i].cpu().data.numpy(), 16000) 182 | # STOI += stoi(train_clean[i].cpu().data.numpy(), out_audio[i].cpu().data.numpy(), 16000, extended=False) 183 | #PESQ /= len(train_mixed) 184 | #STOI /= len(train_mixed) 185 | 186 | #flot tensorboard 187 | if iteration % 100 == 0 : 188 | print('[epoch: {}, iteration: {}] train loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(epoch, iteration, loss, PESQ, STOI)) 189 | 190 | avg_loss /= n 191 | #summary.add_scalar('Train Loss', avg_loss.item(), iteration) 192 | 193 | train_losses.append(avg_loss) 194 | if (len(train_losses) > 2) and (train_losses[-2] < avg_loss): 195 | print("Learning rate Decay") 196 | scheduler.step() 197 | 198 | #test phase 199 | n = 0 200 | avg_test_loss = 0 201 | test_bar = tqdm(test_data_loader) 202 | 203 | net.eval() 204 | with torch.no_grad(): 205 | for input in test_bar: 206 | test_mixed, test_clean, seq_len = map(lambda x: x.cuda(), input) 207 | mixed = stft(test_mixed) 208 | cleaned = stft(test_clean) 209 | mixed = mixed.transpose(1,2) 210 | cleaned = cleaned.transpose(1,2) 211 | real, imag = mixed[..., 0], mixed[..., 1] 212 | clean_real, clean_imag = cleaned[..., 0], cleaned[..., 1] 213 | mag = torch.sqrt(real**2 + imag**2) 214 | clean_mag = torch.sqrt(clean_real**2 + clean_imag**2) 215 | phase = torch.atan2(imag, real) 216 | 217 | logits_mag, logits_attn_weight = net(mag) 218 | logits_real = logits_mag * torch.cos(phase) 219 | logits_imag = logits_mag * torch.sin(phase) 220 | logits_real, logits_imag = torch.squeeze(logits_real, 1), torch.squeeze(logits_imag, 1) 221 | logits_real = logits_real.transpose(1,2) 222 | logits_imag = logits_imag.transpose(1,2) 223 | 224 | logits_audio = istft(logits_real, logits_imag, test_mixed.size(1)) 225 | logits_audio = torch.squeeze(logits_audio, dim=1) 226 | for i, l in enumerate(seq_len): 227 | logits_audio[i, l:] = 0 228 | 229 | test_loss = 0 230 | test_PESQ = 0 231 | test_STOI = 0 232 | 233 | test_loss = F.mse_loss(logits_mag, clean_mag, True) 234 | #for i in range(len(test_mixed)): 235 | #librosa.output.write_wav('test_out.wav', logits_audio[i].cpu().data.numpy()[:seq_len[i].cpu().data.numpy()], 16000) 236 | # test_PESQ += pesq(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000) 237 | # test_STOI += stoi(test_clean[i].detach().cpu().numpy(), logits_audio[i].detach().cpu().numpy(), 16000, extended=False) 238 | 239 | #test_STOI /= len(test_mixed) 240 | avg_test_loss += test_loss 241 | n += 1 242 | #test loss 243 | #test_loss = wSDRLoss(test_mixed, test_clean, out_audio) 244 | #test_loss = torch.nn.MSELoss(out_audio, test_clean) 245 | 246 | #test accuracy 247 | #test_pesq = pesq('test_clean.wav', 'test_out.wav', 16000) 248 | #test_stoi = stoi('test_clean.wav', 'test_out.wav', 16000) 249 | 250 | avg_test_loss /= n 251 | test_losses.append(avg_test_loss) 252 | #summary.add_scalar('Test Loss', avg_test_loss.item(), iteration) 253 | print('[epoch: {}, iteration: {}] test loss : {:.4f} PESQ : {:.4f} STOI : {:.4f}'.format(epoch, iteration, avg_test_loss, test_PESQ, test_STOI)) 254 | if avg_test_loss < best_loss: 255 | best_PESQ = test_PESQ 256 | best_STOI = test_STOI 257 | best_loss = avg_test_loss 258 | # Note: optimizer also has states ! don't forget to save them as well. 259 | ckpt = {'model':net.state_dict(), 260 | 'optimizer':optimizer.state_dict(), 261 | 'best_loss':best_loss} 262 | torch.save(ckpt, ckpt_path) 263 | print('checkpoint is saved !') 264 | 265 | if __name__ == '__main__': 266 | main() 267 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import shutil 5 | 6 | import torch 7 | 8 | 9 | class Params(): 10 | """Class that loads hyperparameters from a json file. 11 | Example: 12 | ``` 13 | params = Params(json_path) 14 | print(params.learning_rate) 15 | params.learning_rate = 0.5 # change the value of learning_rate in params 16 | ``` 17 | """ 18 | 19 | def __init__(self, json_path): 20 | with open(json_path) as f: 21 | params = json.load(f) 22 | self.__dict__.update(params) 23 | 24 | def save(self, json_path): 25 | with open(json_path, 'w') as f: 26 | json.dump(self.__dict__, f, indent=4) 27 | 28 | def update(self, json_path): 29 | """Loads parameters from json file""" 30 | with open(json_path) as f: 31 | params = json.load(f) 32 | self.__dict__.update(params) 33 | 34 | @property 35 | def dict(self): 36 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 37 | return self.__dict__ 38 | 39 | 40 | class RunningAverage(): 41 | """A simple class that maintains the running average of a quantity 42 | Example: 43 | ``` 44 | loss_avg = RunningAverage() 45 | loss_avg.update(2) 46 | loss_avg.update(4) 47 | loss_avg() = 3 48 | ``` 49 | """ 50 | 51 | def __init__(self): 52 | self.steps = 0 53 | self.total = 0 54 | 55 | def update(self, val): 56 | self.total += val 57 | self.steps += 1 58 | 59 | def __call__(self): 60 | return self.total / float(self.steps) 61 | 62 | 63 | def set_logger(log_path): 64 | """Set the logger to log info in terminal and file `log_path`. 65 | In general, it is useful to have a logger so that every output to the terminal is saved 66 | in a permanent file. Here we save it to `model_dir/train.log`. 67 | Example: 68 | ``` 69 | logging.info("Starting training...") 70 | ``` 71 | Args: 72 | log_path: (string) where to log 73 | """ 74 | logger = logging.getLogger() 75 | logger.setLevel(logging.INFO) 76 | 77 | if not logger.handlers: 78 | # Logging to a file 79 | file_handler = logging.FileHandler(log_path) 80 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 81 | logger.addHandler(file_handler) 82 | 83 | # Logging to console 84 | stream_handler = logging.StreamHandler() 85 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 86 | logger.addHandler(stream_handler) 87 | 88 | 89 | def save_dict_to_json(d, json_path): 90 | """Saves dict of floats in json file 91 | Args: 92 | d: (dict) of float-castable values (np.float, int, float, etc.) 93 | json_path: (string) path to json file 94 | """ 95 | with open(json_path, 'w') as f: 96 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 97 | d = {k: float(v) for k, v in d.items()} 98 | json.dump(d, f, indent=4) 99 | 100 | 101 | def save_checkpoint(state, is_best, checkpoint): 102 | """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves 103 | checkpoint + 'best.pth.tar' 104 | Args: 105 | state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict 106 | is_best: (bool) True if it is the best model seen till now 107 | checkpoint: (string) folder where parameters are to be saved 108 | """ 109 | filepath = os.path.join(checkpoint, 'last.pth.tar') 110 | if not os.path.exists(checkpoint): 111 | print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) 112 | os.mkdir(checkpoint) 113 | else: 114 | print("Checkpoint Directory exists! ") 115 | torch.save(state, filepath) 116 | if is_best: 117 | shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar')) 118 | 119 | 120 | def load_checkpoint(checkpoint, model, optimizer=None): 121 | """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of 122 | optimizer assuming it is present in checkpoint. 123 | Args: 124 | checkpoint: (string) filename which needs to be loaded 125 | model: (torch.nn.Module) model for which the parameters are loaded 126 | optimizer: (torch.optim) optional: resume optimizer from checkpoint 127 | """ 128 | if not os.path.exists(checkpoint): 129 | raise ("File doesn't exist {}".format(checkpoint)) 130 | checkpoint = torch.load(checkpoint) 131 | model.load_state_dict(checkpoint['state_dict']) 132 | 133 | if optimizer: 134 | optimizer.load_state_dict(checkpoint['optim_dict']) 135 | 136 | return checkpoint -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import numpy 5 | import matplotlib.pyplot as plt 6 | import json 7 | import argparse 8 | import librosa 9 | import numpy as np 10 | from PIL import Image 11 | from matplotlib import cm 12 | 13 | def plot_head_map(mma, num = 5): 14 | 15 | space = len(mma)/num 16 | fig, ax = plt.subplots() 17 | heatmap = ax.pcolor(mma, cmap=plt.cm.jet) 18 | 19 | # put the major ticks at the middle of each cell 20 | ax.set_xticks(numpy.arange(mma.shape[1])*space + 0.5, minor=False) 21 | ax.set_yticks(numpy.arange(mma.shape[0])*space + 0.5, minor=False) 22 | 23 | # without this I get some extra columns rows 24 | ax.set_xlim(0, int(mma.shape[1])) 25 | ax.set_ylim(0, int(mma.shape[0])) 26 | 27 | # want a more natural, table-like display 28 | #ax.invert_yaxis() 29 | #ax.xaxis.tick_top() 30 | 31 | # source words -> column labels 32 | ax.set_xticklabels(numpy.arange(mma.shape[1])*space, minor=False) 33 | # target words -> row labels 34 | ax.set_yticklabels(numpy.arange(mma.shape[0])*space, minor=False) 35 | 36 | plt.xticks(rotation=45) 37 | 38 | plt.colorbar(heatmap, ax=ax) 39 | # plt.tight_layout() 40 | plt.show() 41 | #plt.savefig('result.png') 42 | 43 | 44 | #ntt = numpy.array([[0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1],[0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19,0.2],[0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29,0.3],[0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39,0.4],[0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49,0.5],[0.51,0.52,0.53,0.54,0.55,0.56,0.57,0.58,0.59,0.6],[0.61,0.62,0.63,0.64,0.65,0.66,0.67,0.68,0.69,0.7],[0.71,0.72,0.73,0.74,0.75,0.76,0.77,0.78,0.79,0.8],[0.81,0.82,0.83,0.84,0.85,0.86,0.87,0.88,0.89,0.9],[0.91,0.92,0.93,0.94,0.95,0.96,0.97,0.98,0.99,1]]) 45 | #plot_head_map(ntt) 46 | 47 | 48 | def get_spectrogram(base): 49 | S = librosa.amplitude_to_db(librosa.core.magphase(librosa.stft(base, hop_length=128, win_length=512, n_fft=512))[0], ref=np.max) 50 | S = prepare_spec_image(S) 51 | return S 52 | 53 | def prepare_spec_image(spectrogram): 54 | spectrogram = (spectrogram - np.min(spectrogram)) / ((np.max(spectrogram)) - np.min(spectrogram)) 55 | spectrogram = np.flip(spectrogram, axis=0) 56 | return np.uint8(cm.gist_heat(spectrogram) * 255) 57 | 58 | def read_raw(input_file_dir): 59 | data = np.fromfile(input_file_dir, dtype=np.int16) # (# total frame, feature_size) 60 | data = np.float32(data) / 32767. 61 | data = np.squeeze(data) 62 | return data 63 | 64 | def make_spectrogram_array(file_name): 65 | y = read_raw(file_name) 66 | y_spec = get_spectrogram(y).transpose(2,0,1)[0:3].transpose(1,2,0) 67 | return y_spec 68 | 69 | def save_spectrogram(spec, file_name): 70 | img = Image.fromarray(spec, 'RGB') 71 | img.save('spec_'+file_name+'.png') 72 | 73 | #input = "clean4.wav" 74 | #spt_array = make_spectrogram_array(input) 75 | #save_spectrogram(spt_array, input) --------------------------------------------------------------------------------