├── .gitattributes ├── exp ├── DNS-large-full │ └── checkpoint │ │ └── pretrained.pkl └── DNS-large-high │ └── checkpoint │ └── pretrained.pkl ├── Dockerfile ├── incl_licenses ├── LICENSE_4 ├── LICENSE_3 ├── LICENSE_5 ├── LICENSE_2 └── LICENSE_1 ├── LICENSE ├── configs ├── DNS-large-full.json └── DNS-large-high.json ├── python_eval.py ├── dataset.py ├── util.py ├── stft_loss.py ├── denoise.py ├── README.md ├── distributed.py ├── train.py └── network.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /exp/DNS-large-full/checkpoint/pretrained.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:145c101eb5bbfa3ba52fb2b4ec7e5b64a361c102f89291f75e1dd42601d95dc9 3 | size 184336765 4 | -------------------------------------------------------------------------------- /exp/DNS-large-high/checkpoint/pretrained.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:513d9e4f69483bf2bcc3059dd6b3644140763bf3f22df41d7ee366cc2cbd1829 3 | size 184336765 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.12-py3 2 | RUN apt-get update --fix-missing 3 | 4 | RUN pip install pillow==6.2.0 5 | RUN pip install torchaudio==0.8.0 6 | RUN pip install inflect==4.1.0 7 | RUN pip install scipy==1.5.0 8 | RUN pip install tqdm 9 | RUN pip install pesq 10 | RUN pip install pystoi -------------------------------------------------------------------------------- /incl_licenses/LICENSE_4: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-Present OpenNMT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NVIDIA CORPORATION. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /incl_licenses/LICENSE_3: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Victor Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /incl_licenses/LICENSE_5: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /incl_licenses/LICENSE_2: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Tomoki Hayashi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /incl_licenses/LICENSE_1: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /configs/DNS-large-full.json: -------------------------------------------------------------------------------- 1 | { 2 | "network_config": { 3 | "channels_input": 1, 4 | "channels_output": 1, 5 | "channels_H": 64, 6 | "max_H": 768, 7 | "encoder_n_layers": 8, 8 | "kernel_size": 4, 9 | "stride": 2, 10 | "tsfm_n_layers": 5, 11 | "tsfm_n_head": 8, 12 | "tsfm_d_model": 512, 13 | "tsfm_d_inner": 2048 14 | }, 15 | "train_config": { 16 | "exp_path": "DNS-large-full", 17 | "log":{ 18 | "directory": "./exp", 19 | "ckpt_iter": "max", 20 | "iters_per_ckpt": 10000, 21 | "iters_per_valid": 500 22 | }, 23 | "optimization":{ 24 | "n_iters": 250000, 25 | "learning_rate": 2e-4, 26 | "batch_size_per_gpu": 8 27 | }, 28 | "loss_config":{ 29 | "ell_p": 1, 30 | "ell_p_lambda": 1, 31 | "stft_lambda": 1, 32 | "stft_config":{ 33 | "sc_lambda": 0.5, 34 | "mag_lambda": 0.5, 35 | "band": "full", 36 | "hop_sizes": [50, 120, 240], 37 | "win_lengths": [240, 600, 1200], 38 | "fft_sizes": [512, 1024, 2048] 39 | } 40 | } 41 | }, 42 | "trainset_config": { 43 | "root": "./dns", 44 | "crop_length_sec": 10, 45 | "sample_rate": 16000 46 | }, 47 | "gen_config":{ 48 | "output_directory": "./exp" 49 | }, 50 | "dist_config": { 51 | "dist_backend": "nccl", 52 | "dist_url": "tcp://localhost:54321" 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /configs/DNS-large-high.json: -------------------------------------------------------------------------------- 1 | { 2 | "network_config": { 3 | "channels_input": 1, 4 | "channels_output": 1, 5 | "channels_H": 64, 6 | "max_H": 768, 7 | "encoder_n_layers": 8, 8 | "kernel_size": 4, 9 | "stride": 2, 10 | "tsfm_n_layers": 5, 11 | "tsfm_n_head": 8, 12 | "tsfm_d_model": 512, 13 | "tsfm_d_inner": 2048 14 | }, 15 | "train_config": { 16 | "exp_path": "DNS-large-high", 17 | "log":{ 18 | "directory": "./exp", 19 | "ckpt_iter": "max", 20 | "iters_per_ckpt": 10000, 21 | "iters_per_valid": 500 22 | }, 23 | "optimization":{ 24 | "n_iters": 250000, 25 | "learning_rate": 2e-4, 26 | "batch_size_per_gpu": 8 27 | }, 28 | "loss_config":{ 29 | "ell_p": 1, 30 | "ell_p_lambda": 1, 31 | "stft_lambda": 1, 32 | "stft_config":{ 33 | "sc_lambda": 0.5, 34 | "mag_lambda": 0.5, 35 | "band": "high", 36 | "hop_sizes": [50, 120, 240], 37 | "win_lengths": [240, 600, 1200], 38 | "fft_sizes": [512, 1024, 2048] 39 | } 40 | } 41 | }, 42 | "trainset_config": { 43 | "root": "./dns", 44 | "crop_length_sec": 10, 45 | "sample_rate": 16000 46 | }, 47 | "gen_config":{ 48 | "output_directory": "./exp" 49 | }, 50 | "dist_config": { 51 | "dist_backend": "nccl", 52 | "dist_url": "tcp://localhost:54321" 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /python_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import sys 6 | from collections import defaultdict 7 | from tqdm import tqdm 8 | import argparse 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | import numpy as np 13 | from scipy.io import wavfile 14 | 15 | from pesq import pesq 16 | from pystoi import stoi 17 | 18 | 19 | def evaluate_dns(testset_path, enhanced_path, target): 20 | reverb = 'no' 21 | result = defaultdict(int) 22 | 23 | for i in tqdm(range(300)): 24 | try: 25 | rate, clean = wavfile.read(os.path.join(testset_path, "clean", "clean_fileid_{}.wav".format(i))) 26 | if target == 'noisy': 27 | rate, target_wav = wavfile.read(os.path.join(testset_path, "noisy", "noisy_fileid_{}.wav".format(i))) 28 | else: 29 | rate, target_wav = wavfile.read(os.path.join(enhanced_path, "enhanced_fileid_{}.wav".format(i))) 30 | except: 31 | continue 32 | 33 | length = target_wav.shape[-1] 34 | 35 | result['pesq_wb'] += pesq(16000, clean, target_wav, 'wb') * length # wide band 36 | result['pesq_nb'] += pesq(16000, clean, target_wav, 'nb') * length # narrow band 37 | result['stoi'] += stoi(clean, target_wav, rate) * length 38 | result['count'] += 1 * length 39 | 40 | return result 41 | 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('-d', '--dataset', type=str, default='dns', help='dataset') 46 | parser.add_argument('-e', '--enhanced_path', type=str, help='enhanced audio path') 47 | parser.add_argument('-t', '--testset_path', type=str, help='testset path') 48 | args = parser.parse_args() 49 | 50 | enhanced_path = args.enhanced_path 51 | testset_path = args.testset_path 52 | target = 'enhanced' 53 | 54 | if args.dataset == 'dns': 55 | result = evaluate_dns(testset_path, enhanced_path, target) 56 | 57 | # logging 58 | for key in result: 59 | if key != 'count': 60 | print('{} = {:.3f}'.format(key, result[key]/result['count']), end=", ") 61 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import numpy as np 6 | 7 | from scipy.io.wavfile import read as wavread 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | import random 16 | random.seed(0) 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | 20 | from torchvision import datasets, models, transforms 21 | import torchaudio 22 | 23 | 24 | class CleanNoisyPairDataset(Dataset): 25 | """ 26 | Create a Dataset of clean and noisy audio pairs. 27 | Each element is a tuple of the form (clean waveform, noisy waveform, file_id) 28 | """ 29 | 30 | def __init__(self, root='./', subset='training', crop_length_sec=0): 31 | super(CleanNoisyPairDataset).__init__() 32 | 33 | assert subset is None or subset in ["training", "testing"] 34 | self.crop_length_sec = crop_length_sec 35 | self.subset = subset 36 | 37 | N_clean = len(os.listdir(os.path.join(root, 'training_set/clean'))) 38 | N_noisy = len(os.listdir(os.path.join(root, 'training_set/noisy'))) 39 | assert N_clean == N_noisy 40 | 41 | if subset == "training": 42 | self.files = [(os.path.join(root, 'training_set/clean', 'fileid_{}.wav'.format(i)), 43 | os.path.join(root, 'training_set/noisy', 'fileid_{}.wav'.format(i))) for i in range(N_clean)] 44 | 45 | elif subset == "testing": 46 | sortkey = lambda name: '_'.join(name.split('_')[-2:]) # specific for dns due to test sample names 47 | _p = os.path.join(root, 'datasets/test_set/synthetic/no_reverb') # path for DNS 48 | 49 | clean_files = os.listdir(os.path.join(_p, 'clean')) 50 | noisy_files = os.listdir(os.path.join(_p, 'noisy')) 51 | 52 | clean_files.sort(key=sortkey) 53 | noisy_files.sort(key=sortkey) 54 | 55 | self.files = [] 56 | for _c, _n in zip(clean_files, noisy_files): 57 | assert sortkey(_c) == sortkey(_n) 58 | self.files.append((os.path.join(_p, 'clean', _c), 59 | os.path.join(_p, 'noisy', _n))) 60 | self.crop_length_sec = 0 61 | 62 | else: 63 | raise NotImplementedError 64 | 65 | def __getitem__(self, n): 66 | fileid = self.files[n] 67 | clean_audio, sample_rate = torchaudio.load(fileid[0]) 68 | noisy_audio, sample_rate = torchaudio.load(fileid[1]) 69 | clean_audio, noisy_audio = clean_audio.squeeze(0), noisy_audio.squeeze(0) 70 | assert len(clean_audio) == len(noisy_audio) 71 | 72 | crop_length = int(self.crop_length_sec * sample_rate) 73 | assert crop_length < len(clean_audio) 74 | 75 | # random crop 76 | if self.subset != 'testing' and crop_length > 0: 77 | start = np.random.randint(low=0, high=len(clean_audio) - crop_length + 1) 78 | clean_audio = clean_audio[start:(start + crop_length)] 79 | noisy_audio = noisy_audio[start:(start + crop_length)] 80 | 81 | clean_audio, noisy_audio = clean_audio.unsqueeze(0), noisy_audio.unsqueeze(0) 82 | return (clean_audio, noisy_audio, fileid) 83 | 84 | def __len__(self): 85 | return len(self.files) 86 | 87 | 88 | def load_CleanNoisyPairDataset(root, subset, crop_length_sec, batch_size, sample_rate, num_gpus=1): 89 | """ 90 | Get dataloader with distributed sampling 91 | """ 92 | dataset = CleanNoisyPairDataset(root=root, subset=subset, crop_length_sec=crop_length_sec) 93 | kwargs = {"batch_size": batch_size, "num_workers": 4, "pin_memory": False, "drop_last": False} 94 | 95 | if num_gpus > 1: 96 | train_sampler = DistributedSampler(dataset) 97 | dataloader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, **kwargs) 98 | else: 99 | dataloader = torch.utils.data.DataLoader(dataset, sampler=None, shuffle=True, **kwargs) 100 | 101 | return dataloader 102 | 103 | 104 | if __name__ == '__main__': 105 | import json 106 | with open('./configs/DNS-large-full.json') as f: 107 | data = f.read() 108 | config = json.loads(data) 109 | trainset_config = config["trainset_config"] 110 | 111 | trainloader = load_CleanNoisyPairDataset(**trainset_config, subset='training', batch_size=2, num_gpus=1) 112 | testloader = load_CleanNoisyPairDataset(**trainset_config, subset='testing', batch_size=2, num_gpus=1) 113 | print(len(trainloader), len(testloader)) 114 | 115 | for clean_audio, noisy_audio, fileid in trainloader: 116 | clean_audio = clean_audio.cuda() 117 | noisy_audio = noisy_audio.cuda() 118 | print(clean_audio.shape, noisy_audio.shape, fileid) 119 | break 120 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import functools 4 | import numpy as np 5 | from math import cos, pi, floor, sin 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from stft_loss import MultiResolutionSTFTLoss 13 | 14 | 15 | def flatten(v): 16 | return [x for y in v for x in y] 17 | 18 | 19 | def rescale(x): 20 | return (x - x.min()) / (x.max() - x.min()) 21 | 22 | 23 | def find_max_epoch(path): 24 | """ 25 | Find latest checkpoint 26 | 27 | Returns: 28 | maximum iteration, -1 if there is no (valid) checkpoint 29 | """ 30 | 31 | files = os.listdir(path) 32 | epoch = -1 33 | for f in files: 34 | if len(f) <= 4: 35 | continue 36 | if f[-4:] == '.pkl': 37 | number = f[:-4] 38 | try: 39 | epoch = max(epoch, int(number)) 40 | except: 41 | continue 42 | return epoch 43 | 44 | 45 | def print_size(net, keyword=None): 46 | """ 47 | Print the number of parameters of a network 48 | """ 49 | 50 | if net is not None and isinstance(net, torch.nn.Module): 51 | module_parameters = filter(lambda p: p.requires_grad, net.parameters()) 52 | params = sum([np.prod(p.size()) for p in module_parameters]) 53 | 54 | print("{} Parameters: {:.6f}M".format( 55 | net.__class__.__name__, params / 1e6), flush=True, end="; ") 56 | 57 | if keyword is not None: 58 | keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name] 59 | params = sum([np.prod(p.size()) for p in keyword_parameters]) 60 | print("{} Parameters: {:.6f}M".format( 61 | keyword, params / 1e6), flush=True, end="; ") 62 | 63 | print(" ") 64 | 65 | 66 | ####################### lr scheduler: Linear Warmup then Cosine Decay ############################# 67 | 68 | # Adapted from https://github.com/rosinality/vq-vae-2-pytorch 69 | 70 | # Original Copyright 2019 Kim Seonghyeon 71 | # MIT License (https://opensource.org/licenses/MIT) 72 | 73 | 74 | def anneal_linear(start, end, proportion): 75 | return start + proportion * (end - start) 76 | 77 | 78 | def anneal_cosine(start, end, proportion): 79 | cos_val = cos(pi * proportion) + 1 80 | return end + (start - end) / 2 * cos_val 81 | 82 | 83 | class Phase: 84 | def __init__(self, start, end, n_iter, cur_iter, anneal_fn): 85 | self.start, self.end = start, end 86 | self.n_iter = n_iter 87 | self.anneal_fn = anneal_fn 88 | self.n = cur_iter 89 | 90 | def step(self): 91 | self.n += 1 92 | 93 | return self.anneal_fn(self.start, self.end, self.n / self.n_iter) 94 | 95 | def reset(self): 96 | self.n = 0 97 | 98 | @property 99 | def is_done(self): 100 | return self.n >= self.n_iter 101 | 102 | 103 | class LinearWarmupCosineDecay: 104 | def __init__( 105 | self, 106 | optimizer, 107 | lr_max, 108 | n_iter, 109 | iteration=0, 110 | divider=25, 111 | warmup_proportion=0.3, 112 | phase=('linear', 'cosine'), 113 | ): 114 | self.optimizer = optimizer 115 | 116 | phase1 = int(n_iter * warmup_proportion) 117 | phase2 = n_iter - phase1 118 | lr_min = lr_max / divider 119 | 120 | phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine} 121 | 122 | cur_iter_phase1 = iteration 123 | cur_iter_phase2 = max(0, iteration - phase1) 124 | self.lr_phase = [ 125 | Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]), 126 | Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]), 127 | ] 128 | 129 | if iteration < phase1: 130 | self.phase = 0 131 | else: 132 | self.phase = 1 133 | 134 | def step(self): 135 | lr = self.lr_phase[self.phase].step() 136 | 137 | for group in self.optimizer.param_groups: 138 | group['lr'] = lr 139 | 140 | if self.lr_phase[self.phase].is_done: 141 | self.phase += 1 142 | 143 | if self.phase >= len(self.lr_phase): 144 | for phase in self.lr_phase: 145 | phase.reset() 146 | 147 | self.phase = 0 148 | 149 | return lr 150 | 151 | 152 | ####################### model util ############################# 153 | 154 | def std_normal(size): 155 | """ 156 | Generate the standard Gaussian variable of a certain size 157 | """ 158 | 159 | return torch.normal(0, 1, size=size).cuda() 160 | 161 | 162 | def weight_scaling_init(layer): 163 | """ 164 | weight rescaling initialization from https://arxiv.org/abs/1911.13254 165 | """ 166 | w = layer.weight.detach() 167 | alpha = 10.0 * w.std() 168 | layer.weight.data /= torch.sqrt(alpha) 169 | layer.bias.data /= torch.sqrt(alpha) 170 | 171 | 172 | @torch.no_grad() 173 | def sampling(net, noisy_audio): 174 | """ 175 | Perform denoising (forward) step 176 | """ 177 | 178 | return net(noisy_audio) 179 | 180 | 181 | def loss_fn(net, X, ell_p, ell_p_lambda, stft_lambda, mrstftloss, **kwargs): 182 | """ 183 | Loss function in CleanUNet 184 | 185 | Parameters: 186 | net: network 187 | X: training data pair (clean audio, noisy_audio) 188 | ell_p: \ell_p norm (1 or 2) of the AE loss 189 | ell_p_lambda: factor of the AE loss 190 | stft_lambda: factor of the STFT loss 191 | mrstftloss: multi-resolution STFT loss function 192 | 193 | Returns: 194 | loss: value of objective function 195 | output_dic: values of each component of loss 196 | """ 197 | 198 | assert type(X) == tuple and len(X) == 2 199 | 200 | clean_audio, noisy_audio = X 201 | B, C, L = clean_audio.shape 202 | output_dic = {} 203 | loss = 0.0 204 | 205 | # AE loss 206 | denoised_audio = net(noisy_audio) 207 | 208 | if ell_p == 2: 209 | ae_loss = nn.MSELoss()(denoised_audio, clean_audio) 210 | elif ell_p == 1: 211 | ae_loss = F.l1_loss(denoised_audio, clean_audio) 212 | else: 213 | raise NotImplementedError 214 | loss += ae_loss * ell_p_lambda 215 | output_dic["reconstruct"] = ae_loss.data * ell_p_lambda 216 | 217 | if stft_lambda > 0: 218 | sc_loss, mag_loss = mrstftloss(denoised_audio.squeeze(1), clean_audio.squeeze(1)) 219 | loss += (sc_loss + mag_loss) * stft_lambda 220 | output_dic["stft_sc"] = sc_loss.data * stft_lambda 221 | output_dic["stft_mag"] = mag_loss.data * stft_lambda 222 | 223 | return loss, output_dic 224 | 225 | -------------------------------------------------------------------------------- /stft_loss.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/kan-bayashi/ParallelWaveGAN 2 | 3 | # Original Copyright 2019 Tomoki Hayashi 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | """STFT-based Loss modules.""" 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from distutils.version import LooseVersion 12 | 13 | is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7") 14 | 15 | 16 | def stft(x, fft_size, hop_size, win_length, window): 17 | """Perform STFT and convert to magnitude spectrogram. 18 | Args: 19 | x (Tensor): Input signal tensor (B, T). 20 | fft_size (int): FFT size. 21 | hop_size (int): Hop size. 22 | win_length (int): Window length. 23 | window (str): Window function type. 24 | Returns: 25 | Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). 26 | 27 | """ 28 | if is_pytorch_17plus: 29 | x_stft = torch.stft( 30 | x, fft_size, hop_size, win_length, window, return_complex=False 31 | ) 32 | else: 33 | x_stft = torch.stft(x, fft_size, hop_size, win_length, window) 34 | real = x_stft[..., 0] 35 | imag = x_stft[..., 1] 36 | 37 | # NOTE(kan-bayashi): clamp is needed to avoid nan or inf 38 | return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) 39 | 40 | 41 | class SpectralConvergenceLoss(torch.nn.Module): 42 | """Spectral convergence loss module.""" 43 | 44 | def __init__(self): 45 | """Initilize spectral convergence loss module.""" 46 | super(SpectralConvergenceLoss, self).__init__() 47 | 48 | def forward(self, x_mag, y_mag): 49 | """Calculate forward propagation. 50 | 51 | Args: 52 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 53 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 54 | 55 | Returns: 56 | Tensor: Spectral convergence loss value. 57 | 58 | """ 59 | return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") 60 | 61 | 62 | class LogSTFTMagnitudeLoss(torch.nn.Module): 63 | """Log STFT magnitude loss module.""" 64 | 65 | def __init__(self): 66 | """Initilize los STFT magnitude loss module.""" 67 | super(LogSTFTMagnitudeLoss, self).__init__() 68 | 69 | def forward(self, x_mag, y_mag): 70 | """Calculate forward propagation. 71 | 72 | Args: 73 | x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). 74 | y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). 75 | 76 | Returns: 77 | Tensor: Log STFT magnitude loss value. 78 | 79 | """ 80 | return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) 81 | 82 | 83 | class STFTLoss(torch.nn.Module): 84 | """STFT loss module.""" 85 | 86 | def __init__( 87 | self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", 88 | band="full" 89 | ): 90 | """Initialize STFT loss module.""" 91 | super(STFTLoss, self).__init__() 92 | self.fft_size = fft_size 93 | self.shift_size = shift_size 94 | self.win_length = win_length 95 | self.band = band 96 | 97 | self.spectral_convergence_loss = SpectralConvergenceLoss() 98 | self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() 99 | # NOTE(kan-bayashi): Use register_buffer to fix #223 100 | self.register_buffer("window", getattr(torch, window)(win_length)) 101 | 102 | def forward(self, x, y): 103 | """Calculate forward propagation. 104 | 105 | Args: 106 | x (Tensor): Predicted signal (B, T). 107 | y (Tensor): Groundtruth signal (B, T). 108 | 109 | Returns: 110 | Tensor: Spectral convergence loss value. 111 | Tensor: Log STFT magnitude loss value. 112 | 113 | """ 114 | x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) 115 | y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) 116 | 117 | if self.band == "high": 118 | freq_mask_ind = x_mag.shape[1] // 2 # only select high frequency bands 119 | sc_loss = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:]) 120 | mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:]) 121 | elif self.band == "full": 122 | sc_loss = self.spectral_convergence_loss(x_mag, y_mag) 123 | mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) 124 | else: 125 | raise NotImplementedError 126 | 127 | return sc_loss, mag_loss 128 | 129 | 130 | class MultiResolutionSTFTLoss(torch.nn.Module): 131 | """Multi resolution STFT loss module.""" 132 | 133 | def __init__( 134 | self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], 135 | window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full" 136 | ): 137 | """Initialize Multi resolution STFT loss module. 138 | 139 | Args: 140 | fft_sizes (list): List of FFT sizes. 141 | hop_sizes (list): List of hop sizes. 142 | win_lengths (list): List of window lengths. 143 | window (str): Window function type. 144 | *_lambda (float): a balancing factor across different losses. 145 | band (str): high-band or full-band loss 146 | 147 | """ 148 | super(MultiResolutionSTFTLoss, self).__init__() 149 | self.sc_lambda = sc_lambda 150 | self.mag_lambda = mag_lambda 151 | 152 | assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) 153 | self.stft_losses = torch.nn.ModuleList() 154 | for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): 155 | self.stft_losses += [STFTLoss(fs, ss, wl, window, band)] 156 | 157 | def forward(self, x, y): 158 | """Calculate forward propagation. 159 | 160 | Args: 161 | x (Tensor): Predicted signal (B, T) or (B, #subband, T). 162 | y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). 163 | 164 | Returns: 165 | Tensor: Multi resolution spectral convergence loss value. 166 | Tensor: Multi resolution log STFT magnitude loss value. 167 | 168 | """ 169 | if len(x.shape) == 3: 170 | x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T) 171 | y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T) 172 | sc_loss = 0.0 173 | mag_loss = 0.0 174 | for f in self.stft_losses: 175 | sc_l, mag_l = f(x, y) 176 | sc_loss += sc_l 177 | mag_loss += mag_l 178 | 179 | sc_loss *= self.sc_lambda 180 | sc_loss /= len(self.stft_losses) 181 | mag_loss *= self.mag_lambda 182 | mag_loss /= len(self.stft_losses) 183 | 184 | return sc_loss, mag_loss 185 | -------------------------------------------------------------------------------- /denoise.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/NVIDIA/waveglow under the BSD 3-Clause License. 2 | 3 | # ***************************************************************************** 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the NVIDIA CORPORATION nor the 14 | # names of its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 21 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # ***************************************************************************** 29 | 30 | import os 31 | import argparse 32 | import json 33 | from tqdm import tqdm 34 | from copy import deepcopy 35 | 36 | import numpy as np 37 | import torch 38 | import torch.nn as nn 39 | # from torch.utils.tensorboard import SummaryWriter 40 | 41 | import random 42 | random.seed(0) 43 | torch.manual_seed(0) 44 | np.random.seed(0) 45 | 46 | from scipy.io.wavfile import write as wavwrite 47 | from scipy.io.wavfile import read as wavread 48 | 49 | from dataset import load_CleanNoisyPairDataset 50 | from util import rescale, find_max_epoch, print_size, sampling 51 | from network import CleanUNet 52 | 53 | 54 | def denoise(output_directory, ckpt_iter, subset, dump=False): 55 | """ 56 | Denoise audio 57 | 58 | Parameters: 59 | output_directory (str): save generated speeches to this path 60 | ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; 61 | automitically selects the maximum iteration if 'max' is selected 62 | subset (str): training, testing, validation 63 | dump (bool): whether save enhanced (denoised) audio 64 | """ 65 | 66 | # setup local experiment path 67 | exp_path = train_config["exp_path"] 68 | print('exp_path:', exp_path) 69 | 70 | # load data 71 | loader_config = deepcopy(trainset_config) 72 | loader_config["crop_length_sec"] = 0 73 | dataloader = load_CleanNoisyPairDataset( 74 | **loader_config, 75 | subset=subset, 76 | batch_size=1, 77 | num_gpus=1 78 | ) 79 | 80 | # predefine model 81 | net = CleanUNet(**network_config).cuda() 82 | print_size(net) 83 | 84 | # load checkpoint 85 | ckpt_directory = os.path.join(train_config["log"]["directory"], exp_path, 'checkpoint') 86 | if ckpt_iter == 'max': 87 | ckpt_iter = find_max_epoch(ckpt_directory) 88 | if ckpt_iter != 'pretrained': 89 | ckpt_iter = int(ckpt_iter) 90 | model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter)) 91 | checkpoint = torch.load(model_path, map_location='cpu') 92 | net.load_state_dict(checkpoint['model_state_dict']) 93 | net.eval() 94 | 95 | # get output directory ready 96 | if ckpt_iter == "pretrained": 97 | speech_directory = os.path.join(output_directory, exp_path, 'speech', ckpt_iter) 98 | else: 99 | speech_directory = os.path.join(output_directory, exp_path, 'speech', '{}k'.format(ckpt_iter//1000)) 100 | if dump and not os.path.isdir(speech_directory): 101 | os.makedirs(speech_directory) 102 | os.chmod(speech_directory, 0o775) 103 | print("speech_directory: ", speech_directory, flush=True) 104 | 105 | # inference 106 | all_generated_audio = [] 107 | all_clean_audio = [] 108 | sortkey = lambda name: '_'.join(name.split('/')[-1].split('_')[1:]) 109 | for clean_audio, noisy_audio, fileid in tqdm(dataloader): 110 | filename = sortkey(fileid[0][0]) 111 | 112 | noisy_audio = noisy_audio.cuda() 113 | LENGTH = len(noisy_audio[0].squeeze()) 114 | generated_audio = sampling(net, noisy_audio) 115 | 116 | if dump: 117 | wavwrite(os.path.join(speech_directory, 'enhanced_{}'.format(filename)), 118 | trainset_config["sample_rate"], 119 | generated_audio[0].squeeze().cpu().numpy()) 120 | else: 121 | all_clean_audio.append(clean_audio[0].squeeze().cpu().numpy()) 122 | all_generated_audio.append(generated_audio[0].squeeze().cpu().numpy()) 123 | 124 | return all_clean_audio, all_generated_audio 125 | 126 | 127 | if __name__ == "__main__": 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('-c', '--config', type=str, default='config.json', 130 | help='JSON file for configuration') 131 | parser.add_argument('-ckpt_iter', '--ckpt_iter', default='max', 132 | help='Which checkpoint to use; assign a number or "max" or "pretrained"') 133 | parser.add_argument('-subset', '--subset', type=str, choices=['training', 'testing', 'validation'], 134 | default='testing', help='subset for denoising') 135 | args = parser.parse_args() 136 | 137 | # Parse configs. Globals nicer in this case 138 | with open(args.config) as f: 139 | data = f.read() 140 | config = json.loads(data) 141 | gen_config = config["gen_config"] 142 | global network_config 143 | network_config = config["network_config"] # to define wavenet 144 | global train_config 145 | train_config = config["train_config"] # train config 146 | global trainset_config 147 | trainset_config = config["trainset_config"] # to read trainset configurations 148 | 149 | torch.backends.cudnn.enabled = True 150 | torch.backends.cudnn.benchmark = True 151 | 152 | if args.subset == "testing": 153 | denoise(gen_config["output_directory"], 154 | subset=args.subset, 155 | ckpt_iter=args.ckpt_iter, 156 | dump=True) 157 | 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of CleanUNet 2 | 3 | This repo contains official PyTorch implementation of CleanUNet: [Speech Denoising in the Waveform Domain with Self-Attention](https://arxiv.org/abs/2202.07790). CleanUNet is a causal speech denoising 4 | model on the raw waveform. It is based 5 | on an encoder-decoder architecture combined with several 6 | self-attention blocks to refine its bottleneck representations, 7 | which is crucial to obtain good results. The model is optimized 8 | through a set of losses defined over both waveform and multi-resolution spectrograms. The proposed method outperforms 9 | the state-of-the-art models in terms of denoised speech quality 10 | from various objective and subjective evaluation metrics. Sound demos can be found in [This blog](https://nv-adlr.github.io/projects/cleanunet/). 11 | 12 | ## Datasets 13 | 14 | - [Microsoft DNS 2020](https://arxiv.org/ftp/arxiv/papers/2005/2005.13981.pdf) dataset. The dataset, pre-processing codes, and instruction to generate training data can be found in [this link](https://github.com/microsoft/DNS-Challenge/tree/interspeech2020/master). Assume the dataset is stored under ```./dns```. Before generating clean-noisy data pairs, modify the following parameters in their ```noisyspeech_synthesizer.cfg``` file: 15 | ``` 16 | total_hours: 500, 17 | snr_lower: -5, 18 | snr_upper: 25, 19 | total_snrlevels: 31 20 | ``` 21 | And also update paths as (since their original code uses Windows-style paths) 22 | ``` 23 | noise_dir: ./datasets/noise 24 | speech_dir: ./datasets/clean 25 | noisy_destination: ./training_set/noisy 26 | clean_destination: ./training_set/clean 27 | noise_destination: ./training_set/noise 28 | log_dir: ./logs 29 | unit_tests_log_dir: ./unittests_logs 30 | ``` 31 | Then, for conciseness and to comply with our data loading codes, modify file names (lines 198-201) in their ```noisyspeech_synthesizer_singleprocess.py``` to 32 | ``` 33 | noisyfilename = 'fileid_' + str(file_num) + '.wav' 34 | cleanfilename = 'fileid_' + str(file_num) + '.wav' 35 | noisefilename = 'fileid_' + str(file_num) + '.wav' 36 | ``` 37 | To generate training data, run 38 | ``` 39 | python noisyspeech_synthesizer_singleprocess.py 40 | ``` 41 | It is also recommended to rename files in the test set for conciseness: 42 | ``` 43 | cd ./dns/datasets/test_set/synthetic/no_reverb/noisy/ 44 | for NAME in $(ls ./); do arr=(${NAME//fileid_/ }); mv ${NAME} noisy_fileid_${arr[1]}; done 45 | ``` 46 | 47 | After these steps, we assume that the structure of the dataset folder is: 48 | ``` 49 | Training sets: 50 | ./dns/training_set/clean/fileid_{0..59999}.wav 51 | ./dns/training_set/noisy/fileid_{0..59999}.wav 52 | ./dns/training_set/noise/fileid_{0..59999}.wav 53 | 54 | Testing sets (no-reverb): 55 | ./dns/datasets/test_set/synthetic/no_reverb/clean/clean_fileid_{0..299}.wav 56 | ./dns/datasets/test_set/synthetic/no_reverb/noisy/noisy_fileid_{0..299}.wav 57 | ``` 58 | 59 | - Other datasets are also supported; lines 49-50 of ```dataset.py``` need to be carefully changed to handle paths and file names. 60 | 61 | ## Training 62 | 63 | The ```$EXP``` variable can be any config name in ```./configs/```, such as ```DNS-large-full``` and ```DNS-large-high```. The default experiment path is ```./exp```; it can be changed by modifying ```train_config[log[directory]]``` in the config files. ```trainset_config[root]``` needs to be set as the root path of the dataset. Then, the training code is 64 | 65 | ```python3 distributed.py -c configs/${EXP}.json``` 66 | 67 | We use 8 GPUs for training. The global batch size is 64 and we train the models for 250K iterations. Note that, this is different from the training setup in our paper i.e., 1M iterations with a batch size of 16. We find negligible difference in terms of objective and subjective evaluation, but the current setup is faster. 68 | 69 | **Pre-trained** models for denoising are provided in ```./exp/${EXP}/checkpoint/pretrained.pkl``` (each one has size ~177Mb; use ```git lfs``` to download). Note that these models are not trained to remove reverb. 70 | 71 | ## Denoising 72 | 73 | We perform denoising on the DNS no-reverb test dataset. The output path is ```gen_config[output_directory]```, which is ```./exp``` by default. The denoising code is 74 | 75 | ```python denoise.py -c configs/${EXP}.json --ckpt_iter ${ITERATION}``` 76 | 77 | For example, if you want to use pre-trained models to denoise, run: 78 | 79 | ```python denoise.py -c configs/DNS-large-high.json --ckpt_iter pretrained``` 80 | 81 | 1 GPU is used for denoising. 82 | 83 | ## Evaluation 84 | 85 | The following evaluation code generates [PESQ](https://www.itu.int/rec/T-REC-P.862) and [STOI](https://ceestaal.nl/code/) scores. More evaluation metrics can be found in the [SEGAN (PyTorch)](https://github.com/santi-pdp/segan_pytorch) repo. 86 | 87 | ```python python_eval.py -d dns -e ${PATH_TO_DENOISED_SPEECH} -t ${PATH_TO_TESTSET_PATH} >> eval.log``` 88 | 89 | 1 GPU is used for evaluation. 90 | 91 | ## Requirements 92 | 93 | To synthesize [Microsoft DNS 2020](https://arxiv.org/ftp/arxiv/papers/2005/2005.13981.pdf) training data, you need [these dependencies](https://github.com/microsoft/DNS-Challenge/blob/interspeech2020/master/requirements.txt). If you just want to evaluate our pre-trained models on the test data, you may jump this. 94 | 95 | Our code is tested on 8 NVIDIA V100 GPUs. You need to install very standard dependencies: ```numpy``` and ```scipy``` for scientific computing, ```torch, torchvision, torchaudio``` for deep learning and data loading, ```pesq, pystoi``` for audio evaluation, and ```tqdm``` for visualization. 96 | 97 | ## References 98 | 99 | The code structure and distributed training are adapted from [WaveGlow (PyTorch)](https://github.com/NVIDIA/waveglow) (BSD-3-Clause license). The ```stft_loss.py``` is adapted from [ParallelWaveGAN (PyTorch)](https://github.com/kan-bayashi/ParallelWaveGAN) (MIT license). The self-attention blocks in ```network.py``` is adapted from [Attention is all you need (PyTorch)](https://github.com/jadore801120/attention-is-all-you-need-pytorch) (MIT license), which borrows from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) (MIT license). The learning rate scheduler in ```util.py``` is adapted from [VQVAE2 (PyTorch)](https://github.com/rosinality/vq-vae-2-pytorch) (MIT license). Some utility functions are borrowed from [DiffWave (PyTorch)](https://github.com/philsyn/DiffWave-Vocoder) (MIT license) and [WaveGlow (PyTorch)](https://github.com/NVIDIA/waveglow) (BSD-3-Clause license). 100 | 101 | For more evaluation methods, we refer readers to look at [SEGAN (PyTorch)](https://github.com/santi-pdp/segan_pytorch/blob/master/segan/utils.py) (MIT license). For more data augmentation methods, we refer readers to look at [FAIR-denoiser](https://github.com/facebookresearch/denoiser/blob/main/denoiser/augment.py) (CC-BY-NC 4.0 license). 102 | 103 | ## Citation 104 | 105 | ``` 106 | @inproceedings{kong2022speech, 107 | title={Speech Denoising in the Waveform Domain with Self-Attention}, 108 | author={Kong, Zhifeng and Ping, Wei and Dantrey, Ambrish and Catanzaro, Bryan}, 109 | booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 110 | pages={7867--7871}, 111 | year={2022}, 112 | organization={IEEE} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/NVIDIA/waveglow under the BSD 3-Clause License. 2 | 3 | # ***************************************************************************** 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the NVIDIA CORPORATION nor the 14 | # names of its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 21 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # ***************************************************************************** 29 | 30 | import os 31 | import sys 32 | import time 33 | import subprocess 34 | import argparse 35 | import warnings 36 | warnings.filterwarnings("ignore") 37 | 38 | import torch 39 | import torch.distributed as dist 40 | from torch.autograd import Variable 41 | 42 | def reduce_tensor(tensor, num_gpus): 43 | rt = tensor.clone() 44 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 45 | rt /= num_gpus 46 | return rt 47 | 48 | def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): 49 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 50 | print("Initializing Distributed") 51 | 52 | # Set cuda device so everything is done on the right GPU. 53 | torch.cuda.set_device(rank % torch.cuda.device_count()) 54 | 55 | # Initialize distributed communication 56 | dist.init_process_group(dist_backend, init_method=dist_url, 57 | world_size=num_gpus, rank=rank, 58 | group_name=group_name) 59 | 60 | def _flatten_dense_tensors(tensors): 61 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 62 | same dense type. 63 | Since inputs are dense, the resulting tensor will be a concatenated 1D 64 | buffer. Element-wise operation on this buffer will be equivalent to 65 | operating individually. 66 | Arguments: 67 | tensors (Iterable[Tensor]): dense tensors to flatten. 68 | Returns: 69 | A contiguous 1D buffer containing input tensors. 70 | """ 71 | if len(tensors) == 1: 72 | return tensors[0].contiguous().view(-1) 73 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 74 | return flat 75 | 76 | def _unflatten_dense_tensors(flat, tensors): 77 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 78 | same dense type, and that flat is given by _flatten_dense_tensors. 79 | Arguments: 80 | flat (Tensor): flattened dense tensors to unflatten. 81 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 82 | unflatten flat. 83 | Returns: 84 | Unflattened dense tensors with sizes same as tensors and values from 85 | flat. 86 | """ 87 | outputs = [] 88 | offset = 0 89 | for tensor in tensors: 90 | numel = tensor.numel() 91 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 92 | offset += numel 93 | return tuple(outputs) 94 | 95 | def apply_gradient_allreduce(module): 96 | """ 97 | Modifies existing model to do gradient allreduce, but doesn't change class 98 | so you don't need "module" 99 | """ 100 | if not hasattr(dist, '_backend'): 101 | module.warn_on_half = True 102 | else: 103 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 104 | 105 | for p in module.state_dict().values(): 106 | if not torch.is_tensor(p): 107 | continue 108 | dist.broadcast(p, 0) 109 | 110 | def allreduce_params(): 111 | if(module.needs_reduction): 112 | module.needs_reduction = False 113 | buckets = {} 114 | for param in module.parameters(): 115 | if param.requires_grad and param.grad is not None: 116 | tp = type(param.data) 117 | if tp not in buckets: 118 | buckets[tp] = [] 119 | buckets[tp].append(param) 120 | if module.warn_on_half: 121 | if torch.cuda.HalfTensor in buckets: 122 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 123 | " It is recommended to use the NCCL backend in this case. This currently requires" + 124 | "PyTorch built from top of tree master.") 125 | module.warn_on_half = False 126 | 127 | for tp in buckets: 128 | bucket = buckets[tp] 129 | grads = [param.grad.data for param in bucket] 130 | coalesced = _flatten_dense_tensors(grads) 131 | dist.all_reduce(coalesced) 132 | coalesced /= dist.get_world_size() 133 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 134 | buf.copy_(synced) 135 | 136 | for param in list(module.parameters()): 137 | def allreduce_hook(*unused): 138 | Variable._execution_engine.queue_callback(allreduce_params) 139 | if param.requires_grad: 140 | param.register_hook(allreduce_hook) 141 | dir(param) 142 | 143 | def set_needs_reduction(self, input, output): 144 | self.needs_reduction = True 145 | 146 | module.register_forward_hook(set_needs_reduction) 147 | return module 148 | 149 | 150 | def main(config, stdout_dir, args_str): 151 | args_list = ['train.py'] 152 | args_list += args_str.split(' ') if len(args_str) > 0 else [] 153 | 154 | args_list.append('--config={}'.format(config)) 155 | 156 | num_gpus = torch.cuda.device_count() 157 | print('num_gpus: {}'.format(num_gpus)) 158 | args_list.append('--num_gpus={}'.format(num_gpus)) 159 | args_list.append("--group_name=group_{}".format(time.strftime("%Y_%m_%d-%H%M%S"))) 160 | 161 | if not os.path.isdir(stdout_dir): 162 | os.makedirs(stdout_dir) 163 | os.chmod(stdout_dir, 0o775) 164 | 165 | workers = [] 166 | 167 | for i in range(num_gpus): 168 | args_list[-2] = '--rank={}'.format(i) 169 | stdout = None if i == 0 else open( 170 | os.path.join(stdout_dir, "GPU_{}.log".format(i)), "w") 171 | print(args_list) 172 | p = subprocess.Popen([str(sys.executable)]+args_list, stdout=stdout) 173 | workers.append(p) 174 | 175 | for p in workers: 176 | p.wait() 177 | 178 | 179 | if __name__ == '__main__': 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('-c', '--config', type=str, default='config.json', 182 | help='JSON file for configuration') 183 | parser.add_argument('-s', '--stdout_dir', type=str, default="./logs/", 184 | help='directory to save stoud logs') 185 | parser.add_argument('-a', '--args_str', type=str, default='', 186 | help='double quoted string with space separated key value pairs') 187 | 188 | args = parser.parse_args() 189 | main(args.config, args.stdout_dir, args.args_str) 190 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/NVIDIA/waveglow under the BSD 3-Clause License. 2 | 3 | # ***************************************************************************** 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the NVIDIA CORPORATION nor the 14 | # names of its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY 21 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | # 28 | # ***************************************************************************** 29 | 30 | import os 31 | import time 32 | import argparse 33 | import json 34 | 35 | import numpy as np 36 | import torch 37 | import torch.nn as nn 38 | from torch.utils.tensorboard import SummaryWriter 39 | 40 | import random 41 | random.seed(0) 42 | torch.manual_seed(0) 43 | np.random.seed(0) 44 | 45 | 46 | from distributed import init_distributed, apply_gradient_allreduce, reduce_tensor 47 | 48 | from dataset import load_CleanNoisyPairDataset 49 | from stft_loss import MultiResolutionSTFTLoss 50 | from util import rescale, find_max_epoch, print_size 51 | from util import LinearWarmupCosineDecay, loss_fn 52 | 53 | from network import CleanUNet 54 | 55 | 56 | def train(num_gpus, rank, group_name, 57 | exp_path, log, optimization, loss_config): 58 | 59 | # setup local experiment path 60 | if rank == 0: 61 | print('exp_path:', exp_path) 62 | 63 | # Create tensorboard logger. 64 | log_directory = os.path.join(log["directory"], exp_path) 65 | if rank == 0: 66 | tb = SummaryWriter(os.path.join(log_directory, 'tensorboard')) 67 | 68 | # distributed running initialization 69 | if num_gpus > 1: 70 | init_distributed(rank, num_gpus, group_name, **dist_config) 71 | 72 | # Get shared ckpt_directory ready 73 | ckpt_directory = os.path.join(log_directory, 'checkpoint') 74 | if rank == 0: 75 | if not os.path.isdir(ckpt_directory): 76 | os.makedirs(ckpt_directory) 77 | os.chmod(ckpt_directory, 0o775) 78 | print("ckpt_directory: ", ckpt_directory, flush=True) 79 | 80 | # load training data 81 | trainloader = load_CleanNoisyPairDataset(**trainset_config, 82 | subset='training', 83 | batch_size=optimization["batch_size_per_gpu"], 84 | num_gpus=num_gpus) 85 | print('Data loaded') 86 | 87 | # predefine model 88 | net = CleanUNet(**network_config).cuda() 89 | print_size(net) 90 | 91 | # apply gradient all reduce 92 | if num_gpus > 1: 93 | net = apply_gradient_allreduce(net) 94 | 95 | # define optimizer 96 | optimizer = torch.optim.Adam(net.parameters(), lr=optimization["learning_rate"]) 97 | 98 | # load checkpoint 99 | time0 = time.time() 100 | if log["ckpt_iter"] == 'max': 101 | ckpt_iter = find_max_epoch(ckpt_directory) 102 | else: 103 | ckpt_iter = log["ckpt_iter"] 104 | if ckpt_iter >= 0: 105 | try: 106 | # load checkpoint file 107 | model_path = os.path.join(ckpt_directory, '{}.pkl'.format(ckpt_iter)) 108 | checkpoint = torch.load(model_path, map_location='cpu') 109 | 110 | # feed model dict and optimizer state 111 | net.load_state_dict(checkpoint['model_state_dict']) 112 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 113 | 114 | # record training time based on elapsed time 115 | time0 -= checkpoint['training_time_seconds'] 116 | print('Model at iteration %s has been trained for %s seconds' % (ckpt_iter, checkpoint['training_time_seconds'])) 117 | print('checkpoint model loaded successfully') 118 | except: 119 | ckpt_iter = -1 120 | print('No valid checkpoint model found, start training from initialization.') 121 | else: 122 | ckpt_iter = -1 123 | print('No valid checkpoint model found, start training from initialization.') 124 | 125 | # training 126 | n_iter = ckpt_iter + 1 127 | 128 | # define learning rate scheduler and stft-loss 129 | scheduler = LinearWarmupCosineDecay( 130 | optimizer, 131 | lr_max=optimization["learning_rate"], 132 | n_iter=optimization["n_iters"], 133 | iteration=n_iter, 134 | divider=25, 135 | warmup_proportion=0.05, 136 | phase=('linear', 'cosine'), 137 | ) 138 | 139 | if loss_config["stft_lambda"] > 0: 140 | mrstftloss = MultiResolutionSTFTLoss(**loss_config["stft_config"]).cuda() 141 | else: 142 | mrstftloss = None 143 | 144 | while n_iter < optimization["n_iters"] + 1: 145 | # for each epoch 146 | for clean_audio, noisy_audio, _ in trainloader: 147 | 148 | clean_audio = clean_audio.cuda() 149 | noisy_audio = noisy_audio.cuda() 150 | 151 | # If you have a data augmentation function augment() 152 | # noise = noisy_audio - clean_audio 153 | # noise, clean_audio = augment((noise, clean_audio)) 154 | # noisy_audio = noise + clean_audio 155 | 156 | # back-propagation 157 | optimizer.zero_grad() 158 | X = (clean_audio, noisy_audio) 159 | loss, loss_dic = loss_fn(net, X, **loss_config, mrstftloss=mrstftloss) 160 | if num_gpus > 1: 161 | reduced_loss = reduce_tensor(loss.data, num_gpus).item() 162 | else: 163 | reduced_loss = loss.item() 164 | loss.backward() 165 | grad_norm = nn.utils.clip_grad_norm_(net.parameters(), 1e9) 166 | scheduler.step() 167 | optimizer.step() 168 | 169 | # output to log 170 | if n_iter % log["iters_per_valid"] == 0: 171 | print("iteration: {} \treduced loss: {:.7f} \tloss: {:.7f}".format( 172 | n_iter, reduced_loss, loss.item()), flush=True) 173 | 174 | if rank == 0: 175 | # save to tensorboard 176 | tb.add_scalar("Train/Train-Loss", loss.item(), n_iter) 177 | tb.add_scalar("Train/Train-Reduced-Loss", reduced_loss, n_iter) 178 | tb.add_scalar("Train/Gradient-Norm", grad_norm, n_iter) 179 | tb.add_scalar("Train/learning-rate", optimizer.param_groups[0]["lr"], n_iter) 180 | 181 | # save checkpoint 182 | if n_iter > 0 and n_iter % log["iters_per_ckpt"] == 0 and rank == 0: 183 | checkpoint_name = '{}.pkl'.format(n_iter) 184 | torch.save({'iter': n_iter, 185 | 'model_state_dict': net.state_dict(), 186 | 'optimizer_state_dict': optimizer.state_dict(), 187 | 'training_time_seconds': int(time.time()-time0)}, 188 | os.path.join(ckpt_directory, checkpoint_name)) 189 | print('model at iteration %s is saved' % n_iter) 190 | 191 | n_iter += 1 192 | 193 | # After training, close TensorBoard. 194 | if rank == 0: 195 | tb.close() 196 | 197 | return 0 198 | 199 | 200 | if __name__ == "__main__": 201 | parser = argparse.ArgumentParser() 202 | parser.add_argument('-c', '--config', type=str, default='config.json', 203 | help='JSON file for configuration') 204 | parser.add_argument('-r', '--rank', type=int, default=0, 205 | help='rank of process for distributed') 206 | parser.add_argument('-g', '--group_name', type=str, default='', 207 | help='name of group for distributed') 208 | args = parser.parse_args() 209 | 210 | # Parse configs. Globals nicer in this case 211 | with open(args.config) as f: 212 | data = f.read() 213 | config = json.loads(data) 214 | train_config = config["train_config"] # training parameters 215 | global dist_config 216 | dist_config = config["dist_config"] # to initialize distributed training 217 | global network_config 218 | network_config = config["network_config"] # to define network 219 | global trainset_config 220 | trainset_config = config["trainset_config"] # to load trainset 221 | 222 | num_gpus = torch.cuda.device_count() 223 | if num_gpus > 1: 224 | if args.group_name == '': 225 | print("WARNING: Multiple GPUs detected but no distributed group set") 226 | print("Only running 1 GPU. Use distributed.py for multiple GPUs") 227 | num_gpus = 1 228 | 229 | if num_gpus == 1 and args.rank != 0: 230 | raise Exception("Doing single GPU training on rank > 0") 231 | 232 | torch.backends.cudnn.enabled = True 233 | torch.backends.cudnn.benchmark = True 234 | train(num_gpus, args.rank, args.group_name, **train_config) 235 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from util import weight_scaling_init 11 | 12 | 13 | # Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch 14 | # Original Copyright 2017 Victor Huang 15 | # MIT License (https://opensource.org/licenses/MIT) 16 | 17 | class ScaledDotProductAttention(nn.Module): 18 | ''' Scaled Dot-Product Attention ''' 19 | 20 | def __init__(self, temperature, attn_dropout=0.1): 21 | super().__init__() 22 | self.temperature = temperature 23 | self.dropout = nn.Dropout(attn_dropout) 24 | 25 | def forward(self, q, k, v, mask=None): 26 | 27 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 28 | 29 | if mask is not None: 30 | attn = attn.masked_fill(mask == 0, -1e9) 31 | 32 | attn = self.dropout(F.softmax(attn, dim=-1)) 33 | output = torch.matmul(attn, v) 34 | 35 | return output, attn 36 | 37 | 38 | class MultiHeadAttention(nn.Module): 39 | ''' Multi-Head Attention module ''' 40 | 41 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 42 | super().__init__() 43 | 44 | self.n_head = n_head 45 | self.d_k = d_k 46 | self.d_v = d_v 47 | 48 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 49 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 50 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 51 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 52 | 53 | self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) 54 | 55 | self.dropout = nn.Dropout(dropout) 56 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 57 | 58 | 59 | def forward(self, q, k, v, mask=None): 60 | 61 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 62 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 63 | 64 | residual = q 65 | 66 | # Pass through the pre-attention projection: b x lq x (n*dv) 67 | # Separate different heads: b x lq x n x dv 68 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 69 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 70 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 71 | 72 | # Transpose for attention dot product: b x n x lq x dv 73 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 74 | 75 | if mask is not None: 76 | mask = mask.unsqueeze(1) # For head axis broadcasting. 77 | 78 | q, attn = self.attention(q, k, v, mask=mask) 79 | 80 | # Transpose to move the head dimension back: b x lq x n x dv 81 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 82 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 83 | q = self.dropout(self.fc(q)) 84 | q += residual 85 | 86 | q = self.layer_norm(q) 87 | 88 | return q, attn 89 | 90 | 91 | class PositionwiseFeedForward(nn.Module): 92 | ''' A two-feed-forward-layer module ''' 93 | 94 | def __init__(self, d_in, d_hid, dropout=0.1): 95 | super().__init__() 96 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 97 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 98 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 99 | self.dropout = nn.Dropout(dropout) 100 | 101 | def forward(self, x): 102 | 103 | residual = x 104 | 105 | x = self.w_2(F.relu(self.w_1(x))) 106 | x = self.dropout(x) 107 | x += residual 108 | 109 | x = self.layer_norm(x) 110 | 111 | return x 112 | 113 | 114 | def get_subsequent_mask(seq): 115 | ''' For masking out the subsequent info. ''' 116 | sz_b, len_s = seq.size() 117 | subsequent_mask = (1 - torch.triu( 118 | torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool() 119 | return subsequent_mask 120 | 121 | 122 | class PositionalEncoding(nn.Module): 123 | 124 | def __init__(self, d_hid, n_position=200): 125 | super(PositionalEncoding, self).__init__() 126 | 127 | # Not a parameter 128 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 129 | 130 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 131 | ''' Sinusoid position encoding table ''' 132 | # TODO: make it with torch instead of numpy 133 | 134 | def get_position_angle_vec(position): 135 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 136 | 137 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 138 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 139 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 140 | 141 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 142 | 143 | def forward(self, x): 144 | return x + self.pos_table[:, :x.size(1)].clone().detach() 145 | 146 | 147 | class EncoderLayer(nn.Module): 148 | ''' Compose with two layers ''' 149 | 150 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0): 151 | super(EncoderLayer, self).__init__() 152 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 153 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 154 | 155 | def forward(self, enc_input, slf_attn_mask=None): 156 | enc_output, enc_slf_attn = self.slf_attn( 157 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 158 | enc_output = self.pos_ffn(enc_output) 159 | return enc_output, enc_slf_attn 160 | 161 | 162 | class TransformerEncoder(nn.Module): 163 | ''' A encoder model with self attention mechanism. ''' 164 | 165 | def __init__( 166 | self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, 167 | d_model=512, d_inner=2048, dropout=0.1, n_position=624, scale_emb=False): 168 | 169 | super().__init__() 170 | 171 | # self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx) 172 | if n_position > 0: 173 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) 174 | else: 175 | self.position_enc = lambda x: x 176 | self.dropout = nn.Dropout(p=dropout) 177 | self.layer_stack = nn.ModuleList([ 178 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 179 | for _ in range(n_layers)]) 180 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 181 | self.scale_emb = scale_emb 182 | self.d_model = d_model 183 | 184 | def forward(self, src_seq, src_mask, return_attns=False): 185 | 186 | enc_slf_attn_list = [] 187 | 188 | # -- Forward 189 | # enc_output = self.src_word_emb(src_seq) 190 | enc_output = src_seq 191 | if self.scale_emb: 192 | enc_output *= self.d_model ** 0.5 193 | enc_output = self.dropout(self.position_enc(enc_output)) 194 | enc_output = self.layer_norm(enc_output) 195 | 196 | for enc_layer in self.layer_stack: 197 | enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) 198 | enc_slf_attn_list += [enc_slf_attn] if return_attns else [] 199 | 200 | if return_attns: 201 | return enc_output, enc_slf_attn_list 202 | return enc_output 203 | 204 | 205 | # CleanUNet architecture 206 | 207 | 208 | def padding(x, D, K, S): 209 | """padding zeroes to x so that denoised audio has the same length""" 210 | 211 | L = x.shape[-1] 212 | for _ in range(D): 213 | if L < K: 214 | L = 1 215 | else: 216 | L = 1 + np.ceil((L - K) / S) 217 | 218 | for _ in range(D): 219 | L = (L - 1) * S + K 220 | 221 | L = int(L) 222 | x = F.pad(x, (0, L - x.shape[-1])) 223 | return x 224 | 225 | 226 | class CleanUNet(nn.Module): 227 | """ CleanUNet architecture. """ 228 | 229 | def __init__(self, channels_input=1, channels_output=1, 230 | channels_H=64, max_H=768, 231 | encoder_n_layers=8, kernel_size=4, stride=2, 232 | tsfm_n_layers=3, 233 | tsfm_n_head=8, 234 | tsfm_d_model=512, 235 | tsfm_d_inner=2048): 236 | 237 | """ 238 | Parameters: 239 | channels_input (int): input channels 240 | channels_output (int): output channels 241 | channels_H (int): middle channels H that controls capacity 242 | max_H (int): maximum H 243 | encoder_n_layers (int): number of encoder/decoder layers D 244 | kernel_size (int): kernel size K 245 | stride (int): stride S 246 | tsfm_n_layers (int): number of self attention blocks N 247 | tsfm_n_head (int): number of heads in each self attention block 248 | tsfm_d_model (int): d_model of self attention 249 | tsfm_d_inner (int): d_inner of self attention 250 | """ 251 | 252 | super(CleanUNet, self).__init__() 253 | 254 | self.channels_input = channels_input 255 | self.channels_output = channels_output 256 | self.channels_H = channels_H 257 | self.max_H = max_H 258 | self.encoder_n_layers = encoder_n_layers 259 | self.kernel_size = kernel_size 260 | self.stride = stride 261 | 262 | self.tsfm_n_layers = tsfm_n_layers 263 | self.tsfm_n_head = tsfm_n_head 264 | self.tsfm_d_model = tsfm_d_model 265 | self.tsfm_d_inner = tsfm_d_inner 266 | 267 | # encoder and decoder 268 | self.encoder = nn.ModuleList() 269 | self.decoder = nn.ModuleList() 270 | 271 | for i in range(encoder_n_layers): 272 | self.encoder.append(nn.Sequential( 273 | nn.Conv1d(channels_input, channels_H, kernel_size, stride), 274 | nn.ReLU(), 275 | nn.Conv1d(channels_H, channels_H * 2, 1), 276 | nn.GLU(dim=1) 277 | )) 278 | channels_input = channels_H 279 | 280 | if i == 0: 281 | # no relu at end 282 | self.decoder.append(nn.Sequential( 283 | nn.Conv1d(channels_H, channels_H * 2, 1), 284 | nn.GLU(dim=1), 285 | nn.ConvTranspose1d(channels_H, channels_output, kernel_size, stride) 286 | )) 287 | else: 288 | self.decoder.insert(0, nn.Sequential( 289 | nn.Conv1d(channels_H, channels_H * 2, 1), 290 | nn.GLU(dim=1), 291 | nn.ConvTranspose1d(channels_H, channels_output, kernel_size, stride), 292 | nn.ReLU() 293 | )) 294 | channels_output = channels_H 295 | 296 | # double H but keep below max_H 297 | channels_H *= 2 298 | channels_H = min(channels_H, max_H) 299 | 300 | # self attention block 301 | self.tsfm_conv1 = nn.Conv1d(channels_output, tsfm_d_model, kernel_size=1) 302 | self.tsfm_encoder = TransformerEncoder(d_word_vec=tsfm_d_model, 303 | n_layers=tsfm_n_layers, 304 | n_head=tsfm_n_head, 305 | d_k=tsfm_d_model // tsfm_n_head, 306 | d_v=tsfm_d_model // tsfm_n_head, 307 | d_model=tsfm_d_model, 308 | d_inner=tsfm_d_inner, 309 | dropout=0.0, 310 | n_position=0, 311 | scale_emb=False) 312 | self.tsfm_conv2 = nn.Conv1d(tsfm_d_model, channels_output, kernel_size=1) 313 | 314 | # weight scaling initialization 315 | for layer in self.modules(): 316 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 317 | weight_scaling_init(layer) 318 | 319 | def forward(self, noisy_audio): 320 | # (B, L) -> (B, C, L) 321 | if len(noisy_audio.shape) == 2: 322 | noisy_audio = noisy_audio.unsqueeze(1) 323 | B, C, L = noisy_audio.shape 324 | assert C == 1 325 | 326 | # normalization and padding 327 | std = noisy_audio.std(dim=2, keepdim=True) + 1e-3 328 | noisy_audio /= std 329 | x = padding(noisy_audio, self.encoder_n_layers, self.kernel_size, self.stride) 330 | 331 | # encoder 332 | skip_connections = [] 333 | for downsampling_block in self.encoder: 334 | x = downsampling_block(x) 335 | skip_connections.append(x) 336 | skip_connections = skip_connections[::-1] 337 | 338 | # attention mask for causal inference; for non-causal, set attn_mask to None 339 | len_s = x.shape[-1] # length at bottleneck 340 | attn_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=x.device), diagonal=1)).bool() 341 | 342 | x = self.tsfm_conv1(x) # C 1024 -> 512 343 | x = x.permute(0, 2, 1) 344 | x = self.tsfm_encoder(x, src_mask=attn_mask) 345 | x = x.permute(0, 2, 1) 346 | x = self.tsfm_conv2(x) # C 512 -> 1024 347 | 348 | # decoder 349 | for i, upsampling_block in enumerate(self.decoder): 350 | skip_i = skip_connections[i] 351 | x += skip_i[:, :, :x.shape[-1]] 352 | x = upsampling_block(x) 353 | 354 | x = x[:, :, :L] * std 355 | return x 356 | 357 | 358 | if __name__ == '__main__': 359 | import json 360 | import argparse 361 | import os 362 | 363 | parser = argparse.ArgumentParser() 364 | parser.add_argument('-c', '--config', type=str, default='configs/DNS-large-full.json', 365 | help='JSON file for configuration') 366 | args = parser.parse_args() 367 | 368 | with open(args.config) as f: 369 | data = f.read() 370 | config = json.loads(data) 371 | network_config = config["network_config"] 372 | 373 | model = CleanUNet(**network_config).cuda() 374 | from util import print_size 375 | print_size(model, keyword="tsfm") 376 | 377 | input_data = torch.ones([4,1,int(4.5*16000)]).cuda() 378 | output = model(input_data) 379 | print(output.shape) 380 | 381 | y = torch.rand([4,1,int(4.5*16000)]).cuda() 382 | loss = torch.nn.MSELoss()(y, output) 383 | loss.backward() 384 | print(loss.item()) 385 | --------------------------------------------------------------------------------