├── sdr.py ├── README.md ├── generate_wav.py ├── .gitignore ├── test.py ├── discriminator.py ├── embedding.py ├── resnet.py ├── data_io.py ├── aecnn.py ├── plot_wav.py ├── train.py └── trainer.py /sdr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def remove_dc(signal): 4 | """Normalized to zero mean""" 5 | mean = np.mean(signal) 6 | signal -= mean 7 | return signal 8 | 9 | 10 | def pow_np_norm(signal): 11 | """Compute 2 Norm""" 12 | return np.square(np.linalg.norm(signal, ord=2)) 13 | 14 | 15 | def pow_norm(s1, s2): 16 | return np.sum(s1 * s2) 17 | 18 | 19 | def si_sdr(estimated, original): 20 | # estimated = remove_dc(estimated) 21 | # original = remove_dc(original) 22 | target = pow_norm(estimated, original) * original / pow_np_norm(original) 23 | noise = estimated - target 24 | return 10 * np.log10(pow_np_norm(target) / pow_np_norm(noise)) 25 | 26 | 27 | def permute_si_sdr(e1, e2, c1, c2): 28 | sdr1 = si_sdr(e1, c1) + si_sdr(e2, c2) 29 | sdr2 = si_sdr(e1, c2) + si_sdr(e2, c1) 30 | if sdr1 > sdr2: 31 | return sdr1 * 0.5 32 | else: 33 | return sdr2 * 0.5 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speech Enhancement with Mimic Loss 2 | 3 | This project seeks to bring together the surprisingly separate worlds of 4 | speech enhancement and noise-robust ASR by applying phonetic knowledge to 5 | improve a front-end enhancement module. This can improve both intellegibility 6 | metrics (STOI and eSTOI) as well as ASR performed on the outputs of this 7 | enhancement system. 8 | 9 | The backbone of this project is work by Pandey et al. [1] which performs 10 | denoising in the time domain, but generates a loss in the spectral domain 11 | and backpropagates through the STFT to improve the denoising model. We 12 | apply mimic loss [2] in the spectral domain and backpropagate to the 13 | time domain. 14 | 15 | [1] Ashutosh Pandey and DeLiang Wang, "A new framework for supervised 16 | speech enhancement in the time domain," Interspeech, 2018. 17 | 18 | [2] Deblin Bagchi, Peter Plantinga, Adam Stiff, and Eric Fosler-Lussier, 19 | "Spectral feature mapping with mimic loss for robust ASR," ICASSP, 2018 20 | -------------------------------------------------------------------------------- /generate_wav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from aecnn import AECNN 5 | from data_io import wav_dataset 6 | from train import parse_args 7 | import soundfile as sf 8 | import numpy as np 9 | 10 | def run_test(config): 11 | """ Define our model and test it """ 12 | 13 | generator = AECNN( 14 | channel_counts = config.gchan, 15 | kernel_size = config.gkernel, 16 | block_size = config.gblocksize, 17 | dropout = config.gdrop, 18 | ).cuda().eval() 19 | 20 | generator.load_state_dict(torch.load(config.gcheckpoints)) 21 | 22 | # Initialize datasets 23 | #for phase in ['tr', 'dt', 'et']: 24 | 25 | max_ch = 6 if config.phase == 'tr' else 1 26 | 27 | count = 0 28 | for ch in range(max_ch): 29 | dataset = wav_dataset(config, config.phase, ch) 30 | 31 | with torch.no_grad(): 32 | for example in dataset: 33 | data = np.squeeze(generator(example['noisy'].cuda()).cpu().detach().numpy()) 34 | fname = make_filename(config, ch, example['id']) 35 | with sf.SoundFile(fname, 'w', 16000, 1) as w: 36 | w.write(data) 37 | 38 | if count % 1000 == 0: 39 | print("finished #%d" % count) 40 | count += 1 41 | 42 | def make_filename(config, channel, id): 43 | args = [config.output_dir, config.phase, id + '.wav'] 44 | if config.phase == 'tr': 45 | args[-1] = id + '.ch%d' % channel + '.wav' 46 | return os.path.join(*args) 47 | 48 | def main(): 49 | config = parse_args() 50 | run_test(config) 51 | 52 | if __name__=='__main__': 53 | main() 54 | 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # vim 107 | *.sw* 108 | 109 | # Scripts 110 | scripts/ 111 | *.o[1-9][0-9]* 112 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aecnn import AECNN 4 | from data_io import wav_dataset 5 | from train import parse_args 6 | from pystoi.stoi import stoi 7 | import numpy as np 8 | from sdr import si_sdr 9 | import os 10 | import soundfile as sf 11 | 12 | def run_test(config): 13 | """ Define our model and test it """ 14 | 15 | generator = AECNN( 16 | channel_counts = config.gchan, 17 | kernel_size = config.gkernel, 18 | block_size = config.gblocksize, 19 | dropout = config.gdrop, 20 | ).cuda() 21 | 22 | generator.load_state_dict(torch.load(config.gcheckpoints)) 23 | 24 | # Initialize datasets 25 | ev_dataset = wav_dataset(config, 'et', 4) 26 | 27 | 28 | count = 0 29 | score = {'stoi': 0, 'estoi':0, 'sdr':0} 30 | for example in ev_dataset: 31 | data = np.squeeze(generator(example['noisy'].cuda()).cpu().detach().numpy()) 32 | clean = np.squeeze(example['clean'].numpy()) 33 | noisy = np.squeeze(example['noisy'].numpy()) 34 | score['stoi'] += stoi(clean, data, 16000, extended=False) 35 | score['estoi'] += stoi(clean, data, 16000, extended=True) 36 | score['sdr'] += si_sdr(data, clean) 37 | count += 1 38 | #if count == 1: 39 | # with sf.SoundFile('clean.wav', 'w', 16000, 1) as w: 40 | # w.write(clean) 41 | # with sf.SoundFile('noisy.wav', 'w', 16000, 1) as w: 42 | # w.write(noisy) 43 | # with sf.SoundFile('test.wav', 'w', 16000, 1) as w: 44 | # w.write(data) 45 | # break 46 | 47 | print('stoi: %f' % (score['stoi'] / count)) 48 | print('estoi: %f' % (score['estoi'] / count)) 49 | print('sdr: %f' % (score['sdr'] / count)) 50 | 51 | 52 | def main(): 53 | config = parse_args() 54 | run_test(config) 55 | 56 | if __name__=='__main__': 57 | main() 58 | 59 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Discriminator(nn.Module): 5 | 6 | def __init__( 7 | self, 8 | channel_counts = [64, 128, 256], 9 | kernel_size = 11, 10 | block_size = 3, 11 | activation = lambda x: nn.functional.leaky_relu(x, negative_slope = 0.3), 12 | fc_layers = 2, 13 | fc_nodes = 1024, 14 | dropout = 0.2, 15 | training = True, 16 | ): 17 | super(Discriminator, self).__init__() 18 | 19 | # Store hyperparameters 20 | self.kernel_size = kernel_size 21 | self.block_size = block_size 22 | self.dropout = dropout 23 | self.activation = activation 24 | 25 | # Initialize all layer containers 26 | self.downsample_layers = nn.ModuleList() 27 | self.fc_layers = nn.ModuleList() 28 | 29 | in_channels = 1 30 | for out_channels in channel_counts: 31 | for i in range(block_size): 32 | self.downsample_layers.append(self.conv_layer(in_channels, out_channels, downsample = in_channels != 1)) 33 | in_channels = out_channels 34 | 35 | for layer in range(fc_layers): 36 | self.fc_layers.append(nn.Linear(in_channels, fc_nodes)) 37 | in_channels = fc_nodes 38 | 39 | self.fc_layers.append(nn.Linear(in_channels, 1)) 40 | 41 | 42 | # Define a single convolutional layer 43 | def conv_layer(self, in_channels, out_channels, downsample = False): 44 | 45 | return nn.Conv1d( 46 | in_channels = in_channels, 47 | out_channels = out_channels, 48 | kernel_size = self.kernel_size, 49 | stride = 2 if downsample else 1, 50 | padding = self.kernel_size // 2, 51 | ) 52 | 53 | # Define the forward pass of our model, unet-style 54 | def forward(self, x): 55 | 56 | # Apply downsampling layers 57 | for i, layer in enumerate(self.downsample_layers): 58 | x = self.activation(layer(x)) 59 | 60 | # dropout once each block 61 | if i % self.block_size == 0: 62 | x = nn.functional.dropout(x, p = self.dropout) 63 | 64 | x = x.transpose(1, 2) 65 | 66 | for layer in self.fc_layers: 67 | x = self.activation(layer(x)) 68 | 69 | # Reduce to 1 channel and scale from -1 to 1 70 | return x 71 | -------------------------------------------------------------------------------- /embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_io import mag 3 | 4 | class Embedding(torch.nn.Module): 5 | 6 | def __init__( 7 | self, 8 | embedding_size, 9 | ): 10 | super(Embedding, self).__init__() 11 | 12 | self.embedding_size = embedding_size 13 | 14 | weight = torch.zeros((embedding_size, embedding_size)) 15 | self.embedding = torch.nn.Embedding.from_pretrained(weight) 16 | 17 | def pretrain(self, training_set, mimic_model, device): 18 | """ This is a hack of embeddings and autograd to allow soft 19 | 'prototypical' posterior distributions for each senone. """ 20 | 21 | # Enable grad, so we can hack it to update our senone "embedding" 22 | self.embedding.weight.requires_grad = True 23 | 24 | # Initialize counts to a small value, so we don't divide by zero 25 | senone_counts = torch.zeros([self.embedding_size, 1]).to(device) + 0.01 26 | 27 | # Go through dataset, and add up posteriors and counts 28 | for example in training_set: 29 | 30 | # Generate posterior 31 | clean_mag = mag(example['clean'].to(device), truncate=True) 32 | senones = example['senone'].to(device) 33 | 34 | posteriors = mimic_model(clean_mag)[-1] 35 | posteriors = posteriors[:,:,:senones.shape[1]].transpose(1, 2) 36 | 37 | # Embed senone, so we can update the result 38 | embedded = self.embedding(senones) 39 | 40 | # Multiply posteriors so that we can add to the gradient 41 | embedded *= posteriors 42 | 43 | # Propagate gradient to the embedding 44 | embedded.sum().backward() 45 | 46 | # Count instances of senones 47 | example_senone_counts = senones[0].bincount(minlength = self.embedding_size).float().unsqueeze(1) 48 | senone_counts += example_senone_counts 49 | 50 | # Divide and update 51 | with torch.no_grad(): 52 | self.embedding.weight *= (senone_counts - example_senone_counts) / senone_counts 53 | self.embedding.weight += self.embedding.weight.grad / senone_counts 54 | self.embedding.weight.grad.zero_() 55 | 56 | # Turn off grad again 57 | self.embedding.weight.requires_grad = False 58 | 59 | def forward(self, x): 60 | return self.embedding(x) 61 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ResNet(nn.Module): 5 | 6 | def __init__( 7 | self, 8 | input_dim, 9 | output_dim, 10 | channel_counts = [64, 128, 256, 512], 11 | dense_count = 2, 12 | dense_nodes = 1024, 13 | activation = lambda x: nn.functional.leaky_relu(x, negative_slope = 0.3), 14 | dropout = 0.3, 15 | training = True, 16 | ): 17 | super(ResNet, self).__init__() 18 | 19 | # Store hyperparameters 20 | self.activation = activation 21 | self.dropout = dropout 22 | self.training = training 23 | self.output_dim = output_dim 24 | 25 | # Define conv layers 26 | in_channels = 1 27 | self.conv_layers = nn.ModuleList() 28 | for out_channels in channel_counts: 29 | self.conv_layers.extend(conv_block(in_channels, out_channels)) 30 | in_channels = out_channels 31 | 32 | # Define dense layers 33 | self.dense_layers = nn.ModuleList() 34 | self.in_features = out_channels * input_dim // 4 ** len(channel_counts) 35 | for i in range(dense_count): 36 | self.dense_layers.append(nn.Linear(self.in_features, dense_nodes)) 37 | 38 | if output_dim is not None: 39 | self.output_layer = nn.Linear(dense_nodes, output_dim) 40 | 41 | # Define forward pass 42 | def forward(self, x): 43 | 44 | outputs = [] 45 | 46 | # Convolutional part 47 | for i, layer in enumerate(self.conv_layers): 48 | 49 | x = layer(x) 50 | 51 | # Record shortcut 52 | if i % 3 == 0: 53 | downsampled = x 54 | 55 | x = self.activation(x) 56 | x = nn.functional.dropout2d(x, p = self.dropout, training = self.training) 57 | 58 | # Re-add shortcut 59 | if i % 3 == 2: 60 | x += downsampled 61 | outputs.append(x) 62 | 63 | # Smush last two dimensions 64 | if len(self.dense_layers) > 0: 65 | x = x.permute(0, 3, 1, 2) 66 | x = x.contiguous().view(1, -1, self.in_features) 67 | x = nn.functional.dropout(x, p = self.dropout, training = self.training) 68 | 69 | # Fully conntected part 70 | for layer in self.dense_layers: 71 | x = self.activation(layer(x)) 72 | 73 | if self.output_dim is not None: 74 | x = self.output_layer(x) 75 | x = x.permute(0, 2, 1) 76 | 77 | outputs.append(x) 78 | 79 | return outputs 80 | 81 | def conv_layer(in_channels, out_channels, downsample = False): 82 | return nn.Conv2d( 83 | in_channels = in_channels, 84 | out_channels = out_channels, 85 | kernel_size = 3, 86 | stride = 2 if downsample else 1, 87 | padding = 1, 88 | ) 89 | 90 | def conv_block(in_channels, out_channels): 91 | return [ 92 | conv_layer(in_channels, out_channels, downsample = True), 93 | conv_layer(out_channels, out_channels), 94 | conv_layer(out_channels, out_channels), 95 | ] 96 | 97 | -------------------------------------------------------------------------------- /data_io.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import torch 5 | import numpy as np 6 | import soundfile as sf 7 | import sys 8 | import copy 9 | 10 | class wav_dataset(Dataset): 11 | 12 | def __init__(self, config, phase, ch = None): 13 | 14 | key_list = [] 15 | self.flists = {} 16 | for ftype, flist in [('clean', config.clean_flist), ('noise', config.noise_flist), ('noisy', config.noisy_flist)]: 17 | if flist: 18 | with open(os.path.join(config.base_dir, phase, flist)) as f: 19 | self.flists[ftype] = json.load(f) 20 | key_list = self.flists[ftype].keys() 21 | ch_count = len(self.flists[ftype][list(key_list)[0]]) 22 | 23 | if config.senone_file: 24 | self.flists['senone'] = {} 25 | with open(os.path.join(config.base_dir, phase, config.senone_file)) as f: 26 | for line in f: 27 | line = line.split() 28 | self.flists['senone'][line[0]] = np.array([int(i) for i in line[1:]], np.int64) 29 | key_list = self.flists['senone'].keys() 30 | 31 | self.flist = [] 32 | for key in key_list: 33 | list_item = {'id': key} 34 | index = np.random.randint(ch_count) if ch is None else ch 35 | for ftype in self.flists: 36 | if ftype == 'senone': 37 | list_item[ftype] = self.flists[ftype][key] 38 | else: 39 | list_item[ftype] = self.flists[ftype][key][index] 40 | self.flist.append(list_item) 41 | 42 | self.base_dir = config.base_dir 43 | 44 | def __len__(self): 45 | return len(self.flist) 46 | 47 | def __getitem__(self, idx): 48 | 49 | data = copy.deepcopy(self.flist[idx]) 50 | 51 | for ftype in ['clean', 'noisy', 'noise']: 52 | if ftype in data: 53 | wav = self._load_wav(data[ftype]) 54 | newlen = len(wav) - len(wav) % 1024 55 | data[ftype] = torch.from_numpy(wav[np.newaxis, np.newaxis, :newlen]) 56 | if ftype == 'noise': 57 | data['noisy'] = data['noise'] + data['clean'] 58 | del data['noise'] 59 | 60 | if 'senone' in data: 61 | target = torch.from_numpy(data['senone']) 62 | newlen = len(target) - len(target) % 16 63 | data['senone'] = target[np.newaxis, :newlen] 64 | 65 | return data 66 | 67 | def _load_wav(self, fname): 68 | data, sr = sf.read(os.path.join(self.base_dir, fname)) 69 | return np.array(data, dtype=np.float32) 70 | 71 | def mag(tensor, truncate = False, log = False): 72 | 73 | spectrogram = torch.stft( 74 | torch.squeeze(tensor), 75 | n_fft = 512, 76 | hop_length = 160, 77 | win_length = 400, 78 | window = torch.hann_window(400), 79 | ) 80 | 81 | real = spectrogram[:, :, 0] 82 | imag = spectrogram[:, :, 1] 83 | 84 | magnitude = torch.sqrt(real * real + imag * imag + 1e-8) 85 | 86 | if truncate: 87 | magnitude = magnitude[None, None, :-1] 88 | 89 | if log: 90 | return torch.log10(magnitude + 0.1) 91 | 92 | return magnitude 93 | -------------------------------------------------------------------------------- /aecnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AECNN(nn.Module): 5 | 6 | def __init__( 7 | self, 8 | channel_counts = [64, 128, 256], 9 | kernel_size = 11, 10 | block_size = 3, 11 | activation = lambda x: nn.functional.leaky_relu(x, negative_slope = 0.3), 12 | dropout = 0.2, 13 | training = True, 14 | ): 15 | super(AECNN, self).__init__() 16 | 17 | # Store hyperparameters 18 | self.kernel_size = kernel_size 19 | self.block_size = block_size 20 | self.dropout = dropout 21 | self.activation = activation 22 | 23 | # Initialize all layer containers 24 | self.encoder_layers = nn.ModuleList() 25 | self.decoder_layers = nn.ModuleList() 26 | 27 | # Encoder uses conv layers with stride to downsample inputs 28 | in_channels = 1 29 | for out_channels in channel_counts: 30 | for i in range(block_size): 31 | self.encoder_layers.append(self.conv_layer(in_channels, out_channels, downsample = in_channels != 1)) 32 | in_channels = out_channels 33 | 34 | # Decoder layers get concatenated with corresponding encoder layers (unet-style) 35 | in_channels = None 36 | for out_channels in reversed(channel_counts): 37 | for i in range(block_size): 38 | if in_channels is None: 39 | in_channels = out_channels 40 | else: 41 | self.decoder_layers.append(self.conv_layer(in_channels, out_channels)) 42 | in_channels = out_channels * 2 43 | 44 | # Final layer doesn't change size, just filters 45 | self.decrease_channels = self.conv_layer(in_channels, out_channels = 1) 46 | 47 | 48 | # Define a single convolutional layer 49 | def conv_layer(self, in_channels, out_channels, downsample = False): 50 | 51 | return nn.Conv1d( 52 | in_channels = in_channels, 53 | out_channels = out_channels, 54 | kernel_size = self.kernel_size, 55 | stride = 2 if downsample else 1, 56 | padding = self.kernel_size // 2, 57 | ) 58 | 59 | # Define the forward pass of our model, unet-style 60 | def forward(self, x): 61 | 62 | # Apply encoder downsampling layers 63 | encoder_outputs = [x] 64 | for i, layer in enumerate(self.encoder_layers): 65 | encoder_outputs.append(self.activation(layer(encoder_outputs[-1]))) 66 | 67 | # dropout once each block 68 | if i % self.block_size == 0: 69 | encoder_outputs[-1] = nn.functional.dropout(encoder_outputs[-1], p = self.dropout) 70 | 71 | # Apply upsampling and decoder layers 72 | decoder_inputs = encoder_outputs[-1] 73 | for i, layer in enumerate(self.decoder_layers): 74 | decoder_inputs = nn.functional.interpolate(decoder_inputs, scale_factor = 2) 75 | decoder_output = self.activation(layer(decoder_inputs)) 76 | 77 | # Concatenate layer with corresponding encoder layer 78 | decoder_inputs = torch.cat((encoder_outputs[-i - 2], decoder_output), dim = 1) 79 | 80 | # Dropout once each block 81 | if i % self.block_size == 0: 82 | decoder_inputs = nn.functional.dropout(decoder_inputs, p = self.dropout) 83 | 84 | # Reduce to 1 channel and scale from -1 to 1 85 | return torch.tanh(self.decrease_channels(decoder_inputs)) 86 | -------------------------------------------------------------------------------- /plot_wav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from aecnn import AECNN 4 | from data_io import wav_dataset, mag 5 | from train import parse_args 6 | import numpy as np 7 | import os 8 | import soundfile as sf 9 | import matplotlib.pyplot as plt 10 | from python_speech_features import fbank 11 | 12 | 13 | def print_spec(array, filename, xAxisRange=None, axes='on'): 14 | """ Print a spectrogram to a file """ 15 | 16 | if xAxisRange: 17 | array = np.flipud(array.T)[:-3,xAxisRange[0]:xAxisRange[1]] 18 | extent = [xAxisRange[0] / 100., xAxisRange[1] / 100., 0, 8] 19 | else: 20 | array = np.flipud(array.T) 21 | extent = [0, array.shape[1] / 100., 0, 8] 22 | 23 | fig = plt.figure() 24 | ax = fig.add_subplot(111) 25 | ax.imshow(array,cmap=plt.cm.jet, interpolation='none', extent=extent, aspect=1./14) 26 | 27 | if axes == 'on': 28 | ax.set_xlabel("Time (s)") 29 | ax.set_ylabel("Frequency (kHz)") 30 | fig.savefig(filename, format='pdf', bbox_inches='tight') 31 | else: 32 | ax.axis('off') 33 | fig.savefig(filename, format='pdf', bbox_inches=0) 34 | 35 | plt.close(fig) 36 | 37 | 38 | def energy(data, window = 200): 39 | 40 | e = np.zeros_like(data) 41 | for i in range(len(data)-2*window): 42 | i += window 43 | e[i-window:i+window] += np.sum(data[i-window:i+window] ** 2) / window 44 | 45 | cap = 0.2 46 | e[e > cap] = cap 47 | e[e < cap] = 0 48 | 49 | return e 50 | 51 | def zero_crossings(data, window = 200): 52 | 53 | z = np.zeros_like(data) 54 | for i in range(len(data)-2*window-1): 55 | i += window 56 | crossed = data[i-window:i+window] * data[i-window+1:i+window+1] 57 | crossed[crossed > 0] = 0 58 | crossed[crossed < 0] = 0.3 59 | z[i-window:i+window] += np.sum(crossed) / window / window 60 | 61 | cap = 0.2 62 | z[z > cap] = cap 63 | z[z < cap] = 0 64 | 65 | return -z 66 | 67 | def print_wav(data, fname, sr = 16000.): 68 | 69 | 70 | e = energy(data) 71 | z = zero_crossings(data) 72 | 73 | fig = plt.figure() 74 | ax = fig.add_subplot(111) 75 | xpoints = np.arange(len(data)) / sr 76 | ax.plot(xpoints, data, linewidth=0.5) 77 | ax.plot(xpoints, e, linewidth=0.5) 78 | ax.plot(xpoints, z, linewidth=0.5) 79 | 80 | fig.savefig(fname, format='pdf', bbox_inches='tight') 81 | 82 | plt.close(fig) 83 | 84 | def run_test(config): 85 | """ Define our model and test it """ 86 | 87 | generator = AECNN( 88 | channel_counts = config.gchan, 89 | kernel_size = config.gkernel, 90 | block_size = config.gblocksize, 91 | dropout = config.gdrop, 92 | ).cuda() 93 | 94 | generator.load_state_dict(torch.load(config.gcheckpoints)) 95 | 96 | # Initialize datasets 97 | #ev_dataset = wav_dataset(config, 'et', 4) 98 | ev_dataset = wav_dataset(config, 'et') 99 | 100 | 101 | #count = 0 102 | #score = {'stoi': 0, 'estoi':0, 'sdr':0} 103 | example = ev_dataset[361] 104 | print(example['id']) 105 | data = np.squeeze(generator(example['noisy'].cuda()).cpu().detach().numpy()) 106 | #clean = np.squeeze(example['clean'].numpy()) 107 | noisy = np.squeeze(example['noisy'].numpy()) 108 | #with sf.SoundFile('clean.wav', 'w', 16000, 1) as w: 109 | # w.write(clean) 110 | with sf.SoundFile('noisy.wav', 'w', 16000, 1) as w: 111 | w.write(noisy) 112 | with sf.SoundFile('test.wav', 'w', 16000, 1) as w: 113 | w.write(data) 114 | 115 | #print_wav(noisy, 'noisy_waveform.pdf') 116 | #print_wav(clean, 'clean_waveform.pdf') 117 | #print_wav(data, 'waveform.pdf') 118 | #data = np.squeeze(generator(example['noisy']).detach().numpy()) 119 | #clean = np.squeeze(example['clean'].numpy()) 120 | #noisy = np.squeeze(example['noisy'].numpy()) 121 | 122 | 123 | #data, _ = fbank(data,nfilt=80) 124 | #clean, _ = fbank(clean,nfilt=80) 125 | #noisy, _ = fbank(noisy,nfilt=80) 126 | #data, clean, noisy = np.log(data), np.log(clean), np.log(noisy) 127 | #minimum = min(np.min(data), np.min(clean), np.min(noisy)) 128 | #data, clean, noisy = data - minimum, clean - minimum, noisy - minimum 129 | #maximum = max(np.max(data), np.max(clean), np.max(noisy)) 130 | #data, clean, noisy = data / maximum, clean / maximum, noisy / maximum 131 | #print_spec(data, 'spectrogram.pdf', xAxisRange=[110,140]) 132 | #print_spec(clean, 'clean_spec.pdf', xAxisRange=[110,140]) 133 | #print_spec(noisy, 'noisy_spec.pdf', xAxisRange=[110,140]) 134 | 135 | 136 | def main(): 137 | config = parse_args() 138 | run_test(config) 139 | 140 | if __name__=='__main__': 141 | main() 142 | 143 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import time 4 | import os 5 | 6 | from aecnn import AECNN 7 | from resnet import ResNet 8 | from discriminator import Discriminator 9 | from embedding import Embedding 10 | from trainer import Trainer 11 | from data_io import wav_dataset 12 | from torch import autograd 13 | 14 | def run_training(config): 15 | """ Define our model and train it """ 16 | 17 | load_generator = config.gpretrain is not None 18 | train_generator = config.gcheckpoints is not None 19 | 20 | load_mimic = config.mpretrain is not None 21 | train_mimic = config.mcheckpoints is not None 22 | 23 | if torch.cuda.is_available(): 24 | config.device = torch.device('cuda') 25 | else: 26 | config.device = torch.device('cpu') 27 | 28 | models = {} 29 | 30 | # Build enhancement model 31 | if load_generator or train_generator: 32 | models['generator'] = AECNN( 33 | channel_counts = config.gchan, 34 | kernel_size = config.gkernel, 35 | block_size = config.gblocksize, 36 | dropout = config.gdrop, 37 | training = train_generator, 38 | ).to(config.device) 39 | 40 | models['generator'].requires_grad = train_generator 41 | 42 | if load_generator: 43 | models['generator'].load_state_dict(torch.load(config.gpretrain, map_location=config.device)) 44 | 45 | # Build acoustic model 46 | if load_mimic or train_mimic: 47 | 48 | if config.mact == 'rrelu': 49 | activation = lambda x: torch.nn.functional.rrelu(x, training = train_mimic) 50 | else: 51 | activation = lambda x: torch.nn.functional.leaky_relu(x, negative_slope = 0.3) 52 | 53 | models['mimic'] = ResNet( 54 | input_dim = 256, 55 | output_dim = config.moutdim, 56 | channel_counts = config.mchan, 57 | dropout = config.mdrop, 58 | training = train_mimic, 59 | activation = activation, 60 | ).to(config.device) 61 | 62 | models['mimic'].requires_grad = train_mimic 63 | 64 | if load_mimic: 65 | models['mimic'].load_state_dict(torch.load(config.mpretrain, map_location=config.device)) 66 | 67 | if config.mimic_weight > 0 or any(config.texture_weights) and train_mimic: 68 | models['teacher'] = ResNet( 69 | input_dim = 256, 70 | output_dim = config.moutdim, 71 | channel_counts = config.mchan, 72 | dropout = 0, 73 | training = False, 74 | ).to(config.device) 75 | 76 | models['teacher'].requires_grad = False 77 | models['teacher'].load_state_dict(torch.load(config.mpretrain, map_location=config.device)) 78 | 79 | if config.gan_weight > 0: 80 | models['discriminator'] = Discriminator( 81 | channel_counts = config.gchan, 82 | kernel_size = config.gkernel, 83 | block_size = config.gblocksize, 84 | dropout = config.gdrop, 85 | training = True, 86 | ).to(config.device) 87 | 88 | # Initialize datasets 89 | tr_dataset = wav_dataset(config, 'tr') 90 | dt_dataset = wav_dataset(config, 'dt', 4) 91 | 92 | if config.soft_senone_weight > 0: 93 | print("Pretraining senone embeddings") 94 | models['embedding'] = Embedding(config.moutdim).to(config.device) 95 | models['embedding'].pretrain(tr_dataset, models['mimic'], config.device) 96 | print("Completed embedding pretraining") 97 | 98 | if config.real_senone_file: 99 | real_config = config 100 | real_config.senone_file = config.real_senone_file 101 | real_config.noisy_flist = config.real_flist 102 | real_config.noise_flist = None 103 | real_config.clean_flist = None 104 | tr_real_dataset = wav_dataset(real_config, 'tr') 105 | 106 | trainer = Trainer(config, models) 107 | 108 | # Run the training 109 | best_dev_loss = float('inf') 110 | for epoch in range(config.epochs): 111 | print("Starting epoch %d" % epoch) 112 | 113 | # Train for one epoch 114 | start_time = time.time() 115 | trainer.run_epoch(tr_dataset, training = True) 116 | total_time = time.time() - start_time 117 | 118 | print("Completed epoch %d in %d seconds" % (epoch, int(total_time))) 119 | 120 | dev_loss, dev_losses = trainer.run_epoch(dt_dataset, training = False) 121 | 122 | print("Dev loss: %f" % dev_loss) 123 | for key in dev_losses: 124 | print("%s loss: %f" % (key, dev_losses[key])) 125 | 126 | # Save our model 127 | if dev_loss < best_dev_loss: 128 | best_dev_loss = dev_loss 129 | if train_mimic: 130 | mfile = os.path.join(config.mcheckpoints, config.mfile) 131 | torch.save(models['mimic'].state_dict(), mfile) 132 | if train_generator: 133 | gfile = os.path.join(config.gcheckpoints, config.gfile) 134 | torch.save(models['generator'].state_dict(), gfile) 135 | 136 | def parse_args(): 137 | parser = argparse.ArgumentParser() 138 | 139 | file_args = { 140 | 'base_dir': None, 'clean_flist': None, 'noise_flist': None, 'noisy_flist': None, 'senone_file': None, 141 | 'gpretrain': None, 'gcheckpoints': None, 'mpretrain': None, 'mcheckpoints': None, 142 | 'mfile': 'model.pt', 'gfile': 'model.pt', 'output_dir': None, 'phase': None, 143 | 'real_flist': None, 'real_senone_file': None, 144 | } 145 | train_args = { 146 | 'learn_rate': 2e-4, 'lr_decay': 0.5, 'patience': 1, 'epochs': 25, 'batch_size': 4, 147 | 'l1_weight': 0., 'sm_weight': 0., 'mimic_weight': 0., 'ce_weight': 0., 'real_ce_weight': 0., 148 | 'texture_weights': [0., 0., 0., 0.], 'gan_weight': 0., 'soft_senone_weight': 0., 149 | } 150 | gen_args = { 151 | 'gmodel': 'aecnn', 'gchan': [64, 128, 256], 'gblocksize': 3, 'gdrop': 0.2, 'gkernel': 11, 152 | } 153 | mim_args = { 154 | 'mmodel': 'resnet', 'mchan': [64, 128, 256, 512], 'mdrop': 0.2, 'moutdim': 2023, 'mact': 'lrelu', 155 | } 156 | 157 | for arg_list in [file_args, train_args, gen_args, mim_args]: 158 | for arg, default in arg_list.items(): 159 | if default is None: 160 | parser.add_argument(f"--{arg}") 161 | elif type(default) == list: 162 | parser.add_argument(f"--{arg}", default=default, nargs="+", type=type(default[0])) 163 | else: 164 | parser.add_argument(f"--{arg}", default=default, type=type(default)) 165 | 166 | return parser.parse_args() 167 | 168 | def main(): 169 | config = parse_args() 170 | run_training(config) 171 | 172 | if __name__=='__main__': 173 | main() 174 | 175 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from data_io import mag 4 | from collections import defaultdict 5 | 6 | class Trainer: 7 | def __init__(self, config, models): 8 | 9 | self.config = config 10 | self.models = models 11 | 12 | # Initialize optimizer 13 | params = [] 14 | if self.config.gcheckpoints: 15 | params.append({'params': models['generator'].parameters()}) 16 | if self.config.mcheckpoints: 17 | params.append({'params': models['mimic'].parameters()}) 18 | if self.config.gan_weight > 0: 19 | self.optimizerD = torch.optim.Adam(models['discriminator'].parameters(), lr = self.config.learn_rate) 20 | self.optimizerD.zero_grad() 21 | self.schedulerD = torch.optim.lr_scheduler.ReduceLROnPlateau( 22 | self.optimizerD, 23 | patience = self.config.patience, 24 | factor = self.config.lr_decay, 25 | verbose = True, 26 | ) 27 | 28 | self.optimizer = torch.optim.Adam(params, lr = self.config.learn_rate) 29 | self.optimizer.zero_grad() 30 | 31 | # Reduce learning rate if we're not improving dev loss 32 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 33 | self.optimizer, 34 | patience = self.config.patience, 35 | factor = self.config.lr_decay, 36 | verbose = True, 37 | ) 38 | 39 | def run_epoch(self, dataset, training = False, real = False): 40 | 41 | if training: 42 | samples = 0 43 | for sample in dataset: 44 | samples += 1 45 | 46 | outputs = self.forward(sample) 47 | if self.config.gan_weight > 0:# and outputs['d_fake'] > 0.2: 48 | d_loss = self.discriminate_loss(outputs) 49 | d_loss.backward() 50 | 51 | if samples % self.config.batch_size == 0: 52 | self.optimizerD.step() 53 | self.optimizerD.zero_grad() 54 | 55 | outputs = self.forward(sample) 56 | 57 | loss, losses = self.generate_loss(outputs, training, real) 58 | if loss != 0: 59 | loss.backward() 60 | 61 | if samples % self.config.batch_size == 0: 62 | self.optimizer.step() 63 | self.optimizer.zero_grad() 64 | 65 | else: 66 | dev_loss = 0 67 | dev_losses = defaultdict(lambda: 0) 68 | with torch.no_grad(): 69 | for sample in dataset: 70 | outputs = self.forward(sample) 71 | loss, losses = self.generate_loss(outputs, training) 72 | dev_loss += loss / len(dataset) 73 | for key in losses: 74 | dev_losses[key] += losses[key] / len(dataset) 75 | 76 | self.scheduler.step(dev_loss) 77 | if self.config.gan_weight > 0: 78 | self.schedulerD.step(dev_loss) 79 | 80 | return dev_loss, dev_losses 81 | 82 | def forward(self, sample): 83 | 84 | device = self.config.device 85 | 86 | if 'generator' not in self.models: 87 | outputs = normalize(mag(sample['clean'].to(device), truncate = True), sample['senone'].to(device)) 88 | outputs['mimic'] = self.models['mimic'](outputs['clean_mag']) 89 | else: 90 | outputs = { 91 | 'generator': self.models['generator'](sample['noisy'].to(device)), 92 | } 93 | if 'clean' in sample: 94 | outputs['clean_wav'] = sample['clean'].to(device) 95 | 96 | if self.config.sm_weight or 'mimic' in self.models: 97 | outputs['denoised_mag'] = mag(outputs['generator'], truncate = True) 98 | 99 | if 'clean' in sample: 100 | outputs['clean_mag'] = mag(outputs['clean_wav'], truncate = True) 101 | 102 | if 'mimic' in self.models: 103 | outputs['mimic'] = self.models['mimic'](outputs['denoised_mag']) 104 | 105 | mimic_losses = self.config.texture_weights + \ 106 | [self.config.mimic_weight, self.config.soft_senone_weight] 107 | 108 | if 'teacher' in self.models and 'clean' in sample: 109 | outputs['soft_label'] = self.models['teacher'](outputs['clean_mag']) 110 | elif any(mimic_losses) and 'clean' in sample: 111 | outputs['soft_label'] = self.models['mimic'](outputs['clean_mag']) 112 | 113 | 114 | if 'discriminator' in self.models: 115 | outputs['d_real'] = self.models['discriminator'](outputs['clean_wav']) 116 | outputs['d_fake'] = self.models['discriminator'](outputs['generator']) 117 | 118 | if 'senone' in sample: 119 | outputs['senone'] = sample['senone'].to(device) 120 | 121 | if self.config.soft_senone_weight: 122 | outputs['embedding'] = self.models['embedding'](outputs['senone']).transpose(1, 2) 123 | 124 | return outputs 125 | 126 | def discriminate_loss(self, outputs): 127 | 128 | #print("Discrim real error: %f" % outputs['d_real'].mean()) 129 | #print("Discrim fake error: %f" % outputs['d_fake'].mean()) 130 | 131 | target_real = torch.ones_like(outputs['d_real']) 132 | loss_real = F.l1_loss(outputs['d_real'], target_real) 133 | 134 | target_fake = torch.zeros_like(outputs['d_fake']) 135 | loss_fake = F.l1_loss(outputs['d_fake'], target_fake) 136 | 137 | return self.config.gan_weight * (loss_real + loss_fake) 138 | 139 | # Compute loss, using weights for each type of loss 140 | def generate_loss(self, outputs, training = False, real = False): 141 | 142 | # Acoustic model training 143 | if 'generator' not in self.models or real: 144 | loss = self.config.ce_weight * truncate_and_ce(outputs['mimic'][-1], outputs['senone']) 145 | losses = {'ce': truncate_and_ce(outputs['mimic'][-1], outputs['senone'])} 146 | 147 | # Enhancement model training 148 | else: 149 | loss = 0 150 | losses = {} 151 | 152 | # Time-domain loss 153 | if self.config.l1_weight > 0: 154 | loss += self.config.l1_weight * F.l1_loss(outputs['generator'], outputs['clean_wav']) 155 | losses['l1'] = F.l1_loss(outputs['generator'], outputs['clean_wav']).detach() 156 | 157 | # Spectral mapping loss 158 | if self.config.sm_weight > 0: 159 | loss += self.config.sm_weight * F.l1_loss(outputs['denoised_mag'], outputs['clean_mag']) 160 | losses['sm'] = F.l1_loss(outputs['denoised_mag'], outputs['clean_mag']).detach() 161 | 162 | # Mimic loss (perceptual loss) 163 | if self.config.mimic_weight > 0: 164 | loss += self.config.mimic_weight * F.l1_loss(outputs['mimic'][-1], outputs['soft_label'][-1]) 165 | losses['mimic'] = F.l1_loss(outputs['mimic'][-1], outputs['soft_label'][-1]).detach() 166 | 167 | # Texture loss at each convolutional block 168 | if any(self.config.texture_weights): 169 | for index in range(len(outputs['mimic']) - 1): 170 | if self.config.texture_weights[index] > 0: 171 | prediction = outputs['mimic'][index] 172 | target = outputs['soft_label'][index] 173 | loss += self.config.texture_weights[index] * F.l1_loss(prediction, target) 174 | losses['texture%d' % index] = F.l1_loss(prediction, target).detach() 175 | 176 | # Cross-entropy loss (for joint training?) 177 | if self.config.ce_weight > 0: 178 | #norm = normalize(outputs['denoised_mag'], outputs['senone']) 179 | #outputs = self.models['mimic'](norm['clean_mag'])[-1] 180 | #targets = norm['senone'] 181 | loss += self.config.ce_weight * truncate_and_ce(outputs['mimic'], outputs['senone']) 182 | losses['ce'] = truncate_and_ce(outputs['mimic'], outputs['senone']).detach() 183 | 184 | if self.config.gan_weight > 0: 185 | target = torch.ones_like(outputs['d_fake']) 186 | losses['generator'] = F.mse_loss(outputs['d_fake'], target) 187 | #print("Generator prediction: %f" % outputs['d_fake'].mean()) 188 | #print("Generator loss: %f" % losses['generator']) 189 | 190 | if training:#outputs['d_fake'].mean() < 0.4 and training: 191 | loss += self.config.gan_weight * F.l1_loss(outputs['d_fake'], target) 192 | 193 | if self.config.soft_senone_weight > 0: 194 | losses['soft_senone'] = truncate_and_l1(outputs['mimic'][-1], outputs['embedding']).detach() 195 | loss += self.config.soft_senone_weight * truncate_and_l1(outputs['mimic'][-1], outputs['embedding']) 196 | 197 | return loss, losses 198 | 199 | def normalize(inputs, target, factor = 16): 200 | 201 | # Ensure equal length 202 | newlen = min(inputs.shape[3], target.shape[1]) 203 | newlen -= newlen % factor 204 | inputs = inputs[:, :, :, :newlen] 205 | target = target[:, :newlen] 206 | 207 | return {'clean_mag': inputs, 'senone': target} 208 | 209 | def get_gram_matrix(x): 210 | feature_maps = x.shape[1] 211 | x = x.view(feature_maps, -1) 212 | x = (x - torch.mean(x)) / torch.std(x) 213 | 214 | mat = torch.mm(x, x.t()) 215 | 216 | return mat 217 | 218 | def truncate_and_l1(inputs, target): 219 | newlen = min(inputs.shape[-1], target.shape[-1]) 220 | 221 | inputs = inputs[:, :, :newlen] 222 | target = target[:, :, :newlen] 223 | 224 | return F.l1_loss(inputs, target) 225 | 226 | def truncate_and_ce(inputs, target): 227 | newlen = min(inputs.shape[-1], target.shape[-1]) 228 | 229 | inputs = inputs[:, :, :newlen] 230 | target = target[:, :newlen] 231 | 232 | return F.cross_entropy(inputs, target) 233 | --------------------------------------------------------------------------------