├── model ├── __init__.py └── nsnet_model.py ├── dataloader ├── __init__.py └── wav_dataset.py ├── train_nn.py ├── test_nn.py ├── README.md └── .gitignore /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_nn.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from model.nsnet_model import NSNetModel 3 | from argparse import Namespace 4 | import os 5 | 6 | train_dir = './WAVs/dataset/training' 7 | val_dir = './WAVs/dataset/validation' 8 | 9 | hparams = {'train_dir': train_dir, 10 | 'val_dir': val_dir, 11 | 'batch_size': 64, 12 | 'n_fft': 512, 13 | 'n_gru_layers': 3, 14 | 'gru_dropout': 0.2, 15 | 'alpha': 0.35} 16 | 17 | model = NSNetModel(hparams=Namespace(**hparams)) 18 | 19 | trainer = Trainer(gpus=1) 20 | trainer.fit(model) 21 | -------------------------------------------------------------------------------- /test_nn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import soundfile as sf 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchaudio.functional import angle, istft 7 | 8 | from dataloader.wav_dataset import WAVDataset 9 | from model.nsnet_model import NSNetModel 10 | 11 | model = NSNetModel.load_from_checkpoint(Path('/home/guillaume/Downloads/epoch=7.ckpt')) 12 | 13 | testing_dir = Path('/home/guillaume/Github/EHNet/WAVs/MS-SNSD-test/testing_seen_noise') 14 | n_fft = 512 15 | dataset = WAVDataset(dir=testing_dir, n_fft=n_fft, test=True) 16 | dataloader = DataLoader(dataset, batch_size=16, drop_last=False, shuffle=True) 17 | noisy_waveform, clean_waveform, x_stft, _, x_lps, x_ms, _, _ = next(iter(dataloader)) 18 | 19 | # enable eval mode 20 | model.zero_grad() 21 | model.eval() 22 | model.freeze() 23 | 24 | # disable gradients to save memory 25 | torch.set_grad_enabled(False) 26 | 27 | gain_mask = model(x_lps) 28 | y_spectrogram_hat = x_ms * gain_mask 29 | 30 | y_stft_hat = torch.stack([y_spectrogram_hat * torch.cos(angle(x_stft)), 31 | y_spectrogram_hat * torch.sin(angle(x_stft))], dim=-1) 32 | 33 | window = torch.hamming_window(n_fft) 34 | y_waveform_hat = istft(y_stft_hat, n_fft=n_fft, hop_length=n_fft // 4, win_length=n_fft, window=window, length=clean_waveform.shape[-1]) 35 | for i, waveform in enumerate(y_waveform_hat.numpy()): 36 | sf.write('denoised' + str(i) + '.wav', waveform, 16000) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NSNet 2 | This in an implementation of NSNet [1] in PyTorch and PyTorch Lightning. 3 | NSNet is a recurrent neural network for single channel speech enhancement. 4 | This was implemented as part of my thesis for the Master in Electrical Engineering at Ghent University. 5 | 6 | ## Prerequisites 7 | * torch 1.4 8 | * pytorch_lightning 0.7.6 9 | * torchaudio 1.4 10 | * soundfile 0.10.3.post1 11 | 12 | ## How to train 13 | A dataset containing both clean speech and corresponding noisy speech (i.e. clean speech with noise added) is required. 14 | 15 | Running _train_nn.py_ starts the training. 16 | 17 | The _train_dir_ variable should contain the path to a folder containing a _clean_ and a _noisy_ folder, containing the clean WAV files and the noisy WAV files respectively. The filename of a noisy WAV file must be the same as the corresponding clean WAV file, with optionally a suffix added delimited by _+_, 18 | e.g. clean01.wav → clean01+noise.wav 19 | 20 | The _val_dir_ follows the same convention, but this folder is used for validation. 21 | 22 | ## How to test 23 | Running the _test_nn.py_ file results in the output (denoised) WAV files. 24 | 25 | _testing_dir_ should point to a folder with the same structure as _train_dir_ and _val_dir_. 26 | 27 | ## Acknowledgements 28 | [1] Y. Xia, S. Braun, C. K. A. Reddy, H. Dubey, R. Cutler, and I. Tashev, “Weighted Speech Distortion Losses for Neural-network-based Real-time Speech Enhancement,” arXiv:2001.10601 [cs, eess], Feb. 2020. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | lightning_logs 3 | output 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /model/nsnet_model.py: -------------------------------------------------------------------------------- 1 | import logging as log 2 | from argparse import Namespace 3 | from collections import OrderedDict 4 | from pathlib import Path 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import optim 11 | from torch.utils.data import DataLoader 12 | 13 | from dataloader.wav_dataset import WAVDataset 14 | 15 | 16 | class NSNetModel(pl.LightningModule): 17 | def __init__(self, hparams=Namespace(**{'train_dir': Path(), 'val_dir': Path(), 'batch_size': 4, 'n_fft': 512, 18 | 'n_gru_layers': 3, 'gru_dropout': 0, 'alpha': 0.35})): 19 | super(NSNetModel, self).__init__() 20 | 21 | self.hparams = hparams 22 | self.train_dir = Path(self.hparams.train_dir) 23 | self.val_dir = Path(self.hparams.val_dir) 24 | self.batch_size = self.hparams.batch_size 25 | self.n_fft = self.hparams.n_fft 26 | self.n_frequency_bins = self.n_fft // 2 + 1 27 | self.n_gru_layers = self.hparams.n_gru_layers 28 | self.gru_dropout = self.hparams.gru_dropout 29 | self.alpha = self.hparams.alpha 30 | 31 | # build model 32 | self.__build_model() 33 | 34 | # --------------------- 35 | # MODEL SETUP 36 | # --------------------- 37 | def __build_model(self): 38 | """ 39 | Layout model 40 | :return: 41 | """ 42 | self.gru = nn.GRU(input_size=self.n_frequency_bins, hidden_size=self.n_frequency_bins, num_layers=self.n_gru_layers, 43 | batch_first=True, dropout=self.gru_dropout) 44 | self.dense = nn.Linear(in_features=self.n_frequency_bins, out_features=self.n_frequency_bins) 45 | 46 | # --------------------- 47 | # TRAINING 48 | # --------------------- 49 | def forward(self, x): 50 | x = x.permute(0, 2, 1) # (batch_size, time, n_frequency_bins) 51 | x, _ = self.gru(x) # (batch_size, time, n_frequency_bins) 52 | x = torch.sigmoid(self.dense(x)) # (batch_size, time, frequency_bins) 53 | x = x.permute(0, 2, 1) # (batch_size, frequency_bins, time) 54 | 55 | return x 56 | 57 | def loss(self, target, prediction): 58 | loss = F.mse_loss(prediction, target) 59 | return loss 60 | 61 | def training_step(self, batch, batch_idx): 62 | # forward pass 63 | x_lps, x_ms, y_ms, noise_ms, VAD = batch 64 | VAD_expanded = torch.unsqueeze(VAD, dim=1).expand_as(y_ms) 65 | 66 | y_hat = self.forward(x_lps) 67 | 68 | loss_speech = self.loss(y_ms[VAD_expanded], (y_hat * y_ms)[VAD_expanded]) 69 | loss_noise = self.loss(torch.zeros_like(y_hat), y_hat * noise_ms) 70 | 71 | loss_val = self.alpha * loss_speech + (1 - self.alpha) * loss_noise 72 | 73 | tqdm_dict = {'train_loss': loss_val} 74 | output = OrderedDict({ 75 | 'loss': loss_val, 76 | 'progress_bar': tqdm_dict, 77 | 'log': tqdm_dict 78 | }) 79 | 80 | return output 81 | 82 | def validation_step(self, batch, batch_idx): 83 | # forward pass 84 | x_lps, x_ms, y_ms, noise_ms, VAD = batch 85 | VAD_expanded = torch.unsqueeze(VAD, dim=1).expand_as(y_ms) 86 | 87 | y_hat = self.forward(x_lps) 88 | 89 | loss_speech = self.loss(y_ms[VAD_expanded], (y_hat * y_ms)[VAD_expanded]) 90 | loss_noise = self.loss(torch.zeros_like(y_hat), y_hat * noise_ms) 91 | 92 | loss_val = self.alpha * loss_speech + (1 - self.alpha) * loss_noise 93 | 94 | output = OrderedDict({ 95 | 'val_loss': loss_val, 96 | }) 97 | 98 | # can also return just a scalar instead of a dict (return loss_val) 99 | return output 100 | 101 | def validation_epoch_end(self, outputs): 102 | val_loss_mean = 0 103 | for output in outputs: 104 | val_loss = output['val_loss'] 105 | val_loss_mean += val_loss 106 | 107 | val_loss_mean /= len(outputs) 108 | tqdm_dict = {'val_loss': val_loss_mean} 109 | result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': val_loss_mean} 110 | return result 111 | 112 | # --------------------- 113 | # TRAINING SETUP 114 | # --------------------- 115 | def configure_optimizers(self): 116 | optimizer = optim.Adam(self.parameters(), lr=1e-3) 117 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, verbose=True, min_lr=1e-6, patience=5) 118 | return [optimizer], [scheduler] 119 | 120 | def __dataloader(self, train): 121 | # init data generators 122 | 123 | if train: 124 | dataset = WAVDataset(self.train_dir, n_fft=self.n_fft) 125 | else: 126 | dataset = WAVDataset(self.val_dir, n_fft=self.n_fft) 127 | 128 | loader = DataLoader( 129 | dataset=dataset, 130 | batch_size=self.batch_size, 131 | shuffle=True, 132 | num_workers=24, 133 | ) 134 | 135 | return loader 136 | 137 | def train_dataloader(self): 138 | log.info('Training data loader called.') 139 | return self.__dataloader(train=True) 140 | 141 | def val_dataloader(self): 142 | log.info('Validation data loader called.') 143 | return self.__dataloader(train=False) 144 | -------------------------------------------------------------------------------- /dataloader/wav_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio 8 | from torch.utils.data import Dataset 9 | 10 | torchaudio.set_audio_backend("soundfile") # default backend (SoX) has bugs when loading WAVs 11 | 12 | 13 | class WAVDataset(Dataset): 14 | """ 15 | Create a PyTorch Dataset object from a directory containing clean and noisy WAV files 16 | """ 17 | def __init__(self, dir: Path, n_fft, test=False): 18 | self.clean_dir = dir.joinpath('clean') 19 | self.noisy_dir = dir.joinpath('noisy') 20 | self.n_fft = n_fft 21 | self.test = test 22 | 23 | assert os.path.exists(self.clean_dir), 'No clean WAV file folder found!' 24 | assert os.path.exists(self.noisy_dir), 'No noisy WAV file folder found!' 25 | 26 | self.clean_WAVs = {} 27 | for i, filename in enumerate(sorted(os.listdir(self.clean_dir))): 28 | self.clean_WAVs[i] = self.clean_dir.joinpath(filename) 29 | 30 | self.noisy_WAVs = {} 31 | for i, filename in enumerate(sorted(os.listdir(self.noisy_dir))): 32 | self.noisy_WAVs[i] = self.noisy_dir.joinpath(filename) 33 | 34 | # VAD 35 | step = 16000 / self.n_fft 36 | frequency_bins = np.arange(0, (self.n_fft // 2 + 1) * step, step=step) 37 | self.VAD_frequencies = np.where((frequency_bins >= 300) & (frequency_bins <= 5000), True, False) 38 | 39 | # mean normalization 40 | frameshift = (self.n_fft / 16000) / 4 # frameshift between STFT frames 41 | t_init = 0.1 # init time 42 | tauFeat = 3 # time constant 43 | tauFeatInit = 0.1 # time constant during init 44 | self.n_init_frames = math.ceil(t_init / frameshift) 45 | self.alpha_feat_init = math.exp(-frameshift / tauFeatInit) 46 | self.alpha_feat = math.exp(-frameshift / tauFeat) 47 | 48 | def __len__(self): 49 | return len(self.noisy_WAVs) 50 | 51 | def __getitem__(self, idx): 52 | noisy_path = self.noisy_WAVs[idx] 53 | clean_path = self.clean_dir.joinpath(noisy_path.name.split('+')[0] + '.wav') # get the filename of the clean WAV from the filename of the noisy WAV 54 | while True: 55 | try: 56 | clean_waveform, _ = torchaudio.load(clean_path, normalization=2**15) 57 | noisy_waveform, _ = torchaudio.load(noisy_path, normalization=2**15) 58 | except (RuntimeError, OSError): 59 | continue 60 | break 61 | 62 | assert clean_waveform.shape[0] == 1 and noisy_waveform.shape[0] == 1, 'WAV file is not single channel!' 63 | 64 | window = torch.hamming_window(self.n_fft) 65 | x_stft = torch.stft(noisy_waveform.view(-1), n_fft=self.n_fft, hop_length=self.n_fft // 4, win_length=self.n_fft, window=window) 66 | y_stft = torch.stft(clean_waveform.view(-1), n_fft=self.n_fft, hop_length=self.n_fft // 4, win_length=self.n_fft, window=window) 67 | 68 | x_ps = x_stft.pow(2).sum(-1) 69 | x_lps = LogTransform()(x_ps) 70 | 71 | x_ms = x_ps.sqrt() 72 | y_ms = y_stft.pow(2).sum(-1).sqrt() 73 | 74 | noise_ms = (x_stft - y_stft).pow(2).sum(-1).sqrt() 75 | 76 | # VAD 77 | y_ms_filtered = y_ms[self.VAD_frequencies] 78 | y_energy_filtered = y_ms_filtered.pow(2).mean(dim=0) 79 | y_energy_filtered_averaged = self.__moving_average(y_energy_filtered) 80 | y_peak_energy = y_energy_filtered_averaged.max() 81 | VAD = torch.where(y_energy_filtered_averaged > y_peak_energy / 1000, torch.ones_like(y_energy_filtered), torch.zeros_like(y_energy_filtered)) 82 | VAD = VAD.bool() 83 | 84 | # mean normalization 85 | frames = [] 86 | x_lps = x_lps.transpose(0, 1) # (time, frequency) 87 | n_init_frames = self.n_init_frames 88 | alpha_feat_init = self.alpha_feat_init 89 | alpha_feat = self.alpha_feat 90 | for frame_counter, frame_feature in enumerate(x_lps): 91 | if frame_counter < n_init_frames: 92 | alpha = alpha_feat_init 93 | else: 94 | alpha = alpha_feat 95 | if frame_counter == 0: 96 | mu = frame_feature 97 | sigmasquare = frame_feature.pow(2) 98 | mu = alpha * mu + (1 - alpha) * frame_feature 99 | sigmasquare = alpha * sigmasquare + (1 - alpha) * frame_feature.pow(2) 100 | sigma = torch.sqrt(torch.clamp(sigmasquare - mu.pow(2), min=1e-12)) # limit for sqrt 101 | norm_feature = (frame_feature - mu) / sigma 102 | frames.append(norm_feature) 103 | 104 | x_lps = torch.stack(frames, dim=0).transpose(0, 1) # (frequency, time) 105 | 106 | if not self.test: 107 | return x_lps, x_ms, y_ms, noise_ms, VAD 108 | if self.test: 109 | return noisy_waveform.view(-1), clean_waveform.view(-1), x_stft, y_stft, x_lps, x_ms, y_ms, VAD 110 | 111 | def __moving_average(self, a, n=3): 112 | ret = torch.cumsum(a, dim=0) 113 | ret[n:] = ret[n:] - ret[:-n] 114 | ret[:n - 1] = a[:n - 1] 115 | ret[n - 1:] = ret[n - 1:] / n 116 | return ret 117 | 118 | 119 | class LogTransform(torch.nn.Module): 120 | def __init__(self, floor=10**-12): 121 | super().__init__() 122 | self.floor = floor 123 | 124 | def forward(self, specgram): 125 | return torch.log(torch.clamp(specgram, min=self.floor)) 126 | --------------------------------------------------------------------------------