├── README.md ├── examples └── filelists │ ├── tr_list.txt │ └── tt_list.txt ├── requirements.txt └── scripts ├── configs.py ├── measure.py ├── run_evaluate.sh ├── run_train.sh ├── test.py ├── train.py └── utils ├── criteria.py ├── data_utils.py ├── metrics.py ├── models.py ├── networks.py ├── pipeline_modules.py ├── stft.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # A Convolutional Recurrent Neural Network for Real-Time Speech Enhancement 2 | 3 | This repository provides an implementation of the convolutional recurrent network (CRN) for monaural speech enhancement, developed in ["A convolutional recurrent neural network for real-time speech enhancement"](https://web.cse.ohio-state.edu/~wang.77/papers/Tan-Wang1.interspeech18.pdf), Proceedings of Interspeech, pp. 3229-3233, 2018. In the paper, a causal convolutional recurrent network was proposed to perform spectral mapping, which combines a convolutional encoder-decoder and long short-term memory. 4 | 5 | ## Installation 6 | The program is developed using Python 3.7. 7 | Clone this repo, and install the dependencies: 8 | ``` 9 | git clone https://github.com/JupiterEthan/CRN-causal.git 10 | cd CRN-causal 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Data preparation 15 | To use this program, data and file lists need to be prepared. If configured correctly, the directory tree should look like this: 16 | ``` 17 | . 18 | ├── data 19 | │   └── datasets 20 | │   ├── cv 21 | │   │   └── cv.ex 22 | │   ├── tr 23 | │   │   ├── tr_0.ex 24 | │   │   ├── tr_1.ex 25 | │   │   ├── tr_2.ex 26 | │   │   ├── tr_3.ex 27 | │   │   └── tr_4.ex 28 | │   └── tt 29 | │   ├── tt_snr0.ex 30 | │   ├── tt_snr-5.ex 31 | │   └── tt_snr5.ex 32 | ├── examples 33 | │   └── filelists 34 | │   ├── tr_list.txt 35 | │   └── tt_list.txt 36 | ├── filelists 37 | │   ├── tr_list.txt 38 | │   └── tt_list.txt 39 | ├── README.md 40 | ├── requirements.txt 41 | └── scripts 42 | ├── configs.py 43 | ├── measure.py 44 | ├── run_evaluate.sh 45 | ├── run_train.sh 46 | ├── test.py 47 | ├── train.py 48 | └── utils 49 | ├── criteria.py 50 | ├── data_utils.py 51 | ├── metrics.py 52 | ├── models.py 53 | ├── networks.py 54 | ├── pipeline_modules.py 55 | ├── stft.py 56 | └── utils.py 57 | ``` 58 | You will find that some files above are missing in your directory tree. Those are for you to prepare. Don't worry. Follow these instructions: 59 | 1. Write your own scripts to prepare data for training, validation and testing. 60 | - For the training set, each example needs to be saved into an HDF5 file, which contains two HDF5 datasets, named ```mix``` and ```sph``` respectively. ```mix``` stores a noisy mixture utterance, ```sph``` the corresponding clean speech utterance. 61 | - Example code: 62 | ``` 63 | import os 64 | 65 | import h5py 66 | import numpy as np 67 | 68 | 69 | # some settings 70 | ... 71 | rms = 1.0 72 | 73 | for idx in range(n_tr_ex): # n_tr_ex is the number of training examples 74 | # generate a noisy mixture 75 | ... 76 | mix = sph + noi 77 | # normalize 78 | c = rms * np.sqrt(mix.size / np.sum(mix**2)) 79 | mix *= c 80 | sph *= c 81 | 82 | filename = 'tr_{}.ex'.format(idx) 83 | writer = h5py.File(os.path.join(filepath, filename), 'w') 84 | writer.create_dataset('mix', data=mix.astype(np.float32), shape=mix.shape, chunks=True) 85 | writer.create_dataset('sph', data=sph.astype(np.float32), shape=sph.shape, chunks=True) 86 | writer.close() 87 | ``` 88 | - For the validation set, all examples need to be saved into a single HDF5 file, each of which is stored in a HDF5 group. Each group contains two HDF5 datasets, one named ```mix``` and the other named ```sph```. 89 | - Example code: 90 | ``` 91 | import os 92 | 93 | import h5py 94 | import numpy as np 95 | 96 | 97 | # some settings 98 | ... 99 | rms = 1.0 100 | 101 | filename = 'cv.ex' 102 | writer = h5py.File(os.path.join(filepath, filename), 'w') 103 | for idx in range(n_cv_ex): 104 | # generate a noisy mixture 105 | ... 106 | mix = sph + noi 107 | # normalize 108 | c = rms * np.sqrt(mix.size / np.sum(mix**2)) 109 | mix *= c 110 | sph *= c 111 | 112 | writer_grp = writer.create_group(str(count)) 113 | writer_grp.create_dataset('mix', data=mix.astype(np.float32), shape=mix.shape, chunks=True) 114 | writer_grp.create_dataset('sph', data=sph.astype(np.float32), shape=sph.shape, chunks=True) 115 | writer.close() 116 | ``` 117 | 118 | - For the test set(s), all examples (in each condition) need to be saved into a single HDF5 file, each of which is stored in a HDF5 group. Each group contains two HDF5 datasets, one named ```mix``` and the other named ```sph```. 119 | - Example code: 120 | ``` 121 | import os 122 | 123 | import h5py 124 | import numpy as np 125 | 126 | 127 | # some settings 128 | ... 129 | rms = 1.0 130 | 131 | filename = 'tt_snr-5.ex' 132 | writer = h5py.File(os.path.join(filepath, filename), 'w') 133 | for idx in range(n_cv_ex): 134 | # generate a noisy mixture 135 | ... 136 | mix = sph + noi 137 | # normalize 138 | c = rms * np.sqrt(mix.size / np.sum(mix**2)) 139 | mix *= c 140 | sph *= c 141 | 142 | writer_grp = writer.create_group(str(count)) 143 | writer_grp.create_dataset('mix', data=mix.astype(np.float32), shape=mix.shape, chunks=True) 144 | writer_grp.create_dataset('sph', data=sph.astype(np.float32), shape=sph.shape, chunks=True) 145 | writer.close() 146 | ``` 147 | - In the example code above, the root mean square power of the mixture is normalized to 1. The same scaling factor is applied to clean speech. 148 | 2. Generate the file lists for training and test sets, and save them into a folder named ```filelists```. See [examples/filelists](examples/filelists) for the examples. 149 | 150 | 151 | ## How to run 152 | 1. Change the directory: ```cd scripts```. Remember that this is your working directory. All paths and commands below are relative to it. 153 | 2. Check ```utils/networks.py``` for the GCRN configurations. By default, ```G=2``` (see the original paper) is used for LSTM grouping. 154 | 3. Train the model: ```./run_train.sh```. By default, a directory named ```exp``` will be automatically generated. Two model files will be generated under ```exp/models/```: ```latest.pt```(the model from the latest checkpoint) and ```best.pt```(the model that performs best on the validation set by far). ```latest.pt``` can be used to resume training if interrupted, and ```best.pt``` is typically used for testing. You can check the loss values in ```exp/loss.txt```. 155 | 4. Evaluate the model: ```./run_evaluate.sh```. WAV files will be generated under ```../data/estimates```. STOI, PESQ and SNR results will be written into three files under ```exp```: ```stoi_scores.log```, ```pesq_scores.log``` and ```snr_scores.log```. 156 | 157 | 158 | ## How to cite 159 | ``` 160 | @inproceedings{tan2018convolutional, 161 | title={A Convolutional Recurrent Neural Network for Real-Time Speech Enhancement}, 162 | author={Tan, Ke and Wang, DeLiang}, 163 | booktitle={Interspeech}, 164 | pages={3229--3233}, 165 | year={2018} 166 | } 167 | ``` 168 | -------------------------------------------------------------------------------- /examples/filelists/tr_list.txt: -------------------------------------------------------------------------------- 1 | ../data/datasets/tr/tr_0.ex 2 | ../data/datasets/tr/tr_1.ex 3 | ../data/datasets/tr/tr_2.ex 4 | ../data/datasets/tr/tr_3.ex 5 | ../data/datasets/tr/tr_4.ex 6 | -------------------------------------------------------------------------------- /examples/filelists/tt_list.txt: -------------------------------------------------------------------------------- 1 | ../data/datasets/tt/tt_snr-5.ex 2 | ../data/datasets/tt/tt_snr0.ex 3 | ../data/datasets/tt/tt_snr5.ex 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | numpy==1.19.2 3 | scipy==1.6.1 4 | soundfile==0.10.3.post1 5 | h5py==2.10.0 6 | pystoi 7 | pypesq 8 | -------------------------------------------------------------------------------- /scripts/configs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | exp_conf = { 4 | 'in_norm': False, # normalize the input audio 5 | 'sample_rate': 16000, 6 | 'win_len': 0.020, # window length (sec) 7 | 'hop_len': 0.010 # window shift (sec) 8 | } 9 | -------------------------------------------------------------------------------- /scripts/measure.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | 4 | import torch 5 | 6 | from utils.metrics import Metric 7 | from utils.utils import getLogger 8 | 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | def main(): 14 | # parse the configurations 15 | parser = argparse.ArgumentParser(description='Additioal configurations for measurement', 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument('--metric', 18 | type=str, 19 | required=True, 20 | help='Name of the evaluation metric') 21 | parser.add_argument('--tt_list', 22 | type=str, 23 | required=True, 24 | help='Path to the list of testing files') 25 | parser.add_argument('--ckpt_dir', 26 | type=str, 27 | required=True, 28 | help='Name of the directory to write log') 29 | parser.add_argument('--est_path', 30 | type=str, 31 | default='../data/estimates', 32 | help='Path to saved estimates') 33 | 34 | args = parser.parse_args() 35 | logger.info('Arguments in command:\n{}'.format(pprint.pformat(vars(args)))) 36 | 37 | metric = Metric(args) 38 | metric.evaluate() 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /scripts/run_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | test_step=true 5 | eval_step=true 6 | 7 | ckpt_dir=exp 8 | gpus=0 9 | 10 | if $test_step; then 11 | python -B ./test.py \ 12 | --gpu_ids=$gpus \ 13 | --tt_list=../filelists/tt_list.txt \ 14 | --ckpt_dir=$ckpt_dir \ 15 | --model_file=./${ckpt_dir}/models/best.pt 16 | fi 17 | 18 | if $eval_step; then 19 | python -B ./measure.py --metric=stoi --tt_list=../filelists/tt_list.txt --ckpt_dir=$ckpt_dir 20 | python -B ./measure.py --metric=pesq --tt_list=../filelists/tt_list.txt --ckpt_dir=$ckpt_dir 21 | python -B ./measure.py --metric=snr --tt_list=../filelists/tt_list.txt --ckpt_dir=$ckpt_dir 22 | fi 23 | -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | ckpt_dir=exp 5 | gpus=0 6 | 7 | python -B ./train.py \ 8 | --gpu_ids=$gpus \ 9 | --tr_list=../filelists/tr_list.txt \ 10 | --cv_file=../data/datasets/cv/cv.ex \ 11 | --ckpt_dir=$ckpt_dir \ 12 | --logging_period=1000 \ 13 | --lr=0.0002 \ 14 | --time_log=./time.log \ 15 | --unit=utt \ 16 | --batch_size=16 \ 17 | --buffer_size=32 \ 18 | --max_n_epochs=20 \ 19 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | 4 | import torch 5 | 6 | from utils.models import Model 7 | from utils.utils import getLogger 8 | 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | def main(): 14 | torch.backends.cudnn.enabled = True 15 | torch.backends.cudnn.benchmark = True 16 | 17 | # parse the configuarations 18 | parser = argparse.ArgumentParser(description='Additioal configurations for testing', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--gpu_ids', 21 | type=str, 22 | default='-1', 23 | help='IDs of GPUs to use (please use `,` to split multiple IDs); -1 means CPU only') 24 | parser.add_argument('--tt_list', 25 | type=str, 26 | required=True, 27 | help='Path to the list of testing files') 28 | parser.add_argument('--ckpt_dir', 29 | type=str, 30 | required=True, 31 | help='Name of the directory to write log') 32 | parser.add_argument('--model_file', 33 | type=str, 34 | required=True, 35 | help='Path to the model file') 36 | parser.add_argument('--est_path', 37 | type=str, 38 | default='../data/estimates', 39 | help='Path to dump estimates') 40 | parser.add_argument('--write_ideal', 41 | default=False, 42 | action='store_true', 43 | help='Whether to write ideal signals (the speech signals resynthesized from the ideal training targets; ex. for time-domain enhancement, it is the same as clean speech)') 44 | 45 | args = parser.parse_args() 46 | logger.info('Arguments in command:\n{}'.format(pprint.pformat(vars(args)))) 47 | 48 | model = Model() 49 | model.test(args) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | 4 | import torch 5 | 6 | from utils.models import Model 7 | from utils.utils import getLogger 8 | 9 | 10 | logger = getLogger(__name__) 11 | 12 | 13 | def main(): 14 | torch.backends.cudnn.enabled = True 15 | torch.backends.cudnn.benchmark = True 16 | 17 | # parse the configurations 18 | parser = argparse.ArgumentParser(description='Additioal configurations for training', 19 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--gpu_ids', 21 | type=str, 22 | default='-1', 23 | help='IDs of GPUs to use (please use `,` to split multiple IDs); -1 means CPU only') 24 | parser.add_argument('--tr_list', 25 | type=str, 26 | required=True, 27 | help='Path to the list of training files') 28 | parser.add_argument('--cv_file', 29 | type=str, 30 | required=True, 31 | help='Path to the cross validation file') 32 | parser.add_argument('--ckpt_dir', 33 | type=str, 34 | required=True, 35 | help='Name of the directory to dump checkpoint') 36 | parser.add_argument('--unit', 37 | type=str, 38 | required=True, 39 | help='Unit of sample, can be either `seg` or `utt`') 40 | parser.add_argument('--logging_period', 41 | type=int, 42 | default=1000, 43 | help='Logging period (also the period of cross validation) represented by the number of iterations') 44 | parser.add_argument('--time_log', 45 | type=str, 46 | default='', 47 | help='Log file for timing batch processing') 48 | parser.add_argument('--batch_size', 49 | type=int, 50 | default=16, 51 | help='Minibatch size') 52 | parser.add_argument('--buffer_size', 53 | type=int, 54 | default=32, 55 | help='Buffer size') 56 | parser.add_argument('--segment_size', 57 | type=float, 58 | default=4.0, 59 | help='Length of segments used for training (seconds)') 60 | parser.add_argument('--segment_shift', 61 | type=float, 62 | default=1.0, 63 | help='Shift of segments used for training (seconds)') 64 | parser.add_argument('--lr', 65 | type=float, 66 | default=0.001, 67 | help='Initial learning rate for training') 68 | parser.add_argument('--lr_decay_factor', 69 | type=float, 70 | default=0.98, 71 | help='Decaying factor of learning rate') 72 | parser.add_argument('--lr_decay_period', 73 | type=int, 74 | default=2, 75 | help='Decaying period of learning rate (epochs)') 76 | parser.add_argument('--clip_norm', 77 | type=float, 78 | default=-1.0, 79 | help='Gradient clipping (L2-norm)') 80 | parser.add_argument('--max_n_epochs', 81 | type=int, 82 | default=100, 83 | help='Maximum number of epochs') 84 | parser.add_argument('--loss_log', 85 | type=str, 86 | default='loss.txt', 87 | help='Filename of the loss log') 88 | parser.add_argument('--resume_model', 89 | type=str, 90 | default='', 91 | help='Existing model to resume training from') 92 | 93 | args = parser.parse_args() 94 | logger.info('Arguments in command:\n{}'.format(pprint.pformat(vars(args)))) 95 | 96 | model = Model() 97 | model.train(args) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /scripts/utils/criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LossFunction(object): 5 | def __call__(self, est, lbl, loss_mask, n_frames): 6 | est_t = est * loss_mask 7 | lbl_t = lbl * loss_mask 8 | 9 | n_feats = est.shape[-1] 10 | 11 | loss = torch.sum((est_t - lbl_t)**2) / float(sum(n_frames) * n_feats) 12 | 13 | return loss 14 | -------------------------------------------------------------------------------- /scripts/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datetime import datetime 3 | 4 | import h5py 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class WavReader(object): 10 | def __init__(self, in_file, mode): 11 | # if mode is 'train', in_file is a list of filenames; 12 | # if mode is 'eval', in_file is a filename 13 | self.mode = mode 14 | self.in_file = in_file 15 | assert self.mode in {'train', 'eval'} 16 | if self.mode == 'train': 17 | self.wav_dict = {i: wavfile for i, wavfile in enumerate(in_file)} 18 | else: 19 | reader = h5py.File(in_file, 'r') 20 | self.wav_dict = {i: str(i) for i in range(len(reader))} 21 | reader.close() 22 | self.wav_indices = sorted(list(self.wav_dict.keys())) 23 | 24 | def load(self, idx): 25 | if self.mode == 'train': 26 | filename = self.wav_dict[idx] 27 | reader = h5py.File(filename, 'r') 28 | mix = reader['mix'][:] 29 | sph = reader['sph'][:] 30 | reader.close() 31 | else: 32 | reader = h5py.File(self.in_file, 'r') 33 | reader_grp = reader[self.wav_dict[idx]] 34 | mix = reader_grp['mix'][:] 35 | sph = reader_grp['sph'][:] 36 | reader.close() 37 | return mix, sph 38 | 39 | def __iter__(self): 40 | for idx in self.wav_indices: 41 | yield idx, self.load(idx) 42 | 43 | 44 | class PerUttLoader(object): 45 | def __init__(self, in_file, in_norm, shuffle=True, mode='train'): 46 | self.shuffle = shuffle 47 | self.mode = mode 48 | self.wav_reader = WavReader(in_file, mode) 49 | self.in_norm = in_norm 50 | self.eps = np.finfo(np.float32).eps 51 | 52 | def __iter__(self): 53 | if self.shuffle: 54 | random.shuffle(self.wav_reader.wav_indices) 55 | 56 | for idx, utt in self.wav_reader: 57 | utt_eg = dict() 58 | if self.in_norm: 59 | scale = np.max(np.abs(utt[0])) / 0.9 60 | utt_eg['mix'] = utt[0] / scale 61 | utt_eg['sph'] = utt[1] / scale 62 | else: 63 | utt_eg['mix'] = utt[0] 64 | utt_eg['sph'] = utt[1] 65 | utt_eg['n_samples'] = utt[0].shape[0] 66 | yield utt_eg 67 | 68 | 69 | class SegSplitter(object): 70 | def __init__(self, segment_size, sample_rate, hop_size): 71 | self.seg_len = int(sample_rate * segment_size) 72 | self.hop_len = int(sample_rate * hop_size) 73 | 74 | def __call__(self, utt_eg): 75 | n_samples = utt_eg['n_samples'] 76 | segs = [] 77 | if n_samples < self.seg_len: 78 | pad_size = self.seg_len - n_samples 79 | seg = dict() 80 | seg['mix'] = np.pad(utt_eg['mix'], [(0, pad_size)]) 81 | seg['sph'] = np.pad(utt_eg['sph'], [(0, pad_size)]) 82 | seg['n_samples'] = n_samples 83 | segs.append(seg) 84 | else: 85 | s_point = 0 86 | while True: 87 | if s_point + self.seg_len > n_samples: 88 | break 89 | seg = dict() 90 | seg['mix'] = utt_eg['mix'][s_point:s_point+self.seg_len] 91 | seg['sph'] = utt_eg['sph'][s_point:s_point+self.seg_len] 92 | seg['n_samples'] = self.seg_len 93 | s_point += self.hop_len 94 | segs.append(seg) 95 | return segs 96 | 97 | 98 | class AudioLoader(object): 99 | def __init__(self, 100 | in_file, 101 | sample_rate, 102 | unit='seg', 103 | segment_size=4.0, 104 | segment_shift=1.0, 105 | batch_size=4, 106 | buffer_size=16, 107 | in_norm=True, 108 | mode='train'): 109 | self.mode = mode 110 | assert self.mode in {'train', 'eval'} 111 | self.unit = unit 112 | assert self.unit in {'seg', 'utt'} 113 | if self.mode == 'train': 114 | self.utt_loader = PerUttLoader(in_file, in_norm, shuffle=True, mode='train') 115 | else: 116 | self.utt_loader = PerUttLoader(in_file, in_norm, shuffle=False, mode='eval') 117 | if unit == 'seg': 118 | self.seg_splitter = SegSplitter(segment_size, sample_rate, hop_size=segment_shift) 119 | 120 | self.batch_size = batch_size 121 | self.buffer_size = buffer_size 122 | 123 | def make_batch(self, load_list): 124 | n_batches = len(load_list) // self.batch_size 125 | if n_batches == 0: 126 | return [] 127 | else: 128 | batch_queue = [[] for _ in range(n_batches)] 129 | idx = 0 130 | for seg in load_list[0:n_batches*self.batch_size]: 131 | batch_queue[idx].append(seg) 132 | idx = (idx + 1) % n_batches 133 | if self.unit == 'utt': 134 | for batch in batch_queue: 135 | sig_len = max([eg['mix'].shape[0] for eg in batch]) 136 | for i in range(len(batch)): 137 | pad_size = sig_len - batch[i]['mix'].shape[0] 138 | batch[i]['mix'] = np.pad(batch[i]['mix'], [(0, pad_size)]) 139 | batch[i]['sph'] = np.pad(batch[i]['sph'], [(0, pad_size)]) 140 | return batch_queue 141 | 142 | def to_tensor(self, x) : 143 | return torch.from_numpy(x).float() 144 | 145 | def batch_buffer(self): 146 | while True: 147 | try: 148 | utt_eg = next(self.load_iter) 149 | if self.unit == 'seg': 150 | segs = self.seg_splitter(utt_eg) 151 | self.load_list.extend(segs) 152 | else: 153 | self.load_list.append(utt_eg) 154 | except StopIteration: 155 | self.stop_iter = True 156 | break 157 | if len(self.load_list) >= self.buffer_size: 158 | break 159 | 160 | batch_queue = self.make_batch(self.load_list) 161 | batch_list = [] 162 | for eg_list in batch_queue: 163 | batch = { 164 | 'mix': torch.stack([self.to_tensor(eg['mix']) for eg in eg_list], dim=0), 165 | 'sph': torch.stack([self.to_tensor(eg['sph']) for eg in eg_list], dim=0), 166 | 'n_samples': torch.tensor([eg['n_samples'] for eg in eg_list], dtype=torch.int64) 167 | } 168 | batch_list.append(batch) 169 | # drop used segments and keep remaining segments 170 | rn = len(self.load_list) % self.batch_size 171 | self.load_list = self.load_list[-rn:] if rn else [] 172 | return batch_list 173 | 174 | def __iter__(self): 175 | self.load_iter = iter(self.utt_loader) 176 | self.stop_iter = False 177 | self.load_list = [] 178 | while True: 179 | if self.stop_iter: 180 | break 181 | egs_buffer = self.batch_buffer() 182 | for egs in egs_buffer: 183 | yield egs 184 | 185 | -------------------------------------------------------------------------------- /scripts/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import soundfile as sf 4 | import numpy as np 5 | from pystoi import stoi 6 | from pypesq import pesq 7 | 8 | from configs import exp_conf 9 | from utils.utils import getLogger 10 | 11 | 12 | def snr(ref, est): 13 | ratio = 10 * np.log10(np.sum(ref**2) / np.sum((ref - est)**2)) 14 | return ratio 15 | 16 | 17 | class Metric(object): 18 | def __init__(self, args): 19 | self.sample_rate = exp_conf['sample_rate'] 20 | 21 | self.est_path = args.est_path 22 | self.metric = args.metric 23 | assert self.metric in {'stoi', 'estoi', 'pesq', 'snr'} 24 | 25 | self.ckpt_dir = args.ckpt_dir 26 | if not os.path.isdir(self.ckpt_dir): 27 | os.makedirs(self.ckpt_dir) 28 | 29 | def evaluate(self): 30 | getattr(self, self.metric)() 31 | 32 | def apply_metric(self, metric_func): 33 | logger = getLogger(os.path.join(self.ckpt_dir, self.metric+'_scores.log'), log_file=True) 34 | 35 | all_scores_dir = os.path.join(self.ckpt_dir, 'all_scores') 36 | if not os.path.isdir(all_scores_dir): 37 | os.makedirs(all_scores_dir) 38 | if not os.path.isdir(os.path.join(all_scores_dir, 'scores_arrays')): 39 | os.makedirs(os.path.join(all_scores_dir, 'scores_arrays')) 40 | 41 | for condition in os.listdir(self.est_path): 42 | mix_scores_array = [] 43 | est_scores_array = [] 44 | 45 | score_name = condition + '_' + self.metric 46 | f = open(os.path.join(all_scores_dir, score_name + '.txt'), 'w') 47 | count = 0 48 | for filename in os.listdir(os.path.join(self.est_path, condition)): 49 | if not filename.endswith('_mix.wav'): 50 | continue 51 | count += 1 52 | mix, _ = sf.read(os.path.join(self.est_path, condition, filename), dtype=np.float32) 53 | sph, _ = sf.read(os.path.join(self.est_path, condition, filename.replace('_mix', '_sph')), dtype=np.float32) 54 | sph_est, _ = sf.read(os.path.join(self.est_path, condition, filename.replace('_mix', '_sph_est')), dtype=np.float32) 55 | mix_score = metric_func(sph, mix) 56 | est_score = metric_func(sph, sph_est) 57 | f.write('utt {}: mix {:.4f}, est {:.4f}\n'.format(filename, mix_score, est_score)) 58 | f.flush() 59 | mix_scores_array.append(mix_score) 60 | est_scores_array.append(est_score) 61 | 62 | mix_scores_array = np.array(mix_scores_array, dtype=np.float32) 63 | est_scores_array = np.array(est_scores_array, dtype=np.float32) 64 | f.write('========================================\n') 65 | f.write('{} results: ({} utts)\n'.format(self.metric, count)) 66 | f.write('mix : {:.4f} +- {:.4f}\n'.format(np.mean(mix_scores_array), np.std(mix_scores_array))) 67 | f.write('est : {:.4f} +- {:.4f}\n'.format(np.mean(est_scores_array), np.std(est_scores_array))) 68 | f.close() 69 | np.save(os.path.join(all_scores_dir, 'scores_arrays', score_name + '_mix.npy'), mix_scores_array) 70 | np.save(os.path.join(all_scores_dir, 'scores_arrays', score_name + '_est.npy'), est_scores_array) 71 | 72 | message = 'Evaluating {}: {} utts: '.format(condition, count) + \ 73 | '{} [ mix: {:.4f}, est: {:.4f} | delta: {:.4f} ]'.format(self.metric, 74 | np.mean(mix_scores_array), np.mean(est_scores_array), 75 | np.mean(est_scores_array)-np.mean(mix_scores_array)) 76 | logger.info(message) 77 | 78 | def stoi(self): 79 | fn = lambda ref, est: stoi(ref, est, self.sample_rate, extended=False) 80 | self.apply_metric(fn) 81 | 82 | def estoi(self): 83 | fn = lambda ref, est: stoi(ref, est, self.sample_rate, extended=True) 84 | self.apply_metric(fn) 85 | 86 | def pesq(self): 87 | fn = lambda ref, est: pesq(ref, est, self.sample_rate) 88 | self.apply_metric(fn) 89 | 90 | def snr(self): 91 | fn = lambda ref, est: snr(ref, est) 92 | self.apply_metric(fn) 93 | -------------------------------------------------------------------------------- /scripts/utils/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import timeit 4 | 5 | import numpy as np 6 | import soundfile as sf 7 | import torch 8 | from torch.nn import DataParallel 9 | from torch.nn.utils import clip_grad_norm_ 10 | from torch.utils.data import DataLoader 11 | from torch.optim import Adam, lr_scheduler 12 | 13 | from configs import exp_conf 14 | from utils.utils import getLogger, numParams, countFrames, lossMask, lossLog, wavNormalize 15 | from utils.pipeline_modules import NetFeeder, Resynthesizer 16 | from utils.data_utils import AudioLoader 17 | from utils.networks import Net 18 | from utils.criteria import LossFunction 19 | 20 | 21 | class CheckPoint(object): 22 | def __init__(self, ckpt_info=None, net_state_dict=None, optim_state_dict=None): 23 | self.ckpt_info = ckpt_info 24 | self.net_state_dict = net_state_dict 25 | self.optim_state_dict = optim_state_dict 26 | 27 | def save(self, filename, is_best, best_model=None): 28 | torch.save(self, filename) 29 | if is_best: 30 | shutil.copyfile(filename, best_model) 31 | 32 | def load(self, filename, device): 33 | if not os.path.isfile(filename): 34 | raise FileNotFoundError('No checkpoint found at {}'.format(filename)) 35 | ckpt = torch.load(filename, map_location=device) 36 | self.ckpt_info = ckpt.ckpt_info 37 | self.net_state_dict = ckpt.net_state_dict 38 | self.optim_state_dict = ckpt.optim_state_dict 39 | 40 | 41 | class Model(object): 42 | def __init__(self): 43 | self.in_norm = exp_conf['in_norm'] 44 | self.sample_rate = exp_conf['sample_rate'] 45 | self.win_len = exp_conf['win_len'] 46 | self.hop_len = exp_conf['hop_len'] 47 | 48 | self.win_size = int(self.win_len * self.sample_rate) 49 | self.hop_size = int(self.hop_len * self.sample_rate) 50 | 51 | def train(self, args): 52 | with open(args.tr_list, 'r') as f: 53 | self.tr_list = [line.strip() for line in f.readlines()] 54 | self.tr_size = len(self.tr_list) 55 | self.cv_file = args.cv_file 56 | self.ckpt_dir = args.ckpt_dir 57 | self.logging_period = args.logging_period 58 | self.resume_model = args.resume_model 59 | self.time_log = args.time_log 60 | self.lr = args.lr 61 | self.lr_decay_factor = args.lr_decay_factor 62 | self.lr_decay_period = args.lr_decay_period 63 | self.clip_norm = args.clip_norm 64 | self.max_n_epochs = args.max_n_epochs 65 | self.batch_size = args.batch_size 66 | self.buffer_size = args.buffer_size 67 | self.loss_log = args.loss_log 68 | self.unit = args.unit 69 | self.segment_size = args.segment_size 70 | self.segment_shift = args.segment_shift 71 | 72 | self.gpu_ids = tuple(map(int, args.gpu_ids.split(','))) 73 | if len(self.gpu_ids) == 1 and self.gpu_ids[0] == -1: 74 | # cpu only 75 | self.device = torch.device('cpu') 76 | else: 77 | # gpu 78 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) 79 | 80 | if not os.path.isdir(self.ckpt_dir): 81 | os.makedirs(self.ckpt_dir) 82 | 83 | logger = getLogger(os.path.join(self.ckpt_dir, 'train.log'), log_file=True) 84 | 85 | # create data loaders for training and cross validation 86 | tr_loader = AudioLoader(self.tr_list, self.sample_rate, self.unit, 87 | self.segment_size, self.segment_shift, 88 | self.batch_size, self.buffer_size, 89 | self.in_norm, mode='train') 90 | cv_loader = AudioLoader(self.cv_file, self.sample_rate, unit='utt', 91 | segment_size=None, segment_shift=None, 92 | batch_size=1, buffer_size=10, 93 | in_norm=self.in_norm, mode='eval') 94 | 95 | # create a network 96 | net = Net() 97 | logger.info('Model summary:\n{}'.format(net)) 98 | 99 | net = net.to(self.device) 100 | if len(self.gpu_ids) > 1: 101 | net = DataParallel(net, device_ids=self.gpu_ids) 102 | 103 | # calculate model size 104 | param_count = numParams(net) 105 | logger.info('Trainable parameter count: {:,d} -> {:.2f} MB\n'.format(param_count, param_count*32/8/(2**20))) 106 | 107 | # net feeder 108 | feeder = NetFeeder(self.device, self.win_size, self.hop_size) 109 | 110 | # training criterion and optimizer 111 | criterion = LossFunction() 112 | optimizer = Adam(net.parameters(), lr=self.lr, amsgrad=False) 113 | scheduler = lr_scheduler.StepLR(optimizer, step_size=self.lr_decay_period, gamma=self.lr_decay_factor) 114 | 115 | # resume model if needed 116 | if self.resume_model: 117 | logger.info('Resuming model from {}'.format(self.resume_model)) 118 | ckpt = CheckPoint() 119 | ckpt.load(self.resume_model, self.device) 120 | state_dict = {} 121 | for key in ckpt.net_state_dict: 122 | if len(self.gpu_ids) > 1: 123 | state_dict['module.'+key] = ckpt.net_state_dict[key] 124 | else: 125 | state_dict[key] = ckpt.net_state_dict[key] 126 | net.load_state_dict(state_dict) 127 | optimizer.load_state_dict(ckpt.optim_state_dict) 128 | ckpt_info = ckpt.ckpt_info 129 | logger.info('model info: epoch {}, iter {}, cv_loss - {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1, 130 | ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['cv_loss'])) 131 | else: 132 | logger.info('Training from scratch...\n') 133 | ckpt_info = {'cur_epoch': 0, 134 | 'cur_iter': 0, 135 | 'tr_loss': None, 136 | 'cv_loss': None, 137 | 'best_loss': float('inf')} 138 | 139 | start_iter = 0 140 | # train model 141 | while ckpt_info['cur_epoch'] < self.max_n_epochs: 142 | accu_tr_loss = 0. 143 | accu_n_frames = 0 144 | net.train() 145 | for n_iter, egs in enumerate(tr_loader): 146 | n_iter += start_iter 147 | mix = egs['mix'] 148 | sph = egs['sph'] 149 | n_samples = egs['n_samples'] 150 | 151 | mix = mix.to(self.device) 152 | sph = sph.to(self.device) 153 | n_samples = n_samples.to(self.device) 154 | 155 | n_frames = countFrames(n_samples, self.win_size, self.hop_size) 156 | 157 | start_time = timeit.default_timer() 158 | 159 | # prepare features and labels 160 | feat, lbl = feeder(mix, sph) 161 | loss_mask = lossMask(shape=lbl.shape, n_frames=n_frames, device=self.device) 162 | # forward + backward + optimize 163 | optimizer.zero_grad() 164 | with torch.enable_grad(): 165 | est = net(feat) 166 | loss = criterion(est, lbl, loss_mask, n_frames) 167 | loss.backward() 168 | if self.clip_norm >= 0.0: 169 | clip_grad_norm_(net.parameters(), self.clip_norm) 170 | optimizer.step() 171 | # calculate loss 172 | running_loss = loss.data.item() 173 | accu_tr_loss += running_loss * sum(n_frames) 174 | accu_n_frames += sum(n_frames) 175 | 176 | end_time = timeit.default_timer() 177 | batch_time = end_time - start_time 178 | 179 | if self.time_log: 180 | with open(self.time_log, 'a+') as f: 181 | print('Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'.format(ckpt_info['cur_epoch']+1, 182 | self.max_n_epochs, n_iter, running_loss, accu_tr_loss / accu_n_frames, batch_time), file=f) 183 | f.flush() 184 | else: 185 | print('Epoch [{}/{}], Iter [{}], tr_loss = {:.4f} / {:.4f}, batch_time (s) = {:.4f}'.format(ckpt_info['cur_epoch']+1, 186 | self.max_n_epochs, n_iter, running_loss, accu_tr_loss / accu_n_frames, batch_time), flush=True) 187 | 188 | 189 | if (n_iter + 1) % self.logging_period == 0: 190 | avg_tr_loss = accu_tr_loss / accu_n_frames 191 | avg_cv_loss = self.validate(net, cv_loader, criterion, feeder) 192 | net.train() 193 | 194 | ckpt_info['cur_iter'] = n_iter 195 | is_best = True if avg_cv_loss < ckpt_info['best_loss'] else False 196 | ckpt_info['best_loss'] = avg_cv_loss if is_best else ckpt_info['best_loss'] 197 | latest_model = 'latest.pt' 198 | best_model = 'best.pt' 199 | ckpt_info['tr_loss'] = avg_tr_loss 200 | ckpt_info['cv_loss'] = avg_cv_loss 201 | if len(self.gpu_ids) > 1: 202 | ckpt = CheckPoint(ckpt_info, net.module.state_dict(), optimizer.state_dict()) 203 | else: 204 | ckpt = CheckPoint(ckpt_info, net.state_dict(), optimizer.state_dict()) 205 | logger.info('Saving checkpoint into {}'.format(os.path.join(self.ckpt_dir, latest_model))) 206 | if is_best: 207 | logger.info('Saving checkpoint into {}'.format(os.path.join(self.ckpt_dir, best_model))) 208 | logger.info('Epoch [{}/{}], ( tr_loss: {:.4f} | cv_loss: {:.4f} )\n'.format(ckpt_info['cur_epoch']+1, 209 | self.max_n_epochs, avg_tr_loss, avg_cv_loss)) 210 | 211 | model_path = os.path.join(self.ckpt_dir, 'models') 212 | if not os.path.isdir(model_path): 213 | os.makedirs(model_path) 214 | 215 | ckpt.save(os.path.join(model_path, latest_model), 216 | is_best, 217 | os.path.join(model_path, best_model)) 218 | 219 | lossLog(os.path.join(self.ckpt_dir, self.loss_log), ckpt, self.logging_period) 220 | 221 | accu_tr_loss = 0. 222 | accu_n_frames = 0 223 | 224 | if n_iter + 1 == self.tr_size // self.batch_size: 225 | start_iter = 0 226 | ckpt_info['cur_iter'] = 0 227 | break 228 | 229 | ckpt_info['cur_epoch'] += 1 230 | scheduler.step() # learning rate decay 231 | 232 | return 233 | 234 | def validate(self, net, cv_loader, criterion, feeder): 235 | accu_cv_loss = 0. 236 | accu_n_frames = 0 237 | 238 | if len(self.gpu_ids) > 1: 239 | net = net.module 240 | 241 | net.eval() 242 | for k, egs in enumerate(cv_loader): 243 | mix = egs['mix'] 244 | sph = egs['sph'] 245 | n_samples = egs['n_samples'] 246 | 247 | mix = mix.to(self.device) 248 | sph = sph.to(self.device) 249 | n_samples = n_samples.to(self.device) 250 | 251 | n_frames = countFrames(n_samples, self.win_size, self.hop_size) 252 | 253 | feat, lbl = feeder(mix, sph) 254 | 255 | with torch.no_grad(): 256 | loss_mask = lossMask(shape=lbl.shape, n_frames=n_frames, device=self.device) 257 | est = net(feat) 258 | loss = criterion(est, lbl, loss_mask, n_frames) 259 | 260 | accu_cv_loss += loss.data.item() * sum(n_frames) 261 | accu_n_frames += sum(n_frames) 262 | 263 | avg_cv_loss = accu_cv_loss / accu_n_frames 264 | return avg_cv_loss 265 | 266 | def test(self, args): 267 | with open(args.tt_list, 'r') as f: 268 | self.tt_list = [line.strip() for line in f.readlines()] 269 | self.model_file = args.model_file 270 | self.ckpt_dir = args.ckpt_dir 271 | self.est_path = args.est_path 272 | self.write_ideal = args.write_ideal 273 | self.gpu_ids = tuple(map(int, args.gpu_ids.split(','))) 274 | if len(self.gpu_ids) == 1 and self.gpu_ids[0] == -1: 275 | # cpu only 276 | self.device = torch.device('cpu') 277 | else: 278 | # gpu 279 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) 280 | 281 | if not os.path.isdir(self.ckpt_dir): 282 | os.makedirs(self.ckpt_dir) 283 | logger = getLogger(os.path.join(self.ckpt_dir, 'test.log'), log_file=True) 284 | 285 | # create a network 286 | net = Net() 287 | logger.info('Model summary:\n{}'.format(net)) 288 | 289 | net = net.to(self.device) 290 | 291 | # calculate model size 292 | param_count = numParams(net) 293 | logger.info('Trainable parameter count: {:,d} -> {:.2f} MB\n'.format(param_count, param_count*32/8/(2**20))) 294 | 295 | # training criterion and optimizer 296 | criterion = LossFunction() 297 | 298 | # net feeder 299 | feeder = NetFeeder(self.device, self.win_size, self.hop_size) 300 | 301 | # resynthesizer 302 | resynthesizer = Resynthesizer(self.device, self.win_size, self.hop_size) 303 | 304 | # load model 305 | logger.info('Loading model from {}'.format(self.model_file)) 306 | ckpt = CheckPoint() 307 | ckpt.load(self.model_file, self.device) 308 | net.load_state_dict(ckpt.net_state_dict) 309 | logger.info('model info: epoch {}, iter {}, cv_loss - {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1, 310 | ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['cv_loss'])) 311 | 312 | net.eval() 313 | for i in range(len(self.tt_list)): 314 | # create a data loader for testing 315 | tt_loader = AudioLoader(self.tt_list[i], self.sample_rate, unit='utt', 316 | segment_size=None, segment_shift=None, 317 | batch_size=1, buffer_size=10, 318 | in_norm=self.in_norm, mode='eval') 319 | logger.info('[{}/{}] Estimating on {}'.format(i+1, len(self.tt_list), self.tt_list[i])) 320 | 321 | est_subdir = os.path.join(self.est_path, self.tt_list[i].split('/')[-1].replace('.ex', '')) 322 | if not os.path.isdir(est_subdir): 323 | os.makedirs(est_subdir) 324 | 325 | accu_tt_loss = 0. 326 | accu_n_frames = 0 327 | for k, egs in enumerate(tt_loader): 328 | mix = egs['mix'] 329 | sph = egs['sph'] 330 | n_samples = egs['n_samples'] 331 | 332 | n_frames = countFrames(n_samples, self.win_size, self.hop_size) 333 | 334 | mix = mix.to(self.device) 335 | sph = sph.to(self.device) 336 | 337 | feat, lbl = feeder(mix, sph) 338 | 339 | with torch.no_grad(): 340 | loss_mask = lossMask(shape=lbl.shape, n_frames=n_frames, device=self.device) 341 | est = net(feat) 342 | loss = criterion(est, lbl, loss_mask, n_frames) 343 | 344 | accu_tt_loss += loss.data.item() * sum(n_frames) 345 | accu_n_frames += sum(n_frames) 346 | 347 | sph_idl = resynthesizer(lbl, mix) 348 | sph_est = resynthesizer(est, mix) 349 | 350 | # save estimates 351 | mix = mix[0].cpu().numpy() 352 | sph = sph[0].cpu().numpy() 353 | sph_est = sph_est[0].cpu().numpy() 354 | sph_idl = sph_idl[0].cpu().numpy() 355 | mix, sph, sph_est, sph_idl = wavNormalize(mix, sph, sph_est, sph_idl) 356 | sf.write(os.path.join(est_subdir, '{}_mix.wav'.format(k)), mix, self.sample_rate) 357 | sf.write(os.path.join(est_subdir, '{}_sph.wav'.format(k)), sph, self.sample_rate) 358 | sf.write(os.path.join(est_subdir, '{}_sph_est.wav'.format(k)), sph_est, self.sample_rate) 359 | if self.write_ideal: 360 | sf.write(os.path.join(est_subdir, '{}_sph_idl.wav'.format(k)), sph_idl, self.sample_rate) 361 | 362 | avg_tt_loss = accu_tt_loss / accu_n_frames 363 | logger.info('loss: {:.4f}'.format(avg_tt_loss)) 364 | 365 | return 366 | -------------------------------------------------------------------------------- /scripts/utils/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super(Net, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 11 | self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 12 | self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 13 | self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 14 | self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 15 | 16 | self.lstm = nn.LSTM(256*4, 256*4, 2, batch_first=True) 17 | 18 | self.conv5_t = nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 19 | self.conv4_t = nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 20 | self.conv3_t = nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 21 | self.conv2_t = nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(2,3), stride=(1,2), padding=(1,0), output_padding=(0,1)) 22 | self.conv1_t = nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(2,3), stride=(1,2), padding=(1,0)) 23 | 24 | self.bn1 = nn.BatchNorm2d(16) 25 | self.bn2 = nn.BatchNorm2d(32) 26 | self.bn3 = nn.BatchNorm2d(64) 27 | self.bn4 = nn.BatchNorm2d(128) 28 | self.bn5 = nn.BatchNorm2d(256) 29 | 30 | self.bn5_t = nn.BatchNorm2d(128) 31 | self.bn4_t = nn.BatchNorm2d(64) 32 | self.bn3_t = nn.BatchNorm2d(32) 33 | self.bn2_t = nn.BatchNorm2d(16) 34 | self.bn1_t = nn.BatchNorm2d(1) 35 | 36 | self.elu = nn.ELU(inplace=True) 37 | self.softplus = nn.Softplus() 38 | 39 | def forward(self, x): 40 | 41 | out = x.unsqueeze(dim=1) 42 | e1 = self.elu(self.bn1(self.conv1(out)[:,:,:-1,:].contiguous())) 43 | e2 = self.elu(self.bn2(self.conv2(e1)[:,:,:-1,:].contiguous())) 44 | e3 = self.elu(self.bn3(self.conv3(e2)[:,:,:-1,:].contiguous())) 45 | e4 = self.elu(self.bn4(self.conv4(e3)[:,:,:-1,:].contiguous())) 46 | e5 = self.elu(self.bn5(self.conv5(e4)[:,:,:-1,:].contiguous())) 47 | 48 | out = e5.contiguous().transpose(1, 2) 49 | q1 = out.size(2) 50 | q2 = out.size(3) 51 | out = out.contiguous().view(out.size(0), out.size(1), -1) 52 | out, _ = self.lstm(out) 53 | out = out.contiguous().view(out.size(0), out.size(1), q1, q2) 54 | out = out.contiguous().transpose(1, 2) 55 | 56 | out = torch.cat([out, e5], dim=1) 57 | 58 | d5 = self.elu(torch.cat([self.bn5_t(F.pad(self.conv5_t(out), [0,0,1,0]).contiguous()), e4], dim=1)) 59 | d4 = self.elu(torch.cat([self.bn4_t(F.pad(self.conv4_t(d5), [0,0,1,0]).contiguous()), e3], dim=1)) 60 | d3 = self.elu(torch.cat([self.bn3_t(F.pad(self.conv3_t(d4), [0,0,1,0]).contiguous()), e2], dim=1)) 61 | d2 = self.elu(torch.cat([self.bn2_t(F.pad(self.conv2_t(d3), [0,0,1,0]).contiguous()), e1], dim=1)) 62 | d1 = self.softplus(self.bn1_t(F.pad(self.conv1_t(d2), [0,0,1,0]).contiguous())) 63 | 64 | out = torch.squeeze(d1, dim=1) 65 | 66 | return out 67 | 68 | 69 | -------------------------------------------------------------------------------- /scripts/utils/pipeline_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from utils.stft import STFT 5 | 6 | 7 | class NetFeeder(object): 8 | def __init__(self, device, win_size=320, hop_size=160): 9 | self.eps = torch.finfo(torch.float32).eps 10 | self.stft = STFT(win_size, hop_size).to(device) 11 | 12 | def __call__(self, mix, sph): 13 | real_mix, imag_mix = self.stft.stft(mix) 14 | mag_mix = torch.sqrt(real_mix**2 + imag_mix**2) 15 | feat = mag_mix 16 | 17 | real_sph, imag_sph = self.stft.stft(sph) 18 | mag_sph = torch.sqrt(real_sph**2 + imag_sph**2) 19 | lbl = mag_sph 20 | 21 | return feat, lbl 22 | 23 | 24 | class Resynthesizer(object): 25 | def __init__(self, device, win_size=320, hop_size=160): 26 | self.stft = STFT(win_size, hop_size).to(device) 27 | 28 | def __call__(self, est, mix): 29 | real_mix, imag_mix = self.stft.stft(mix) 30 | pha_mix = torch.atan2(imag_mix.data, real_mix.data) 31 | real_est = est * torch.cos(pha_mix) 32 | imag_est = est * torch.sin(pha_mix) 33 | sph_est = self.stft.istft(torch.stack([real_est, imag_est], dim=1)) 34 | sph_est = F.pad(sph_est, [0, mix.shape[1]-sph_est.shape[1]]) 35 | 36 | return sph_est 37 | -------------------------------------------------------------------------------- /scripts/utils/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy 6 | 7 | 8 | class STFT(nn.Module): 9 | def __init__(self, win_size=320, hop_size=160, requires_grad=False): 10 | super(STFT, self).__init__() 11 | 12 | self.win_size = win_size 13 | self.hop_size = hop_size 14 | self.n_overlap = self.win_size // self.hop_size 15 | self.requires_grad = requires_grad 16 | 17 | win = torch.from_numpy(scipy.hamming(self.win_size).astype(np.float32)) 18 | win = F.relu(win) 19 | win = nn.Parameter(data=win, requires_grad=self.requires_grad) 20 | self.register_parameter('win', win) 21 | 22 | fourier_basis = np.fft.fft(np.eye(self.win_size)) 23 | fourier_basis_r = np.real(fourier_basis).astype(np.float32) 24 | fourier_basis_i = np.imag(fourier_basis).astype(np.float32) 25 | 26 | self.register_buffer('fourier_basis_r', torch.from_numpy(fourier_basis_r)) 27 | self.register_buffer('fourier_basis_i', torch.from_numpy(fourier_basis_i)) 28 | 29 | idx = torch.tensor(range(self.win_size//2-1, 0, -1), dtype=torch.long) 30 | self.register_buffer('idx', idx) 31 | 32 | self.eps = torch.finfo(torch.float32).eps 33 | 34 | def kernel_fw(self): 35 | fourier_basis_r = torch.matmul(self.fourier_basis_r, torch.diag(self.win)) 36 | fourier_basis_i = torch.matmul(self.fourier_basis_i, torch.diag(self.win)) 37 | 38 | fourier_basis = torch.stack([fourier_basis_r, fourier_basis_i], dim=-1) 39 | forward_basis = fourier_basis.unsqueeze(dim=1) 40 | 41 | return forward_basis 42 | 43 | def kernel_bw(self): 44 | inv_fourier_basis_r = self.fourier_basis_r / self.win_size 45 | inv_fourier_basis_i = -self.fourier_basis_i / self.win_size 46 | 47 | inv_fourier_basis = torch.stack([inv_fourier_basis_r, inv_fourier_basis_i], dim=-1) 48 | backward_basis = inv_fourier_basis.unsqueeze(dim=1) 49 | return backward_basis 50 | 51 | def window(self, n_frames): 52 | assert n_frames >= 2 53 | seg = sum([self.win[i*self.hop_size:(i+1)*self.hop_size] for i in range(self.n_overlap)]) 54 | seg = seg.unsqueeze(dim=-1).expand((self.hop_size, n_frames-self.n_overlap+1)) 55 | window = seg.contiguous().view(-1).contiguous() 56 | 57 | return window 58 | 59 | def stft(self, sig): 60 | batch_size = sig.shape[0] 61 | n_samples = sig.shape[1] 62 | 63 | cutoff = self.win_size // 2 + 1 64 | 65 | sig = sig.view(batch_size, 1, n_samples) 66 | kernel = self.kernel_fw() 67 | kernel_r = kernel[...,0] 68 | kernel_i = kernel[...,1] 69 | spec_r = F.conv1d(sig, 70 | kernel_r[:cutoff], 71 | stride=self.hop_size, 72 | padding=self.win_size-self.hop_size) 73 | spec_i = F.conv1d(sig, 74 | kernel_i[:cutoff], 75 | stride=self.hop_size, 76 | padding=self.win_size-self.hop_size) 77 | spec_r = spec_r.transpose(-1, -2).contiguous() 78 | spec_i = spec_i.transpose(-1, -2).contiguous() 79 | 80 | mag = torch.sqrt(spec_r**2 + spec_i**2) 81 | pha = torch.atan2(spec_i.data, spec_r.data) 82 | 83 | return spec_r, spec_i 84 | 85 | def istft(self, x): 86 | spec_r = x[:,0,:,:] 87 | spec_i = x[:,1,:,:] 88 | 89 | n_frames = spec_r.shape[1] 90 | 91 | spec_r = torch.cat([spec_r, spec_r.index_select(dim=-1, index=self.idx)], dim=-1) 92 | spec_i = torch.cat([spec_i, -spec_i.index_select(dim=-1, index=self.idx)], dim=-1) 93 | spec_r = spec_r.transpose(-1, -2).contiguous() 94 | spec_i = spec_i.transpose(-1, -2).contiguous() 95 | 96 | kernel = self.kernel_bw() 97 | kernel_r = kernel[...,0].transpose(0, -1) 98 | kernel_i = kernel[...,1].transpose(0, -1) 99 | 100 | sig = F.conv_transpose1d(spec_r, 101 | kernel_r, 102 | stride=self.hop_size, 103 | padding=self.win_size-self.hop_size) \ 104 | - F.conv_transpose1d(spec_i, 105 | kernel_i, 106 | stride=self.hop_size, 107 | padding=self.win_size-self.hop_size) 108 | sig = sig.squeeze(dim=1) 109 | 110 | window = self.window(n_frames) 111 | sig = sig / (window + self.eps) 112 | 113 | return sig 114 | -------------------------------------------------------------------------------- /scripts/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def getLogger(name, 10 | format_str='%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s', 11 | date_format='%Y-%m-%d %H:%M:%S', 12 | log_file=False): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.INFO) 15 | # file or console 16 | handler = logging.StreamHandler() if not log_file else logging.FileHandler(name) 17 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 18 | handler.setFormatter(formatter) 19 | logger.addHandler(handler) 20 | return logger 21 | 22 | 23 | def numParams(net): 24 | count = sum([int(np.prod(param.shape)) for param in net.parameters()]) 25 | return count 26 | 27 | 28 | def countFrames(n_samples, win_size, hop_size): 29 | n_overlap = win_size // hop_size 30 | fn = lambda x: x // hop_size + n_overlap - 1 31 | n_frames = torch.stack(list(map(fn, n_samples)), dim=0) 32 | return n_frames 33 | 34 | 35 | def lossMask(shape, n_frames, device): 36 | loss_mask = torch.zeros(shape, dtype=torch.float32, device=device) 37 | for i, seq_len in enumerate(n_frames): 38 | loss_mask[i,0:seq_len,:] = 1.0 39 | return loss_mask 40 | 41 | 42 | def lossLog(log_filename, ckpt, logging_period): 43 | if ckpt.ckpt_info['cur_epoch'] == 0 and ckpt.ckpt_info['cur_iter'] + 1 == logging_period: 44 | with open(log_filename, 'w') as f: 45 | f.write('epoch, iter, tr_loss, cv_loss\n') 46 | f.write('{}, {}, {:.4f}, {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1, 47 | ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['tr_loss'], ckpt.ckpt_info['cv_loss'])) 48 | else: 49 | with open(log_filename, 'a') as f: 50 | f.write('{}, {}, {:.4f}, {:.4f}\n'.format(ckpt.ckpt_info['cur_epoch']+1, 51 | ckpt.ckpt_info['cur_iter']+1, ckpt.ckpt_info['tr_loss'], ckpt.ckpt_info['cv_loss'])) 52 | 53 | 54 | def wavNormalize(*sigs): 55 | # sigs is a list of signals to be normalized 56 | scale = max([np.max(np.abs(sig)) for sig in sigs]) + np.finfo(np.float32).eps 57 | sigs_norm = [sig / scale for sig in sigs] 58 | return sigs_norm 59 | 60 | 61 | def dump_json(filename, obj): 62 | with open(filename, 'w') as f: 63 | json.dump(obj, f, indent=4, sort_keys=True) 64 | return 65 | 66 | 67 | def load_json(filename): 68 | if not os.path.isfile(filename): 69 | raise FileNotFoundError('Could not find json file: {}'.format(filename)) 70 | with open(filename, 'r') as f: 71 | obj = json.load(f) 72 | return obj 73 | --------------------------------------------------------------------------------