├── .gitignore ├── EarlyStopping.py ├── Readme.md ├── configs └── dpcrn.json ├── conv_stft.py ├── img ├── enhancement.PNG └── noisy.PNG ├── infer.py ├── logs └── DPCRN │ └── config.json ├── loss.py ├── model.py ├── modules.py ├── pretrained_model └── final_ckp.pth ├── resample_data.py ├── se_dataset.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .buildconfig 2 | .flatpak-builder 3 | .ipynb_checkpoints 4 | _build 5 | __pycache__ 6 | \#* 7 | dist 8 | build 9 | *.egg-info 10 | test 11 | tags 12 | *~ 13 | *snap 14 | venv\ 15 | .idea 16 | dataset\ -------------------------------------------------------------------------------- /EarlyStopping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import time 5 | 6 | class EarlyStopping: 7 | '''Early stop the training if validation loss doesn't improve after a given patience''' 8 | 9 | def __init__(self, patience=10, verbose=False, 10 | delta=0, path='checkpoint.pth'): 11 | self.patience = patience 12 | self.verbose = verbose 13 | self.counter = 0 14 | self.best_score = None 15 | self.early_stop = False 16 | self.val_loss_min = np.Inf 17 | self.delta = delta 18 | self.path = path 19 | if not os.path.exists(path): 20 | os.makedirs(self.path) 21 | 22 | def __call__(self, val_loss, model): 23 | score = -val_loss 24 | 25 | if self.best_score is None: 26 | self.best_score = score 27 | self.save_checkpoint(val_loss, model) 28 | elif score < self.best_score + self.delta: 29 | self.counter += 1 30 | print(f'EarlyStopping Counter: {self.counter} out of {self.patience}') 31 | if self.counter >= self.patience: 32 | self.early_stop = True 33 | else: 34 | self.best_score = score 35 | self.save_checkpoint(val_loss, model) 36 | self.counter = 0 37 | 38 | def save_checkpoint(self, val_loss, model): 39 | '''save model when validation loss decrease''' 40 | if self.verbose: 41 | print(f'validation loss decrease ({self.val_loss_min:.6f})') 42 | path = time.strftime("%Y-%m-%d-%H-%M-%S") + '_ckp.pth' 43 | path = os.path.join(self.path, path) 44 | torch.save(model.state_dict(), path) 45 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | **DPCRN** 2 | 3 | ------------------------ 4 | 5 | An unofficial Pytorch implementation of DPCRN model. 6 | 7 | #### 1.Demo 8 | 9 | noisy speech: 10 | 11 | ![img](img/noisy.PNG) 12 | 13 | enhancement speech: 14 | 15 | ![img](img/enhancement.PNG) 16 | 17 | #### 2.Train 18 | 19 | To train the model, **firstly** change the `folder` variable value of `load_data_list()` function in `se_dataset.py`, **then** run : 20 | 21 | ``` 22 | python train.py --batch_size 8 --num_epochs 100 23 | ``` 24 | 25 | #### 3.Infer 26 | 27 | To denoise the noisy speech, run: 28 | 29 | ``` 30 | python infer.py --ckp_path ./pretrained_model/final_ckp.pth --audio_path your_noisy_speech_path --save_path your_enhancement_speech_save_path 31 | ``` 32 | 33 | #### 4.Pretrained model 34 | 35 | A pretrained model which trained on **VCTK-Demand** dataset can be found in `pretrained_model` folder. 36 | 37 | #### 5.References 38 | 39 | [1]https://github.com/Le-Xiaohuai-speech/DPCRN_DNS3 40 | 41 | [2]https://github.com/chanil1218/DCUnet.pytorch 42 | -------------------------------------------------------------------------------- /configs/dpcrn.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "encoder_in_channel": 2, 4 | "encoder_channel_size": [32,32,32,64,128], 5 | "encoder_kernel_size": [[5,2],[3,2],[3,2],[3,2],[3,2]], 6 | "encoder_stride_size": [[2,1],[2,1],[1,1],[1,1],[1,1]], 7 | "encoder_padding": [[1,0,0,2],[1,0,0,1],[1,0,1,1],[1,0,1,1],[1,0,1,1]], 8 | "decoder_in_channel": 256, 9 | "decoder_channel_size": [64,32,32,32,2], 10 | "decoder_kernel_size": [[3,2],[3,2],[3,2],[3,2],[5,2]], 11 | "decoder_stride_size": [[1,1],[1,1],[1,1],[2,1],[2,1]], 12 | "dprnn_rnn_type": "LSTM", 13 | "dprnn_input_size": 128, 14 | "dprnn_hidden_size": 128, 15 | "dprnn_dropout": 0, 16 | "dprnn_num_layers": 1, 17 | "frame_len": 400, 18 | "frame_shift": 200, 19 | "seed": 1234, 20 | "epochs": 20000, 21 | "learning_rate": 2e-4, 22 | "betas": [0.8, 0.99], 23 | "eps": 1e-9, 24 | "batch_size": 64, 25 | "fp16_run": true, 26 | "lr_decay": 0.999875, 27 | "segment_size": 8192, 28 | "init_lr_ratio": 1, 29 | "warmup_epochs": 0, 30 | "c_mel": 45, 31 | "c_kl": 1.0 32 | }, 33 | "data": { 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | if win_type == 'None' or win_type is None: 10 | window = np.ones(win_len) 11 | else: 12 | window = get_window(win_type, win_len, fftbins=True) # **0.5 13 | 14 | N = fft_len 15 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 16 | real_kernel = np.real(fourier_basis) 17 | imag_kernel = np.imag(fourier_basis) 18 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 19 | 20 | if invers: 21 | kernel = np.linalg.pinv(kernel).T 22 | 23 | kernel = kernel * window 24 | kernel = kernel[:, None, :] 25 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None, :, None].astype(np.float32)) 26 | 27 | 28 | class ConvSTFT(nn.Module): 29 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 30 | super(ConvSTFT, self).__init__() 31 | 32 | if fft_len == None: 33 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 34 | else: 35 | self.fft_len = fft_len 36 | 37 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 38 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 39 | self.register_buffer('weight', kernel) 40 | self.feature_type = feature_type 41 | self.stride = win_inc 42 | self.win_len = win_len 43 | self.dim = self.fft_len 44 | 45 | def forward(self, inputs): 46 | if inputs.dim() == 2: 47 | inputs = torch.unsqueeze(inputs, 1) 48 | inputs = F.pad(inputs, [self.win_len - self.stride, self.win_len - self.stride]) 49 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 50 | 51 | if self.feature_type == 'complex': 52 | return outputs 53 | else: 54 | dim = self.dim // 2 + 1 55 | real = outputs[:, :dim, :] 56 | imag = outputs[:, dim:, :] 57 | if self.feature_type == 'real': 58 | return real, imag 59 | else: 60 | mags = torch.sqrt(real ** 2 + imag ** 2) 61 | phase = torch.atan2(imag, real) 62 | return mags, phase 63 | 64 | 65 | class ConviSTFT(nn.Module): 66 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 67 | super(ConviSTFT, self).__init__() 68 | if fft_len == None: 69 | self.fft_len = np.int(2 ** np.ceil(np.log2(win_len))) 70 | else: 71 | self.fft_len = fft_len 72 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 73 | # self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 74 | self.register_buffer('weight', kernel) 75 | self.feature_type = feature_type 76 | self.win_type = win_type 77 | self.win_len = win_len 78 | self.stride = win_inc 79 | self.stride = win_inc 80 | self.dim = self.fft_len 81 | self.register_buffer('window', window) 82 | self.register_buffer('enframe', torch.eye(win_len)[:, None, :]) 83 | 84 | def forward(self, inputs, phase=None): 85 | """ 86 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 87 | phase: [B, N//2+1, T] (if not none) 88 | """ 89 | if phase is not None: 90 | real = inputs * torch.cos(phase) 91 | imag = inputs * torch.sin(phase) 92 | inputs = torch.cat([real, imag], 1) 93 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 94 | 95 | # this is from torch-stft: https://github.com/pseeth/torch-stft 96 | t = self.window.repeat(1, 1, inputs.size(-1)) ** 2 97 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 98 | outputs = outputs / (coff + 1e-8) 99 | # outputs = torch.where(coff == 0, outputs, outputs/coff) 100 | outputs = outputs[..., self.win_len - self.stride:-(self.win_len - self.stride)] 101 | 102 | return outputs -------------------------------------------------------------------------------- /img/enhancement.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bear-boy/DPCRN-Pytorch/516cb0951204d750f7b41e2849cc91d83285344c/img/enhancement.PNG -------------------------------------------------------------------------------- /img/noisy.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bear-boy/DPCRN-Pytorch/516cb0951204d750f7b41e2849cc91d83285344c/img/noisy.PNG -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | import argparse 4 | import utils 5 | import soundfile as sf 6 | import numpy as np 7 | from model import DPCRN 8 | 9 | def infer(args, model, hps): 10 | input, sr = librosa.load(args.audio_path, sr=None) 11 | l = len(input) 12 | frames = int(np.ceil((l - hps.train.frame_len) / hps.train.frame_shift)) + 1 13 | max_t = (frames - 1) * hps.train.frame_shift + hps.train.frame_len 14 | input_mat = np.zeros((1,max_t), dtype=np.float32) 15 | input_mat[0, :l] = input 16 | x = torch.from_numpy(input_mat).type(torch.FloatTensor).cuda() 17 | y,_,_ = model(x) 18 | y = y.detach().cpu().data.numpy() 19 | output = y[:,:l].T 20 | # out_path = os.path.join(os.curdir, args.save_path) 21 | sf.write(args.save_path, output, samplerate=sr, subtype='FLOAT') 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--ckp_path', default='./pretrained_model/final_ckp.pth', type=str, help='Checkpoint path') 27 | parser.add_argument('--audio_path', default='./dataset/valset_noisy/p232_001.wav', type=str, help='infer audio path') 28 | parser.add_argument('--save_path', default='enhancement_example/p232_001_enh.wav', type=str, help='save audio path') 29 | args = parser.parse_args() 30 | 31 | hps = utils.get_hparams() 32 | net = DPCRN(encoder_in_channel=hps.train.encoder_in_channel, 33 | encoder_channel_size=hps.train.encoder_channel_size, 34 | encoder_kernel_size=hps.train.encoder_kernel_size, 35 | encoder_stride_size=hps.train.encoder_stride_size, 36 | encoder_padding=hps.train.encoder_padding, 37 | decoder_in_channel=hps.train.decoder_in_channel, 38 | decoder_channel_size=hps.train.decoder_channel_size, 39 | decoder_kernel_size=hps.train.decoder_kernel_size, 40 | decoder_stride_size=hps.train.decoder_stride_size, 41 | rnn_type=hps.train.dprnn_rnn_type, 42 | input_size=hps.train.dprnn_input_size, 43 | hidden_size=hps.train.dprnn_hidden_size, 44 | frame_len=hps.train.frame_len, 45 | frame_shift=hps.train.frame_shift).cuda() 46 | net.load_state_dict(torch.load(args.ckp_path)) 47 | 48 | 49 | infer(args, net, hps) -------------------------------------------------------------------------------- /logs/DPCRN/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "encoder_in_channel": 2, 4 | "encoder_channel_size": [32,32,32,64,128], 5 | "encoder_kernel_size": [[5,2],[3,2],[3,2],[3,2],[3,2]], 6 | "encoder_stride_size": [[2,1],[2,1],[1,1],[1,1],[1,1]], 7 | "encoder_padding": [[1,0,0,2],[1,0,0,1],[1,0,1,1],[1,0,1,1],[1,0,1,1]], 8 | "decoder_in_channel": 256, 9 | "decoder_channel_size": [64,32,32,32,2], 10 | "decoder_kernel_size": [[3,2],[3,2],[3,2],[3,2],[5,2]], 11 | "decoder_stride_size": [[1,1],[1,1],[1,1],[2,1],[2,1]], 12 | "dprnn_rnn_type": "LSTM", 13 | "dprnn_input_size": 128, 14 | "dprnn_hidden_size": 128, 15 | "dprnn_dropout": 0, 16 | "dprnn_num_layers": 1, 17 | "frame_len": 400, 18 | "frame_shift": 200, 19 | "seed": 1234, 20 | "epochs": 20000, 21 | "learning_rate": 2e-4, 22 | "betas": [0.8, 0.99], 23 | "eps": 1e-9, 24 | "batch_size": 64, 25 | "fp16_run": true, 26 | "lr_decay": 0.999875, 27 | "segment_size": 8192, 28 | "init_lr_ratio": 1, 29 | "warmup_epochs": 0, 30 | "c_mel": 45, 31 | "c_kl": 1.0 32 | }, 33 | "data": { 34 | 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Loss function for DPCRN 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | def spec_loss(y_true_re, y_true_im, y_pred_re, y_pred_im): 9 | mse_fn = nn.MSELoss() 10 | # real-part loss 11 | real_loss = mse_fn(y_true_re, y_pred_re) 12 | # imag-part loss 13 | imag_loss = mse_fn(y_true_im, y_pred_im) 14 | # magnitude loss 15 | y_true_mag = (y_true_re ** 2 + y_true_im ** 2 + 1e-8) ** 0.5 16 | y_pred_mag = (y_pred_re ** 2 + y_pred_im ** 2 + 1e-8) ** 0.5 17 | mag_loss = mse_fn(y_true_mag, y_pred_mag) 18 | 19 | total_loss = real_loss + imag_loss + mag_loss 20 | total_loss = torch.log(total_loss + 1e-8) 21 | return total_loss 22 | 23 | def snr_loss(y_true, y_pred): 24 | snr = torch.mean(torch.square(y_true),dim=-1,keepdim=True) / (torch.mean(torch.square(y_pred-y_true),dim=-1,keepdim=True) +1e-8) 25 | loss = -10 * torch.log10(snr + 1e-8) 26 | loss = torch.squeeze(loss,-1).mean() 27 | return loss 28 | 29 | 30 | def DPCRNLoss(y_true, y_pred, y_true_re, y_true_im, y_pred_re, y_pred_im): 31 | loss1 = snr_loss(y_true, y_pred) 32 | loss2 = spec_loss(y_true_re, y_true_im, y_pred_re, y_pred_im) 33 | loss = loss1 + loss2 34 | return loss 35 | 36 | 37 | if __name__ == "__main__": 38 | from conv_stft import ConvSTFT 39 | stft = ConvSTFT(win_len=400,win_inc=100) 40 | 41 | batch_size = 16 42 | y_true = torch.randn((batch_size, 16000 * 2)) # 2s inputs 43 | y_true_re ,y_true_im = stft(y_true) 44 | 45 | y_pred = torch.randn((batch_size, 16000 * 2)) 46 | y_pred_re, y_pred_im = stft(y_pred) 47 | 48 | loss = DPCRNLoss(y_true, y_pred, y_true_re, y_true_im, y_pred_re, y_pred_im) 49 | print(loss) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules import DPRNN 5 | from conv_stft import ConvSTFT, ConviSTFT 6 | import utils 7 | import scipy.signal as signal 8 | import numpy as np 9 | 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | cos_win = torch.from_numpy(signal.windows.cosine(400,False)).type(torch.FloatTensor).cuda() 12 | 13 | 14 | class STFT(nn.Module): 15 | def __init__(self, frame_len, frame_hop, fft_len=None): 16 | super(STFT, self).__init__() 17 | self.eps = torch.finfo(torch.float32).eps 18 | self.frame_len = frame_len 19 | self.frame_hop = frame_hop 20 | 21 | def forward(self, x): 22 | if len(x.shape) != 2: 23 | print("x must be in [B, T]") 24 | y = torch.stft(x, hop_length=self.frame_hop, 25 | n_fft=self.frame_len, window=cos_win, return_complex=True, center=False) 26 | r = y.real 27 | i = y.imag 28 | return r,i 29 | 30 | class ISTFT(nn.Module): 31 | def __init__(self, frame_len, frame_hop, fft_len=None): 32 | super(ISTFT, self).__init__() 33 | self.eps = torch.finfo(torch.float32).eps 34 | self.frame_len = frame_len 35 | self.frame_hop = frame_hop 36 | 37 | def forward(self, real, imag): 38 | x = torch.complex(real, imag) 39 | y = torch.istft(x, hop_length=self.frame_hop, 40 | n_fft=self.frame_len, window=cos_win,center=False) 41 | return y 42 | 43 | class DPCRN(nn.Module): 44 | def __init__(self, encoder_in_channel, encoder_channel_size, encoder_kernel_size, encoder_stride_size, encoder_padding, 45 | decoder_in_channel, decoder_channel_size, decoder_kernel_size, decoder_stride_size, 46 | rnn_type, input_size, hidden_size, 47 | frame_len, frame_shift): 48 | super(DPCRN, self).__init__() 49 | self.encoder_channel_size = encoder_channel_size 50 | self.encoder_kernel_size = encoder_kernel_size 51 | self.encoder_stride_size = encoder_stride_size 52 | self.encoder_padding = encoder_padding 53 | self.decoder_channel_size = decoder_channel_size 54 | self.decoder_kernel_size = decoder_kernel_size 55 | self.decoder_stride_size = decoder_stride_size 56 | self.frame_len = frame_len 57 | self.frame_shift = frame_shift 58 | 59 | # self.stft = ConvSTFT(win_len=frame_len, win_inc=frame_shift) 60 | # self.istft = ConviSTFT(win_len=frame_len, win_inc=frame_shift) 61 | self.stft = STFT(self.frame_len, self.frame_shift) 62 | self.istft = ISTFT(self.frame_len, self.frame_shift) 63 | 64 | self.encoder = Encoder(encoder_in_channel, self.encoder_channel_size, 65 | self.encoder_kernel_size, self.encoder_stride_size, self.encoder_padding) 66 | self.decoder = Decoder(decoder_in_channel, self.decoder_channel_size, 67 | self.decoder_kernel_size, self.decoder_stride_size) 68 | self.dprnn = DPRNN(rnn_type, input_size=input_size, hidden_size=hidden_size) 69 | 70 | def forward(self, x): 71 | re, im = self.stft(x) 72 | inputs = torch.stack([re,im],dim=1) # B x C x F x T 73 | x, skips = self.encoder(inputs) 74 | 75 | x = self.dprnn(x) 76 | 77 | mask = self.decoder(x, skips) 78 | en_re, en_im = self.mask_speech(mask, inputs) # en_ shape: B x F x T 79 | en_speech = self.istft(en_re, en_im) 80 | return en_speech, en_re, en_im 81 | 82 | def mask_speech(self, mask, x): 83 | mask_re = mask[:,0,:,:] 84 | mask_im = mask[:,1,:,:] 85 | 86 | x_re = x[:,0,:,:] 87 | x_im = x[:,1,:,:] 88 | 89 | en_re = x_re * mask_re - x_im * mask_im 90 | en_im = x_re * mask_im + x_im * mask_re 91 | return en_re, en_im 92 | 93 | class Encoder(nn.Module): 94 | def __init__(self, in_channel_size, channel_size, kernel_size, stride_size, padding): 95 | super(Encoder, self).__init__() 96 | self.channel_size = channel_size 97 | self.kernel_size = kernel_size 98 | self.stride_size = stride_size 99 | self.padding = padding 100 | 101 | self.conv = nn.ModuleList() 102 | self.norm = nn.ModuleList() 103 | in_chan = in_channel_size 104 | for i in range(len(channel_size)): 105 | self.conv.append(nn.Conv2d(in_channels=in_chan,out_channels=channel_size[i], 106 | kernel_size=kernel_size[i], stride=stride_size[i])) 107 | self.norm.append(nn.BatchNorm2d(channel_size[i])) 108 | in_chan = channel_size[i] 109 | self.prelu = nn.PReLU() 110 | 111 | def forward(self, x): 112 | # x shape: B x C x F x T 113 | skips = [] 114 | for i, (layer, norm) in enumerate(zip(self.conv, self.norm)): 115 | x = F.pad(x, pad=self.padding[i]) 116 | x = layer(x) 117 | x = self.prelu(norm(x)) 118 | skips.append(x) 119 | return x, skips 120 | 121 | class Decoder(nn.Module): 122 | def __init__(self, in_channel_size, channel_size, kernel_size, stride_size): 123 | super(Decoder, self).__init__() 124 | self.channel_size = channel_size 125 | self.kernel_size = kernel_size 126 | self.stride_size = stride_size 127 | 128 | self.conv = nn.ModuleList() 129 | self.norm = nn.ModuleList() 130 | in_chan = in_channel_size 131 | for i in range(len(channel_size)): 132 | if i == 3: 133 | self.conv.append(nn.ConvTranspose2d(in_channels=in_chan, out_channels=channel_size[i], 134 | kernel_size=kernel_size[i], stride=stride_size[i], 135 | padding=[1, 0], output_padding=[1, 0])) 136 | else: 137 | self.conv.append(nn.ConvTranspose2d(in_channels=in_chan, out_channels=channel_size[i], 138 | kernel_size=kernel_size[i], stride=stride_size[i], 139 | padding=[1,0])) 140 | self.norm.append(nn.BatchNorm2d(channel_size[i])) 141 | in_chan = channel_size[i] * 2 142 | self.prelu = nn.PReLU() 143 | 144 | def forward(self, x, skips): 145 | # x shape: B x C x F x T 146 | for i, (layer, norm, skip) in enumerate(zip(self.conv, self.norm, reversed(skips))): 147 | x = torch.cat([x,skip], dim=1) 148 | x = layer(x)[:,:,:,:-1] 149 | x = self.prelu(norm(x)) 150 | return x 151 | 152 | def test_model(): 153 | hps = utils.get_hparams() 154 | model = DPCRN(encoder_in_channel=hps.train.encoder_in_channel, 155 | encoder_channel_size=hps.train.encoder_channel_size, 156 | encoder_kernel_size=hps.train.encoder_kernel_size, 157 | encoder_stride_size=hps.train.encoder_stride_size, 158 | encoder_padding=hps.train.encoder_padding, 159 | decoder_in_channel=hps.train.decoder_in_channel, 160 | decoder_channel_size=hps.train.decoder_channel_size, 161 | decoder_kernel_size=hps.train.decoder_kernel_size, 162 | decoder_stride_size=hps.train.decoder_stride_size, 163 | rnn_type=hps.train.dprnn_rnn_type, 164 | input_size=hps.train.dprnn_input_size, 165 | hidden_size=hps.train.dprnn_hidden_size, 166 | frame_len=hps.train.frame_len, 167 | frame_shift=hps.train.frame_shift) 168 | model = model.to(device) 169 | model.eval() 170 | batch_size = 16 171 | x = torch.randn((batch_size, 16000*5)).to(device) # 5s inputs 172 | y = model(x) 173 | return y 174 | 175 | def test_stft(): 176 | hps = utils.get_hparams() 177 | stft = STFT(frame_len=hps.train.frame_len, frame_hop=hps.train.frame_shift) 178 | istft = ISTFT(frame_len=hps.train.frame_len, frame_hop=hps.train.frame_shift) 179 | x = torch.randn((8,16100*5)) 180 | x_r, x_i = stft(x) 181 | x_rec = istft(x_r, x_i) 182 | print(x_rec.size(1)) 183 | 184 | def get_model_size(): 185 | hps = utils.get_hparams() 186 | model = DPCRN(encoder_in_channel=hps.train.encoder_in_channel, 187 | encoder_channel_size=hps.train.encoder_channel_size, 188 | encoder_kernel_size=hps.train.encoder_kernel_size, 189 | encoder_stride_size=hps.train.encoder_stride_size, 190 | encoder_padding=hps.train.encoder_padding, 191 | decoder_in_channel=hps.train.decoder_in_channel, 192 | decoder_channel_size=hps.train.decoder_channel_size, 193 | decoder_kernel_size=hps.train.decoder_kernel_size, 194 | decoder_stride_size=hps.train.decoder_stride_size, 195 | rnn_type=hps.train.dprnn_rnn_type, 196 | input_size=hps.train.dprnn_input_size, 197 | hidden_size=hps.train.dprnn_hidden_size, 198 | frame_len=hps.train.frame_len, 199 | frame_shift=hps.train.frame_shift) 200 | model = model.to(device) 201 | para = [p.numel() for p in model.parameters() if p.requires_grad] 202 | total_para_size = sum(para) 203 | print(total_para_size) 204 | return total_para_size 205 | 206 | if __name__ == "__main__": 207 | # get_model_size() 208 | # test_stft() 209 | test_model() -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | EPS = 1e-8 5 | 6 | class SingleRNN(nn.Module): 7 | """ 8 | Container module for a single RNN layer. 9 | 10 | args: 11 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 12 | input_size: int, dimension of the input feature. The input should have shape 13 | (batch, seq_len, input_size). 14 | hidden_size: int, dimension of the hidden state. 15 | dropout: float, dropout ratio. Default is 0. 16 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 17 | """ 18 | 19 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False): 20 | super(SingleRNN, self).__init__() 21 | 22 | self.rnn_type = rnn_type 23 | self.input_size = input_size 24 | self.hidden_size = hidden_size 25 | self.num_direction = int(bidirectional) + 1 26 | 27 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, 1, dropout=dropout, batch_first=True, 28 | bidirectional=bidirectional) 29 | 30 | # linear projection layer 31 | self.proj = nn.Linear(hidden_size * self.num_direction, input_size) 32 | 33 | def forward(self, input): 34 | # input shape: batch, seq, dim 35 | #input = input.to(device) 36 | output = input 37 | rnn_output, _ = self.rnn(output) 38 | rnn_output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])).view(output.shape) 39 | return rnn_output 40 | 41 | # dual-path RNN 42 | class DPRNN(nn.Module): 43 | """ 44 | Deep duaL-path RNN. 45 | 46 | args: 47 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 48 | input_size: int, dimension of the input feature. The input should have shape 49 | (batch, seq_len, input_size). 50 | hidden_size: int, dimension of the hidden state. 51 | output_size: int, dimension of the output size. 52 | dropout: float, dropout ratio. Default is 0. 53 | num_layers: int, number of stacked RNN layers. Default is 1. 54 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 55 | """ 56 | 57 | def __init__(self, rnn_type, input_size, hidden_size, 58 | dropout=0, num_layers=1, bidirectional=False): 59 | super(DPRNN, self).__init__() 60 | 61 | self.input_size = input_size 62 | # self.output_size = output_size 63 | self.hidden_size = hidden_size 64 | 65 | # dual-path RNN 66 | self.row_rnn = nn.ModuleList([]) 67 | self.col_rnn = nn.ModuleList([]) 68 | self.row_norm = nn.ModuleList([]) 69 | self.col_norm = nn.ModuleList([]) 70 | for i in range(num_layers): 71 | self.row_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, 72 | bidirectional=True)) # intra-segment RNN is always noncausal 73 | self.col_rnn.append(SingleRNN(rnn_type, input_size, hidden_size, dropout, bidirectional=bidirectional)) 74 | self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 75 | # default is to use noncausal LayerNorm for inter-chunk RNN. For causal setting change it to causal normalization techniques accordingly. 76 | self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) 77 | 78 | # no output layer in DPCRN 79 | # self.output = nn.Sequential(nn.PReLU(), 80 | # nn.Conv2d(input_size, output_size, 1) 81 | # ) 82 | 83 | def forward(self, input): 84 | # input shape: batch, N, dim1, dim2 85 | # apply RNN on dim1 first and then dim2 86 | # output shape: B, output_size, dim1, dim2 87 | #input = input.to(device) 88 | batch_size, _, dim1, dim2 = input.shape 89 | output = input 90 | for i in range(len(self.row_rnn)): 91 | row_input = output.permute(0, 3, 2, 1).contiguous().view(batch_size * dim2, dim1, -1) # B*dim2, dim1, N 92 | row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H 93 | row_output = row_output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 94 | 1).contiguous() # B, N, dim1, dim2 95 | row_output = self.row_norm[i](row_output) 96 | output = output + row_output 97 | 98 | col_input = output.permute(0, 2, 3, 1).contiguous().view(batch_size * dim1, dim2, -1) # B*dim1, dim2, N 99 | col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H 100 | col_output = col_output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 101 | 2).contiguous() # B, N, dim1, dim2 102 | col_output = self.col_norm[i](col_output) 103 | output = output + col_output 104 | 105 | # output = self.output(output) # B, output_size, dim1, dim2 106 | return output -------------------------------------------------------------------------------- /pretrained_model/final_ckp.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bear-boy/DPCRN-Pytorch/516cb0951204d750f7b41e2849cc91d83285344c/pretrained_model/final_ckp.pth -------------------------------------------------------------------------------- /resample_data.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import os 3 | import soundfile as sf 4 | from tqdm import tqdm 5 | 6 | clean_path = './testset_clean_to_be_resampled/' 7 | noisy_path = './testset_noisy_to_be_resampled/' 8 | 9 | clean_samples = os.listdir(clean_path) 10 | noisy_samples = os.listdir(noisy_path) 11 | 12 | clean_samples = [f for f in clean_samples if f.endswith('.wav')] 13 | noisy_samples = [f for f in noisy_samples if f.endswith('.wav')] 14 | 15 | output_clean_path = './dataset/valset_clean/' 16 | output_noisy_path = './dataset/valset_noisy/' 17 | 18 | for id in tqdm(range(len(clean_samples))): 19 | # p = clean_path + clean_samples[id] 20 | y, sr = librosa.load(clean_path + clean_samples[id], sr=None) 21 | y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000) 22 | sf.write(output_clean_path + clean_samples[id], y_16k, samplerate=16000) 23 | 24 | for id in tqdm(range(len(noisy_samples))): 25 | # p = noisy_path + noisy_samples[id] 26 | y, sr = librosa.load(noisy_path + noisy_samples[id], sr=None) 27 | y_16k = librosa.resample(y, orig_sr=sr, target_sr=16000) 28 | sf.write(output_noisy_path + noisy_samples[id], y_16k, samplerate=16000) -------------------------------------------------------------------------------- /se_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import librosa 4 | import os 5 | import torch 6 | from torch.utils import data 7 | 8 | # Reference 9 | # DATA LOADING - LOAD FILE LISTS 10 | def load_data_list(folder='./dataset/', setname='train'): 11 | assert(setname in ['train', 'val']) 12 | 13 | dataset = {} 14 | foldername = folder + '/' + setname + 'set' 15 | 16 | print("Loading files...") 17 | dataset['innames'] = [] 18 | dataset['outnames'] = [] 19 | dataset['shortnames'] = [] 20 | 21 | filelist = os.listdir("%s_noisy"%(foldername)) 22 | filelist = [f for f in filelist if f.endswith(".wav")] 23 | for i in tqdm(filelist): 24 | dataset['innames'].append("%s_noisy/%s"%(foldername,i)) 25 | dataset['outnames'].append("%s_clean/%s"%(foldername,i)) 26 | dataset['shortnames'].append("%s"%(i)) 27 | 28 | return dataset 29 | 30 | # DATA LOADING - LOAD FILE DATA 31 | def load_data(dataset): 32 | 33 | dataset['inaudio'] = [None]*len(dataset['innames']) 34 | dataset['outaudio'] = [None]*len(dataset['outnames']) 35 | 36 | for id in tqdm(range(len(dataset['innames']))): 37 | 38 | if dataset['inaudio'][id] is None: 39 | inputData, sr = librosa.load(dataset['innames'][id], sr=None) 40 | outputData, sr = librosa.load(dataset['outnames'][id], sr=None) 41 | 42 | shape = np.shape(inputData) 43 | 44 | dataset['inaudio'][id] = np.float32(inputData) 45 | dataset['outaudio'][id] = np.float32(outputData) 46 | 47 | return dataset 48 | 49 | class AudioDataset(data.Dataset): 50 | """ 51 | Audio sample reader. 52 | """ 53 | 54 | def __init__(self, data_type, win_len, hop_len): 55 | dataset = load_data_list(setname=data_type) 56 | self.dataset = load_data(dataset) 57 | self.win_len = win_len 58 | self.hop_len = hop_len 59 | self.file_names = dataset['innames'] 60 | 61 | def __getitem__(self, idx): 62 | mixed = torch.from_numpy(self.dataset['inaudio'][idx]).type(torch.FloatTensor) 63 | clean = torch.from_numpy(self.dataset['outaudio'][idx]).type(torch.FloatTensor) 64 | 65 | return mixed, clean 66 | 67 | def __len__(self): 68 | return len(self.file_names) 69 | 70 | def zero_pad_concat(self, inputs): 71 | max_t = max(inp.shape[0] for inp in inputs) 72 | # pad zero at end to make (data_len - win_len)/hop_len is integer, 73 | # which will make sure that torch.istft() outputs' length equal to torch.stft() inputs' 74 | frames = int(np.ceil((max_t - self.win_len) / self.hop_len)) + 1 75 | max_t = (frames - 1) * self.hop_len + self.win_len 76 | shape = (len(inputs), max_t) 77 | input_mat = np.zeros(shape, dtype=np.float32) 78 | for e, inp in enumerate(inputs): 79 | input_mat[e, :inp.shape[0]] = inp 80 | return input_mat 81 | 82 | def collate(self, inputs): 83 | mixeds, cleans = zip(*inputs) 84 | seq_lens = torch.IntTensor([i.shape[0] for i in mixeds]) 85 | 86 | x = torch.FloatTensor(self.zero_pad_concat(mixeds)) 87 | y = torch.FloatTensor(self.zero_pad_concat(cleans)) 88 | 89 | batch = [x, y, seq_lens] 90 | return batch -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | training script for DPCRN 3 | ''' 4 | import argparse 5 | import os 6 | import torch 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.tensorboard import SummaryWriter 11 | from tqdm import tqdm 12 | from model import DPCRN, STFT 13 | import utils 14 | from se_dataset import AudioDataset 15 | from loss import DPCRNLoss 16 | from EarlyStopping import EarlyStopping 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--log_dir', default='./logs/DPCRN/', type=str, help='Log file path to record training status') 20 | parser.add_argument('--batch_size', default=8, type=int, help='train batch size') 21 | parser.add_argument('--num_epochs', default=100, type=int, help='train epochs number') 22 | args = parser.parse_args() 23 | 24 | def get_lr(optim): 25 | return optim.param_groups[0]['lr'] 26 | 27 | def main(): 28 | hps = utils.get_hparams() 29 | net = DPCRN(encoder_in_channel=hps.train.encoder_in_channel, 30 | encoder_channel_size=hps.train.encoder_channel_size, 31 | encoder_kernel_size=hps.train.encoder_kernel_size, 32 | encoder_stride_size=hps.train.encoder_stride_size, 33 | encoder_padding=hps.train.encoder_padding, 34 | decoder_in_channel=hps.train.decoder_in_channel, 35 | decoder_channel_size=hps.train.decoder_channel_size, 36 | decoder_kernel_size=hps.train.decoder_kernel_size, 37 | decoder_stride_size=hps.train.decoder_stride_size, 38 | rnn_type=hps.train.dprnn_rnn_type, 39 | input_size=hps.train.dprnn_input_size, 40 | hidden_size=hps.train.dprnn_hidden_size, 41 | frame_len=hps.train.frame_len, 42 | frame_shift=hps.train.frame_shift).cuda() 43 | 44 | stft = STFT(hps.train.frame_len, hps.train.frame_shift) 45 | 46 | # checkpoints load 47 | if (os.path.exists('./final.pth.tar')): 48 | checkpoint = torch.load('./final.pth.tar') 49 | net.load_state_dict(checkpoint) 50 | 51 | # log 52 | writer = SummaryWriter(args.log_dir) 53 | 54 | train_dataset = AudioDataset(data_type='train',win_len=hps.train.frame_len,hop_len=hps.train.frame_shift) 55 | val_dataset = AudioDataset(data_type='val',win_len=hps.train.frame_len,hop_len=hps.train.frame_shift) 56 | train_data_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, 57 | collate_fn=train_dataset.collate, shuffle=True, num_workers=1) 58 | val_data_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, 59 | collate_fn=val_dataset.collate, shuffle=False, num_workers=1) 60 | 61 | torch.set_printoptions(precision=10, profile="full") 62 | 63 | # Optimizer 64 | optimizer = optim.Adam(net.parameters(), lr=1e-3) 65 | # Learning rate scheduler 66 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, min_lr=10e-10, cooldown=1) 67 | # Early-Stopping 68 | early_stopping = EarlyStopping(patience=10,verbose=True,path='./logs/DPCRN/ckp/') 69 | 70 | writer_step = 0 71 | for epoch in range(args.num_epochs): 72 | # training 73 | net.train() 74 | train_bar = tqdm(train_data_loader) 75 | step = 0 76 | record_loss = 0.0 77 | for input in train_bar: 78 | train_mixed, train_clean, seq_len = map(lambda x: x.cuda(), input) 79 | # train_mixed, train_clean, seq_len = input 80 | train_clean_re, train_clean_im = stft(train_clean) 81 | en_sp, en_re, en_im = net(train_mixed) 82 | for i, l in enumerate(seq_len): 83 | en_sp[i, l:] = 0 84 | loss = DPCRNLoss(train_clean, en_sp, train_clean_re, train_clean_im, en_re, en_im) 85 | 86 | optimizer.zero_grad() 87 | loss.backward() 88 | optimizer.step() 89 | record_loss += loss.item() 90 | step += 1 91 | if step % 50 == 0: 92 | print('Step {} in Epoch [{}/{}], training_loss: {:.4f}'.format(step, epoch, args.num_epochs, loss.item())) 93 | writer.add_scalar('training loss', loss.item(), writer_step) 94 | writer_step += 1 95 | 96 | train_loss = record_loss / len(train_data_loader) 97 | print('#Epoch [{}/{}], training_loss: {:.4f}'.format(epoch, args.num_epochs, train_loss)) 98 | writer.add_scalar('lr', get_lr(optimizer), epoch) 99 | writer.add_scalar('train_loss', train_loss, epoch) 100 | 101 | # validation 102 | with torch.no_grad(): 103 | net.eval() 104 | val_bar = tqdm(val_data_loader) 105 | val_loss = 0.0 106 | for input in val_bar: 107 | val_mixed, val_clean, seq_len = map(lambda x: x.cuda(), input) 108 | val_clean_re, val_clean_im = stft(val_clean) 109 | en_sp, en_re, en_im = net(val_mixed) 110 | 111 | loss = DPCRNLoss(val_clean, en_sp, val_clean_re, val_clean_im, en_re, en_im) 112 | val_loss += loss.item() 113 | 114 | val_loss = val_loss / len(val_data_loader) 115 | print('Epoch [{}/{}], validation_loss: {:.4f}'.format(epoch, args.num_epochs, val_loss)) 116 | writer.add_scalar('val_loss', val_loss, epoch) 117 | 118 | # learning-rate scheduler 119 | scheduler.step(val_loss) 120 | # Early-Stopping Check 121 | early_stopping(val_loss, net) 122 | if early_stopping.early_stop: 123 | print("Early Stopping") 124 | break 125 | 126 | # torch.save(net.state_dict(), './logs/DPCRN/final.pth') 127 | 128 | if __name__ == '__main__': 129 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | def get_hparams(init=True): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-c', '--config', type=str, default="./configs/dpcrn.json", 8 | help='JSON file for configuration') 9 | parser.add_argument('-m', '--model', type=str, default="DPCRN", 10 | help='Model name') 11 | 12 | args = parser.parse_args() 13 | model_dir = os.path.join("./logs", args.model) 14 | 15 | if not os.path.exists(model_dir): 16 | os.makedirs(model_dir) 17 | 18 | config_path = args.config 19 | config_save_path = os.path.join(model_dir, "config.json") 20 | if init: 21 | with open(config_path, "r") as f: 22 | data = f.read() 23 | with open(config_save_path, "w") as f: 24 | f.write(data) 25 | else: 26 | with open(config_save_path, "r") as f: 27 | data = f.read() 28 | config = json.loads(data) 29 | 30 | hparams = HParams(**config) 31 | hparams.model_dir = model_dir 32 | return hparams 33 | 34 | 35 | class HParams(): 36 | def __init__(self, **kwargs): 37 | for k, v in kwargs.items(): 38 | if type(v) == dict: 39 | v = HParams(**v) 40 | self[k] = v 41 | 42 | def keys(self): 43 | return self.__dict__.keys() 44 | 45 | def items(self): 46 | return self.__dict__.items() 47 | 48 | def values(self): 49 | return self.__dict__.values() 50 | 51 | def __len__(self): 52 | return len(self.__dict__) 53 | 54 | def __getitem__(self, key): 55 | return getattr(self, key) 56 | 57 | def __setitem__(self, key, value): 58 | return setattr(self, key, value) 59 | 60 | def __contains__(self, key): 61 | return key in self.__dict__ 62 | 63 | def __repr__(self): 64 | return self.__dict__.__repr__() --------------------------------------------------------------------------------