├── .gitignore ├── requirements.txt ├── utils.py ├── main.py ├── loss.py ├── inference.py ├── config.yaml ├── README.md ├── datamodule.py ├── AECNN.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | test 3 | __py 4 | *.ipynb 5 | logger 6 | output 7 | __pycache__ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==2.0.0 2 | PyYAML==6.0 3 | torch==2.0.0 4 | torchaudio==2.0.0 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def segmentation(x, frame_size, hop_size): 4 | x_seg =x.unfold(-1, frame_size, hop_size) 5 | x_seg = x_seg.transpose(0,1).contiguous() 6 | return x_seg 7 | 8 | 9 | def get_wav_paths(paths: list): 10 | wav_paths=[] 11 | if type(paths)==str: 12 | paths=[paths] 13 | 14 | for path in paths: 15 | for root, dirs, files in os.walk(path): 16 | wav_paths += [os.path.join(root,file) for file in files if os.path.splitext(file)[-1]=='.wav'] 17 | 18 | wav_paths.sort(key=lambda x: os.path.split(x)[-1]) 19 | 20 | return wav_paths 21 | 22 | def check_dir_exist(path_list): 23 | if type(path_list) == str: 24 | path_list = [path_list] 25 | 26 | for path in path_list: 27 | if type(path) == str and os.path.splitext(path)[-1] == '' and not os.path.exists(path): 28 | os.makedirs(path) 29 | 30 | def get_filename(path): 31 | return os.path.splitext(os.path.basename(path)) 32 | 33 | def get_one_sample_path(dir_noisy_path, dir_clean_path): 34 | wav_noisy_path = get_wav_paths(dir_noisy_path)[0] 35 | wav_clean_path = get_wav_paths(dir_clean_path)[0] 36 | return wav_noisy_path, wav_clean_path -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from train import SETrain 2 | 3 | from pytorch_lightning import loggers as pl_loggers 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from datamodule import SEDataModule 7 | 8 | import yaml 9 | from utils import * 10 | 11 | def main(args): 12 | pl.seed_everything(config['random_seed'], workers=True) 13 | se_datamodule = SEDataModule(config) 14 | se_train = SETrain(config) 15 | 16 | check_dir_exist(config['train']['output_dir_path']) 17 | check_dir_exist(config['train']['logger_path']) 18 | 19 | tb_logger = pl_loggers.TensorBoardLogger(config['train']['logger_path'], name=f"SE_logs") 20 | 21 | 22 | tb_logger.log_hyperparams(config) 23 | 24 | checkpoint_callback = ModelCheckpoint( 25 | filename = "{epoch}-{val_loss:.4f}", 26 | save_top_k = 1, 27 | mode = 'min', 28 | monitor = "val_loss" 29 | ) 30 | 31 | trainer=pl.Trainer(devices=config['train']['devices'], accelerator="gpu", strategy='ddp', 32 | max_epochs=config['train']['total_epoch'], 33 | callbacks= [checkpoint_callback], 34 | logger=tb_logger, 35 | profiler = "simple" 36 | ) 37 | 38 | trainer.fit(se_train, se_datamodule) 39 | 40 | if __name__ == "__main__": 41 | 42 | config = yaml.load(open("./config.yaml", 'r'), Loader=yaml.FullLoader) 43 | main(config) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class STFTLoss: 4 | def __init__(self, config): 5 | self.stft_risum = stft_RIsum( 6 | nfft = config['loss']['window_size'], 7 | window_size = config['loss']['window_size'], 8 | hop_size = config['loss']['hop_size'] 9 | ) 10 | 11 | def __call__(self, x_proc, x_orig): 12 | 13 | total_num = x_proc.shape[0] 14 | total_loss = 0 15 | 16 | for idx in range(total_num): 17 | x_noisy = x_proc[idx] 18 | x_target = x_orig[idx] 19 | loss = torch.mean(torch.abs(self.stft_risum(x_target) - self.stft_risum(x_noisy))) 20 | total_loss+=loss 21 | 22 | return total_loss/total_num 23 | 24 | class stft_RIsum: 25 | def __init__(self, nfft, window_size, hop_size): 26 | self.nfft = nfft 27 | self.window_size = window_size 28 | self.hop_size = hop_size 29 | 30 | def __call__(self, x): 31 | 32 | window = torch.hann_window(self.window_size).to(x.device) 33 | x_stft = torch.stft(x, n_fft = self.nfft, hop_length=self.hop_size, win_length=self.window_size, 34 | window = window, return_complex=True) 35 | real = x_stft[...,0] 36 | imag = x_stft[...,1] 37 | 38 | return torch.abs(real) + torch.abs(imag) 39 | 40 | 41 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio as ta 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from train import SETrain 8 | 9 | from utils import * 10 | import yaml 11 | 12 | def inference(config, args): 13 | 14 | se_train = SETrain.load_from_checkpoint(args.path_ckpt, config = config) 15 | se_train.aecnn.eval() 16 | 17 | if args.mode == "wav": 18 | wav_noisy, _ =ta.load(args.path_in) 19 | wav_enh = se_train.synth_one_sample(wav_noisy) 20 | wav_enh = wav_enh.cpu() 21 | 22 | filename = get_filename(args.path_in) 23 | ta.save(os.path.join(os.path.dirname(args.path_in),filename[0]+"_proc"+filename[1]), wav_enh, 16000) 24 | 25 | elif args.mode == "dir": 26 | check_dir_exist(args.path_out) 27 | 28 | path_wavs = get_wav_paths(args.path_in) 29 | for path_wav in path_wavs: 30 | wav_noisy, _ = ta.load(path_wav) 31 | wav_enh = se_train.synth_one_sample(wav_noisy) 32 | wav_enh = wav_enh.cpu() 33 | ta.save(os.path.join(args.path_out, os.path.basename(path_wav)), wav_enh, 16000) 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--path_ckpt", type = str) 38 | parser.add_argument("--mode", type = str, help = 'wav/dir', default = 'wav') 39 | parser.add_argument("--path_in", type = str, help = "path of input wav file or directory") 40 | parser.add_argument("--path_out", type = str, help = "path of directory of output file") 41 | 42 | args = parser.parse_args() 43 | config = yaml.load(open("./config.yaml", 'r'), Loader=yaml.FullLoader) 44 | 45 | inference(config, args) 46 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | 2 | #----------------------------------------------- 3 | #Config that does not have impact on performance 4 | #----------------------------------------------- 5 | 6 | 7 | random_seed: 0b011011 8 | 9 | #----------------------------------------------- 10 | #1. Dataset 11 | #----------------------------------------------- 12 | 13 | dataset: 14 | 15 | #directory that have every dataset in it. 16 | data_dir: /media/zeroone 17 | 18 | clean_train: "clean_trainset_28spk_wav_16k" 19 | clean_val: "clean_testset_wav_16k" 20 | 21 | noisy_train: "noisy_trainset_28spk_wav_16k" 22 | noisy_val: "noisy_testset_wav_16k" 23 | 24 | #So for this case, noisy validation dataset should be located at "/media/zeroone/noisy_testset_wav_16k" 25 | 26 | frame_size: 2048 27 | hop_size: 256 28 | 29 | batch_size: 256 30 | num_workers: 16 31 | 32 | #----------------------------------------------- 33 | #2. Model 34 | #----------------------------------------------- 35 | 36 | model: 37 | kernel_size: 11 38 | 39 | #----------------------------------------------- 40 | #3. Loss 41 | #----------------------------------------------- 42 | #for STFT Loss 43 | loss: 44 | window_size: 512 45 | hop_size: 256 46 | 47 | #----------------------------------------------- 48 | #4. Optimizer(ADAM) 49 | #----------------------------------------------- 50 | optim: 51 | initial_lr: 0.0002 52 | 53 | B1: 0.5 54 | B2: 0.9 55 | 56 | lr_gamma: 1 57 | 58 | #----------------------------------------------- 59 | #Training 60 | #----------------------------------------------- 61 | 62 | train: 63 | total_epoch: 100 64 | 65 | #Path of output of validation. 66 | output_dir_path: "./output" 67 | logger_path: "./logger" 68 | 69 | devices: 70 | - 0 71 | #- 1 72 | # -2 ... if you are using DDP 73 | 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A New Framework for CNN-Based Speech Enhancement in the Time Domain 2 | 3 | This repository contains the unofficial pytorch lightning implementation of the model described in the paper [A New Framework for CNN-Based Speech Enhancement in the Time Domain](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8701652) by Ashutosh Pandey and Deliang Wang. 4 | 5 | ## Requirements 6 | 7 | To run this code, you will need: 8 | 9 | - pytorch_lightning==2.0.0 10 | - PyYAML==6.0 11 | - torch==2.0.0 12 | - torchaudio==2.0.0 13 | 14 | To automatically install these libraries, run the following command: 15 | 16 | ```pip install -r requirements.txt``` 17 | 18 | ## Usage 19 | 20 | To run the code on your own machine, follow these steps: 21 | 22 | 1. Open the 'config.yaml' file and modify the file paths (and hyperparameters as needed). 23 | 2. Run the 'main.py' file to start training the model. 24 | 25 | The trained model will be saved as ckpt file in 'logger' directory. You can then use the trained model to perform real-time speech frequency bandwidth extension on your own audio wav file by running the 'inference.py' file as 26 | 27 | ```python inference.py --mode "wav" --path_ckpt --path_in ``` 28 | 29 | This repository also support directory-level inference, where the inference is performed on a directory consisting of wav files. You can use the following example to perform directory-level inference, 30 | 31 | ```python inference.py --mode "dir" --path_ckpt --path_in --path_out ``` 32 | 33 | ## Note 34 | - 2023.5.1 This code now supports Distributed Data Parallel (DDP) training! 35 | - 2023.4.28 The code has been modified to be compatible with PyTorch Lightning 2.0 environment! It includes support for inference as well. 36 | - Feel free to provide issues! 37 | 38 | ## Citation 39 | 40 | ```bibtex 41 | @ARTICLE{8701652, 42 | author={Pandey, Ashutosh and Wang, DeLiang}, 43 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 44 | title={A New Framework for CNN-Based Speech Enhancement in the Time Domain}, 45 | year={2019}, 46 | volume={27}, 47 | number={7}, 48 | pages={1179-1188}, 49 | doi={10.1109/TASLP.2019.2913512}} 50 | ``` 51 | -------------------------------------------------------------------------------- /datamodule.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import torchaudio as ta 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | 7 | from utils import * 8 | 9 | class SEDataset(Dataset): 10 | #데이터셋의 전처리를 해주는 부분 11 | def __init__(self, path_dir_noisy, path_dir_clean, frame_size, hop_size): 12 | self.path_dir_noisy = path_dir_noisy 13 | self.path_dir_clean = path_dir_clean 14 | 15 | self.wavs = [] 16 | 17 | paths_wav_noisy= get_wav_paths(self.path_dir_noisy) 18 | paths_wav_clean = get_wav_paths(self.path_dir_clean) 19 | 20 | for wav_idx, (path_wav_clean, path_wav_noisy) in enumerate(zip(paths_wav_clean, paths_wav_noisy)): 21 | print(f'\r{wav_idx} th file loaded', end='') 22 | wav_noisy, _ = ta.load(path_wav_noisy) 23 | wav_clean, _ = ta.load(path_wav_clean) 24 | 25 | wav_noisy_seg = segmentation(wav_noisy, frame_size, hop_size) 26 | wav_clean_seg = segmentation(wav_clean, frame_size, hop_size) 27 | 28 | for idx in range(wav_clean_seg.shape[0]): 29 | self.wavs.append([wav_noisy_seg[idx], wav_clean_seg[idx]]) 30 | 31 | # 총 데이터의 개수를 리턴 32 | def __len__(self): 33 | return len(self.wavs) 34 | 35 | # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴 36 | def __getitem__(self, idx): 37 | return self.wavs[idx] 38 | 39 | 40 | class SEDataModule(pl.LightningDataModule): 41 | def __init__(self, config): 42 | super().__init__() 43 | 44 | self.data_dir = config['dataset']['data_dir'] 45 | 46 | self.path_dir_noisy_train = config['dataset']['noisy_train'] 47 | self.path_dir_noisy_val = config['dataset']['noisy_val'] 48 | 49 | self.path_dir_clean_train = config['dataset']['clean_train'] 50 | self.path_dir_clean_val = config['dataset']['clean_val'] 51 | 52 | self.frame_size = config["dataset"]["frame_size"] 53 | self.hop_size = config["dataset"]["hop_size"] 54 | 55 | self.batch_size = config['dataset']['batch_size'] 56 | self.num_workers = config['dataset']['num_workers'] 57 | 58 | def prepare_data(self): 59 | pass 60 | 61 | def setup(self, stage=None): 62 | self.train_dataset = SEDataset( 63 | path_dir_noisy = os.path.join(self.data_dir, self.path_dir_noisy_train), 64 | path_dir_clean = os.path.join(self.data_dir, self.path_dir_clean_train), 65 | frame_size = self.frame_size, 66 | hop_size = self.hop_size 67 | ) 68 | 69 | 70 | self.val_dataset = SEDataset( 71 | path_dir_noisy = os.path.join(self.data_dir, self.path_dir_noisy_val), 72 | path_dir_clean = os.path.join(self.data_dir, self.path_dir_clean_val), 73 | frame_size = self.frame_size, 74 | hop_size = self.hop_size 75 | ) 76 | 77 | 78 | def train_dataloader(self): 79 | return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers) 80 | 81 | def val_dataloader(self): 82 | return DataLoader(self.val_dataset, batch_size = self.batch_size, num_workers = self.num_workers) 83 | 84 | def test_dataloader(self): 85 | pass -------------------------------------------------------------------------------- /AECNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class AECNN(nn.Module): 6 | def __init__(self, kernel_size): 7 | super().__init__() 8 | self.encoder = AECNNEncoder(kernel_size = kernel_size) 9 | self.decoder = AECNNDecoder(kernel_size = kernel_size) 10 | 11 | def forward(self, x): 12 | x, skip = self.encoder(x) 13 | output = self.decoder(x, skip) 14 | 15 | return output 16 | 17 | class AECNNEncoder(nn.Module): 18 | def __init__(self, kernel_size): 19 | super().__init__() 20 | self.kernel_size = kernel_size 21 | 22 | self.out_channels = [2**(6+idx//3) for idx in range(8)] #[64, 64, 64, 128, 128, 128, 256, 256] 23 | self.encoder_layers = nn.ModuleList( 24 | [(nn.Conv1d( 25 | in_channels = 1 if idx ==0 else self.out_channels[idx-1], 26 | out_channels = self.out_channels[idx], 27 | kernel_size = self.kernel_size, 28 | stride = 1 if idx == 0 else 2, 29 | padding = (self.kernel_size-1)// 2 30 | ))for idx in range(8)]) 31 | self.prelu_layers = nn.ModuleList( 32 | [( 33 | nn.PReLU() 34 | )for _ in range(8)] 35 | ) 36 | 37 | self.bottleneck = nn.Conv1d( 38 | in_channels = 256, 39 | out_channels = 256, 40 | kernel_size = self.kernel_size, 41 | stride = 2, 42 | padding = (self.kernel_size - 1) // 2 43 | ) 44 | 45 | def forward(self,x): 46 | skip = [] 47 | for idx in range(8): 48 | x =self.encoder_layers[idx](x) 49 | x = self.prelu_layers[idx](x) 50 | if idx%3 ==2: 51 | x = F.dropout(x, p = 0.2) 52 | skip.append(x) 53 | 54 | 55 | x = self.bottleneck(x) 56 | x = F.dropout(x, p = 0.2) 57 | 58 | return x, skip 59 | 60 | class AECNNDecoder(nn.Module): 61 | def __init__(self, kernel_size): 62 | super().__init__() 63 | self.kernel_size = kernel_size 64 | out_channels = [2**(6+idx//3) for idx in range(8)] 65 | self.out_channels = out_channels[::-1] 66 | 67 | self.decoder_layers = nn.ModuleList( 68 | [(nn.ConvTranspose1d( 69 | in_channels = 256 if idx ==0 else self.out_channels[idx-1]*2, 70 | out_channels = self.out_channels[idx], 71 | kernel_size = self.kernel_size, 72 | stride = 2, 73 | padding = (self.kernel_size-1)// 2, 74 | output_padding = 1 75 | ))for idx in range(8)]) 76 | 77 | self.prelu_layers = nn.ModuleList( 78 | [( 79 | nn.PReLU() 80 | )for _ in range(8)] 81 | ) 82 | 83 | self.output_layer = nn.Conv1d( 84 | in_channels = 128, 85 | out_channels = 1, 86 | kernel_size = self.kernel_size, 87 | stride = 1, 88 | padding = (self.kernel_size - 1) // 2 89 | ) 90 | 91 | def forward(self, x, skip): 92 | skip = skip[::-1] 93 | 94 | for idx in range(8): 95 | x = x if idx == 0 else torch.cat([x, skip[idx-1]], dim = 1) 96 | x = self.decoder_layers[idx](x) 97 | x = self.prelu_layers[idx](x) 98 | if idx%3 ==2: 99 | x = F.dropout(x, p = 0.2) 100 | 101 | x = torch.cat([x, skip[7]] ,dim = 1) 102 | x = self.output_layer(x) 103 | output = F.tanh(x) 104 | 105 | return output -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | 4 | import os 5 | import torch.nn.functional as F 6 | import torchaudio as ta 7 | from loss import * 8 | from utils import * 9 | from AECNN import AECNN 10 | 11 | class SETrain(pl.LightningModule): 12 | def __init__(self, config): 13 | super(SETrain, self).__init__() 14 | self.automatic_optimization = False 15 | self.config = config 16 | 17 | self.kernel_size = config['model']['kernel_size'] 18 | self.aecnn = AECNN(kernel_size = self.kernel_size) 19 | 20 | self.criterion = STFTLoss(config = config) 21 | 22 | #optimizer & scheduler parameters 23 | self.initial_lr = config['optim']['initial_lr'] 24 | self.lr_gamma = config['optim']['lr_gamma'] 25 | 26 | # 27 | self.frame_size = config["dataset"]["frame_size"] 28 | self.hop_size = config["dataset"]["hop_size"] 29 | 30 | #Sample for logging 31 | self.data_dir = config['dataset']['data_dir'] 32 | self.path_dir_noisy_val = config['dataset']['noisy_val'] 33 | self.path_dir_clean_val = config['dataset']['clean_val'] 34 | 35 | self.output_dir_path = config['train']['output_dir_path'] 36 | 37 | self.path_sample_noisy, self.path_sample_clean = get_one_sample_path(dir_noisy_path= os.path.join(self.data_dir, self.path_dir_noisy_val), dir_clean_path=os.path.join(self.data_dir, self.path_dir_clean_val)) 38 | 39 | def forward(self,x): 40 | output = self.aecnn(x) 41 | return output 42 | 43 | 44 | def configure_optimizers(self): 45 | optimizer = torch.optim.Adam(self.aecnn.parameters(), lr=self.initial_lr) 46 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.lr_gamma, verbose=True) 47 | return {'optimizer': optimizer, 'lr_scheduler': scheduler} 48 | 49 | def training_step(self, batch, batch_idx): 50 | optimizer = self.optimizers() 51 | scheduler = self.lr_schedulers() 52 | 53 | wav_noisy, wav_clean = batch 54 | wav_enh = self.forward(wav_noisy) 55 | 56 | loss = self.criterion(wav_enh, wav_clean) 57 | 58 | optimizer.zero_grad() 59 | self.manual_backward(loss) 60 | optimizer.step() 61 | 62 | self.log("train_loss", loss, prog_bar = True, batch_size = self.config['dataset']['batch_size']) 63 | 64 | if self.trainer.is_last_batch: 65 | scheduler.step() 66 | 67 | def validation_step(self, batch, batch_idx): 68 | wav_noisy, wav_clean = batch 69 | wav_enh = self.forward(wav_noisy) 70 | 71 | loss = self.criterion(wav_enh, wav_clean) 72 | 73 | self.log("val_loss", loss, batch_size = self.config['dataset']['batch_size'], sync_dist=True) 74 | 75 | def on_validation_epoch_end(self): 76 | 77 | sample_noisy, _ = ta.load(self.path_sample_noisy) 78 | sample_clean, _ = ta.load(self.path_sample_clean) 79 | sample_noisy = sample_noisy.to(self.device) 80 | sample_clean =sample_clean.to(self.device) 81 | 82 | sample_enh = self.synth_one_sample(sample_noisy) 83 | sample_enh = sample_enh.cpu() 84 | 85 | ta.save(f"{self.output_dir_path}/sample_{self.current_epoch}.wav", sample_enh, 16000) 86 | 87 | #My implementation is showing an error in logging audio. 88 | #It seems to be either an issue with the conda environment or with the code itself. 89 | #If possible to resolve, please leave a comment on the issue. Thank you. 90 | 91 | # self.logger.experiment.add_audio( 92 | # tag='sample/enhanced', 93 | # snd_tensor = sample_enh.squeeze().detach(), 94 | # global_step = self.global_step, 95 | # sample_rate = 16000 96 | # ) 97 | 98 | # self.logger.experiment.add_audio( 99 | # tag='sample/clean', 100 | # snd_tensor = sample_clean.squeeze().detach(), 101 | # global_step=self.global_step, 102 | # sample_rate = 16000 103 | # ) 104 | 105 | def test_step(self, batch, batch_idx): 106 | pass 107 | 108 | def predict_step(self, batch, batch_idx): 109 | pass 110 | 111 | def synth_one_sample(self, wav): 112 | wav = wav.unsqueeze(1) 113 | wav_padded = F.pad(wav, (0, self.frame_size), "constant", 0) 114 | wav_seg = wav_padded.unfold(-1,self.frame_size, self.hop_size) 115 | B, C, T, L = wav_seg.shape 116 | 117 | wav_seg = wav_seg.transpose(1,2).contiguous() 118 | wav_seg = wav_seg.view(B*T, C, L) 119 | 120 | wav_seg = self.forward(wav_seg) 121 | wav_seg.view(B,T,C,L).transpose(1,2).contiguous() 122 | wav_seg = wav_seg.view(B, C*T, L) 123 | 124 | wav_rec = F.fold( 125 | wav_seg.transpose(1,2).contiguous()*torch.hann_window(self.frame_size, device = wav_seg.device).view(1, -1, 1), 126 | output_size = [1, (wav_seg.shape[-2]-1)*self.hop_size + self.frame_size], 127 | kernel_size = (1, self.frame_size), 128 | stride = (1, self.hop_size) 129 | ).squeeze(-2) 130 | 131 | wav_rec = wav_rec / (self.frame_size/(2*self.hop_size)) 132 | 133 | wav_rec = wav_rec.squeeze(0) 134 | return wav_rec --------------------------------------------------------------------------------