├── DCUNet ├── __init__.py ├── complex_nn.py ├── constant.py ├── criterion.py ├── inference.py ├── metric.py ├── noisedataset.py ├── sedataset.py ├── source_separator.py ├── unet.py └── utils.py ├── README.md ├── assets ├── estimated │ ├── p232_001.flac │ ├── p232_005.flac │ ├── p257_001.flac │ └── p257_006.flac ├── gt │ ├── p232_001.flac │ ├── p232_005.flac │ ├── p257_001.flac │ └── p257_006.flac ├── images │ └── melspectrogram.png └── noisy │ ├── p232_001.flac │ ├── p232_005.flac │ ├── p257_001.flac │ └── p257_006.flac ├── downsample.sh ├── estimate_directory.py ├── evaluation.py ├── train.py └── train_dcunet.py /DCUNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant import * 2 | from .complex_nn import * 3 | from .unet import * -------------------------------------------------------------------------------- /DCUNet/complex_nn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ComplexConv2d(nn.Module): 6 | # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py 7 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs): 8 | super().__init__() 9 | 10 | ## Model components 11 | self.conv_re = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, 12 | dilation=dilation, groups=groups, bias=bias, **kwargs) 13 | self.conv_im = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride, padding=padding, 14 | dilation=dilation, groups=groups, bias=bias, **kwargs) 15 | 16 | def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] 17 | real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) 18 | imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) 19 | output = torch.stack((real, imaginary), dim=-1) 20 | return output 21 | 22 | 23 | class ComplexConvTranspose2d(nn.Module): 24 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True, **kwargs): 25 | super().__init__() 26 | 27 | ## Model components 28 | self.tconv_re = nn.ConvTranspose2d(in_channel, out_channel, 29 | kernel_size=kernel_size, 30 | stride=stride, 31 | padding=padding, 32 | output_padding=output_padding, 33 | groups=groups, 34 | bias=bias, 35 | dilation=dilation, 36 | **kwargs) 37 | self.tconv_im = nn.ConvTranspose2d(in_channel, out_channel, 38 | kernel_size=kernel_size, 39 | stride=stride, 40 | padding=padding, 41 | output_padding=output_padding, 42 | groups=groups, 43 | bias=bias, 44 | dilation=dilation, 45 | **kwargs) 46 | 47 | def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2] 48 | real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) 49 | imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) 50 | output = torch.stack((real, imaginary), dim=-1) 51 | return output 52 | 53 | 54 | class ComplexBatchNorm2d(nn.Module): 55 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 56 | track_running_stats=True, **kwargs): 57 | super().__init__() 58 | self.bn_re = nn.BatchNorm2d(num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs) 59 | self.bn_im = nn.BatchNorm2d(num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs) 60 | 61 | def forward(self, x): 62 | real = self.bn_re(x[..., 0]) 63 | imag = self.bn_im(x[..., 1]) 64 | output = torch.stack((real, imag), dim=-1) 65 | return output 66 | 67 | -------------------------------------------------------------------------------- /DCUNet/constant.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | N_FFT = SAMPLE_RATE * 64 // 1000 3 | NUM_MELS = 229 4 | MIN_FREQ = 10 5 | MAX_FREQ = SAMPLE_RATE // 2 6 | HOP_LENGTH = SAMPLE_RATE * 16 // 1000 -------------------------------------------------------------------------------- /DCUNet/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class WeightedSDR: 5 | def __init__(self): 6 | self.loss = weighted_signal_distortion_ratio_loss 7 | 8 | def __call__(self, output, bd): 9 | return self.loss(output, bd) 10 | 11 | 12 | def dotproduct(y, y_hat): 13 | # batch x channel x nsamples 14 | return torch.bmm(y.view(y.shape[0], 1, y.shape[-1]), y_hat.view(y_hat.shape[0], y_hat.shape[-1], 1)).reshape(-1) 15 | 16 | 17 | def weighted_signal_distortion_ratio_loss(output, bd): 18 | y = bd['y'] # target signal 19 | z = bd['z'] # noise signal 20 | 21 | y_hat = output 22 | z_hat = bd['x'] - y_hat # expected noise signal 23 | 24 | # mono channel only... 25 | # can i fix this? 26 | y_norm = torch.norm(y, dim=-1).squeeze(1) 27 | z_norm = torch.norm(z, dim=-1).squeeze(1) 28 | y_hat_norm = torch.norm(y_hat, dim=-1).squeeze(1) 29 | z_hat_norm = torch.norm(z_hat, dim=-1).squeeze(1) 30 | 31 | def loss_sdr(a, a_hat, a_norm, a_hat_norm): 32 | return dotproduct(a, a_hat) / (a_norm * a_hat_norm + 1e-8) 33 | 34 | alpha = y_norm.pow(2) / (y_norm.pow(2) + z_norm.pow(2) + 1e-8) 35 | loss_wSDR = -alpha * loss_sdr(y, y_hat, y_norm, y_hat_norm) - (1 - alpha) * loss_sdr(z, z_hat, z_norm, z_hat_norm) 36 | 37 | return loss_wSDR.mean() 38 | -------------------------------------------------------------------------------- /DCUNet/inference.py: -------------------------------------------------------------------------------- 1 | from .constant import * 2 | from .utils import load_audio 3 | 4 | def any_audio_inference(path, net, sequence_length=None, normalize=True): 5 | audio = load_audio(path, SAMPLE_RATE)['audio'] 6 | y_hat = net.inference_one_audio(audio, normalize=normalize) 7 | return y_hat.cpu().numpy() 8 | -------------------------------------------------------------------------------- /DCUNet/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.functional as F 3 | 4 | from .constant import * 5 | from .utils import istft, realimag 6 | from pypesq import pesq 7 | 8 | 9 | class PESQ: 10 | def __init__(self): 11 | self.pesq = pesq_metric 12 | 13 | def __call__(self, output, bd): 14 | return self.pesq(output, bd) 15 | 16 | 17 | def pesq_metric(y_hat, bd): 18 | # PESQ 19 | with torch.no_grad(): 20 | y_hat = y_hat.cpu().numpy() 21 | y = bd['y'].cpu().numpy() # target signal 22 | 23 | sum = 0 24 | for i in range(len(y)): 25 | sum += pesq(y[i, 0], y_hat[i, 0], SAMPLE_RATE) 26 | 27 | sum /= len(y) 28 | return torch.tensor(sum) -------------------------------------------------------------------------------- /DCUNet/noisedataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | from .constant import * 7 | from .utils import load_audio, cut_padding 8 | 9 | 10 | class NoiseDataset(Dataset): 11 | def __init__(self, signals, noises, 12 | seed=0, 13 | sequence_length=16384, 14 | is_validation=False, 15 | snr_range=(-10, 20), 16 | preload=False): 17 | 18 | super(self.__class__, self).__init__() 19 | self.signals = signals # ['path/001.wav', 'path/002.flac', ... ] 20 | self.noises = noises # ['path/a.wav', 'path/b.flac', ... ] 21 | self.is_validation = is_validation 22 | self.snr_range = snr_range 23 | self.sequence_length = sequence_length 24 | self.random = np.random.RandomState(seed) 25 | self.preload = preload # load wav files on RAM. It reduces Disk I/O, but may consume huge memory. 26 | 27 | print("Got", len(signals), "signals And", len(noises), "noises.") 28 | 29 | if self.preload: 30 | self.data_y = [] 31 | print("Loading Signal Data") 32 | for signal in tqdm(self.signals): 33 | self.data_y.append(load_audio(signal, SAMPLE_RATE, assert_sr=True, channel=1)) 34 | 35 | self.data_z = [] 36 | print("Loading Noise Data") 37 | for noise in tqdm(self.noises): 38 | self.data_z.append(load_audio(noise, SAMPLE_RATE, assert_sr=True, channel=1)) 39 | 40 | def __len__(self): 41 | return len(self.signals) 42 | 43 | def __getitem__(self, idx): 44 | 45 | if self.preload: 46 | y = self.data_y[idx]['audio'] # channel x samples 47 | noise_idx = self.random.randint(len(self.data_z)) # channel x samples 48 | z = self.data_z[noise_idx]['audio'] 49 | 50 | else: 51 | y = load_audio(self.signals[idx], SAMPLE_RATE, assert_sr=True, channel=1)['audio'] 52 | noise_idx = self.random.randint(len(self.noises)) # channel x samples 53 | 54 | # Pitch augmentation 55 | if not self.is_validation and self.random.uniform(0., 1.) < 0.3: 56 | pitch = np.random.uniform(0.6, 1.3) 57 | z = load_audio(self.noises[noise_idx], int(SAMPLE_RATE*pitch), assert_sr=False, channel=1)['audio'] 58 | else: 59 | z = load_audio(self.noises[noise_idx], SAMPLE_RATE, assert_sr=True, channel=1)['audio'] 60 | 61 | if self.sequence_length is not None: 62 | y = cut_padding(y, self.sequence_length, self.random, self.is_validation) 63 | 64 | power_y = y.pow(2).mean(dim=-1).squeeze(0) 65 | power_z = z.pow(2).mean(dim=-1).squeeze(0) 66 | # TODO SNR에 맞춰 볼륨조절 어딘가로 모듈화, 제대로된 validation을 위해 일정하게 SNR 합성하기 67 | target_SNR = self.random.randint(*self.snr_range) 68 | noise_factor = torch.sqrt(power_y / (power_z) / (10 ** (target_SNR / 20))) 69 | 70 | audio_length = y.shape[-1] 71 | noise_length = z.shape[-1] 72 | if noise_length < audio_length: 73 | z = cut_padding(z, audio_length, self.random, self.is_validation) 74 | noise_length = z.shape[-1] 75 | 76 | if self.is_validation: 77 | noise_begin = 0 78 | else: 79 | noise_begin = self.random.randint(noise_length - audio_length + 1) 80 | 81 | noise_end = noise_begin + audio_length 82 | z = z[:, noise_begin:noise_end] 83 | 84 | z *= noise_factor 85 | 86 | x = y + z 87 | 88 | x_max = x.max(dim=-1)[0].view(x.shape[0], -1) 89 | x_min = x.min(dim=-1)[0].view(x.shape[0], -1) 90 | 91 | # Inverse : x = x + 1 (x + x_min ) / 2 92 | x = 2 * (x - x_min) / (x_max - x_min) - 1. 93 | y = 2 * (y - x_min) / (x_max - x_min) - 1. 94 | z = 2 * (z - x_min) / (x_max - x_min) - 1. 95 | 96 | rt = dict(x=x, 97 | y=y, 98 | z=z, 99 | x_max=x_max, 100 | x_min=x_min, 101 | power_y=power_y, 102 | power_z=power_z, 103 | SNR=target_SNR) 104 | # signal_path=self.data_y[idx]['path'], 105 | # noise_path=self.data_z[noise_idx]['path']) 106 | 107 | return rt 108 | -------------------------------------------------------------------------------- /DCUNet/sedataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | from .constant import * 7 | from .utils import load_audio, cut_padding 8 | 9 | """ 10 | Note: 11 | Target Dataset: https://datashare.is.ed.ac.uk/handle/10283/2791 12 | 13 | psquare 14 | bus 15 | cafe 16 | living 17 | office 18 | """ 19 | class SEDataset(Dataset): 20 | def __init__(self, signals, mixtures, 21 | seed=0, 22 | sequence_length=16384, 23 | is_validation=False, 24 | preload=False, 25 | ): 26 | 27 | super(self.__class__, self).__init__() 28 | self.signals = signals 29 | self.mixtures = mixtures 30 | self.is_validation = is_validation 31 | self.sequence_length = sequence_length 32 | self.random = np.random.RandomState(seed) 33 | self.preload = preload 34 | 35 | print("Got", len(signals), "signals and", len(mixtures), "mixtures.") 36 | 37 | if self.preload: 38 | self.data_y = [] 39 | print("Loading Signal Data") 40 | for signal in tqdm(self.signals): 41 | self.data_y.append(load_audio(signal, SAMPLE_RATE, assert_sr=True, channel=1)) 42 | 43 | self.data_x = [] 44 | print("Loading Mixture Data") 45 | for noise in tqdm(self.mixtures): 46 | self.data_x.append(load_audio(noise, SAMPLE_RATE, assert_sr=True, channel=1)) 47 | 48 | def __len__(self): 49 | return len(self.signals) 50 | 51 | def __getitem__(self, idx): 52 | if self.preload: 53 | x = self.data_x[idx]['audio'] 54 | y = self.data_y[idx]['audio'] # channel x samples 55 | z = x - y 56 | else: 57 | x = load_audio(self.mixtures[idx], SAMPLE_RATE, assert_sr=True, channel=1)['audio'] 58 | y = load_audio(self.signals[idx], SAMPLE_RATE, assert_sr=True, channel=1)['audio'] 59 | z = x - y 60 | 61 | if self.sequence_length is not None: 62 | x, y, z = cut_padding([x, y, z], self.sequence_length, self.random, deterministic=self.is_validation) 63 | 64 | x_max = x.max(dim=-1)[0].view(x.shape[0], -1) 65 | x_min = x.min(dim=-1)[0].view(x.shape[0], -1) 66 | 67 | # Inverse : x = x + 1 (x + x_min ) / 2 68 | x = 2 * (x - x_min) / (x_max - x_min) - 1. 69 | y = 2 * (y - x_min) / (x_max - x_min) - 1. 70 | z = 2 * (z - x_min) / (x_max - x_min) - 1. 71 | 72 | rt = dict(x=x, 73 | y=y, 74 | z=z, 75 | x_max=x_max, 76 | x_min=x_min) 77 | 78 | return rt -------------------------------------------------------------------------------- /DCUNet/source_separator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio_contrib as audio_nn 5 | import numpy as np 6 | from .constant import * 7 | from .utils import realimag, istft, cut_padding 8 | from .unet import UNet 9 | 10 | 11 | class SourceSeparator(nn.Module): 12 | def __init__(self, complex, model_complexity, model_depth, log_amp, padding_mode): 13 | """ 14 | :param complex: Whether to use complex networks. 15 | :param model_complexity: 16 | :param model_depth: Only two options are available : 10, 20 17 | :param log_amp: Whether to use log amplitude to estimate signals 18 | :param padding_mode: Encoder's convolution filter. 'zeros', 'reflect' 19 | """ 20 | super().__init__() 21 | self.net = nn.Sequential( 22 | STFT(complex=complex, log_amp=log_amp), 23 | UNet(1, complex=complex, model_complexity=model_complexity, model_depth=model_depth, padding_mode=padding_mode), 24 | ApplyMask(complex=complex, log_amp=log_amp), 25 | ISTFT(complex=complex, log_amp=log_amp) 26 | ) 27 | 28 | def forward(self, x, istft=True): 29 | if istft: 30 | return self.net(x) 31 | else: 32 | x = self.net[0](x) 33 | x = self.net[1](x) 34 | x = self.net[2](x) 35 | return x 36 | 37 | def inference_one_audio(self, audio, normalize=True): 38 | """ 39 | :param audio: channel x samples (tensor, float) 40 | :return: 41 | """ 42 | audict = SourceSeparator.preprocess_audio(audio, sequence_length=16384) 43 | with torch.no_grad(): 44 | for k, v in audict.items(): 45 | audict[k] = v.unsqueeze(1).cuda() 46 | Y_hat = self.forward(audict, istft=False).squeeze(1) 47 | y_hat = istft(Y_hat, HOP_LENGTH, length=audio.shape[-1]) 48 | if normalize: 49 | mx = y_hat.max(dim=-1)[0].view(y_hat.shape[0], -1) 50 | mn = y_hat.min(dim=-1)[0].view(y_hat.shape[0], -1) 51 | y_hat = 2 * (y_hat - mn) / (mx - mn) - 1. 52 | return y_hat 53 | 54 | @staticmethod 55 | def preprocess_audio(x, sequence_length=None): 56 | assert sequence_length is not None 57 | audio_length = x.shape[-1] 58 | 59 | if sequence_length is not None: 60 | if audio_length % sequence_length > 0: 61 | target_length = (audio_length // sequence_length + 1) * sequence_length 62 | else: 63 | target_length = audio_length 64 | 65 | x = cut_padding(x, target_length, np.random.RandomState(0), deterministic=True) 66 | 67 | x_max = x.max(dim=-1)[0].view(x.shape[0], -1) 68 | x_min = x.min(dim=-1)[0].view(x.shape[0], -1) 69 | x = 2 * (x - x_min) / (x_max - x_min) - 1. 70 | 71 | rt = dict(x=x, 72 | x_max=x_max, 73 | x_min=x_min) 74 | 75 | return rt 76 | 77 | 78 | class STFT(nn.Module): 79 | def __init__(self, complex=True, log_amp=False): 80 | super(self.__class__, self).__init__() 81 | self.stft = audio_nn.STFT(fft_length=N_FFT, hop_length=HOP_LENGTH) 82 | self.amp2db = audio_nn.AmplitudeToDb() 83 | 84 | self.complex = complex 85 | self.log_amp = log_amp 86 | window = torch.hann_window(N_FFT) 87 | self.register_buffer('window', window) 88 | 89 | def forward(self, bd): 90 | with torch.no_grad(): 91 | bd['X'] = self.stft(bd['x']) 92 | 93 | if not self.complex: 94 | bd['mag_X'], bd['phase_X'] = audio_nn.magphase(bd['X'], power=1.) 95 | if self.log_amp: 96 | bd['X'] = self.amp2db(bd['X']) 97 | return bd 98 | 99 | 100 | class ApplyMask(nn.Module): 101 | def __init__(self, complex=True, log_amp=False): 102 | super().__init__() 103 | self.amp2db = audio_nn.DbToAmplitude() 104 | self.complex = complex 105 | self.log_amp = log_amp 106 | 107 | def forward(self, bd): 108 | if not self.complex: 109 | Y_hat = bd['mag_X'] * bd['M_hat'] 110 | Y_hat = realimag(Y_hat, bd['phase_X']) 111 | if self.log_amp: 112 | raise NotImplementedError 113 | else: 114 | Y_hat = bd['X'] * bd['M_hat'] 115 | if self.log_amp: 116 | Y_hat = self.amp2db(Y_hat) 117 | 118 | return Y_hat 119 | 120 | 121 | class ISTFT(nn.Module): 122 | def __init__(self, complex=True, log_amp=False, length=16384): 123 | super().__init__() 124 | self.amp2db = audio_nn.DbToAmplitude() 125 | self.complex = complex 126 | self.log_amp = log_amp 127 | self.length = length 128 | 129 | def forward(self, Y_hat): 130 | # Y_hat : batch x channel x freq x time x 2 131 | num_batch = Y_hat.shape[0] 132 | num_channel = Y_hat.shape[1] 133 | Y_hat = Y_hat.view(Y_hat.shape[0] * Y_hat.shape[1], Y_hat.shape[2], Y_hat.shape[3], Y_hat.shape[4]) 134 | y_hat = istft(Y_hat, hop_length=HOP_LENGTH, win_length=N_FFT, length=self.length) # expected target signal 135 | y_hat = y_hat.view(num_batch, num_channel, -1) 136 | 137 | return y_hat 138 | 139 | -------------------------------------------------------------------------------- /DCUNet/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import DCUNet.complex_nn as complex_nn 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding=None, complex=False, padding_mode="zeros"): 9 | super().__init__() 10 | if padding is None: 11 | padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding 12 | 13 | if complex: 14 | conv = complex_nn.ComplexConv2d 15 | bn = complex_nn.ComplexBatchNorm2d 16 | else: 17 | conv = nn.Conv2d 18 | bn = nn.BatchNorm2d 19 | 20 | self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode) 21 | self.bn = bn(out_channels) 22 | self.relu = nn.LeakyReLU(inplace=True) 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | x = self.bn(x) 27 | x = self.relu(x) 28 | return x 29 | 30 | 31 | class Decoder(nn.Module): 32 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding=(0, 0), complex=False): 33 | super().__init__() 34 | if complex: 35 | tconv = complex_nn.ComplexConvTranspose2d 36 | bn = complex_nn.ComplexBatchNorm2d 37 | else: 38 | tconv = nn.ConvTranspose2d 39 | bn = nn.BatchNorm2d 40 | 41 | self.transconv = tconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 42 | self.bn = bn(out_channels) 43 | self.relu = nn.LeakyReLU(inplace=True) 44 | 45 | def forward(self, x): 46 | x = self.transconv(x) 47 | x = self.bn(x) 48 | x = self.relu(x) 49 | return x 50 | 51 | 52 | class UNet(nn.Module): 53 | def __init__(self, input_channels=1, 54 | complex=False, 55 | model_complexity=45, 56 | model_depth=20, 57 | padding_mode="zeros"): 58 | super().__init__() 59 | 60 | if complex: 61 | model_complexity = int(model_complexity // 1.414) 62 | 63 | self.set_size(model_complexity=model_complexity, input_channels=input_channels, model_depth=model_depth) 64 | self.encoders = [] 65 | self.model_length = model_depth // 2 66 | 67 | for i in range(self.model_length): 68 | module = Encoder(self.enc_channels[i], self.enc_channels[i + 1], kernel_size=self.enc_kernel_sizes[i], 69 | stride=self.enc_strides[i], padding=self.enc_paddings[i], complex=complex, padding_mode=padding_mode) 70 | self.add_module("encoder{}".format(i), module) 71 | self.encoders.append(module) 72 | 73 | self.decoders = [] 74 | 75 | for i in range(self.model_length): 76 | module = Decoder(self.dec_channels[i] + self.enc_channels[self.model_length - i], self.dec_channels[i + 1], kernel_size=self.dec_kernel_sizes[i], 77 | stride=self.dec_strides[i], padding=self.dec_paddings[i], complex=complex) 78 | self.add_module("decoder{}".format(i), module) 79 | self.decoders.append(module) 80 | 81 | if complex: 82 | conv = complex_nn.ComplexConv2d 83 | else: 84 | conv = nn.Conv2d 85 | 86 | linear = conv(self.dec_channels[-1], 1, 1) 87 | 88 | self.add_module("linear", linear) 89 | self.complex = complex 90 | self.padding_mode = padding_mode 91 | 92 | self.decoders = nn.ModuleList(self.decoders) 93 | self.encoders = nn.ModuleList(self.encoders) 94 | 95 | def forward(self, bd): 96 | if self.complex: 97 | x = bd['X'] 98 | else: 99 | x = bd['mag_X'] 100 | # go down 101 | xs = [] 102 | for i, encoder in enumerate(self.encoders): 103 | xs.append(x) 104 | #print("x{}".format(i), x.shape) 105 | x = encoder(x) 106 | # xs : x0=input x1 ... x9 107 | 108 | #print(x.shape) 109 | p = x 110 | for i, decoder in enumerate(self.decoders): 111 | p = decoder(p) 112 | if i == self.model_length - 1: 113 | break 114 | #print(f"p{i}, {p.shape} + x{self.model_length - 1 - i}, {xs[self.model_length - 1 -i].shape}, padding {self.dec_paddings[i]}") 115 | p = torch.cat([p, xs[self.model_length - 1 - i]], dim=1) 116 | 117 | #print(p.shape) 118 | mask = self.linear(p) 119 | mask = torch.tanh(mask) 120 | bd['M_hat'] = mask 121 | return bd 122 | 123 | def set_size(self, model_complexity, model_depth=20, input_channels=1): 124 | if model_depth == 10: 125 | self.enc_channels = [input_channels, 126 | model_complexity, 127 | model_complexity * 2, 128 | model_complexity * 2, 129 | model_complexity * 2, 130 | model_complexity * 2, 131 | ] 132 | self.enc_kernel_sizes = [(7, 5), 133 | (7, 5), 134 | (5, 3), 135 | (5, 3), 136 | (5, 3)] 137 | self.enc_strides = [(2, 2), 138 | (2, 2), 139 | (2, 2), 140 | (2, 2), 141 | (2, 1)] 142 | self.enc_paddings = [(2, 1), 143 | None, 144 | None, 145 | None, 146 | None] 147 | 148 | self.dec_channels = [0, 149 | model_complexity * 2, 150 | model_complexity * 2, 151 | model_complexity * 2, 152 | model_complexity * 2, 153 | model_complexity * 2] 154 | 155 | self.dec_kernel_sizes = [(4, 3), 156 | (4, 4), 157 | (6, 4), 158 | (6, 4), 159 | (7, 5)] 160 | 161 | self.dec_strides = [(2, 1), 162 | (2, 2), 163 | (2, 2), 164 | (2, 2), 165 | (2, 2)] 166 | 167 | self.dec_paddings = [(1, 1), 168 | (1, 1), 169 | (2, 1), 170 | (2, 1), 171 | (2, 1)] 172 | 173 | elif model_depth == 20: 174 | self.enc_channels = [input_channels, 175 | model_complexity, 176 | model_complexity, 177 | model_complexity * 2, 178 | model_complexity * 2, 179 | model_complexity * 2, 180 | model_complexity * 2, 181 | model_complexity * 2, 182 | model_complexity * 2, 183 | model_complexity * 2, 184 | 128] 185 | 186 | self.enc_kernel_sizes = [(7, 1), 187 | (1, 7), 188 | (6, 4), 189 | (7, 5), 190 | (5, 3), 191 | (5, 3), 192 | (5, 3), 193 | (5, 3), 194 | (5, 3), 195 | (5, 3)] 196 | 197 | self.enc_strides = [(1, 1), 198 | (1, 1), 199 | (2, 2), 200 | (2, 1), 201 | (2, 2), 202 | (2, 1), 203 | (2, 2), 204 | (2, 1), 205 | (2, 2), 206 | (2, 1)] 207 | 208 | self.enc_paddings = [(3, 0), 209 | (0, 3), 210 | None, 211 | None, 212 | None, 213 | None, 214 | None, 215 | None, 216 | None, 217 | None] 218 | 219 | self.dec_channels = [0, 220 | model_complexity * 2, 221 | model_complexity * 2, 222 | model_complexity * 2, 223 | model_complexity * 2, 224 | model_complexity * 2, 225 | model_complexity * 2, 226 | model_complexity * 2, 227 | model_complexity * 2, 228 | model_complexity * 2, 229 | model_complexity * 2, 230 | model_complexity * 2] 231 | 232 | self.dec_kernel_sizes = [(4, 3), 233 | (4, 2), 234 | (4, 3), 235 | (4, 2), 236 | (4, 3), 237 | (4, 2), 238 | (6, 3), 239 | (7, 5), 240 | (1, 7), 241 | (7, 1)] 242 | 243 | self.dec_strides = [(2, 1), 244 | (2, 2), 245 | (2, 1), 246 | (2, 2), 247 | (2, 1), 248 | (2, 2), 249 | (2, 1), 250 | (2, 2), 251 | (1, 1), 252 | (1, 1)] 253 | 254 | self.dec_paddings = [(1, 1), 255 | (1, 0), 256 | (1, 1), 257 | (1, 0), 258 | (1, 1), 259 | (1, 0), 260 | (2, 1), 261 | (2, 1), 262 | (0, 3), 263 | (3, 0)] 264 | else: 265 | raise ValueError("Unknown model depth : {}".format(model_depth)) 266 | -------------------------------------------------------------------------------- /DCUNet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import soundfile 4 | import librosa 5 | import torch.nn.functional as F 6 | 7 | 8 | def cut_padding(y, required_length, random_state, deterministic=False): 9 | 10 | if isinstance(y, list): 11 | audio_length = y[0].shape[-1] 12 | else: 13 | audio_length = y.shape[-1] 14 | 15 | if audio_length < required_length: 16 | if deterministic: 17 | pad_left = 0 18 | else: 19 | pad_left = random_state.randint(required_length - audio_length + 1) # 0 ~ 50 random 20 | pad_right = required_length - audio_length - pad_left # 50~ 0 21 | 22 | if isinstance(y, list): 23 | for i in range(len(y)): 24 | y[i] = F.pad(y[i], (pad_left, pad_right)) 25 | audio_length = y[0].shape[-1] 26 | else: 27 | y = F.pad(y, (pad_left, pad_right)) 28 | audio_length = y.shape[-1] 29 | 30 | if deterministic: 31 | audio_begin = 0 32 | else: 33 | audio_begin = random_state.randint(audio_length - required_length + 1) 34 | audio_end = required_length + audio_begin 35 | if isinstance(y, list): 36 | for i in range(len(y)): 37 | y[i] = y[i][..., audio_begin:audio_end] 38 | else: 39 | y = y[..., audio_begin:audio_end] 40 | return y 41 | 42 | 43 | def load_audio(path, sample_rate, assert_sr=False, channel=None): 44 | if path[-3:] == "pcm": 45 | audio, sr = soundfile.read(path, format="RAW", samplerate=16000, channels=1, subtype="PCM_16", 46 | dtype="float32") 47 | else: 48 | audio, sr = soundfile.read(path, dtype="float32") 49 | 50 | if len(audio.shape) == 1: # if mono 51 | audio = np.expand_dims(audio, 1) 52 | 53 | # samples x channel 54 | # sr이 16000이 아니면 resample 55 | if assert_sr: 56 | assert sr == sample_rate 57 | 58 | if sr != sample_rate: 59 | audio = librosa.core.resample(audio.T, sr, sample_rate).T 60 | 61 | # assert sr == SAMPLE_RATE 62 | audio = torch.FloatTensor(audio).permute(1, 0) 63 | if channel is not None: 64 | audio = audio[:channel] 65 | 66 | return dict(audio=audio, path=path) # channel x samples 67 | 68 | def get_audio_by_magphase(mag, phase, hop_length, n_fft, length=None): 69 | # mag : channel x freq x time 70 | # phase : channel x freq x time 71 | mono_audio_stft = realimag(mag, phase) 72 | # channel x freq x time x 2 73 | 74 | mono_audio = istft(mono_audio_stft, hop_length, n_fft, length=length) 75 | return mono_audio 76 | 77 | 78 | def _get_time_values(sig_length, sr, hop): 79 | """ 80 | Get the time axis values given the signal length, sample 81 | rate and hop size. 82 | """ 83 | return torch.linspace(0, sig_length/sr, sig_length//hop+1) 84 | 85 | 86 | def _get_freq_values(n_fft, sr): 87 | """ 88 | Get the frequency axis values given the number of FFT bins 89 | and sample rate. 90 | """ 91 | return torch.linspace(0, sr/2, n_fft//2 + 1) 92 | 93 | 94 | def get_spectrogram_axis(sig_length, sr, n_fft=2048, hop=512): 95 | t = _get_time_values(sig_length, sr, hop) 96 | f = _get_freq_values(n_fft, sr) 97 | return t, f 98 | 99 | 100 | def istft(stft_matrix, hop_length=None, win_length=None, window='hann', 101 | center=True, normalized=False, onesided=True, length=None): 102 | # keunwoochoi's implementation 103 | # https://gist.github.com/keunwoochoi/2f349e72cc941f6f10d4adf9b0d3f37e 104 | 105 | """stft_matrix = (batch, freq, time, complex) 106 | 107 | All based on librosa 108 | - http://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#istft 109 | What's missing? 110 | - normalize by sum of squared window --> do we need it here? 111 | Actually the result is ok by simply dividing y by 2. 112 | """ 113 | assert normalized == False 114 | assert onesided == True 115 | assert window == "hann" 116 | assert center == True 117 | 118 | device = stft_matrix.device 119 | n_fft = 2 * (stft_matrix.shape[-3] - 1) 120 | 121 | batch = stft_matrix.shape[0] 122 | 123 | # By default, use the entire frame 124 | if win_length is None: 125 | win_length = n_fft 126 | 127 | if hop_length is None: 128 | hop_length = int(win_length // 4) 129 | 130 | istft_window = torch.hann_window(n_fft).to(device).view(1, -1) # (batch, freq) 131 | 132 | n_frames = stft_matrix.shape[-2] 133 | expected_signal_len = n_fft + hop_length * (n_frames - 1) 134 | 135 | y = torch.zeros(batch, expected_signal_len, device=device) 136 | for i in range(n_frames): 137 | sample = i * hop_length 138 | spec = stft_matrix[:, :, i] 139 | iffted = torch.irfft(spec, signal_ndim=1, signal_sizes=(win_length,)) 140 | 141 | ytmp = istft_window * iffted 142 | y[:, sample:(sample + n_fft)] += ytmp 143 | 144 | y = y[:, n_fft // 2:] 145 | 146 | if length is not None: 147 | if y.shape[1] > length: 148 | y = y[:, :length] 149 | elif y.shape[1] < length: 150 | y = F.pad(y, (0, length - y.shape[1])) 151 | # y = torch.cat((y[:, :length], torch.zeros(y.shape[0], length - y.shape[1]))) 152 | 153 | coeff = n_fft / float( 154 | hop_length) / 2.0 # -> this might go wrong if curretnly asserted values (especially, `normalized`) changes. 155 | return y / coeff 156 | 157 | 158 | def angle(tensor): 159 | """ 160 | Return angle of a complex tensor with shape (*, 2). 161 | """ 162 | return torch.atan2(tensor[...,1], tensor[...,0]) 163 | 164 | 165 | def magphase(spec, power=1.): 166 | """ 167 | Separate a complex-valued spectrogram with shape (*,2) 168 | into its magnitude and phase. 169 | """ 170 | mag = spec.pow(2).sum(-1).pow(power/2) 171 | phase = angle(spec) 172 | return mag, phase 173 | 174 | 175 | def realimag(mag, phase): 176 | """ 177 | Combine a magnitude spectrogram and a phase spectrogram to a complex-valued spectrogram with shape (*, 2) 178 | """ 179 | spec_real = mag * torch.cos(phase) 180 | spec_imag = mag * torch.sin(phase) 181 | spec = torch.stack([spec_real, spec_imag], dim=-1) 182 | return spec 183 | 184 | 185 | def get_snr(y, z): 186 | y_power = y.pow(2).mean(dim=-1) 187 | z_power = z.pow(2).mean(dim=-1) 188 | snr = 20*torch.log10(y_power/z_power) 189 | return snr 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Complex U-Net 2 | --- 3 | **Unofficial** PyTorch Implementation of [Phase-Aware Speech Enhancement with Deep Complex U-Net](https://openreview.net/forum?id=SkeRTsAcYm), (H. Choi et al., 2018) 4 | > **Note** 5 | > 6 | > This is NOT author's implementation. 7 | 8 | 9 | ## Architecture 10 | --- 11 | (TO BE) 12 | 13 | 14 | ## Requirements 15 | ```text 16 | torch==1.1 17 | soundfile==0.9.0 18 | easydict==1.9 19 | git+https://github.com/keunwoochoi/torchaudio-contrib@61fc6a804c941dec3cf8a06478704d19fc5e415a 20 | git+https://github.com/sweetcocoa/PinkBlack@e45a65623c1b511181f7ea697ca841a7b2900f17 21 | torchcontrib==0.0.2 22 | git+https://github.com/vBaiCai/python-pesq 23 | # gputil # if you need to execute multiple training process 24 | ``` 25 | 26 | ## Train 27 | --- 28 | 1. Download Datasets: 29 | - [https://datashare.is.ed.ac.uk/handle/10283/2791](https://datashare.is.ed.ac.uk/handle/10283/2791) 30 | 31 | 2. Separate each train / test wavs 32 | 33 | 3. Downsample wavs 34 | ```bash 35 | # prerequisite : ffmpeg. 36 | # sudo apt-get install ffmpeg (Ubuntu) 37 | bash downsample.sh # all wavs below $PWD will be converted to .flac, 16k samplerate 38 | ``` 39 | 40 | 4. Train 41 | ```bash 42 | python train_dcunet.py --batch_size 12 \ 43 | --train_signal /path/to/train/clean/speech/ \ 44 | --train_noise /path/to/train/noisy/speech/ \ 45 | --test_signal /path/to/test/clean/speech/ \ 46 | --test_noise /path/to/test/noisy/speech/ \ 47 | --ckpt /path/to/save/checkpoint.pth \ 48 | --num_step 50000 \ 49 | --validation_interval 500 \ 50 | --complex 51 | 52 | # You can check other arguments from the source code. ( Sorry for the lack description. ) 53 | ``` 54 | 55 | ## Test 56 | --- 57 | ```bash 58 | python estimate_directory.py --input_dir /path/to/noisy/speech/ \ 59 | --output_dir /path/to/estimate/dir/ \ 60 | --ckpt /path/to/saved/checkpoint.pth 61 | ``` 62 | 63 | 64 | ## Results 65 | --- 66 | | PESQ(cRMCn/cRMRn) | Paper | Mine* | 67 | | -------------------- | ----- | ---- | 68 | | DCUNet - 10 | **2.72**/2.51 | 3.03/**3.07** | 69 | | DCUNet - 20| **3.24**/2.74 | **3.12**/3.11 | 70 | 71 | - *cRMCn* : Complex-valued input/Output 72 | - *cRMRn* : Real-valued input/Output 73 | 74 | Comparing the two(Paper's, Mine) values above is inappropriate for the following reasons: 75 | 76 | - \* I did not use matlab code that the author used to calculate pesq, but instead used [pypesq](https://github.com/vBaiCai/python-pesq). 77 | 78 | - \* The Architecture of model is slightly different from the original paper. (Such as kernel size of convolution filters) 79 | 80 | - MelSpec 81 | ![img](./assets/images/melspectrogram.png) 82 | 83 | ## Notes 84 | --- 85 | - Log amplitute estimate has slightly worse performance than non-log amplitude 86 | - Complex-valued network does not make the metric better.. 87 | 88 | ## Sample Wavs 89 | --- 90 | | Mixture | Estimated Speech | GT(Clean Speech) | 91 | | --------|-----------|-------------| 92 | |[mixture1.wav](./assets/noisy/p232_001.flac?raw=true)|[Estimated1.wav](./assets/estimated/p232_001.flac?raw=true)|[GroundTruth1.wav](./assets/gt/p232_001.flac?raw=true)| 93 | |[mixture2.wav](./assets/noisy/p232_005.flac?raw=true)|[Estimated2.wav](./assets/estimated/p232_005.flac?raw=true)|[GroundTruth2.wav](./assets/gt/p232_005.flac?raw=true)| 94 | 95 | 96 | ## Contact 97 | --- 98 | - Jongho Choi(sweetcocoa@snu.ac.kr / Seoul National Univ., ESTsoft ) 99 | -------------------------------------------------------------------------------- /assets/estimated/p232_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/estimated/p232_001.flac -------------------------------------------------------------------------------- /assets/estimated/p232_005.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/estimated/p232_005.flac -------------------------------------------------------------------------------- /assets/estimated/p257_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/estimated/p257_001.flac -------------------------------------------------------------------------------- /assets/estimated/p257_006.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/estimated/p257_006.flac -------------------------------------------------------------------------------- /assets/gt/p232_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/gt/p232_001.flac -------------------------------------------------------------------------------- /assets/gt/p232_005.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/gt/p232_005.flac -------------------------------------------------------------------------------- /assets/gt/p257_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/gt/p257_001.flac -------------------------------------------------------------------------------- /assets/gt/p257_006.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/gt/p257_006.flac -------------------------------------------------------------------------------- /assets/images/melspectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/images/melspectrogram.png -------------------------------------------------------------------------------- /assets/noisy/p232_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/noisy/p232_001.flac -------------------------------------------------------------------------------- /assets/noisy/p232_005.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/noisy/p232_005.flac -------------------------------------------------------------------------------- /assets/noisy/p257_001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/noisy/p257_001.flac -------------------------------------------------------------------------------- /assets/noisy/p257_006.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sweetcocoa/DeepComplexUNetPyTorch/c68510a4d822f19fa366f1da84eff8c0c25ff88a/assets/noisy/p257_006.flac -------------------------------------------------------------------------------- /downsample.sh: -------------------------------------------------------------------------------- 1 | echo Converting the audio files to FLAC ... 2 | FOLDER=$PWD 3 | COUNTER=$(find -name *.wav|wc -l) 4 | for f in $PWD/**/**/*.wav; do 5 | COUNTER=$((COUNTER - 1)) 6 | echo -ne "\rConverting ($COUNTER) : $f..." 7 | ffmpeg -y -loglevel fatal -i $f -ac 1 -ar 16000 ${f/\.wav/.flac} 8 | done 9 | -------------------------------------------------------------------------------- /estimate_directory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage :: 3 | python estimate_directory.py --input_dir /path/to/input/wavs/ --output_dir /path/to/output/wavs/ --ckpt ckpt/ckpt.pth 4 | 5 | input_dir (폴더 안의 폴더 recursively 포함)의 *.wav, *.flac을 노이즈 제거하여 6 | output_dir 에 출력한다. 7 | 8 | 알려진 문제점 9 | 1) 많은 노이즈가 제대로 제거되지 않는 문제 10 | 2) 느린 문제 11 | """ 12 | 13 | import os, glob 14 | import soundfile 15 | import numpy as np 16 | import torch 17 | import json 18 | from tqdm import tqdm 19 | from DCUNet.constant import * 20 | from DCUNet.source_separator import SourceSeparator 21 | from DCUNet.inference import any_audio_inference 22 | from easydict import EasyDict 23 | 24 | import argparse 25 | 26 | def get_arg(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--input_dir', default="/path/to/data/wav_dir/") 29 | parser.add_argument('--output_dir', default="/path/to/data/wav_dir/") 30 | parser.add_argument('--ckpt', default="ckpt/190628_False_mix_20_sz60.pth") 31 | args = parser.parse_args() 32 | 33 | args = EasyDict(args.__dict__) 34 | 35 | ckpt = args.ckpt 36 | args.sequence_length = 16384 # 고정 37 | 38 | model_spec = sorted(glob.glob(ckpt + "*.args"))[-1] 39 | with open(model_spec) as f: 40 | specs = EasyDict(json.load(f)) 41 | 42 | args.update(specs) 43 | args.ckpt = ckpt # 저장된 args의 ckpt로 덮어써지므로. 44 | 45 | if not hasattr(args, "padding_mode"): 46 | print("No 'padding_mode' is specified, 'zeros' will be used as padding_mode") 47 | args.padding_mode = "zeros" 48 | 49 | return args 50 | 51 | args = get_arg() 52 | input_files = [] 53 | input_files.extend(glob.glob(args.input_dir + "/**/*.wav", recursive=True)) 54 | input_files.extend(glob.glob(args.input_dir + "/**/*.flac", recursive=True)) 55 | 56 | net = SourceSeparator(complex=args.complex, 57 | log_amp=args.log_amp, 58 | model_complexity=args.model_complexity, 59 | model_depth=args.model_depth, 60 | padding_mode=args.padding_mode 61 | ).cuda() 62 | 63 | net.load_state_dict(torch.load(args.ckpt, map_location='cuda')) 64 | net.eval() 65 | 66 | os.makedirs(args.output_dir, exist_ok=True) 67 | 68 | for file in tqdm(sorted(input_files)): 69 | y_hat = any_audio_inference(file, net, sequence_length=args.sequence_length, normalize=True).transpose() 70 | output_file = file.replace(args.input_dir, args.output_dir) 71 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 72 | soundfile.write(output_file, y_hat, SAMPLE_RATE) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from pypesq import pesq 2 | import os, glob 3 | import soundfile 4 | from tqdm import tqdm 5 | 6 | refs = sorted(glob.glob("path/to/clean/*.flac")) 7 | evals = sorted(glob.glob("path/to/estimated/*.flac")) 8 | 9 | results = dict() 10 | for i, (ref, eval) in tqdm(enumerate(zip(refs, evals))): 11 | assert os.path.basename(ref) == os.path.basename(eval) 12 | y, sr = soundfile.read(ref) 13 | y_hat, sr = soundfile.read(eval) 14 | 15 | results[os.path.basename(ref)] = pesq(y, y_hat, sr) 16 | 17 | # print(results) 18 | print(sum(results.values())/len(results)) 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from easydict import EasyDict 4 | import GPUtil 5 | import sys, time 6 | 7 | gpu = GPUtil.getAvailable(limit=4, excludeID=[6,7]) 8 | 9 | print("gpu available : ", gpu) 10 | 11 | args = EasyDict(dict(gpu="2", 12 | batch_size=12, 13 | train_signal="/data/jongho/data/2019challenge/dcunet/clean_speech/train/", 14 | train_noise="/data/jongho/data/2019challenge/dcunet/noisy_noise/train/", 15 | test_signal="/data/jongho/data/2019challenge/ss/reference/16k/clean_testset_wav/", 16 | test_noise="/data/jongho/data/2019challenge/ss/reference/16k/noisy_testset_wav/", 17 | sequence_length=16384, 18 | num_step=40000, 19 | validation_interval=500, 20 | num_workers=12, 21 | ckpt="unet/ckpt.pth", 22 | model_complexity=45, 23 | lr=0.01, 24 | num_signal=0, 25 | num_noise=0, 26 | optimizer="adam", 27 | lr_decay=0.5, 28 | momentum=0, 29 | multi_gpu=False, 30 | complex=True, 31 | model_depth=10, 32 | swa=False, 33 | loss="wsdr", 34 | log_amp=False, 35 | metric="pesq", 36 | train_dataset="se", 37 | valid_dataset="se", 38 | preload=False, 39 | padding_mode="reflect")) 40 | 41 | se_y_train = "/data/jongho/data/2019challenge/ss/reference/16k/clean_trainset_28spk_wav/" 42 | se_x_train = "/data/jongho/data/2019challenge/ss/reference/16k/noisy_trainset_28spk_wav/" 43 | 44 | # mix_y_train = "/data/jongho/data/2019challenge/ss/clean_speech/train/" 45 | # mix_x_train = "/data/jongho/data/2019challenge/ss/noisy_noise/train/" 46 | 47 | mix_y_train = "/data/jongho/data/2019challenge/dcunet/clean_speech/train/" 48 | mix_x_train = "/data/jongho/data/2019challenge/dcunet/demand/train/" 49 | 50 | # mix_y_train = "/data/jongho/data/2019challenge/ss/dataset/train/speech/" 51 | # mix_x_train = "/data/jongho/data/2019challenge/ss/dataset/train/noise/" 52 | 53 | for model_complexity in [45, 90]: 54 | for model_depth in [10, 20]: 55 | # skip cnfig 56 | if model_complexity == 90 and model_depth == 10: 57 | continue 58 | if model_complexity == 45 and model_depth == 20: 59 | continue 60 | 61 | for complex in [False, True]: 62 | for log in [False]: 63 | for train_dataset in ['se']: 64 | for optimizer, lr in [('adam', 0.01)]: 65 | for padding_mode in ['zeros']: 66 | while not gpu: 67 | sleep_sec = 600 68 | print(f"no gpu available, sleep {sleep_sec}s...") 69 | time.sleep(sleep_sec) 70 | gpu = GPUtil.getAvailable(limit=4, excludeID=[6,7]) 71 | 72 | command = [f"/miniconda/bin/python", f"{os.getcwd()}/train_dcunet.py"] 73 | args.train_dataset = train_dataset 74 | 75 | if train_dataset == "se": 76 | args.train_signal = se_y_train 77 | args.train_noise = se_x_train 78 | else: 79 | args.train_signal = mix_y_train 80 | args.train_noise = mix_x_train 81 | 82 | args.padding_mode = padding_mode 83 | args.model_depth = model_depth 84 | args.gpu = str(gpu.pop()) 85 | args.model_complexity = model_complexity 86 | args.ckpt = f"demand_experiment_report/190717_{log}_dp{model_depth}_{train_dataset}_sz{model_complexity}_{padding_mode}_comp_{complex}.pth" 87 | args.optimizer = optimizer 88 | args.lr = lr 89 | 90 | for k,v in args.items(): 91 | if isinstance(v, bool): 92 | pass 93 | else: 94 | command.append(f"--{k}") 95 | command.append(f"{v}") 96 | 97 | if log: 98 | command.append("--log_amp") 99 | 100 | if args.preload: 101 | command.append("--preload") 102 | 103 | if complex: 104 | command.append("--complex") 105 | 106 | print("command : {", command, "}") 107 | subprocess.Popen(command, shell=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 108 | time.sleep(1) 109 | # exit() 110 | 111 | time.sleep(86400*10) -------------------------------------------------------------------------------- /train_dcunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import glob, os 5 | import numpy as np 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from torchcontrib.optim import SWA 10 | import PinkBlack.trainer 11 | 12 | from DCUNet.constant import * 13 | from DCUNet.noisedataset import NoiseDataset 14 | from DCUNet.sedataset import SEDataset 15 | from DCUNet.source_separator import SourceSeparator 16 | from DCUNet.criterion import WeightedSDR 17 | from DCUNet.metric import PESQ 18 | 19 | args = PinkBlack.io.setup(default_args=dict(gpu="0", 20 | batch_size=12, 21 | train_signal="/data/jongho/data/2019challenge/ss/reference/16k/clean_trainset_28spk_wav/", 22 | train_noise="/data/jongho/data/2019challenge/ss/reference/16k/noisy_trainset_28spk_wav/", 23 | test_signal="/data/jongho/data/2019challenge/ss/reference/16k/clean_testset_wav/", 24 | test_noise="/data/jongho/data/2019challenge/ss/reference/16k/noisy_testset_wav/", 25 | sequence_length=16384, 26 | num_step=100000, 27 | validation_interval=500, 28 | num_workers=0, 29 | ckpt="unet/ckpt.pth", 30 | model_complexity=45, 31 | lr=0.01, 32 | num_signal=0, 33 | num_noise=0, 34 | optimizer="adam", 35 | lr_decay=0.5, 36 | momentum=0, 37 | multi_gpu=False, 38 | complex=False, 39 | model_depth=20, 40 | swa=False, 41 | loss="wsdr", 42 | log_amp=False, 43 | metric="pesq", 44 | train_dataset="se", 45 | valid_dataset="se", 46 | preload=False, # Whether to load datasets on memory 47 | padding_mode="reflect", # conv2d's padding mode 48 | )) 49 | 50 | 51 | def get_dataset(args): 52 | def get_wav(dir): 53 | wavs = [] 54 | wavs.extend(glob.glob(os.path.join(dir, "**/*.wav"), recursive=True)) 55 | wavs.extend(glob.glob(os.path.join(dir, "**/*.flac"), recursive=True)) 56 | wavs.extend(glob.glob(os.path.join(dir, "**/*.pcm"), recursive=True)) 57 | return wavs 58 | 59 | if args.train_dataset == "mix": 60 | train_signals = get_wav(args.train_signal) 61 | train_noises = get_wav(args.train_noise) 62 | else: 63 | train_signals = get_wav(args.train_signal) 64 | train_noises = [signal.replace("clean", "noisy") for signal in train_signals] 65 | 66 | if args.valid_dataset == "mix": 67 | test_signals = get_wav(args.test_signal) 68 | test_noises = get_wav(args.test_noise) 69 | else: 70 | test_signals = get_wav(args.test_signal) 71 | test_noises = [signal.replace("clean", "noisy") for signal in test_signals] 72 | 73 | if args.num_signal > 0: 74 | train_signals = train_signals[:args.num_signal] 75 | test_signals = test_signals[:args.num_signal] 76 | 77 | if args.num_noise > 0: 78 | train_noises = train_noises[:args.num_noise] 79 | test_noises = test_noises[:args.num_noise] 80 | 81 | if args.train_dataset == "mix": 82 | train_dset = NoiseDataset(train_signals, train_noises, sequence_length=args.sequence_length, is_validation=False, preload=args.preload) 83 | else: 84 | train_noises = [signal.replace("clean", "noisy") for signal in train_signals] 85 | train_dset = SEDataset(train_signals, train_noises, sequence_length=args.sequence_length, is_validation=False) 86 | 87 | if args.valid_dataset == "mix": 88 | rand = np.random.RandomState(0) 89 | rand.shuffle(test_signals) 90 | test_signals = test_signals[:1000] 91 | valid_dset = NoiseDataset(test_signals, test_noises, sequence_length=args.sequence_length, is_validation=True, preload=args.preload) 92 | else: 93 | test_noises = [signal.replace("clean", "noisy") for signal in test_signals] 94 | valid_dset = SEDataset(test_signals, test_noises, sequence_length=args.sequence_length, is_validation=True) 95 | 96 | return dict(train_dset=train_dset, 97 | valid_dset=valid_dset) 98 | 99 | dset = get_dataset(args) 100 | train_dset, valid_dset = dset['train_dset'], dset['valid_dset'] 101 | 102 | train_dl = DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, 103 | pin_memory=False) 104 | valid_dl = DataLoader(valid_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, 105 | pin_memory=False) 106 | 107 | if args.loss == "wsdr": 108 | loss = WeightedSDR() 109 | else: 110 | raise NotImplementedError(f"unknown loss ({args.loss})") 111 | 112 | if args.metric == "pesq": 113 | metric = PESQ() 114 | else: 115 | def metric(output, bd): 116 | with torch.no_grad(): 117 | return -loss(output, bd) 118 | 119 | net = SourceSeparator(complex=args.complex, model_complexity=args.model_complexity, model_depth=args.model_depth, log_amp=args.log_amp, padding_mode=args.padding_mode).cuda() 120 | print(net) 121 | 122 | if args.multi_gpu: 123 | net = nn.DataParallel(net).cuda() 124 | 125 | if args.optimizer == "adam": 126 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=args.lr) 127 | elif args.optimizer == "sgd": 128 | optimizer = optim.SGD(filter(lambda x: x.requires_grad, net.parameters()), lr=args.lr, momentum=args.momentum) 129 | else: 130 | raise ValueError(f"Unknown optimizer - {args.optimizer}") 131 | 132 | if args.swa: 133 | steps_per_epoch = args.validation_interval 134 | optimizer = SWA(optimizer, swa_start=int(20) * steps_per_epoch, swa_freq=1 * steps_per_epoch) 135 | 136 | if args.lr_decay >= 1 or args.lr_decay <= 0: 137 | scheduler = None 138 | else: 139 | if args.optimizer == "swa": 140 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer.optimizer, mode="max", patience=5, factor=args.lr_decay) 141 | else: 142 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=5, factor=args.lr_decay) 143 | 144 | trainer = PinkBlack.trainer.Trainer(net, 145 | criterion=loss, 146 | metric=metric, 147 | train_dataloader=train_dl, 148 | val_dataloader=valid_dl, 149 | ckpt=args.ckpt, 150 | optimizer=optimizer, 151 | lr_scheduler=scheduler, 152 | is_data_dict=True, 153 | logdir="log_se") 154 | 155 | trainer.train(step=args.num_step, validation_interval=args.validation_interval) 156 | 157 | if args.swa: 158 | trainer.swa_apply(bn_update=True) 159 | trainer.train(1, phases=['val']) 160 | --------------------------------------------------------------------------------