├── .gitattributes ├── .gitignore ├── MelGAN.py ├── README.md ├── SEANet.py ├── config.yaml ├── datamodule.py ├── demo ├── wav1_bwe.wav ├── wav1_nb.wav ├── wav1_wb.wav ├── wav2_bwe.wav ├── wav2_nb.wav ├── wav2_wb.wav ├── wav3_bwe.wav ├── wav3_nb.wav ├── wav3_wb.wav ├── wav4_bwe.wav ├── wav4_nb.wav └── wav4_wb.wav ├── inference.py ├── main.py ├── requirements.txt ├── train.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | demo/final_model.ckpt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logger 2 | output 3 | __pycache__ 4 | *.ipynb 5 | logger_* 6 | output_* 7 | dockerfile -------------------------------------------------------------------------------- /MelGAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | class Discriminator_1D(nn.Module): 6 | def __init__(self, bias= True): 7 | 8 | super().__init__() 9 | self.conv1 = nn.Sequential( 10 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 1, out_channels = 16, 11 | kernel_size = 15, stride= 1, padding= 7, bias= bias)), 12 | nn.LeakyReLU(0.2)) 13 | self.conv2 = nn.Sequential( 14 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 16*1, out_channels = 64, 15 | kernel_size = 41, stride= 4, groups = 4, padding = 20, bias= bias)), 16 | nn.LeakyReLU(0.2)) 17 | self.conv3 = nn.Sequential( 18 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 64, out_channels = 256, 19 | kernel_size = 41, stride= 4, groups = 4, padding = 20, bias= bias)), 20 | nn.LeakyReLU(0.2)) 21 | self.conv4 = nn.Sequential( 22 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 256, out_channels = 1024, 23 | kernel_size = 41, stride= 4, groups = 4, padding = 20, bias= bias)), 24 | nn.LeakyReLU(0.2)) 25 | self.conv5 = nn.Sequential( 26 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 1024, out_channels = 1024, 27 | kernel_size = 41, stride= 4, groups = 4, padding = 20, bias= bias)), 28 | nn.LeakyReLU(0.2)) 29 | self.conv6 = nn.Sequential( 30 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 1024, out_channels = 1024, 31 | kernel_size = 5, stride= 1, groups = 1, padding = 2, bias= bias)), 32 | nn.LeakyReLU(0.2)) 33 | self.conv7 = nn.Sequential( 34 | th.nn.utils.weight_norm(nn.Conv1d(in_channels = 1024, out_channels = 1, 35 | kernel_size = 3, stride= 1, groups = 1, padding = 1, bias= bias))) 36 | 37 | 38 | def forward(self, x): 39 | while len(x.size()) <= 2: 40 | x = x.unsqueeze(-2) 41 | xs = [] 42 | x = self.conv1(x) 43 | xs.append(x) 44 | x = self.conv2(x) 45 | xs.append(x) 46 | x = self.conv3(x) 47 | xs.append(x) 48 | x = self.conv4(x) 49 | xs.append(x) 50 | x = self.conv5(x) 51 | xs.append(x) 52 | x = self.conv6(x) 53 | xs.append(x) 54 | x = self.conv7(x) 55 | xs.append(x) 56 | return x, xs 57 | 58 | class Discriminator_MelGAN(nn.Module): 59 | def __init__(self, **kwargs): 60 | super().__init__() 61 | 62 | self.disc = nn.ModuleList([Discriminator_1D(bias = True) for i in range(3)]) 63 | self.pool1 = nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) 64 | self.pool2 = nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False) 65 | 66 | def forward(self, x): 67 | 68 | while len(x.size()) <= 2: 69 | x = x.unsqueeze(-2) 70 | 71 | x1 = x 72 | x2 = self.pool1(x1) 73 | x3 = self.pool2(x2) 74 | 75 | d1, f1 = self.disc[0](x1) 76 | d2, f2 = self.disc[1](x2) 77 | d3, f3 = self.disc[2](x3) 78 | return (d1, d2, d3), (f1, f2, f3) 79 | 80 | def loss_D(self, x_proc, x_orig, *args, **kwargs): 81 | x_proc = x_proc.squeeze()[...,:x_orig.shape[-1]].detach() 82 | x_orig = x_orig.squeeze()[...,:x_proc.shape[-1]] 83 | 84 | D_proc, F_proc = self(x_proc) 85 | D_orig, F_orig = self(x_orig) 86 | 87 | loss = 0 88 | 89 | loss_GAN = [] 90 | for r in range(len(D_proc)): 91 | dist = (1-D_orig[r]).relu().mean() + (1+D_proc[r]).relu().mean() # Hinge loss 92 | 93 | loss_GAN.append(dist) 94 | loss_GAN = sum(loss_GAN)/len(loss_GAN) 95 | 96 | loss += loss_GAN 97 | 98 | return loss 99 | 100 | def loss_G(self, x_proc, x_orig, *args, **kwargs): 101 | x_proc = x_proc.squeeze()[...,:x_orig.shape[-1]] 102 | x_orig = x_orig.squeeze()[...,:x_proc.shape[-1]] 103 | 104 | D_proc, F_proc = self(x_proc) 105 | D_orig, F_orig = self(x_orig) 106 | 107 | loss_GAN = [] 108 | loss_FM = [] 109 | 110 | 111 | for r in range(len(D_proc)): 112 | 113 | loss_GAN.append((1-D_proc[r]).relu().mean()) 114 | 115 | for l in range(len(F_proc[r])-1): 116 | loss_FM.append((F_proc[r][l] - F_orig[r][l].detach()).abs().mean()) 117 | 118 | loss_GAN = sum(loss_GAN)/len(loss_GAN) 119 | loss_FM = sum(loss_FM)/len(loss_FM) 120 | 121 | loss = 100*loss_FM + loss_GAN 122 | return loss 123 | 124 | def get_name(self): 125 | return self.name 126 | 127 | if __name__ == "__main__": 128 | audio = torch.rand(4,1,64000) 129 | noisy = torch.rand(4,1,64000) 130 | 131 | melgan = Discriminator_MelGAN() 132 | print(melgan.loss_G(audio, noisy)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Real-time Speech Frequency Bandwidth Extension 2 | 3 | This repository contains the unofficial pytorch lightning implementation of the model described in the paper [Real-Time Speech Frequency Bandwidth Extension](https://arxiv.org/pdf/2010.10677.pdf) by Yunpeng Li et al. (2021). 4 | 5 | ## Requirements 6 | 7 | To run this code, you will need: 8 | 9 | - torch==2.0.0 10 | - pytorch_lightning==2.0.0 11 | - numpy==1.23.5 12 | - pesq==0.0.4 13 | - PyYAML==6.0 14 | - torchaudio==2.0.0 15 | 16 | To automatically install these libraries, run the following command: 17 | 18 | ```pip install -r requirements.txt``` 19 | 20 | ## Usage 21 | 22 | To run the code on your own machine, follow these steps: 23 | 24 | 1. Open the 'config.yaml' file and modify the file paths, training configurations, and hyperparameters as needed. 25 | 2. Run the 'main.py' file to start training the model. 26 | 27 | The trained model will be saved as ckpt file in 'logger' directory. You can then use the trained model to perform real-time speech frequency bandwidth extension on your own audio wav file by running the 'inference.py' file as 28 | 29 | ```python inference.py --mode wav --path_ckpt --path_in ``` 30 | 31 | This repository also support directory-level inference, where the inference is performed on a directory consisting of wav files. Before running directory-level inference, it is necessary to modify the 'predict' section of the config.yaml file. You can use the following example to perform directory-level inference, 32 | 33 | ```python inference.py --mode dir --path_ckpt ``` 34 | 35 | ## Note 36 | - 2023.5.1 This code now supports Distributed Data Parallel (DDP) training! 37 | - This implementation does not include streaming convolution and uses the conventional causal convolution instead. Although this deviates from the contributions of the original paper, I am focusing on verifying the bandwidth extension performance of this model. 38 | - The original paper conducted training for 1 million steps, whereas this implementation trained for 350 epochs for personal research convenience. The number of epochs can be adjusted arbitrarily. 39 | - Feel free to provide issues! 40 | 41 | ## Citation 42 | 43 | ```bibtex 44 | @inproceedings{SEANetBWE21, 45 | title={Real-time speech frequency bandwidth extension}, 46 | author={Li, Yunpeng and Tagliasacchi, Marco and Rybakov, Oleg and Ungureanu, Victor and Roblek, Dominik}, 47 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 48 | pages={691--695}, 49 | year={2021}, 50 | organization={IEEE} 51 | } 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /SEANet.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SEANet(nn.Module): 8 | def __init__(self, min_dim=8, **kwargs): 9 | super().__init__() 10 | 11 | self.min_dim = min_dim 12 | 13 | self.conv_in = Conv1d( 14 | in_channels = 1, 15 | out_channels = min_dim, 16 | kernel_size = 7, 17 | stride = 1 18 | ) 19 | 20 | self.encoder = nn.ModuleList([ 21 | EncBlock(min_dim*2, 2), 22 | EncBlock(min_dim*4, 2), 23 | EncBlock(min_dim*8, 8), 24 | EncBlock(min_dim*16, 8) 25 | ]) 26 | 27 | self.conv_bottle = nn.Sequential( 28 | Conv1d( 29 | in_channels=min_dim*16, 30 | out_channels = min_dim*16//4, 31 | kernel_size = 7, 32 | stride = 1, 33 | ), 34 | 35 | Conv1d( 36 | in_channels=min_dim*16//4, 37 | out_channels = min_dim*16, 38 | kernel_size = 7, 39 | stride = 1, 40 | ), 41 | ) 42 | 43 | self.decoder = nn.ModuleList([ 44 | DecBlock(min_dim*8, 8), 45 | DecBlock(min_dim*4, 8), 46 | DecBlock(min_dim*2, 2), 47 | DecBlock(min_dim, 2), 48 | ]) 49 | 50 | self.conv_out = Conv1d( 51 | in_channels = min_dim, 52 | out_channels = 1, 53 | kernel_size = 7, 54 | stride = 1, 55 | ) 56 | 57 | def forward(self, x): 58 | 59 | while len(x.size()) < 3: 60 | x = x.unsqueeze(-2) 61 | 62 | skip = [x] 63 | 64 | x = self.conv_in(x) 65 | skip.append(x) 66 | 67 | for encoder in self.encoder: 68 | x = encoder(x) 69 | skip.append(x) 70 | 71 | x = self.conv_bottle(x) 72 | 73 | skip = skip[::-1] 74 | 75 | for l in range(len(self.decoder)): 76 | x = x + skip[l] 77 | x = self.decoder[l](x) 78 | 79 | x = x + skip[4] 80 | x = self.conv_out(x) 81 | 82 | x = x + skip[5] 83 | return x 84 | 85 | 86 | class EncBlock(nn.Module): 87 | def __init__(self, out_channels, stride): 88 | super().__init__() 89 | 90 | 91 | self.res_units = nn.ModuleList([ 92 | ResUnit(out_channels//2, 1), 93 | ResUnit(out_channels//2, 3), 94 | ResUnit(out_channels//2, 9) 95 | ]) 96 | 97 | self.conv = nn.Sequential( 98 | nn.ELU(), 99 | Pad((2 * stride - 1, 0)), 100 | nn.Conv1d(in_channels = out_channels//2, 101 | out_channels = out_channels, 102 | kernel_size = 2 * stride, 103 | stride = stride, padding = 0), 104 | ) 105 | 106 | 107 | def forward(self, x): 108 | 109 | for res_unit in self.res_units: 110 | x = res_unit(x) 111 | x = self.conv(x) 112 | 113 | return x 114 | 115 | 116 | class DecBlock(nn.Module): 117 | def __init__(self, out_channels, stride): 118 | super().__init__() 119 | 120 | 121 | self.conv = ConvTransposed1d( 122 | in_channels = out_channels*2, 123 | out_channels = out_channels, 124 | kernel_size = 2*stride, stride= stride, 125 | dilation = 1, 126 | ) 127 | 128 | 129 | self.res_units = nn.ModuleList([ 130 | ResUnit(out_channels, 1), 131 | ResUnit(out_channels, 3), 132 | ResUnit(out_channels, 9) 133 | ]) 134 | 135 | self.stride = stride 136 | 137 | 138 | def forward(self, x): 139 | x = self.conv(x) 140 | for res_unit in self.res_units: 141 | x = res_unit(x) 142 | return x 143 | 144 | 145 | class ResUnit(nn.Module): 146 | def __init__(self, channels, dilation = 1): 147 | super().__init__() 148 | 149 | 150 | self.conv_in = Conv1d( 151 | in_channels = channels, 152 | out_channels = channels, 153 | kernel_size = 3, stride= 1, 154 | dilation = dilation, 155 | ) 156 | 157 | self.conv_out = Conv1d( 158 | in_channels = channels, 159 | out_channels = channels, 160 | kernel_size = 1, stride= 1, 161 | ) 162 | 163 | self.conv_shortcuts = Conv1d( 164 | in_channels = channels, 165 | out_channels = channels, 166 | kernel_size = 1, stride= 1, 167 | ) 168 | 169 | 170 | 171 | def forward(self, x): 172 | y = self.conv_in(x) 173 | y = self.conv_out(y) 174 | x = self.conv_shortcuts(x) 175 | return x + y 176 | 177 | 178 | class Conv1d(nn.Module): 179 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation = 1, groups = 1): 180 | super().__init__() 181 | 182 | self.conv = nn.Conv1d( 183 | in_channels = in_channels, 184 | out_channels = out_channels, 185 | kernel_size= kernel_size, 186 | stride= stride, 187 | dilation = dilation, 188 | groups = groups 189 | ) 190 | self.conv = nn.utils.weight_norm(self.conv) 191 | 192 | self.pad = Pad(((kernel_size-1)*dilation, 0)) 193 | self.activation = nn.ELU() 194 | 195 | 196 | def forward(self, x): 197 | 198 | x = self.pad(x) 199 | x = self.conv(x) 200 | x = self.activation(x) 201 | 202 | return x 203 | 204 | class ConvTransposed1d(nn.Module): 205 | def __init__(self, in_channels, out_channels, kernel_size = 1, stride = 1, dilation = 1): 206 | super().__init__() 207 | self.conv = nn.ConvTranspose1d( 208 | in_channels = in_channels, 209 | out_channels = out_channels, 210 | kernel_size = kernel_size, 211 | stride =stride, 212 | dilation = dilation 213 | ) 214 | self.conv = nn.utils.weight_norm(self.conv) 215 | 216 | self.activation = nn.ELU() 217 | self.pad = dilation * (kernel_size - 1) - dilation * (stride - 1) 218 | 219 | def forward(self, x): 220 | x = self.conv(x) 221 | x = x[..., :-self.pad] 222 | x = self.activation(x) 223 | return x 224 | 225 | class Pad(nn.Module): 226 | def __init__(self, pad): 227 | super().__init__() 228 | self.pad = pad 229 | 230 | def forward(self, x): 231 | return F.pad(x, pad=self.pad) 232 | 233 | 234 | if __name__ == "__main__": 235 | 236 | model = SEANet(in_channels=1, out_channels=1, min_dim = 32) 237 | wav = torch.rand(4,1,55400) 238 | output = model(wav) 239 | print(f"output_shape: {output.shape}") -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | 2 | #----------------------------------------------- 3 | #Config that does not have impact on performance 4 | #----------------------------------------------- 5 | 6 | 7 | random_seed: 0b011011 8 | 9 | #----------------------------------------------- 10 | #1. Dataset 11 | #----------------------------------------------- 12 | 13 | dataset: 14 | 15 | #directory that have every dataset in it. 16 | data_dir: "/media/zeroone" 17 | 18 | 19 | nb_train: "downsampled/train" 20 | nb_val: "downsampled/test" 21 | 22 | wb_train: "target/train" 23 | wb_val: "target/test" 24 | 25 | #So for this case, wideband test dataset should be located at "/media/zeroone/target/test" 26 | 27 | batch_size: 16 28 | seg_len: 2 29 | 30 | num_workers: 16 31 | 32 | #----------------------------------------------- 33 | #2. Model 34 | #----------------------------------------------- 35 | 36 | #No information 37 | 38 | #----------------------------------------------- 39 | #3. Loss 40 | #----------------------------------------------- 41 | 42 | #No information 43 | 44 | #----------------------------------------------- 45 | #4. Optimizer (ADAM) 46 | #----------------------------------------------- 47 | 48 | optim: 49 | learning_rate: 0.0001 50 | 51 | B1: 0.5 52 | B2: 0.9 53 | 54 | 55 | #----------------------------------------------- 56 | #Training 57 | #----------------------------------------------- 58 | 59 | train: 60 | epoch_save_start: 0 61 | val_epoch: 50 62 | 63 | #Path of output of validation. 64 | output_dir_path: "./output" 65 | logger_path: "./logger" 66 | 67 | max_epochs: 350 68 | 69 | devices: 70 | - 0 71 | #- 1 72 | # -2 ... if you are using DDP 73 | 74 | #----------------------------------------------- 75 | #Predict 76 | #----------------------------------------------- 77 | predict: 78 | nb_pred_path: "/media/youngwon/Neo/NeoChoi/Projects/test/pred_nb" #The path to the directory containing the WAV files 79 | pred_output_path: "/media/youngwon/Neo/NeoChoi/Projects/test/pred_output" #The path to the directory where the output files will be saved. -------------------------------------------------------------------------------- /datamodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset, DataLoader 4 | import torchaudio as ta 5 | import numpy as np 6 | import os 7 | 8 | import pytorch_lightning as pl 9 | 10 | 11 | from utils import * 12 | 13 | 14 | class RTBWEDataset(Dataset): 15 | #데이터셋의 전처리를 해주는 부분 16 | def __init__(self, path_dir_nb, path_dir_wb, seg_len, mode="train"): 17 | self.path_dir_nb = path_dir_nb 18 | self.path_dir_wb = path_dir_wb 19 | 20 | self.seg_len = seg_len 21 | self.mode = mode 22 | 23 | self.wavs={} 24 | self.filenames= [] 25 | 26 | paths_wav_wb= get_wav_paths(self.path_dir_wb) 27 | paths_wav_nb= get_wav_paths(self.path_dir_nb) 28 | 29 | if mode == "pred": 30 | for path_wav_nb in paths_wav_nb: 31 | filename=get_filename(path_wav_nb)[0] 32 | 33 | wav_nb, self.sr_nb = ta.load(path_wav_nb) 34 | 35 | if self.sr_nb != 8000: 36 | wav_nb = ta.functional.resample(wav_nb, self.sr_nb, 8000) 37 | 38 | self.wavs[filename]=(None , wav_nb) 39 | self.filenames.append(filename) 40 | print(f'\r {mode}: {len(self.filenames)} th file loaded', end='') 41 | 42 | else: 43 | for path_wav_wb, path_wav_nb in zip(paths_wav_wb, paths_wav_nb): 44 | filename=get_filename(path_wav_wb)[0] 45 | wav_nb, self.sr_nb = ta.load(path_wav_nb) 46 | wav_wb, self.sr_wb = ta.load(path_wav_wb) 47 | 48 | if self.sr_nb != 8000: 49 | wav_nb = ta.functional.resample(wav_nb, self.sr_nb, 8000) 50 | if self.sr_wb != 16000: 51 | wav_wb = ta.functional.resample(wav_wb, self.sr_wb, 16000) 52 | 53 | self.wavs[filename]=(wav_wb, wav_nb) 54 | self.filenames.append(filename) 55 | print(f'\r {mode}: {len(self.filenames)} th file loaded', end='') 56 | 57 | self.filenames.sort() 58 | 59 | 60 | 61 | # 총 데이터의 개수를 리턴 62 | def __len__(self): 63 | return len(self.filenames) 64 | 65 | 66 | # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴 67 | def __getitem__(self, idx): 68 | 69 | filename = self.filenames[idx] 70 | (wav_wb, wav_nb) = self.wavs[filename] 71 | 72 | 73 | if self.seg_len > 0 and self.mode == "train": 74 | duration= int(self.seg_len * 16000) 75 | 76 | wav_wb= wav_wb.view(1,-1) 77 | wav_nb = wav_nb.view(1,-1) 78 | 79 | sig_len = wav_wb.shape[-1] 80 | 81 | t_start = np.random.randint( 82 | low = 0, 83 | high= np.max([1, sig_len- duration - 2]), 84 | size = 1 85 | )[0] 86 | 87 | if t_start % 2 ==1: 88 | t_start -= 1 89 | 90 | t_end = t_start + duration 91 | 92 | 93 | wav_nb = wav_nb.repeat(1, t_end// sig_len + 1) [ ..., t_start//2 : t_end//2] 94 | wav_wb = wav_wb.repeat(1, t_end // sig_len + 1) [ ..., t_start : t_end] 95 | else: 96 | wav_wb= wav_wb.view(1,-1) 97 | wav_nb = wav_nb.view(1,-1) 98 | 99 | #wav_wb 패딩 100 | nb_padding = 256 - len(wav_nb[-1])%256 101 | wav_nb = torch.cat([wav_nb, torch.zeros((1, nb_padding))], dim=1) 102 | 103 | #wav_wb 패딩 104 | wb_len = wav_nb.shape[1]*2 105 | wb_padding = wb_len - len(wav_wb[-1]) 106 | wav_wb = torch.cat([wav_wb, torch.zeros((1, wb_padding))], dim=1) 107 | 108 | 109 | return wav_nb, wav_wb, filename 110 | 111 | 112 | class RTBWEDataModule(pl.LightningDataModule): 113 | def __init__(self, config): 114 | super().__init__() 115 | 116 | self.data_dir = config['dataset']['data_dir'] 117 | 118 | self.path_dir_nb_train = config['dataset']['nb_train'] 119 | self.path_dir_nb_val = config['dataset']['nb_val'] 120 | self.path_dir_wb_train = config['dataset']['wb_train'] 121 | self.path_dir_wb_val = config['dataset']['wb_val'] 122 | 123 | self.path_dir_nb_pred = config['predict']['nb_pred_path'] 124 | 125 | 126 | 127 | self.batch_size = config['dataset']['batch_size'] 128 | self.seg_len = config['dataset']['seg_len'] 129 | 130 | self.num_workers = config['dataset']['num_workers'] 131 | 132 | 133 | def prepare_data(self): 134 | pass 135 | 136 | def setup(self, stage=None): 137 | self.train_dataset =RTBWEDataset( 138 | path_dir_nb = os.path.join(self.data_dir, self.path_dir_nb_train), 139 | path_dir_wb = os.path.join(self.data_dir, self.path_dir_wb_train), 140 | seg_len = self.seg_len, 141 | mode = "train" 142 | ) 143 | 144 | self.val_dataset = RTBWEDataset( 145 | path_dir_nb = os.path.join(self.data_dir, self.path_dir_nb_val), 146 | path_dir_wb = os.path.join(self.data_dir, self.path_dir_wb_val), 147 | seg_len = self.seg_len, 148 | mode = "val" 149 | ) 150 | 151 | 152 | def train_dataloader(self): 153 | return DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True, num_workers = self.num_workers) 154 | 155 | def val_dataloader(self): 156 | return DataLoader(self.val_dataset, batch_size = 1, num_workers = self.num_workers) 157 | -------------------------------------------------------------------------------- /demo/wav1_bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav1_bwe.wav -------------------------------------------------------------------------------- /demo/wav1_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav1_nb.wav -------------------------------------------------------------------------------- /demo/wav1_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav1_wb.wav -------------------------------------------------------------------------------- /demo/wav2_bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav2_bwe.wav -------------------------------------------------------------------------------- /demo/wav2_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav2_nb.wav -------------------------------------------------------------------------------- /demo/wav2_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav2_wb.wav -------------------------------------------------------------------------------- /demo/wav3_bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav3_bwe.wav -------------------------------------------------------------------------------- /demo/wav3_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav3_nb.wav -------------------------------------------------------------------------------- /demo/wav3_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav3_wb.wav -------------------------------------------------------------------------------- /demo/wav4_bwe.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav4_bwe.wav -------------------------------------------------------------------------------- /demo/wav4_nb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav4_nb.wav -------------------------------------------------------------------------------- /demo/wav4_wb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroone-universe/RealTimeBWE/7515a7ffa693ee65c3c40a25ee9173cd4028c3a1/demo/wav4_wb.wav -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchaudio as ta 3 | import pytorch_lightning as pl 4 | 5 | from train import RTBWETrain 6 | from datamodule import * 7 | from utils import * 8 | import yaml 9 | 10 | def inference(config, args): 11 | 12 | rtbwe_train = RTBWETrain.load_from_checkpoint(args.path_ckpt, config = config) 13 | 14 | if args.mode == 'wav': 15 | wav_nb, sr_nb = ta.load(args.path_in) 16 | wav_nb = wav_nb.unsqueeze(0) 17 | rtbwe_train.generator.eval() 18 | wav_bwe = rtbwe_train.forward(wav_nb) 19 | 20 | filename = get_filename(args.path_in) 21 | ta.save(os.path.join(os.path.dirname(args.path_in),filename[0]+"_bwe"+filename[1]), wav_bwe.squeeze(0), sr_nb*2) 22 | 23 | if args.mode == 'dir': 24 | 25 | pred_dataset = RTBWEDataset( 26 | path_dir_nb = config["predict"]["nb_pred_path"], 27 | path_dir_wb = config["predict"]["nb_pred_path"], 28 | mode = "pred" 29 | ) 30 | trainer = pl.Trainer(devices=1, accelerator="gpu", logger = False) 31 | 32 | trainer.predict(rtbwe_train, pred_dataset) 33 | 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--mode", type = str, help = "wav/dir", default = "wav") 39 | parser.add_argument("--path_ckpt", type = str) 40 | parser.add_argument("--path_in", type = str, help = "path of wav file or directory") 41 | args = parser.parse_args() 42 | 43 | config = yaml.load(open("./config.yaml", 'r'), Loader=yaml.FullLoader) 44 | 45 | inference(config, args) 46 | 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | from datamodule import RTBWEDataModule 3 | 4 | from train import RTBWETrain 5 | 6 | from pytorch_lightning import loggers as pl_loggers 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | 10 | from utils import * 11 | import yaml 12 | 13 | 14 | def main(config): 15 | 16 | pl.seed_everything(config['random_seed'], workers=True) 17 | rtbwe_datamodule = RTBWEDataModule(config) 18 | rtbwe_train = RTBWETrain(config) 19 | 20 | check_dir_exist(config['train']['output_dir_path']) 21 | check_dir_exist(config['train']['logger_path']) 22 | 23 | tb_logger = pl_loggers.TensorBoardLogger(config['train']['logger_path'], name=f"RTBWE_logs") 24 | 25 | checkpoint_callback = ModelCheckpoint( 26 | filename = "{epoch}-{val_pesq_wb:.2f}-{val_pesq_nb:.2f}", 27 | save_top_k = -1, 28 | every_n_epochs = config['train']['val_epoch']) 29 | 30 | tb_logger.log_hyperparams(config) 31 | 32 | trainer=pl.Trainer(devices=config['train']['devices'], accelerator="gpu", strategy='ddp_find_unused_parameters_true', 33 | callbacks= [checkpoint_callback], 34 | max_epochs=config['train']['max_epochs'], 35 | logger=tb_logger, 36 | check_val_every_n_epoch=config['train']['val_epoch'] 37 | ) 38 | 39 | trainer.fit(rtbwe_train, rtbwe_datamodule) 40 | trainer.save_checkpoint(os.path.join(config['train']['output_dir_path'],'final_model.ckpt')) 41 | 42 | if __name__ == "__main__": 43 | config = yaml.load(open("./config.yaml", 'r'), Loader=yaml.FullLoader) 44 | 45 | main(config) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | pesq==0.0.4 3 | pytorch_lightning==2.0.0 4 | PyYAML==6.0 5 | torch==2.0.0 6 | torchaudio==2.0.0 7 | tensorboardX -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | import torchaudio.transforms as T 5 | 6 | import torchaudio as ta 7 | 8 | import os 9 | 10 | from MelGAN import Discriminator_MelGAN 11 | from SEANet import SEANet 12 | 13 | 14 | from utils import * 15 | 16 | from pesq import pesq 17 | 18 | class RTBWETrain(pl.LightningModule): 19 | def __init__(self, config): 20 | super(RTBWETrain, self).__init__() 21 | self.config = config 22 | 23 | self.lr = config['optim']['learning_rate'] 24 | self.B1 = config['optim']['B1'] 25 | self.B2 = config['optim']['B2'] 26 | 27 | self.resampler = T.Resample(8000, 16000) 28 | 29 | self.generator = SEANet(min_dim = 8, causality = True) 30 | 31 | self.output_dir_path = config['train']['output_dir_path'] 32 | 33 | self.epoch_save_start = config['train']['epoch_save_start'] 34 | self.val_epoch = config['train']['val_epoch'] 35 | 36 | self.path_dir_bwe_pred = config['predict']['pred_output_path'] 37 | 38 | self.discriminator = Discriminator_MelGAN() 39 | 40 | 41 | self.automatic_optimization = False 42 | 43 | def forward(self,x): 44 | x = self.resampler(x) 45 | output = self.generator(x) 46 | 47 | return output 48 | 49 | def configure_optimizers(self): 50 | optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr, betas = (self.B1, self.B2)) 51 | optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas = (self.B1, self.B2)) 52 | 53 | return optimizer_d, optimizer_g #, [lrscheduler_d, lr_scheduler_g]) 54 | 55 | 56 | def training_step(self, batch, batch_idx): 57 | optimizer_d, optimizer_g = self.optimizers() 58 | 59 | wav_nb, wav_wb, _ = batch 60 | 61 | wav_bwe = self.forward(wav_nb) 62 | 63 | #optimize discriminator 64 | 65 | self.toggle_optimizer(optimizer_d) 66 | 67 | loss_d =self.discriminator.loss_D(wav_bwe, wav_wb) 68 | 69 | optimizer_d.zero_grad() 70 | self.manual_backward(loss_d) 71 | optimizer_d.step() 72 | 73 | self.untoggle_optimizer(optimizer_d) 74 | 75 | #optimize generator 76 | 77 | self.toggle_optimizer(optimizer_g) 78 | 79 | loss_g = self.discriminator.loss_G(wav_bwe, wav_wb) 80 | 81 | optimizer_g.zero_grad() 82 | self.manual_backward(loss_g) 83 | optimizer_g.step() 84 | 85 | self.untoggle_optimizer(optimizer_g) 86 | 87 | self.log("train_loss_d", loss_d, prog_bar = True, batch_size = self.config['dataset']['batch_size']) 88 | self.log("train_loss_g", loss_g, prog_bar = True, batch_size = self.config['dataset']['batch_size']) 89 | 90 | 91 | 92 | def validation_step(self, batch, batch_idx): 93 | 94 | wav_nb, wav_wb, filename = batch 95 | 96 | wav_bwe = self.forward(wav_nb) 97 | 98 | 99 | loss_d = self.discriminator.loss_D(wav_bwe, wav_wb) 100 | loss_g = self.discriminator.loss_G(wav_bwe, wav_wb) 101 | 102 | 103 | wav_bwe_cpu = wav_bwe.squeeze(0).cpu() 104 | val_dir_path = f"{self.output_dir_path}/epoch_current" 105 | check_dir_exist(val_dir_path) 106 | ta.save(os.path.join(val_dir_path, f"{filename[0]}.wav"), wav_bwe_cpu, 16000) 107 | 108 | wav_wb = wav_wb.squeeze().cpu().numpy() 109 | wav_bwe = wav_bwe.squeeze().cpu().numpy() 110 | 111 | val_pesq_wb = pesq(fs = 16000, ref = wav_wb, deg = wav_bwe, mode = "wb") 112 | val_pesq_nb = pesq(fs = 16000, ref = wav_wb, deg = wav_bwe, mode = "nb") 113 | 114 | self.log_dict({"val_loss/val_loss_d": loss_d, "val_loss/val_loss_g": loss_g}, batch_size = 1, sync_dist=True) 115 | self.log('val_pesq_wb', val_pesq_wb, batch_size = 1, sync_dist=True) 116 | self.log('val_pesq_nb', val_pesq_nb, batch_size = 1, sync_dist=True) 117 | 118 | 119 | def test_step(self, batch, batch_idx): 120 | pass 121 | 122 | 123 | 124 | def predict_step(self, batch, batch_idx): 125 | wav_nb, _, filename = batch 126 | 127 | wav_bwe = self.forward(wav_nb) 128 | 129 | wav_bwe_cpu = wav_bwe.squeeze(0).cpu() 130 | test_dir_path = self.path_dir_bwe_pred 131 | check_dir_exist(test_dir_path) 132 | ta.save(os.path.join(test_dir_path, f"{filename}.wav"), wav_bwe_cpu, 16000) 133 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_wav_paths(paths: list): 4 | wav_paths=[] 5 | if type(paths)==str: 6 | paths=[paths] 7 | 8 | for path in paths: 9 | for root, dirs, files in os.walk(path): 10 | wav_paths += [os.path.join(root,file) for file in files if os.path.splitext(file)[-1]=='.wav'] 11 | 12 | wav_paths.sort(key=lambda x: os.path.split(x)[-1]) 13 | 14 | return wav_paths 15 | 16 | def check_dir_exist(path_list): 17 | if type(path_list) == str: 18 | path_list = [path_list] 19 | 20 | for path in path_list: 21 | if type(path) == str and os.path.splitext(path)[-1] == '' and not os.path.exists(path): 22 | os.makedirs(path) 23 | 24 | def get_filename(path): 25 | return os.path.splitext(os.path.basename(path)) 26 | 27 | --------------------------------------------------------------------------------