├── .DS_Store ├── .vscode └── c_cpp_properties.json ├── README.md ├── dataset ├── dataset.py └── preprocess_dataset.py ├── loss_func ├── __pycache__ │ └── distrib.cpython-38.pyc ├── balancer.py ├── distrib.py └── loss.py ├── model ├── based_model │ └── cust_conv.py ├── cruse.py ├── cruse_net.py ├── deep_filter.py ├── dfsmn.py └── mtfaa.py ├── test ├── __init__.py ├── __pycache__ │ └── test_loss.cpython-38.pyc ├── testBSRNN.py ├── testRandSecFilter.py ├── test_erb.py ├── test_loss.py ├── test_model.py ├── test_norm.py └── test_pqmf.py ├── tools └── train_stand.py ├── train └── trainer_casual.py ├── train_base ├── acoustics │ ├── audioAug.py │ ├── conv_stft.py │ ├── feature.py │ └── mask.py ├── constant.py ├── dataset │ └── base_dataset.py ├── inferencer │ └── base_inferencer.py ├── loss.py ├── metrics.py ├── model │ └── base_model.py ├── trainer │ └── base_trainer.py └── utils.py └── utils ├── logger.py ├── plot.py ├── utils.py └── utils_base.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/.DS_Store -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Mac", 5 | "includePath": [ 6 | "${default}" 7 | ], 8 | "defines": [], 9 | "macFrameworkPath": [], 10 | "compilerPath": "/opt/homebrew/bin/gcc-10", 11 | "intelliSenseMode": "macos-gcc-arm64" 12 | } 13 | ], 14 | "version": 4 15 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 9 | # CRUSE 10 | This repo provides a lightweight network for speech enhancement 11 | 12 | completed list: 13 | - test weighted loss 14 | 15 | ## folder introduction 16 | 17 | tools folder: including main train_stand.py 18 | 19 | train_base folder: the basic train scripts e.g. feature-extraction, filter-designer, base-train, base-inference, base-dataset 20 | 21 | utils folder: something about logger print and others 22 | 23 | CRUSE_plus folder: the -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Okrio 3 | Date: 2022-02-17 00:14:37 4 | LastEditTime: 2022-02-25 00:23:53 5 | LastEditors: Please set LastEditors 6 | Description: training dataset generating and validation 7 | FilePath: /CRUSE/dataset/dataset.py 8 | ''' 9 | 10 | import os 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | from joblib import Parallel, delayed 15 | from tqdm import tqdm 16 | from scipy import signal 17 | import random 18 | from pathlib import Path 19 | import librosa as lib 20 | import soundfile as sf 21 | 22 | 23 | class BaseDataset(Dataset): 24 | def __init__(self) -> None: 25 | super(BaseDataset, self).__init__() 26 | 27 | @staticmethod 28 | def _offset_and_limit(dataset_list, offset, limit): 29 | dataset_list = dataset_list[offset:] 30 | if limit: 31 | dataset_list = dataset_list[:limit] 32 | return dataset_list 33 | 34 | @staticmethod 35 | def _parse_snr_range(snr_range): 36 | assert len( 37 | snr_range 38 | ) == 2, f"The range of snr should be [low, high], not{snr_range}" 39 | assert snr_range[0] <= snr_range[ 40 | -1], "The low snr should not larger than high snr" 41 | 42 | low, high = snr_range 43 | snr_list = [] 44 | for i in range(low, high + 1, 1): 45 | snr_list.append(i) 46 | return snr_list 47 | 48 | 49 | class SynDataset(BaseDataset): 50 | def __init__(self, 51 | clean_dataset, 52 | clean_dataset_limit, 53 | clean_dataset_offset, 54 | noise_dataset, 55 | noise_dataset_limit, 56 | noise_dataset_offset, 57 | rir_dataset, 58 | rir_dataset_limit, 59 | rir_dataset_offset, 60 | rir_noise_dataset, 61 | rir_noise_dataset_limit, 62 | rir_noise_dataset_offset, 63 | snr_range, 64 | reverb_proportion, 65 | reverb_noise_proportion, 66 | silence_length, 67 | target_dB_FS, 68 | target_dB_FS_floating_val, 69 | sub_sample_length, 70 | sr, 71 | dataset_length, 72 | pre_load_clean_dataset, 73 | pre_load_noise, 74 | pre_load_rir, 75 | num_workers, 76 | valid_mode=False) -> None: 77 | super(SynDataset, self).__init__() 78 | self.sr = sr 79 | self.num_workers = num_workers 80 | clean_dataset_list = [ 81 | line.rstrip('\n') for line in open( 82 | os.path.abspath(os.path.expanduser(clean_dataset)), 'r') 83 | ] 84 | noise_dataset_list = [ 85 | line.rstrip('\n') for line in open( 86 | os.path.abspath(os.path.expanduser(noise_dataset)), 'r') 87 | ] 88 | rir_dataset_list = [ 89 | line.rstrip('\n') for line in open( 90 | os.path.abspath(os.path.expanduser(rir_dataset)), 'r') 91 | ] 92 | rir_noise_dataset_list = [ 93 | line.rstrip('\n') for line in open( 94 | os.path.abspath(os.path.expanduser(rir_noise_dataset)), 'r') 95 | ] 96 | 97 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, 98 | clean_dataset_offset, 99 | clean_dataset_limit) 100 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, 101 | noise_dataset_offset, 102 | noise_dataset_limit) 103 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, 104 | rir_dataset_offset, 105 | rir_dataset_limit) 106 | rir_noise_dataset_list = self._offset_and_limit( 107 | rir_noise_dataset_list, rir_noise_dataset_limit, 108 | rir_noise_dataset_offset) 109 | 110 | # if pre_load_clean_dataset: 111 | # clean_dataset_list = self._ 112 | 113 | self.clean_dataset_list = clean_dataset_list 114 | self.noise_dataset_list = noise_dataset_list 115 | self.rir_dataset_list = rir_dataset_list 116 | self.rir_noise_dataset_list = rir_noise_dataset_list 117 | 118 | self.dataset_length = dataset_length 119 | self.valid_mode = valid_mode 120 | snr_list = self._parse_snr_range(snr_range=snr_range) 121 | self.snr_list = snr_list 122 | 123 | assert 0 <= reverb_proportion <= 1, "reverbation proportion should be in [0,1]" 124 | self.reverb_proportion = reverb_proportion 125 | 126 | assert 0 <= reverb_noise_proportion <= 1, "reverb_noise proportion should be in [0,1]" 127 | self.reverb_noise_proportion = reverb_noise_proportion 128 | self.silence_length = silence_length 129 | self.target_dB_FS = target_dB_FS 130 | self.target_dB_FS_floating_val = target_dB_FS_floating_val 131 | self.sub_sample_length = sub_sample_length 132 | self.length = len(self.clean_dataset_list) if bool( 133 | self.dataset_length) is False else int(dataset_length) 134 | self.general_mix_dataset_list = np.random.randint( 135 | 0, len(self.clean_dataset_list), self.length) 136 | 137 | def __len__(self): 138 | return self.length 139 | 140 | # def _preload_dataset(self,file_path_list, remark=''): 141 | # waveform_list = Parallel(n_jobs=self.num_workers)(delayed((l))) 142 | 143 | @staticmethod 144 | def _random_select_from(dataset_list): 145 | return random.choice(dataset_list) 146 | 147 | def _select_clean_y(self, sig, target_length): 148 | sig_len = len(sig) 149 | clean_y = sig 150 | silence = np.zeros(int(self.sr * self.silence_length), 151 | dtype=np.float32) 152 | remain_length = target_length - sig_len 153 | while remain_length > 0: 154 | clean_file = self._random_select_from(self.clean_dataset_list) 155 | clean_new_added, _ = lib.load(clean_file, sr=self.sr) 156 | clean_new_added = clean_new_added if clean_new_added.ndim < 2 else clean_new_added[:, 157 | np 158 | . 159 | random 160 | . 161 | randint( 162 | clean_new_added 163 | . 164 | shape[ 165 | -1] 166 | )] 167 | clean_y = np.append(clean_y, clean_new_added) 168 | remain_length = remain_length - len(clean_new_added) 169 | 170 | if remain_length > 0: 171 | silence_len = min(remain_length, len(silence)) 172 | clean_y = np.append(clean_y, silence[:silence_len]) 173 | remain_length -= silence_len 174 | 175 | if len(clean_y) > target_length: 176 | idx_start = np.random.randint(len(clean_y) - target_length) 177 | clean_y = clean_y[idx_start:idx_start + target_length] 178 | 179 | assert len( 180 | clean_y 181 | ) == target_length, "the clean_y length equals target_length" 182 | return clean_y 183 | 184 | def _select_noise_y(self, target_length): 185 | noise_y = np.zeros(0, dtype=np.float32) 186 | silence = np.zeros(int(self.sr * self.silence_length), 187 | dtype=np.float32) 188 | remaining_length = target_length 189 | while remaining_length > 0: 190 | noise_file = self._random_select_from(self.noise_dataset_list) 191 | noise_new_added, _ = lib.load( 192 | noise_file, sr=self.sr) # todo: check multi-channel wav 193 | noise_y = np.append(noise_y, noise_new_added) 194 | remaining_length = remaining_length - len(noise_new_added) 195 | if remaining_length > 0: 196 | silence_len = min(remaining_length, len(silence)) 197 | noise_y = np.append(noise_y, silence[:silence_len]) 198 | remaining_length -= silence_len 199 | if len(noise_y) > target_length: 200 | idx_start = np.random.randint(len(noise_y) - target_length) 201 | noise_y = noise_y[idx_start:idx_start + target_length] 202 | 203 | return noise_y 204 | 205 | def _select_rir(self, rir_proportion, rir_dataset_list): 206 | use_reverb = bool(np.random.random(1) < rir_proportion) 207 | if use_reverb: 208 | rir_path = self._random_select_from(rir_dataset_list) 209 | rir_y, _ = lib.load( 210 | rir_path, sr=self.sr) # todo: check multi-channel rir wav 211 | else: 212 | rir_y = None 213 | return rir_y 214 | 215 | @staticmethod 216 | def add_reverb(cln_wav, rir_wav, channels=1, predelay=50, sr=16000): 217 | rir_len = rir_wav.shape[0] 218 | wav_tgt = np.zeros([channels, cln_wav.shape[0] + rir_len - 1]) 219 | dt = np.argmax(rir_wav, 0).min() 220 | et = dt + (predelay * sr) // 1000 221 | et_rir = rir_wav[:et] 222 | wav_early_tgt = np.zeros( 223 | [channels, cln_wav.shape[0] + et_rir.shape[0] - 1]) 224 | cln_wav = cln_wav if cln_wav.ndim < 2 else cln_wav[np.random.randint( 225 | cln_wav.shape[-1])] 226 | for i in range(channels): 227 | wav_tgt[i] = signal.fftconvolve(cln_wav, rir_wav[:, i]) 228 | wav_early_tgt[i] = signal.fftconvolve(cln_wav, et_rir[:, i]) 229 | wav_tgt = np.transpose(wav_tgt) 230 | wav_tgt = wav_tgt[:cln_wav.shape[0]] 231 | wav_early_tgt = np.transpose(wav_early_tgt) 232 | wav_early_tgt = wav_early_tgt[:cln_wav.shape[0]] 233 | return wav_tgt, wav_early_tgt 234 | 235 | @staticmethod 236 | def snr_mix(clean_y, 237 | noise_y, 238 | snr, 239 | target_dB_FS, 240 | target_dB_FS_floating_val, 241 | rir=None, 242 | rir_noise=None, 243 | eps=1e-7): 244 | if rir is not None: 245 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)] 246 | if rir_noise is not None: 247 | noise_y = signal.fftconvolve(noise_y, rir_noise)[:len(noise_y)] 248 | 249 | # todo spectral augmentation procedure 250 | 251 | clean_y = clean_y / (np.max(np.abs(clean_y)) + eps) 252 | 253 | clean_rms = (clean_y**2).mean()**0.5 254 | 255 | noise_y = noise_y / (np.max(np.abs(noise_y)) + eps) 256 | 257 | noise_rms = (noise_y**2).mean()**0.5 258 | snr_scalar = clean_rms / (10**(snr / 20)) / (noise_rms + eps) 259 | noise_y *= snr_scalar 260 | noisy_y = clean_y + noise_y 261 | 262 | noisy_target_dB_FS = np.random.randint( 263 | target_dB_FS - target_dB_FS_floating_val, 264 | target_dB_FS + target_dB_FS_floating_val) 265 | -------------------------------------------------------------------------------- /dataset/preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Okrio 3 | Date: 2022-02-15 22:21:55 4 | LastEditTime: 2022-03-28 22:53:13 5 | LastEditors: Please set LastEditors 6 | Description: Okrio 7 | FilePath: /CRUSE/dataset/preprocess_dataset.py 8 | ''' 9 | from dataclasses import replace 10 | import os 11 | # import sys 12 | import colorful 13 | from pathlib import Path 14 | 15 | import librosa 16 | from tqdm import tqdm 17 | import numpy as np 18 | import csv 19 | 20 | main_dir_name = "/dahufs/groupdata/codec" 21 | candidate_datasets = [ 22 | os.path.join(main_dir_name, 23 | 'audio/data_source/audio_classification/audio_noise') 24 | ] 25 | 26 | candidate_merge_datasets = ['asc_noise_txt', 'dns_2021_noise_train.txt'] 27 | dist_file = Path("sns_noise.txt").expanduser().absolute() 28 | dis_merge_file = Path("sns_merge.txt").expanduser().absolute() 29 | low_activity_file = Path("sns_low_activity.txt").expanduser().absolute() 30 | 31 | # speech parameters 32 | sr = 16000 33 | wav_in_second = 3 34 | activity_threshold = 0.6 35 | total_hrs = 10000.0 36 | csv_flag = False 37 | 38 | 39 | def read_csv(csv_name): 40 | with open(csv_name, 'r') as f: 41 | tmp = csv.reader(f) 42 | file_list = [] 43 | for i in tmp: 44 | file_list.append(i) 45 | print(os.path.basename(csv_name), ": Total file: ", len(file_list)) 46 | return file_list 47 | 48 | 49 | def offset_and_limit(data_list, offset, limit): 50 | data_list = data_list[offset:] 51 | if limit: 52 | data_list = data_list[:limit] 53 | return data_list 54 | 55 | 56 | def select_specify_file(file_path, ext=''): 57 | data_list = [] 58 | for i, name in enumerate(file_path): 59 | for j, type_name in enumerate(ext): 60 | if type_name in name: 61 | data_list.append(name) 62 | 63 | return data_list 64 | 65 | 66 | def multi_txt_file_merge(file_path): 67 | data_list = [] 68 | for i in file_path: 69 | data_list += [ 70 | line.rstrip('\n') 71 | for line in open(os.path.abspath(os.path.expanduser(i)), 'r') 72 | ] 73 | with open(dis_merge_file.as_posix(), 'w') as f: 74 | f.writelines(f"{file_path}\n" for file_path in data_list) 75 | return data_list 76 | 77 | 78 | dataset_offset = 0 79 | dataset_limit = None 80 | 81 | if __name__ == "__main__": 82 | # ss = multi_txt_file_merge(candidate_merge_datasets) 83 | all_wav_path_list = [] 84 | output_wav_path_list = [] 85 | accumulated_time = 0.0 86 | 87 | is_clipped_wav_list = [] 88 | is_low_activity_list = [] 89 | is_too_short_list = [] 90 | is_large_rt60_list = [] 91 | 92 | clean_all_wav_path_list = [] 93 | clean_output_wav_path_list = [] 94 | clean_accumulated_time = 0.0 95 | 96 | clean_is_clipped_wav_list = [] 97 | clean_is_low_activity_list = [] 98 | clean_is_too_short_list = [] 99 | 100 | rir_is_clipped_wav_list = [] 101 | rir_is_low_activity_list = [] 102 | rir_is_too_short_list = [] 103 | 104 | for dataset_path in candidate_datasets: 105 | if csv_flag: 106 | dataset_path1 = read_csv(dataset_path) 107 | dataset_path1 = dataset_path1[1:] 108 | dataset_path1 = dataset_path1[:] 109 | all_wav_path_list += dataset_path1 110 | else: 111 | dataset_path = Path(dataset_path).expanduser().absolute() 112 | all_wav_path_list += librosa.util.find_files( 113 | dataset_path.as_posix(), ext=["wav", "flac"]) 114 | 115 | all_wav_path_list = offset_and_limit(all_wav_path_list, dataset_offset, 116 | dataset_limit) 117 | 118 | for wav_file_path in tqdm(all_wav_path_list, desc="Checking..."): 119 | y, _ = librosa.load(wav_file_path, sr=sr) 120 | y = y if len(y.shape) > 1 else y[:, np.random.randint(y.shape[-1] + 1)] 121 | length = np.max(y.shape) 122 | if length == 0: 123 | print(wav_file_path) 124 | continue 125 | 126 | wav_duration = length / sr 127 | wav_file_user_path = wav_file_path, replace( 128 | Path(wav_file_path).home().as_posix(), "~") 129 | 130 | is_clipped_wav = 0 131 | is_low_activity = 0 132 | is_too_short = 0 133 | is_large_r60 = 0 134 | 135 | if is_too_short: 136 | is_too_short_list.append(wav_file_user_path) 137 | continue 138 | if is_clipped_wav: 139 | is_clipped_wav_list.append(wav_file_user_path) 140 | continue 141 | if is_low_activity: 142 | is_low_activity_list.append(wav_file_user_path) 143 | continue 144 | if is_large_r60: 145 | is_large_rt60_list.append(wav_file_user_path) 146 | 147 | if (not is_clipped_wav) and (not is_low_activity) and ( 148 | not is_too_short) and (not is_large_r60): 149 | accumulated_time += wav_duration 150 | output_wav_path_list.append(wav_file_user_path) 151 | 152 | if accumulated_time >= (total_hrs * 3600): 153 | break 154 | 155 | with open(dist_file.as_posix(), "w") as f: 156 | f.writelines(f"{file_path}\n" for file_path in output_wav_path_list) 157 | with open(low_activity_file.as_posix(), 'w') as f: 158 | f.writelines(f"{file_path}\n" for file_path in is_low_activity_list) 159 | 160 | color_spec = colorful 161 | color_spec.use_style("solarized") 162 | 163 | print("=" * 70) 164 | print("Speech Preprocessing") 165 | print(f"\t Original files: {len(all_wav_path_list)}") 166 | print( 167 | color_spec.red( 168 | f"\t Selected files: {accumulated_time / 3600} hrs, {len(output_wav_path_list)} files. " 169 | )) 170 | print(f"\t is_clipped_wav: {len(is_clipped_wav_list)}") 171 | print(f"\t is_low_activity:{len(is_low_activity_list)}") 172 | print(f"\t is_too_short: {len(is_too_short_list)}") 173 | print(f"\t is_too_largeRT60: {len(is_large_rt60_list)}") 174 | 175 | print("succeed") 176 | -------------------------------------------------------------------------------- /loss_func/__pycache__/distrib.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/loss_func/__pycache__/distrib.cpython-38.pyc -------------------------------------------------------------------------------- /loss_func/balancer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import typing as tp 3 | 4 | import torch 5 | from torch import autograd 6 | 7 | # from distrib import average_metrics 8 | 9 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 10 | """ 11 | Average a dictionary of metrics across all workers, using the optional 12 | `count` as unormalized weight. 13 | """ 14 | # if not is_distributed(): 15 | # return metrics 16 | keys, values = zip(*metrics.items()) 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | tensor = torch.tensor(list(values) + [1], 19 | device=device, 20 | dtype=torch.float32) 21 | tensor *= count 22 | # all_reduce(tensor) 23 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 24 | return dict(zip(keys, averaged)) 25 | 26 | def averager(beta: float = 1): 27 | """ 28 | Exponential Moving Average callback. 29 | Returns a single function that can be called to repeatidly update the EMA 30 | with a dict of metrics. The callback will return 31 | the new averaged dict of metrics. 32 | Note that for `beta=1`, this is just plain averaging. 33 | """ 34 | fix: tp.Dict[str, float] = defaultdict(float) 35 | total: tp.Dict[str, float] = defaultdict(float) 36 | 37 | def _update(metrics: tp.Dict[str, tp.Any], 38 | weight: float = 1) -> tp.Dict[str, float]: 39 | nonlocal total, fix 40 | for key, value in metrics.items(): 41 | total[key] = total[key] * beta + weight * float(value) 42 | fix[key] = fix[key] * beta + weight 43 | return {key: tot / fix[key] for key, tot in total.items()} 44 | 45 | return _update 46 | 47 | 48 | class Balancer: 49 | """Loss balancer. 50 | The loss balancer combines losses together to compute gradients for the backward. 51 | A call to the balancer will weight the losses according the specified weight coefficients. 52 | A call to the backward method of the balancer will compute the gradients, combining all the losses and 53 | potentially rescaling the gradients, which can help stabilize the training and reasonate 54 | about multiple losses with varying scales. 55 | Expected usage: 56 | weights = {'loss_a': 1, 'loss_b': 4} 57 | balancer = Balancer(weights, ...) 58 | losses: dict = {} 59 | losses['loss_a'] = compute_loss_a(x, y) 60 | losses['loss_b'] = compute_loss_b(x, y) 61 | if model.training(): 62 | balancer.backward(losses, x) 63 | ..Warning:: It is unclear how this will interact with DistributedDataParallel, 64 | in particular if you have some losses not handled by the balancer. In that case 65 | you can use `encodec.distrib.sync_grad(model.parameters())` and 66 | `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative. 67 | Args: 68 | weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys 69 | from the backward method to match the weights keys to assign weight to each of the provided loss. 70 | rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just 71 | a regular weighted sum of losses. 72 | total_norm (float): Reference norm when rescaling gradients, ignored otherwise. 73 | emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True. 74 | per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds 75 | when rescaling the gradients. 76 | epsilon (float): Epsilon value for numerical stability. 77 | monitor (bool): Whether to store additional ratio for each loss key in metrics. 78 | """ 79 | def __init__(self, 80 | weights: tp.Dict[str, float], 81 | rescale_grads: bool = True, 82 | total_norm: float = 1., 83 | ema_decay: float = 0.999, 84 | per_batch_item: bool = True, 85 | epsilon: float = 1e-12, 86 | monitor: bool = False): 87 | self.weights = weights 88 | self.per_batch_item = per_batch_item 89 | self.total_norm = total_norm 90 | self.averager = averager(ema_decay) 91 | self.epsilon = epsilon 92 | self.monitor = monitor 93 | self.rescale_grads = rescale_grads 94 | self._metrics: tp.Dict[str, tp.Any] = {} 95 | 96 | @property 97 | def metrics(self): 98 | return self._metrics 99 | 100 | def backward(self, losses: tp.Dict[str, torch.Tensor], 101 | input: torch.Tensor): 102 | norms = {} 103 | grads = {} 104 | for name, loss in losses.items(): 105 | grad, = autograd.grad(loss, [input], retain_graph=True) 106 | if self.per_batch_item: 107 | dims = tuple(range(1, grad.dim())) 108 | norm = grad.norm(dim=dims).mean() 109 | else: 110 | norm = grad.norm() 111 | norms[name] = norm 112 | grads[name] = grad 113 | 114 | count = 1 115 | if self.per_batch_item: 116 | count = len(grad) 117 | avg_norms = average_metrics(self.averager(norms), count) 118 | total = sum(avg_norms.values()) 119 | 120 | self._metrics = {} 121 | if self.monitor: 122 | for k, v in avg_norms.items(): 123 | self._metrics[f'ratio_{k}'] = v / total 124 | 125 | total_weights = sum([self.weights[k] for k in avg_norms]) 126 | ratios = {k: w / total_weights for k, w in self.weights.items()} 127 | 128 | out_grad: tp.Any = 0 129 | for name, avg_norm in avg_norms.items(): 130 | if self.rescale_grads: 131 | scale = ratios[name] * self.total_norm / (self.epsilon + 132 | avg_norm) 133 | grad = grads[name] * scale 134 | else: 135 | grad = self.weights[name] * grads[name] 136 | out_grad += grad 137 | input.backward(out_grad) 138 | 139 | 140 | def test(): 141 | from torch.nn import functional as F 142 | x = torch.zeros(1, requires_grad=True) 143 | one = torch.ones_like(x) 144 | loss_1 = F.l1_loss(x, one) 145 | loss_2 = 100 * F.l1_loss(x, -one) 146 | losses = {'1': loss_1, '2': loss_2} 147 | 148 | balancer = Balancer(weights={'1': 1, '2': 1}, rescale_grads=False) 149 | balancer.backward(losses, x) 150 | assert torch.allclose(x.grad, torch.tensor(99.)), x.grad 151 | 152 | loss_1 = F.l1_loss(x, one) 153 | loss_2 = 100 * F.l1_loss(x, -one) 154 | losses = {'1': loss_1, '2': loss_2} 155 | x.grad = None 156 | balancer = Balancer(weights={'1': 1, '2': 1}, rescale_grads=True) 157 | balancer.backward({'1': loss_1, '2': loss_2}, x) 158 | assert torch.allclose(x.grad, torch.tensor(0.)), x.grad 159 | 160 | 161 | if __name__ == '__main__': 162 | test() 163 | -------------------------------------------------------------------------------- /loss_func/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Torch distributed utilities.""" 7 | 8 | import typing as tp 9 | 10 | import torch 11 | 12 | 13 | def rank(): 14 | if torch.distributed.is_initialized(): 15 | return torch.distributed.get_rank() 16 | else: 17 | return 0 18 | 19 | 20 | def world_size(): 21 | if torch.distributed.is_initialized(): 22 | return torch.distributed.get_world_size() 23 | else: 24 | return 1 25 | 26 | 27 | def is_distributed(): 28 | return world_size() > 1 29 | 30 | 31 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 32 | if is_distributed(): 33 | return torch.distributed.all_reduce(tensor, op) 34 | 35 | 36 | def _is_complex_or_float(tensor): 37 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 38 | 39 | 40 | def _check_number_of_params(params: tp.List[torch.Tensor]): 41 | # utility function to check that the number of params in all workers is the same, 42 | # and thus avoid a deadlock with distributed all reduce. 43 | if not is_distributed() or not params: 44 | return 45 | tensor = torch.tensor([len(params)], 46 | device=params[0].device, 47 | dtype=torch.long) 48 | all_reduce(tensor) 49 | if tensor.item() != len(params) * world_size(): 50 | # If not all the workers have the same number, for at least one of them, 51 | # this inequality will be verified. 52 | raise RuntimeError( 53 | f"Mismatch in number of params: ours is {len(params)}, " 54 | "at least one worker has a different one.") 55 | 56 | 57 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): 58 | """Broadcast the tensors from the given parameters to all workers. 59 | This can be used to ensure that all workers have the same model to start with. 60 | """ 61 | if not is_distributed(): 62 | return 63 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 64 | _check_number_of_params(tensors) 65 | handles = [] 66 | for tensor in tensors: 67 | handle = torch.distributed.broadcast(tensor.data, 68 | src=src, 69 | async_op=True) 70 | handles.append(handle) 71 | for handle in handles: 72 | handle.wait() 73 | 74 | 75 | def sync_buffer(buffers, average=True): 76 | """ 77 | Sync grad for buffers. If average is False, broadcast instead of averaging. 78 | """ 79 | if not is_distributed(): 80 | return 81 | handles = [] 82 | for buffer in buffers: 83 | if torch.is_floating_point(buffer.data): 84 | if average: 85 | handle = torch.distributed.all_reduce( 86 | buffer.data, 87 | op=torch.distributed.ReduceOp.SUM, 88 | async_op=True) 89 | else: 90 | handle = torch.distributed.broadcast(buffer.data, 91 | src=0, 92 | async_op=True) 93 | handles.append((buffer, handle)) 94 | for buffer, handle in handles: 95 | handle.wait() 96 | if average: 97 | buffer.data /= world_size 98 | 99 | 100 | def sync_grad(params): 101 | """ 102 | Simpler alternative to DistributedDataParallel, that doesn't rely 103 | on any black magic. For simple models it can also be as fast. 104 | Just call this on your model parameters after the call to backward! 105 | """ 106 | if not is_distributed(): 107 | return 108 | handles = [] 109 | for p in params: 110 | if p.grad is not None: 111 | handle = torch.distributed.all_reduce( 112 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 113 | handles.append((p, handle)) 114 | for p, handle in handles: 115 | handle.wait() 116 | p.grad.data /= world_size() 117 | 118 | 119 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 120 | """ 121 | Average a dictionary of metrics across all workers, using the optional 122 | `count` as unormalized weight. 123 | """ 124 | # if not is_distributed(): 125 | # return metrics 126 | keys, values = zip(*metrics.items()) 127 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 128 | tensor = torch.tensor(list(values) + [1], 129 | device=device, 130 | dtype=torch.float32) 131 | tensor *= count 132 | # all_reduce(tensor) 133 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 134 | return dict(zip(keys, averaged)) 135 | -------------------------------------------------------------------------------- /loss_func/loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Okrio 3 | Date: 2022-02-12 11:12:24 4 | LastEditTime: 2022-03-04 23:51:43 5 | LastEditors: Please set LastEditors 6 | Description: loss function 7 | FilePath: /CRUSE/loss_func/loss.py 8 | ''' 9 | 10 | import torch 11 | from utils.utils import activity_detector_amp, activity_detector_tf_frame 12 | # from utils.utils import active_rms 13 | # import numpy as np 14 | 15 | 16 | class loss_func: 17 | def __init__(self, loss_mode) -> None: 18 | assert loss_mode in [ 19 | 'SI-SNR', 'SS-SNR', 'MSE', 'Normal_MSE', 'CN_MSE', 'D_MSE', 20 | 'WO_MALE', 'C_MSE' 21 | ], "Loss mode must be one of ***" 22 | self.loss_mode = loss_mode 23 | 24 | def loss(self, inputs, labels, noisy=None): 25 | if self.loss_mode == 'SI-SNR': 26 | return -(sisnr(inputs, labels)) 27 | elif self.loss_mode == 'SS-SNR': 28 | return 0 29 | elif self.loss_mode == 'WO_MALE': 30 | return wo_male(labels, inputs, noisy) 31 | elif self.loss_mode == 'C_MSE': 32 | return c_rmse(labels, inputs) 33 | elif self.loss_mode == 'MSE': 34 | return rmse(labels, inputs) 35 | 36 | 37 | def l2_norm(s1, s2): 38 | norm = torch.sum(s1 * s2, -1, keepdim=True) 39 | return norm 40 | 41 | 42 | def remove_dc(data): 43 | mean = torch.mean(data, -1, keepdim=True) 44 | data = data - mean 45 | return data 46 | 47 | 48 | def sisnr(s1, s2, eps=1e-8): 49 | s1_s2_norm = l2_norm(s1, s2) 50 | s2_s2_norm = l2_norm(s2, s2) 51 | s_target = s1_s2_norm / (s2_s2_norm + eps) * s2 52 | e_noise = s1 - s_target 53 | targer_norm = l2_norm(s_target, s_target) 54 | noise_norm = l2_norm(e_noise, e_noise) 55 | snr = 10 * torch.log10((targer_norm) / (noise_norm + eps) + eps) 56 | return torch.mean(snr) 57 | 58 | 59 | def rmse(ref, est, eps=1e-8): 60 | """ 61 | ref: BCTF 62 | est: BCTF 63 | 64 | """ 65 | if ref.shape != est.shape: 66 | raise RuntimeError( 67 | f"Dimension mismatch when calculate rmse, {ref.shape} vs {est.shape}" 68 | ) 69 | 70 | B, C, T, F = torch.size(ref) 71 | # real_ref = ref[:, 0, :, :] 72 | # imag_ref = ref[:, 1, :, :] 73 | 74 | # real_est = est[:, 0, :, :] 75 | # imag_est = est[:, 1, :, :] 76 | err = est - ref 77 | mse = torch.sum(torch.sqrt(err**2)) / (B * T * F) 78 | return mse 79 | 80 | 81 | def cn_rmse(ref, est, unproc=None, eps=1e-8): 82 | if ref.shape != est.shape: 83 | raise RuntimeError( 84 | f"Dimension mismatch when calculate c_mse, {ref.shape} vs {est.shape}" 85 | ) 86 | 87 | 88 | def c_rmse(ref, est, unproc=None, norm=False, eps=1e-8): 89 | if ref.shape != est.shape: 90 | raise RuntimeError( 91 | f"Dimension mismatch when calculate c_mse, {ref.shape} vs {est.shape}" 92 | ) 93 | c = 0.3 94 | beta = 0.3 95 | B, C, T, F = torch.size(ref) 96 | real_ref = ref[:, 0, :, :] 97 | imag_ref = ref[:, 1, :, :] 98 | 99 | real_est = est[:, 0, :, :] 100 | imag_est = est[:, 1, :, :] 101 | 102 | mag_ref = torch.sqrt(real_ref**2 + imag_ref**2) 103 | phase_ref = torch.atan2(imag_ref, real_ref) 104 | 105 | mag_est = torch.sqrt(real_est**2 + imag_est**2) 106 | phase_est = torch.atan2(imag_est, real_est) 107 | tmp1 = torch.pow(mag_est, c) 108 | tmp2 = torch.pow(mag_ref, c) 109 | tmp3 = tmp1 * torch.cos(phase_ref) + tmp1 * torch.sin(phase_ref) * 1j 110 | 111 | tmp4 = tmp2 * torch.cos(phase_est) + tmp1 * torch.sin(phase_est) * 1j 112 | tmp5 = tmp3 - tmp4 113 | tmp5 = torch.abs(tmp5) 114 | 115 | loss1 = (torch.pow(mag_ref, c) - torch.pow(mag_est, c))**2 116 | loss2 = tmp5**2 117 | loss = (1 - beta) * torch.sum(loss1) + beta * torch.sum(loss2) 118 | return loss 119 | 120 | 121 | def wo_male(ref, est, unproc, norm=False, eps=1e-8): 122 | if ref.shape != est.shape: 123 | raise RuntimeError( 124 | f"Dimension mismatch when calculate wo-male, {ref.shape} vs {est.shape}" 125 | ) 126 | alpha = 2 127 | beta = 1 128 | gamma = 1 129 | B, C, T, F = torch.size(ref) 130 | real_ref = ref[:, 0, :, :] 131 | imag_ref = ref[:, 1, :, :] 132 | 133 | real_est = est[:, 0, :, :] 134 | imag_est = est[:, 1, :, :] 135 | mag_ref = torch.sqrt(real_ref**2 + imag_ref**2) 136 | mag_est = torch.sqrt(real_est**2 + imag_est**2) 137 | # phase_ref = torch.atan2(imag_ref, real_ref) 138 | # phase_est = torch.atan2(imag_est, real_est) 139 | mag_unproc = torch.sqrt(unproc[:, 0, :, :]**2 + unproc[:, 1, :, 1]**2) 140 | # phase_unproc = torch.atan2(unproc[:, 1, :, :], unproc[:, 0, :, :]) 141 | 142 | iam = (mag_ref / mag_unproc)**gamma 143 | W_iam = torch.exp(alpha / (beta + iam)) 144 | 145 | loss = W_iam * torch.abs( 146 | torch.log10(mag_est + 1) - torch.log10(mag_ref + 1)) 147 | loss = torch.sum(loss) / (B * T * F * 1.0) 148 | return loss 149 | 150 | 151 | def sdnr(ref_clean, 152 | est_g, 153 | ref_noise, 154 | snr=None, 155 | beta=20, 156 | alpha=None, 157 | norm=False, 158 | eps=1e-8): 159 | """ 160 | ref_clean: B*C*T*F 161 | est_g: B*C*T*F 162 | ref_noise: BCTF 163 | """ 164 | B, C, T, F = torch.size(ref_clean) 165 | 166 | L_noise = torch.mean(torch.norm(ref_noise * est_g, p=2, dim=(1, 2))**2) 167 | 168 | vad = activity_detector_tf_frame(ref_clean) 169 | S_sa = vad * ref_clean 170 | L_speech = torch.mean(torch.norm(S_sa - est_g * S_sa, p=2, dim=(1, 2))**2) 171 | snr_tmp = 10**(snr / 10) 172 | beta_tmp = 10**(beta / 10) 173 | alpha = snr_tmp / (snr_tmp + beta_tmp) 174 | loss_out = alpha * L_speech + (1 - alpha) * L_noise 175 | return loss_out 176 | 177 | 178 | if __name__ == "__main__": 179 | # inda = np.linspa 180 | x = torch.ones((2, 1, 3, 4)) 181 | y = torch.sum(x, dim=0) 182 | print(f"y shape:{y.shape}") 183 | print('sc') -------------------------------------------------------------------------------- /model/based_model/cust_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | from typing import Callable, Iterable, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor, nn 9 | from torch.nn import functional as F 10 | from torch.nn import init 11 | from torch.nn.parameter import Parameter 12 | from typing_extensions import Final 13 | 14 | 15 | class Conv2dNormAct(nn.Sequential): 16 | def __init__(self, 17 | in_ch, 18 | out_ch, 19 | kernel_size, 20 | fstride=1, 21 | dilation=1, 22 | fpad=True, 23 | bias=True, 24 | separable=False, 25 | norm_layer=torch.nn.BatchNorm2d, 26 | activation_layer=torch.nn.ReLU): 27 | """ 28 | [B C T F] 29 | """ 30 | lookahead = 0 31 | kernel_size = ((kernel_size, kernel_size) if isinstance( 32 | kernel_size, int) else tuple(kernel_size)) 33 | 34 | if fpad: 35 | fpad_ = kernel_size[1] // 2 + dilation - 1 36 | else: 37 | fpad_ = 0 38 | pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) 39 | layers = [] 40 | if any(x > 0 for x in pad): 41 | layers.append(nn.ConstantPad2d(pad, 0.0)) 42 | groups = math.gcd(in_ch, out_ch) if separable else 1 43 | if groups == 1: 44 | separable = False 45 | if max(kernel_size) == 1: 46 | separable = False 47 | layers.append( 48 | nn.Conv2d(in_ch, 49 | out_ch, 50 | kernel_size, 51 | padding=(0, fpad_), 52 | stride=(1, fstride), 53 | dilation=(1, dilation), 54 | groups=groups, 55 | bias=bias)) 56 | if separable: 57 | layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) 58 | if norm_layer is not None: 59 | layers.append(norm_layer(out_ch)) 60 | if activation_layer is not None: 61 | layers.append(activation_layer()) 62 | super().__init__(*layers) 63 | 64 | 65 | class ConvTranspose2dNormAct(nn.Sequential): 66 | def __init__(self, 67 | in_ch, 68 | out_ch, 69 | kernel_size, 70 | fstride=1, 71 | dilation=1, 72 | fpad=True, 73 | bias=True, 74 | separable=False, 75 | norm_layer=torch.nn.BatchNorm2d, 76 | activation_layer=torch.nn.ReLU): 77 | lookahead = 0 78 | kernel_size = (kernel_size, kernel_size) if isinstance( 79 | kernel_size, int) else kernel_size 80 | if fpad: 81 | fpad_ = kernel_size[1] // 2 82 | else: 83 | fpad_ = 0 84 | 85 | pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) 86 | layers = [] 87 | 88 | if any(x > 0 for x in pad): 89 | layers.append(nn.ConstantPad2d(pad, 0.0)) 90 | groups = math.gcd(in_ch, out_ch) if separable else 1 91 | 92 | if groups == 1: 93 | separable = False 94 | layers.append( 95 | nn.ConvTranspose2d(in_ch, 96 | out_ch, 97 | kernel_size=kernel_size, 98 | padding=(kernel_size[0] - 1, 99 | fpad_ + dilation - 1), 100 | output_padding=(0, fpad_), 101 | stride=(1, fstride), 102 | dilation=(1, dilation), 103 | groups=groups, 104 | bias=bias)) 105 | if separable: 106 | layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) 107 | if norm_layer is not None: 108 | layers.append(norm_layer(out_ch)) 109 | if activation_layer is not None: 110 | layers.append(activation_layer()) 111 | super().__init__(*layers) 112 | 113 | 114 | def convkxf(in_ch, 115 | out_ch, 116 | k=1, 117 | f=3, 118 | fstride=2, 119 | lookahead=0, 120 | batch_norm=False, 121 | act=torch.nn.ReLU(inplace=True), 122 | mode="normal", 123 | depthwise=True, 124 | complex_in=False): 125 | bias = batch_norm is False 126 | assert f % 2 == 1 127 | stride = 1 if f == 1 else (1, fstride) 128 | if out_ch is None: 129 | out_ch = in_ch * 2 if mode == "normal" else in_ch // 2 130 | fpad = (f - 1) // 2 131 | convpad = (0, fpad) 132 | 133 | modules = [] 134 | pad = [0, 0, k - 1 - lookahead, lookahead] 135 | if any(p > 0 for p in pad): 136 | modules.append(("pad", nn.ConstantPad2d(pad, 0.0))) 137 | if depthwise: 138 | groups = min(in_ch, out_ch) 139 | else: 140 | groups = 1 141 | if in_ch % groups != 0 or out_ch % groups != 0: 142 | groups = 1 143 | if complex_in and groups % 2 == 0: 144 | groups //= 2 145 | 146 | convkwargs = { 147 | "in_channels": in_ch, 148 | "out_channels": out_ch, 149 | "kernel_size": (k, f), 150 | "stride": stride, 151 | "groups": groups, 152 | "bias": bias 153 | } 154 | if mode == "normal": 155 | modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) 156 | 157 | elif mode == "transposed": 158 | padding = (k - 1, fpad) 159 | modules.append(("sconv", 160 | nn.ConvTranspose2d(padding=padding, 161 | output_padding=convpad, 162 | **convkwargs))) 163 | elif mode == "upsample": 164 | modules.append(("upsample", FreqUpsample(fstride))) 165 | convkwargs["stride"] = 1 166 | modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) 167 | else: 168 | raise NotImplementedError() 169 | if groups > 1: 170 | modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False))) 171 | if batch_norm: 172 | modules.append(("norm", nn.BatchNorm2d(out_ch))) 173 | modules.append(("act", act)) 174 | return nn.Sequential(OrderedDict(modules)) 175 | 176 | 177 | class FreqUpsample(nn.Module): 178 | def __init__(self, factor, mode="nearest"): 179 | super().__init__() 180 | self.f = float(factor) 181 | self.mode = mode 182 | 183 | def forward(self, x): 184 | return F.interpolate(x, scale_factor=(1., self.f), mode=self.mode) 185 | 186 | 187 | def erb_fb_use(width: np.ndarray, 188 | sr: int, 189 | normalized: bool = True, 190 | inverse: bool = False) -> Tensor: 191 | """ 192 | construct freq2erb transform matrix 193 | """ 194 | n_freqs = int(np.sum(width)) 195 | all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1] 196 | b_pts = np.cumsum([0] + width.tolist()).astype(int)[:-1] 197 | fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) 198 | for i, (b, w) in enumerate(zip(b_pts.tolist(), width.tolist())): 199 | fb[b:b + w, i] = 1 200 | if inverse: 201 | fb = fb.t() 202 | if not normalized: 203 | fb /= fb.sum(dim=1, keepdim=True) 204 | else: 205 | if normalized: 206 | fb /= fb.sum(dim=0) 207 | return fb 208 | 209 | 210 | def freq2erb(freq_hz): 211 | return 9.265 * torch.log1p(freq_hz / (24.7 * 9.265)) 212 | 213 | 214 | def erb2freq(n_erb): 215 | return 24.7 * 9.265 * (torch.exp(n_erb / 9.265) - 1.) 216 | 217 | 218 | def erb_fb(sr, fft_size, nb_bands, min_nb_freqs): 219 | """ 220 | slice frequency band to erb, which is not overlap 221 | """ 222 | nyq_freq = sr / 2 223 | freq_width = sr / fft_size 224 | erb_low = freq2erb(torch.Tensor([0.])) 225 | erb_high = freq2erb(torch.Tensor([nyq_freq])) 226 | erb = torch.zeros([nb_bands], dtype=torch.int16) 227 | step = (erb_high - erb_low) / nb_bands 228 | prev_freq = 0 229 | freq_over = 0 230 | for i in range(nb_bands): 231 | f = erb2freq(erb_low + (i + 1) * step) 232 | fb = int(torch.round(f / freq_width)) 233 | nb_freqs = fb - prev_freq - freq_over 234 | if nb_freqs < min_nb_freqs: 235 | freq_over = min_nb_freqs - nb_freqs 236 | nb_freqs = min_nb_freqs 237 | else: 238 | freq_over = 0 239 | erb[i] = nb_freqs 240 | prev_freq = fb 241 | erb[nb_bands - 1] += 1 242 | too_large = torch.sum(erb) - (fft_size / 2 + 1) 243 | if too_large > 0: 244 | erb[nb_bands - 1] -= too_large 245 | assert torch.sum(erb) == (fft_size / 2 + 1) 246 | 247 | return erb 248 | 249 | 250 | class GroupedGRULayer(nn.Module): 251 | input_size: Final[int] 252 | hidden_size: Final[int] 253 | out_size: Final[int] 254 | bidirectional: Final[bool] 255 | num_directions: Final[int] 256 | groups: Final[int] 257 | batch_first: Final[bool] 258 | 259 | def __init__(self, 260 | input_size: int, 261 | hidden_size: int, 262 | groups: int, 263 | batch_first: bool = True, 264 | bias=True, 265 | dropout: float = 0, 266 | bidirectional=False): 267 | super().__init__() 268 | 269 | assert input_size % groups == 0 270 | assert hidden_size % groups == 0 271 | kwargs = { 272 | "bias": bias, 273 | "batch_first": batch_first, 274 | "dropout": dropout, 275 | "bidirectional": bidirectional, 276 | } 277 | self.input_size = input_size // groups 278 | self.hidden_size = hidden_size // groups 279 | self.out_size = hidden_size 280 | self.bidirectional = bidirectional 281 | self.num_directions = 2 if bidirectional else 1 282 | self.groups = groups 283 | self.batch_first = batch_first 284 | assert (self.hidden_size % 285 | groups) == 0, "Hidden size must be divisible by groups" 286 | self.layers = nn.ModuleList((nn.GRU(self.input_size, self.hidden_size, 287 | **kwargs) for _ in range(groups))) 288 | 289 | def flatten_parameters(self): 290 | for layer in self.layers: 291 | layer.flatten_parameters() 292 | 293 | def get_h0(self, 294 | batch_size: int = 1, 295 | device: torch.device = torch.device("cpu")): 296 | return torch.zeros( 297 | self.groups * self.num_directions, 298 | batch_size, 299 | self.hidden_size, 300 | device=device, 301 | ) 302 | 303 | def forward(self, 304 | input: Tensor, 305 | h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 306 | # input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size 307 | # state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size 308 | 309 | if h0 is None: 310 | dim0, dim1 = input.shape[:2] 311 | bs = dim0 if self.batch_first else dim1 312 | h0 = self.get_h0(bs, device=input.device) 313 | outputs: List[Tensor] = [] 314 | outstates: List[Tensor] = [] 315 | for i, layer in enumerate(self.layers): 316 | o, s = layer( 317 | input[..., i * self.input_size:(i + 1) * self.input_size], 318 | h0[i * self.num_directions:(i + 1) * 319 | self.num_directions].detach(), 320 | ) 321 | outputs.append(o) 322 | outstates.append(s) 323 | output = torch.cat(outputs, dim=-1) 324 | h = torch.cat(outstates, dim=0) 325 | return output, h 326 | 327 | 328 | class GroupGRU(nn.Module): 329 | groups: Final[int] 330 | num_layers: Final[int] 331 | batch_first: Final[bool] 332 | hidden_size: Final[int] 333 | bidirectional: Final[bool] 334 | num_directions: Final[int] 335 | shuffle: Final[int] 336 | add_outputs: Final[bool] 337 | 338 | def __init__(self, 339 | input_size, 340 | hidden_size, 341 | num_layers=1, 342 | groups=4, 343 | bias=True, 344 | batch_first=True, 345 | dropout=0., 346 | bidirectional=False, 347 | shuffle=True, 348 | add_outputs=False): 349 | super().__init__() 350 | kwargs = { 351 | "groups": groups, 352 | "bias": bias, 353 | "batch_first": batch_first, 354 | "dropout": dropout, 355 | "bidirectional": bidirectional 356 | } 357 | assert input_size % groups == 0 358 | assert hidden_size % groups == 0 359 | assert num_layers > 0 360 | self.input_size = input_size 361 | self.groups = groups 362 | self.num_layers = num_layers 363 | self.batch_first = batch_first 364 | self.hidden_size = hidden_size // groups 365 | self.bidirectional = bidirectional 366 | self.num_directions = 2 if bidirectional else 1 367 | 368 | if self.groups == 1: 369 | shuffle = False 370 | self.shuffle = shuffle 371 | self.add_outputs = add_outputs 372 | self.grus: List[GroupedGRULayer] = nn.ModuleList() 373 | self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs)) 374 | for _ in range(1, num_layers): 375 | self.grus.append( 376 | GroupedGRULayer(hidden_size, hidden_size, **kwargs)) 377 | self.flatten_parameters() 378 | 379 | def flatten_parameters(self): 380 | for gru in self.grus: 381 | gru.flatten_parameters() 382 | 383 | def get_h0( 384 | self, 385 | batch_size: int, 386 | ) -> Tensor: 387 | return torch.zeros( 388 | (self.num_layers * self.groups * self.num_directions, batch_size, 389 | self.hidden_size), ) 390 | 391 | def forward(self, 392 | input: Tensor, 393 | state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 394 | dim0, dim1, _ = input.shape 395 | b = dim0 if self.batch_first else dim1 396 | if state is None: 397 | state = self.get_h0(b, input.device) 398 | output = torch.zeros(dim0, 399 | dim1, 400 | self.hidden_size * self.num_directions * 401 | self.groups, 402 | device=input.device) 403 | outstates = [] 404 | h = self.groups * self.num_directions 405 | for i, gru in enumerate(self.grus): 406 | input, s = gru(input, state[i * h:(i + 1) * h]) 407 | outstates.append(s) 408 | if self.shuffle and i < self.num_layers - 1: 409 | input = (input.view(dim0, dim1, -1, self.groups).transpose( 410 | 2, 3).reshape(dim0, dim1, -1)) 411 | if self.add_outputs: 412 | output += input 413 | else: 414 | output = input 415 | outstate = torch.cat(outstates, dim=0) 416 | return output, outstate 417 | 418 | 419 | class SqueezedGRU(nn.Module): 420 | input_size: Final[int] 421 | hidden_size: Final[int] 422 | 423 | def __init__( 424 | self, 425 | input_size: int, 426 | hidden_size: int, 427 | output_size: Optional[int] = None, 428 | num_layers: int = 1, 429 | linear_groups: int = 8, 430 | batch_first: bool = True, 431 | gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, 432 | linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, 433 | ): 434 | super().__init__() 435 | self.input_size = input_size 436 | self.hidden_size = hidden_size 437 | self.linear_in = nn.Sequential( 438 | GroupedLinearEinsum(input_size, hidden_size, linear_groups), 439 | linear_act_layer()) 440 | self.gru = nn.GRU(hidden_size, 441 | hidden_size, 442 | num_layers=num_layers, 443 | batch_first=batch_first) 444 | self.gru_skip = gru_skip_op() if gru_skip_op is not None else None 445 | if output_size is not None: 446 | self.linear_out = nn.Sequential( 447 | GroupedLinearEinsum(hidden_size, output_size, linear_groups), 448 | linear_act_layer()) 449 | else: 450 | self.linear_out = nn.Identity() 451 | 452 | def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: 453 | input = self.linear_in(input) 454 | x, h = self.gru(input, h) 455 | if self.gru_skip is not None: 456 | x = x + self.gru_skip(input) 457 | x = self.linear_out(x) 458 | return x, h 459 | 460 | 461 | class SqueezedGRU_S(nn.Module): 462 | input_size: Final[int] 463 | hidden_size: Final[int] 464 | 465 | def __init__( 466 | self, 467 | input_size: int, 468 | hidden_size: int, 469 | output_size: Optional[int] = None, 470 | num_layers: int = 1, 471 | linear_groups: int = 8, 472 | batch_first: bool = True, 473 | gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, 474 | linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, 475 | ): 476 | super().__init__() 477 | self.input_size = input_size 478 | self.hidden_size = hidden_size 479 | self.linear_in = nn.Sequential( 480 | GroupedLinearEinsum(input_size, hidden_size, linear_groups), 481 | linear_act_layer()) 482 | self.gru = nn.GRU(hidden_size, 483 | hidden_size, 484 | num_layers=num_layers, 485 | batch_first=batch_first) 486 | self.gru_skip = gru_skip_op() if gru_skip_op is not None else None 487 | if output_size is not None: 488 | self.linear_out = nn.Sequential( 489 | GroupedLinearEinsum(hidden_size, output_size, linear_groups), 490 | linear_act_layer()) 491 | else: 492 | self.linear_out = nn.Identity() 493 | 494 | def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: 495 | x = self.linear_in(input) 496 | x, h = self.gru(x, h) 497 | x = self.linear_out(x) 498 | if self.gru_skip is not None: 499 | x = x + self.gru_skip(input) 500 | return x, h 501 | 502 | 503 | class GroupedLinearEinsum(nn.Module): 504 | input_size: Final[int] 505 | hidden_size: Final[int] 506 | groups: Final[int] 507 | 508 | def __init__(self, input_size: int, hidden_size: int, groups: int = 1): 509 | super().__init__() 510 | # self.weight: Tensor 511 | self.input_size = input_size 512 | self.hidden_size = hidden_size 513 | self.groups = groups 514 | assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" 515 | assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" 516 | self.ws = input_size // groups 517 | self.register_parameter( 518 | "weight", 519 | Parameter(torch.zeros(groups, input_size // groups, 520 | hidden_size // groups), 521 | requires_grad=True), 522 | ) 523 | self.reset_parameters() 524 | 525 | def reset_parameters(self): 526 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore 527 | 528 | def forward(self, x: Tensor) -> Tensor: 529 | # x: [..., I] 530 | b, t, _ = x.shape 531 | # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] 532 | new_shape = (b, t, self.groups, self.ws) 533 | x = x.view(new_shape) 534 | # The better way, but not supported by torchscript 535 | # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] 536 | x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] 537 | x = x.flatten(2, 3) # [B, T, H] 538 | return x 539 | 540 | def __repr__(self): 541 | cls = self.__class__.__name__ 542 | return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" 543 | 544 | 545 | class GroupedLinear(nn.Module): 546 | input_size: Final[int] 547 | hidden_size: Final[int] 548 | groups: Final[int] 549 | shuffle: Final[bool] 550 | 551 | def __init__(self, 552 | input_size: int, 553 | hidden_size: int, 554 | groups: int = 1, 555 | shuffle: bool = True): 556 | super().__init__() 557 | assert input_size % groups == 0 558 | assert hidden_size % groups == 0 559 | self.groups = groups 560 | self.input_size = input_size // groups 561 | self.hidden_size = hidden_size // groups 562 | if groups == 1: 563 | shuffle = False 564 | self.shuffle = shuffle 565 | self.layers = nn.ModuleList( 566 | nn.Linear(self.input_size, self.hidden_size) 567 | for _ in range(groups)) 568 | 569 | def forward(self, x: Tensor) -> Tensor: 570 | outputs: List[Tensor] = [] 571 | for i, layer in enumerate(self.layers): 572 | outputs.append( 573 | layer(x[..., i * self.input_size:(i + 1) * self.input_size])) 574 | output = torch.cat(outputs, dim=-1) 575 | if self.shuffle: 576 | orig_shape = output.shape 577 | output = (output.view(-1, self.hidden_size, self.groups).transpose( 578 | -1, -2).reshape(orig_shape)) 579 | return output 580 | 581 | 582 | def test_grouped_gru(): 583 | 584 | g = 2 585 | h = 4 586 | i = 2 587 | b = 1 588 | t = 5 589 | 590 | m = GroupedGRULayer(i, h, g, batch_first=True) 591 | 592 | input = torch.randn((b, t, i)) 593 | h0 = m.get_h0(b) 594 | assert list(h0.shape) == [g, b, h // g] 595 | out, hout = m(input, h0) 596 | 597 | num = 2 598 | m1 = GroupGRU(i, h, num, g, batch_first=True, shuffle=True) 599 | h0 = m1.get_h0(b) 600 | out1, hout1 = m1(input, h0) 601 | 602 | 603 | def test_erb(): 604 | from matplotlib import pyplot as plt 605 | import colorful 606 | colortool = colorful 607 | colortool.use_style("solarized") 608 | 609 | sr = 48000 610 | fft_size = 960 611 | nb_bands = 32 612 | min_nb_freqs = 2 613 | erb = erb_fb(sr, fft_size, nb_bands, min_nb_freqs) 614 | erb = erb.numpy().astype(int) 615 | fb = erb_fb_use(erb, sr, normalized=True) 616 | fb_inverse = erb_fb_use(erb, sr, normalized=True, inverse=True) 617 | print(colortool.red(f"fb:{fb.shape} {fb}")) 618 | print(colortool.yellow(f"fb_inverse:{fb_inverse.shape}, {fb_inverse}")) 619 | n_freqs = fft_size // 2 + 1 620 | input = torch.randn((1, 1, 1, n_freqs), dtype=torch.complex64) 621 | input_abs = input.abs().square() 622 | # erb_widths = erb 623 | py_erb = torch.matmul(input_abs, fb) 624 | 625 | py_out = torch.matmul(py_erb, fb_inverse) 626 | print(f"py_out:{torch.allclose(input_abs, py_out)}" 627 | ) # todo[okrio]: erb transform is not equal inverse erb transform 628 | 629 | print(colortool.red(f"py_out:{py_out.shape} ")) 630 | print(colortool) 631 | plt.figure() 632 | plt.plot(fb) 633 | plt.figure() 634 | plt.plot(fb_inverse.transpose(1, 0)) 635 | print('sc') 636 | 637 | 638 | if __name__ == "__main__": 639 | test_erb() 640 | -------------------------------------------------------------------------------- /model/cruse.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/model/cruse.py -------------------------------------------------------------------------------- /model/cruse_net.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: your name 3 | Date: 2022-02-13 19:35:14 4 | LastEditTime: 2022-02-13 23:28:52 5 | LastEditors: Please set LastEditors 6 | Description: In User Settings Edit 7 | FilePath: /CRUSE/model/cruse_net.py 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class GGRU(nn.Module): 15 | def __init__(self, 16 | in_features=None, 17 | out_features=None, 18 | mid_features=None, 19 | hidden_size=1024, 20 | groups=2): 21 | super(GGRU, self).__init__() 22 | hidden_size_t = hidden_size // groups 23 | self.gru_list1 = nn.ModuleList([ 24 | nn.GRU(hidden_size_t, hidden_size_t, 1, batch_first=True) 25 | for i in range(groups) 26 | ]) 27 | 28 | self.gru_list2 = nn.ModuleList([ 29 | nn.GRU(hidden_size_t, hidden_size_t, 1, batch_first=True) 30 | for i in range(groups) 31 | ]) 32 | self.ln1 = nn.LayerNorm(hidden_size) 33 | self.ln2 = nn.LayerNorm(hidden_size) 34 | self.groups = groups 35 | self.mid_features = mid_features 36 | 37 | def forward(self, x): 38 | out = x 39 | out = out.transpose(1, 2).contiguous() 40 | out = out.view(out.size(0), out.size(1), -1).contiguous() 41 | 42 | out = torch.chunk(out, self.groups, dim=-1) 43 | out = torch.stack( 44 | [self.gru_list1[i](out[i])[0] for i in range(self.groups)], dim=-1) 45 | out = torch.flatten(out, start_dim=-2, end_dim=-1) 46 | out = self.ln1(out) 47 | 48 | out = torch.chunk(out, self.groups, dim=-1) 49 | out = torch.cat( 50 | [self.gru_list2[i](out[i])[0] for i in range(self.groups)], dim=-1) 51 | out = self.ln2(out) 52 | 53 | out = self.view(out.size(0), out.size(1), x.size(1), -1).contiguous() 54 | out = out.transpose(1, 2).contiguous() 55 | return out 56 | 57 | 58 | class unet1(nn.Modules): 59 | def __init__( 60 | self, 61 | in_features, 62 | in_channels, 63 | out_channels, 64 | kernel_size, 65 | stride, 66 | padding, 67 | ) -> None: 68 | super(unet1, self).__init__() 69 | self.in_features = in_features 70 | self.in_channels = in_channels 71 | self.out_channels = out_channels 72 | self.stride = stride 73 | self.kernel_size = kernel_size 74 | self.padding = padding 75 | 76 | self.conv1 = nn.Conv2d( 77 | in_channels=in_channels, 78 | out_channels=out_channels, 79 | stride=stride, 80 | kernel_size=kernel_size, 81 | ) 82 | 83 | self.conv2 = nn.Conv2d( 84 | in_channels=in_channels, 85 | out_channels=out_channels, 86 | kernel_size=kernel_size, 87 | stride=stride, 88 | ) 89 | 90 | self.conv3 = nn.Conv2d(in_channels=in_channels, 91 | out_channels=out_channels, 92 | kernel_size=kernel_size, 93 | stride=stride) 94 | 95 | self.gru = GGRU(groups=2) 96 | 97 | self.conv2_t_1 = nn.ConvTranspose2d(in_channels=in_channels, 98 | out_channels=out_channels, 99 | kernel_size=kernel_size, 100 | stride=stride) 101 | 102 | self.conv1_t_1 = nn.ConvTranspose2d(in_channels=in_channels, 103 | out_channels=out_channels, 104 | kernel_size=kernel_size, 105 | stride=stride) 106 | self.fc = nn.Linear(in_features=in_features, out_features=in_features) 107 | 108 | self.bn1 = nn.BatchNorm2d(in_channels) 109 | self.bn2 = nn.BatchNorm2d(in_channels) 110 | self.bn3 = nn.BatchNorm2d(in_channels) 111 | self.bn2_t_1 = nn.BatchNorm2d(in_channels) 112 | self.bn1_t_1 = nn.BatchNorm2d(in_channels) 113 | self.elu = nn.ELU(inplace=True) 114 | 115 | def forward(self, x): 116 | out = x 117 | e1 = self.elu(self.bn1(self.conv1(out))) 118 | e2 = self.elu(self.bn2(self.conv2(e1))) 119 | e3 = self.elu(self.bn3(self.conv3(e2))) 120 | 121 | out = e3 122 | out = self.gru(out) 123 | out1 = self.elu(self.bn2_t_1(self.conv2_t_1(out))) 124 | out1 = self.elu(self.bn1_t_1(self.conv2_t_1(out1))) 125 | out1 = self.fc(out1) 126 | 127 | return out 128 | 129 | class unet_2(nn.Module): 130 | def __init__(self,in_feat=161, ch=(1,8,16,32,64), stride=(1,2), rnn_groups=4): 131 | super(unet_2,self).__init__() 132 | self.laynum = len(ch) -1 133 | hidden_size = (in_feat//2**self.laynum*ch[-1]) 134 | self.ker_x = 2 135 | self.stride = stride 136 | self.padding = [self.ker_x -stride[0], 3 - stride[1]] 137 | for i in range(len(ch)-1): 138 | setattr(self,"conv{}".format(i+1), nn.Conv2d(ch[i], ch[i+1],(self.ker_x,3), self.stride, self.padding)) 139 | tmp = len(ch) - 1 -i 140 | setattr(self,"conv{}".format(tmp), nn.Conv2d(ch[tmp-1], ch[i+1],(1,3), self.stride)) 141 | setattr(self,"bn{}".format(i+1), nn.BatchNorm2d(ch[i+1])) 142 | setattr(self,"bn{}_t".format(i+1), nn.BatchNorm2d(ch[tmp-1])) 143 | setattr(self,"skip_connect_{}".format(i+1), nn.Conv2d(ch[i+1], ch[i+1],(1,3),bias=False)) 144 | self.gru = GGRU(groups=rnn_groups) 145 | self.elu = nn.ReLU() 146 | self.fc = nn.Linear(in_feat, in_feat) 147 | def forward(self,x): 148 | out =x 149 | e1_tmp = self.elu(self.bn1(self.conv1(out)[...,-self.padding[0],:])) 150 | e2 = self.elu(self.bn2(self.conv2(e1_tmp)[...,-self.padding[0],:])) 151 | e3 = self.elu(self.bn2(self.conv2(e2)[...,-self.padding[0],:])) 152 | e4 = self.elu(self.bn2(self.conv2(e3)[...,-self.padding[0],:])) 153 | skip1 = self.skip_connect_1(e1_tmp) 154 | skip2 = self.skip_connect_1(e2) 155 | ski3 = self.skip_connect_1(e3) 156 | skip4 = self.skip_connect_1(e4) 157 | b,c,t,f = e3.size() 158 | out_gru = self.gru(e4) 159 | out2 = out_gru 160 | out = out2 + skip4 161 | d4_1 = self.elu(self.bn4_t(self.conv4_t(out)[...,:-1]))+skip3 162 | d3_1 = self.elu(self.bn3_t(self.conv3_t(out)[...,:-1]))+skip2 163 | d2_1 = self.elu(self.bn2_t(self.conv2_t(out)[...,:-1]))+skip1 164 | d1_1 = nn.Sigmoid()(self.conv1_t(out)[...,:-1]) 165 | return d1_1 166 | 167 | 168 | if __name__ == "__main__": 169 | x = torch.randn((2, 1, 10, 161)) 170 | net = unet1() 171 | y = net(x) 172 | print(f"{x.shape}->{y.shape}") 173 | -------------------------------------------------------------------------------- /model/deep_filter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Okrio 3 | Date: 2022-02-24 22:29:22 4 | LastEditTime: 2022-02-24 23:10:17 5 | LastEditors: Please set LastEditors 6 | Description: deep filter module 7 | FilePath: /CRUSE/model/deep_filter.py 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class DeepFilter(nn.Module): 16 | def __init__(self, t_dim, f_dim): 17 | super(DeepFilter, self).__init__() 18 | self.t_dim = t_dim 19 | self.f_dim = f_dim 20 | t_width = t_dim * 2 + 1 21 | f_width = f_dim * 2 + 1 22 | kernel = torch.eye(t_width * f_width) 23 | # self.kernel = kernel 24 | self.register_buffer( 25 | 'kernel', 26 | torch.reshape(kernel, [t_width(f_width, 1, f_width, t_width)])) 27 | 28 | def forward(self, inputs, filters): 29 | chunked_inputs = F.conv2d(torch.cat(inputs, 0)[:, None], 30 | self.kernel, 31 | padding=[self.f_dim, self.t_dim]) 32 | inputs_r, inputs_i = torch.chunk(chunked_inputs, 2, 0) 33 | chunked_filters = F.conv2d(torch.cat(filters, 0)[:, None], 34 | self.kernel, 35 | padding=[self.f_dim, self.t_dim]) 36 | filters_r, filters_i = torch.chunk(chunked_filters, 2, 0) 37 | outputs_r = inputs_r * filters_r - inputs_i * filters_i 38 | output_i = inputs_r * filters_i + inputs_r * filters_i 39 | outputs_r = torch.sum(outputs_r, 1) 40 | output_i = torch.sum(output_i, 1) 41 | return torch.cat([outputs_r, output_i], dim=1) 42 | 43 | 44 | if __name__ == "__main__": 45 | inputs = [torch.randn(10, 256, 100), torch.randn(10, 256, 100)] 46 | mask = [torch.randn(10, 256, 100), torch.randn(10, 256, 100)] 47 | net = DeepFilter(1, 5) 48 | outputs = net(inputs, mask) 49 | print(outputs.shape) 50 | print('sc') 51 | -------------------------------------------------------------------------------- /model/dfsmn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Okrio 3 | Date: 2022-02-24 22:06:47 4 | LastEditTime: 2022-02-24 23:10:38 5 | LastEditors: Okrio 6 | Description: dfsmn module 7 | FilePath: /CRUSE/model/dfsmn.py 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class DFSMN(nn.Module): 16 | def __init__(self, 17 | input_dim, 18 | hidden_dim, 19 | output_dim, 20 | left_frames=1, 21 | left_dilation=1, 22 | right_frames=1, 23 | right_dilation=1): 24 | super(DFSMN, self).__init__() 25 | self.left_frames = left_frames 26 | self.right_frames = right_frames 27 | self.in_conv = nn.Conv1d(input_dim, hidden_dim, kernel_size=1) 28 | # self.norm=nn.InstanceNorm1d(hidden_dim) 29 | # nn.init.normal_(self.in_conv.weight.data, std=0.05) 30 | if left_frames > 0: 31 | self.left_conv = nn.Sequential( 32 | nn.ConstantPad1d([left_dilation * left_frames, 0], 0), 33 | nn.Conv1d(hidden_dim, 34 | hidden_dim, 35 | kernel_size=left_frames + 1, 36 | dilation=left_dilation, 37 | bias=False, 38 | groups=hidden_dim)) 39 | # nn.init.normal_(self.left_conv[1].weight.data,std=0.05) 40 | if right_frames > 0: 41 | self.right_conv = nn.Sequential( 42 | nn.ConstantPad1d( 43 | [-right_dilation, right_frames * right_dilation], 0), 44 | nn.Conv1d(hidden_dim, 45 | hidden_dim, 46 | kernel_size=right_frames, 47 | dilation=right_dilation, 48 | bias=False, 49 | groups=hidden_dim)) 50 | # nn.init.normal_(self.right_conv[1].weight.data,std=0.05) 51 | self.out_conv = nn.Conv1d(hidden_dim, output_dim, kernel_size=1) 52 | # nn.init.normal_(self.out_conv.weight.data,std=0.05) 53 | self.weight = nn.Parameter(torch.Tensor([0]), requires_grad=True) 54 | 55 | def forward(self, inputs, hidden=None): 56 | out = self.in_conv(inputs) 57 | # out = F.relu(out) 58 | # out = self.norm(out) 59 | if self.left_frames > 0: 60 | left = self.left_conv(out) 61 | else: 62 | left = 0 63 | if self.right_frames > 0: 64 | right = self.right_conv(out) 65 | else: 66 | right = 0 67 | out_p = out + left + right 68 | if hidden is not None: 69 | out_p = hidden + F.relu(out_p) * self.weight 70 | out = self.out_conv(out_p) 71 | return out, out_p 72 | 73 | 74 | if __name__ == "__main__": 75 | inputs = torch.randn(10, 257, 199) 76 | net = DFSMN(257, 128, 137, left_dilation=2, right_dilation=3) 77 | print(net(inputs)[0].shape) 78 | print('sc') 79 | -------------------------------------------------------------------------------- /model/mtfaa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from spafe.fbank import linear_fbanks 5 | import einops 6 | 7 | 8 | class STFT(nn.Module): 9 | def __init__(self, win_len, hop_len, fft_len, win_type) -> None: 10 | super(STFT, self).__init__() 11 | self.win, self.hop = win_len, hop_len 12 | self.nfft = fft_len 13 | window = { 14 | "hann": torch.hann_window(win_len), 15 | "hamm": torch.hamming_window(win_len), 16 | } 17 | assert win_type in window.keys() 18 | self.window = window[win_type] 19 | 20 | def transform(self, inp): 21 | cspec = torch.stft(inp, 22 | self.nfft, 23 | self.hop, 24 | self.win, 25 | self.window.to(inp.device), 26 | return_complex=False) 27 | cspec = einops.rearrange(cspec, "b f t c -> b c f t") 28 | return cspec 29 | 30 | def inverse(self, real, imag): 31 | """ 32 | real, imag:BFT 33 | """ 34 | inp = torch.stack([real, imag], dim=-1) 35 | return torch.istft(inp, self.nfft, self.hop, 36 | self.window.to(real.device)) 37 | 38 | 39 | class ComplexConv2d(nn.Module): 40 | def __init__(self, 41 | in_channels, 42 | out_channels, 43 | kernel_size=(1, 1), 44 | stride=(1, 1), 45 | padding=(0, 0), 46 | dilation=1, 47 | groups=1, 48 | casual=True, 49 | complex_axis=1): 50 | super(ComplexConv2d, self).__init__() 51 | self.in_channels = in_channels // 2 52 | self.out_channels = out_channels // 2 53 | self.kernel_size = kernel_size 54 | self.stride = stride 55 | self.padding = padding 56 | self.causal = casual 57 | self.groups = groups 58 | self.dilation = dilation 59 | self.complex_axis = complex_axis 60 | 61 | self.real_conv = nn.Conv2d(self.in_channels, 62 | self.out_channels, 63 | kernel_size, 64 | self.stride, 65 | padding=(self.padding[0], 0), 66 | dilation=self.dilation, 67 | groups=self.groups) 68 | self.imag_conv = nn.Conv2d(self.in_channels, 69 | self.out_channels, 70 | kernel_size, 71 | self.stride, 72 | padding=(self.padding[0], 0), 73 | dilation=self.dilation, 74 | groups=self.groups) 75 | 76 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 77 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 78 | nn.init.normal_(self.real_conv.bias, 0.) 79 | nn.init.normal_(self.imag_conv.bias, 0.) 80 | 81 | def forward(self, inputs): 82 | if self.padding[1] != 0 and self.causal: 83 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) # causal 84 | else: 85 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 86 | 87 | if self.complex_axis == 0: 88 | real = self.real_conv(inputs) 89 | imag = self.imag_conv(inputs) 90 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 91 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 92 | 93 | else: 94 | if isinstance(inputs, torch.Tensor): 95 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 96 | 97 | real2real = self.real_conv(real, ) 98 | imag2imag = self.imag_conv(imag) 99 | real2imag = self.imag_conv(real) 100 | imag2real = self.real_conv(imag) 101 | 102 | real = real2real - imag2imag 103 | imag = real2imag + imag2real 104 | 105 | out = torch.cat((real, imag), self.complex_axis) 106 | 107 | return out 108 | 109 | 110 | def complex_cat(inps, dim=1): 111 | reals, imags = [], [] 112 | 113 | for inp in inps: 114 | real, imag = inp.chunk(2, dim) 115 | reals.append(real) 116 | imags.append(imag) 117 | 118 | reals = torch.cat(reals, dim) 119 | imags = torch.cat(imags, dim) 120 | return reals, imags 121 | 122 | 123 | class ComplexLinearProjection(nn.Module): 124 | def __init__(self, cin): 125 | super(ComplexLinearProjection, self).__init__() 126 | self.clp = ComplexConv2d(cin, cin) 127 | 128 | def forward(self, real, imag): 129 | """ 130 | real, imag: BCFT 131 | """ 132 | 133 | inputs = torch.cat((real, imag), 1) 134 | outputs = self.clp(inputs) 135 | 136 | real, imag = outputs.chunk(2, dim=1) 137 | outputs = torch.sqrt(real**2 + imag**2 + 1e-8) 138 | return outputs 139 | 140 | 141 | class PhaseEncoder(nn.Module): 142 | def __init__(self, cout, n_sig, cin=2, alpha=0.5): 143 | super(PhaseEncoder, self).__init__() 144 | self.complexnn = nn.ModuleList() 145 | 146 | for _ in range(n_sig): 147 | self.complexnn.append( 148 | nn.Sequential(nn.ConstantPad2d((2, 0, 0, 0), 0.0), 149 | ComplexConv2d(cin, cout, (1, 3)))) 150 | self.clp = ComplexLinearProjection(cout * n_sig) 151 | self.alpha = alpha 152 | 153 | def forward(self, cspecs): 154 | """ 155 | cspecs: BCFT 156 | """ 157 | outs = [] 158 | for idx, layer in enumerate(self.complexnn): 159 | outs.append(layer(cspecs[idx])) 160 | real, imag = complex_cat(outs, dim=1) 161 | 162 | amp = self.clp(real, imag) 163 | return amp**self.alpha 164 | 165 | 166 | class TFCM_Block(nn.Module): 167 | def __init__(self, cin=24, K=(3, 3), dila=1, causal=True): 168 | super(TFCM_Block, self).__init__() 169 | self.pconv1 = nn.Sequential(nn.Conv2d(cin, cin, kernel_size=(1, 1)), 170 | nn.BatchNorm2d(cin), nn.PReLU(cin)) 171 | dila_pad = dila * (K[1] - 1) 172 | if causal: 173 | self.dila_conv = nn.Sequential( 174 | nn.ConstantPad2d((dila_pad, 0, 1, 1), 0.0), 175 | nn.Conv2d(cin, cin, K, 1, dilation=(1, dila), groups=cin), 176 | nn.BatchNorm2d(cin), nn.PReLU(cin)) 177 | else: 178 | self.dila_conv = nn.Sequential( 179 | nn.ConstantPad2d((dila_pad // 2, dila_pad // 2, 1, 1), 0, 0), 180 | nn.Conv2d(cin, cin, K, 1, dilation=(1, dila)), 181 | nn.BatchNorm2d(cin), nn.PReLU(cin)) 182 | self.pconv2 = nn.Conv2d(cin, cin, kernel_size=(1, 1)) 183 | self.causal = causal 184 | self.dila_pad = dila_pad 185 | 186 | def forward(self, inps): 187 | """ 188 | inp : BCFT 189 | """ 190 | outs = self.pconv1(inps) 191 | outs = self.dila_conv(outs) 192 | outs = self.pconv2(outs) 193 | return outs + inps 194 | 195 | 196 | class TFCM(nn.Module): 197 | def __init__(self, cin=24, K=(3, 3), tfcm_layer=6, causal=True) -> None: 198 | super(TFCM).__init__() 199 | self.tfcm = nn.ModuleList() 200 | for idx in range(tfcm_layer): 201 | self.tfcm.append(TFCM_Block(cin, K, 2**idx, causal=causal)) 202 | 203 | def forward(self, inp): 204 | out = inp 205 | for idx in range(len(self.tfcm)): 206 | out = self.tfcm[idx](out) 207 | return out 208 | 209 | 210 | class Banks(nn.Module): 211 | def __init__(self, 212 | nfilters, 213 | nfft, 214 | fs, 215 | low_freq=None, 216 | high_freq=None, 217 | learnable=False) -> None: 218 | super(Banks, self).__init__() 219 | self.nfilters, self.nfft, self.fs = nfilters, nfft, fs 220 | filter = linear_fbanks.linear_filter_banks(nfilts=self.nfilters, 221 | nfft=self.nfft, 222 | low_freq=low_freq, 223 | high_freq=high_freq, 224 | fs=self.fs) 225 | filter = torch.from_numpy(filter).float() 226 | if not learnable: 227 | self.register_buffer('filter', filter * 1.3) 228 | self.register_buffer('filter_inv', torch.pinverse(filter)) 229 | 230 | else: 231 | self.filter = nn.Parameter(filter) 232 | self.filter_inv = nn.Parameter(torch.pinverse(filter)) 233 | 234 | def amp2bank(self, amp): 235 | amp_feature = torch.einsum('bcft,kf->bckt', amp, self.filter) 236 | return amp_feature 237 | 238 | def bank2amp(self, inputs): 239 | return torch.einsum("bckt,kt->bcft", inputs, self.filter_inv) 240 | 241 | 242 | def test_bank(): 243 | import soundfile as sf 244 | import numpy as np 245 | 246 | stft = STFT(32 * 48, 8 * 48, 32 * 48, "hann") 247 | net = Banks(256, 32 * 48, 48000) 248 | 249 | sig_raw, sr = sf.read(".wav") 250 | sig = torch.from_numpy(sig_raw)[None, :].float() 251 | cspec = stft.transform(sig) 252 | 253 | mag = torch.norm(cspec, dim=1) 254 | phase = torch.atan2(cspec[:, 1, :, :], cspec[:, 0, :, :]) 255 | mag = mag.unsqueeze(dim=1) 256 | outs = net.amp2bank(mag) 257 | 258 | outs = net.bank2amp(outs) 259 | print(F.mse_loss(outs, mag)) 260 | 261 | outs = outs.squeeze(dim=1) 262 | real = outs * torch.cos(phase) 263 | imag = outs * torch.sin(phase) 264 | 265 | sig_rec = stft.inverse(real, imag) 266 | sig_rec = sig_rec.cpu().data.numpy()[0] 267 | min_len = min(len(sig_rec), len(sig_raw)) 268 | sf.write('rs.wav', np.stack([sig_rec[:min_len], sig_raw[:min_len]], 269 | axis=1), sr) 270 | print(np.mean(np.square(sig_rec[:min_len] - sig_raw[:min_len]))) 271 | 272 | 273 | def test_tfcm(): 274 | nnet = TFCM(24) 275 | inp = torch.randn(2, 24, 256, 101) 276 | out = nnet(inp) 277 | print(out.shape) 278 | 279 | 280 | if __name__ == "__main__": 281 | net = PhaseEncoder(cout=4, n_sig=1) 282 | inps = torch.randn(3, 2, 769, 126) 283 | outs = net([inps]) 284 | print(outs.shape) 285 | print('sc') 286 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: your name 3 | Date: 2022-03-03 23:45:17 4 | LastEditTime: 2022-03-03 23:45:18 5 | LastEditors: Please set LastEditors 6 | Description: 打开koroFileHeader查看配置 进行设置: https://github.com/OBKoro1/koro1FileHeader/wiki/%E9%85%8D%E7%BD%AE 7 | FilePath: /CRUSE/test/__init__.py 8 | ''' 9 | -------------------------------------------------------------------------------- /test/__pycache__/test_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/test/__pycache__/test_loss.cpython-38.pyc -------------------------------------------------------------------------------- /test/testBSRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | from joblib import Parallel, delayed 5 | from pesq import pesq 6 | import numpy as np 7 | 8 | class LearnableSigmoid(nn.Module): 9 | def __init__(self, in_features, beta=1.2): 10 | super().__init__() 11 | self.beta = beta 12 | self.slope = nn.Parameter(torch.ones(in_features)) 13 | self.slope.requiresGrad = True 14 | 15 | def forward(self, x): 16 | return self.beta * torch.sigmoid(self.slope * x) 17 | 18 | def pesq_loss(clean, noisy, sr=16000): 19 | try: 20 | pesq_score = pesq(sr, clean, noisy, 'wb') 21 | except: 22 | # error can happen due to silent period 23 | pesq_score = -1 24 | return pesq_score 25 | 26 | def batch_pesq(clean, noisy): 27 | pesq_score = Parallel(n_jobs=-1)(delayed(pesq_loss)(c, n) for c, n in zip(clean, noisy)) 28 | pesq_score = np.array(pesq_score) 29 | if -1 in pesq_score: 30 | return None 31 | pesq_score = (pesq_score + 0.5) / 5 32 | return torch.FloatTensor(pesq_score).to('cuda') 33 | 34 | class BSRNN(nn.Module): 35 | def __init__(self, num_channel=128, num_layer=6): 36 | super(BSRNN, self).__init__() 37 | self.num_layer = num_layer 38 | self.band_split = BandSplit(channels=num_channel) 39 | 40 | for i in range(self.num_layer): 41 | setattr(self, 'norm_t{}'.format(i + 1), nn.GroupNorm(1,num_channel)) 42 | setattr(self, 'lstm_t{}'.format(i + 1), nn.LSTM(num_channel,2*num_channel,batch_first=True)) 43 | setattr(self, 'fc_t{}'.format(i + 1), nn.Linear(2*num_channel,num_channel)) 44 | 45 | for i in range(self.num_layer): 46 | setattr(self, 'norm_k{}'.format(i + 1), nn.GroupNorm(1,num_channel)) 47 | setattr(self, 'lstm_k{}'.format(i + 1), nn.LSTM(num_channel,2*num_channel,batch_first=True,bidirectional=True)) 48 | setattr(self, 'fc_k{}'.format(i + 1), nn.Linear(4*num_channel,num_channel)) 49 | 50 | self.mask_decoder = MaskDecoder(channels=num_channel) 51 | 52 | # Perform initialization 53 | for m in self.modules(): 54 | if type(m) in [nn.LSTM]: 55 | for name, param in m.named_parameters(): 56 | if 'weight_ih' in name: 57 | torch.nn.init.xavier_uniform_(param.data) 58 | elif 'weight_hh' in name: 59 | torch.nn.init.orthogonal_(param.data) 60 | elif 'bias' in name: 61 | param.data.fill_(0) 62 | if isinstance(m, torch.nn.Linear): 63 | m.weight.data = torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x): 68 | x = torch.view_as_real(x) 69 | z = self.band_split(x).transpose(1,2) 70 | 71 | B, N, T, K = z.shape 72 | 73 | skip = z 74 | for i in range(self.num_layer): 75 | out = getattr(self, 'norm_t{}'.format(i + 1))(skip) 76 | out = out.transpose(1,3).reshape(B*K, T, N) 77 | out, _ = getattr(self, 'lstm_t{}'.format(i + 1))(out) 78 | out = getattr(self, 'fc_t{}'.format(i + 1))(out) 79 | out = out.reshape(B, K, T, N).transpose(1,3) 80 | skip = skip + out 81 | 82 | for i in range(self.num_layer): 83 | out = getattr(self, 'norm_k{}'.format(i + 1))(skip) 84 | out = out.permute(0,2,3,1).contiguous().reshape(B*T, K, N) 85 | out, _ = getattr(self, 'lstm_k{}'.format(i + 1))(out) 86 | out = getattr(self, 'fc_k{}'.format(i + 1))(out) 87 | out = out.reshape(B, T, K, N).permute(0,3,1,2).contiguous() 88 | skip = skip + out 89 | 90 | m = self.mask_decoder(skip) 91 | m = torch.view_as_complex(m) 92 | x = torch.view_as_complex(x) 93 | 94 | s = m[:,:,1:-1,0]*x[:,:,:-2]+m[:,:,1:-1,1]*x[:,:,1:-1]+m[:,:,1:-1,2]*x[:,:,2:] 95 | s_f = m[:,:,0,1]*x[:,:,0]+m[:,:,0,2]*x[:,:,1] 96 | s_l = m[:,:,-1,0]*x[:,:,-2]+m[:,:,-1,1]*x[:,:,-1] 97 | s = torch.cat((s_f.unsqueeze(2),s,s_l.unsqueeze(2)),dim=2) 98 | 99 | return s 100 | 101 | class BandSplit(nn.Module): 102 | def __init__(self, channels=128): 103 | super(BandSplit, self).__init__() 104 | self.band = torch.Tensor([ 105 | 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 106 | 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 107 | 16, 16, 16, 16, 16, 16, 16, 17]) 108 | for i in range(len(self.band)): 109 | setattr(self, 'norm{}'.format(i + 1), nn.GroupNorm(1,int(self.band[i]*2))) 110 | setattr(self, 'fc{}'.format(i + 1), nn.Linear(int(self.band[i]*2),channels)) 111 | 112 | def forward(self, x): 113 | hz_band = 0 114 | x = x.transpose(1,2) 115 | for i in range(len(self.band)): 116 | x_band = x[:,:,hz_band:hz_band+int(self.band[i]),:] 117 | x_band = torch.reshape(x_band,[x_band.size(0),x_band.size(1),x_band.size(2)*x_band.size(3)]) 118 | out = getattr(self, 'norm{}'.format(i + 1))(x_band.transpose(1,2)) 119 | out = getattr(self, 'fc{}'.format(i + 1))(out.transpose(1,2)) 120 | 121 | if i == 0: 122 | z = out.unsqueeze(3) 123 | else: 124 | z = torch.cat((z,out.unsqueeze(3)),dim=3) 125 | hz_band = hz_band+int(self.band[i]) 126 | return z 127 | 128 | class MaskDecoder(nn.Module): 129 | def __init__(self, channels=128): 130 | super(MaskDecoder, self).__init__() 131 | self.band = torch.Tensor([ 132 | 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 133 | 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 134 | 16, 16, 16, 16, 16, 16, 16, 17]) 135 | for i in range(len(self.band)): 136 | setattr(self, 'norm{}'.format(i + 1), nn.GroupNorm(1,channels)) 137 | setattr(self, 'fc1{}'.format(i + 1), nn.Linear(channels,4*channels)) 138 | setattr(self, 'tanh{}'.format(i + 1), nn.Tanh()) 139 | setattr(self, 'fc2{}'.format(i + 1), nn.Linear(4*channels,int(self.band[i]*12))) 140 | setattr(self, 'glu{}'.format(i + 1), nn.GLU()) 141 | 142 | def forward(self, x): 143 | for i in range(len(self.band)): 144 | x_band = x[:,:,:,i] 145 | out = getattr(self, 'norm{}'.format(i + 1))(x_band) 146 | out = getattr(self, 'fc1{}'.format(i + 1))(out.transpose(1,2)) 147 | out = getattr(self, 'tanh{}'.format(i + 1))(out) 148 | out = getattr(self, 'fc2{}'.format(i + 1))(out) 149 | out = getattr(self, 'glu{}'.format(i + 1))(out) 150 | out = torch.reshape(out,[out.size(0),out.size(1),int(out.size(2)/6), 3, 2]) 151 | if i == 0: 152 | m = out 153 | else: 154 | m = torch.cat((m,out),dim=2) 155 | return m.transpose(1,2) 156 | 157 | class Discriminator(nn.Module): 158 | def __init__(self, ndf, in_channel=2): 159 | super().__init__() 160 | self.layers = nn.Sequential( 161 | nn.utils.spectral_norm(nn.Conv2d(in_channel, ndf, (4,4), (2,2), (1,1), bias=False)), 162 | nn.InstanceNorm2d(ndf, affine=True), 163 | nn.PReLU(ndf), 164 | nn.utils.spectral_norm(nn.Conv2d(ndf, ndf*2, (4,4), (2,2), (1,1), bias=False)), 165 | nn.InstanceNorm2d(ndf*2, affine=True), 166 | nn.PReLU(2*ndf), 167 | nn.utils.spectral_norm(nn.Conv2d(ndf*2, ndf*4, (4,4), (2,2), (1,1), bias=False)), 168 | nn.InstanceNorm2d(ndf*4, affine=True), 169 | nn.PReLU(4*ndf), 170 | nn.utils.spectral_norm(nn.Conv2d(ndf*4, ndf*8, (4,4), (2,2), (1,1), bias=False)), 171 | nn.InstanceNorm2d(ndf*8, affine=True), 172 | nn.PReLU(8*ndf), 173 | nn.AdaptiveMaxPool2d(1), 174 | nn.Flatten(), 175 | nn.utils.spectral_norm(nn.Linear(ndf*8, ndf*4)), 176 | nn.Dropout(0.3), 177 | nn.PReLU(4*ndf), 178 | nn.utils.spectral_norm(nn.Linear(ndf*4, 1)), 179 | LearnableSigmoid(1) 180 | ) 181 | 182 | def forward(self, x, y): 183 | xy = torch.cat([x, y], dim=1) 184 | return self.layers(xy) 185 | 186 | 187 | if __name__ == "__main__": 188 | x =torch.randn(2, 2048) 189 | n_fft = 512 190 | hop = 256 191 | y = torch.stft(x, n_fft, hop, window=torch.hann_window(n_fft).to(x.device), 192 | onesided=True,return_complex=True) 193 | net = BSRNN() 194 | out = net(y) 195 | print('sc') -------------------------------------------------------------------------------- /test/testRandSecFilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | # import os 4 | import numpy as np 5 | from scipy.stat import loguniform 6 | from typing import List 7 | from torch import Tensor 8 | 9 | 10 | def high_shelf(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 11 | sr: float): 12 | w0 = Tensor([2. * np.pi * center_freq / sr]) 13 | amp = torch.pow(10, gain_db / 40.) 14 | alpha = torch.sin(w0) / 2. / q_factor 15 | b0 = amp * ((amp + 1) + 16 | (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha) 17 | b1 = -2 * amp * ((amp - 1) + (amp + 1) * torch.cos(w0)) 18 | b2 = amp * ((amp + 1) + 19 | (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha) 20 | a0 = (amp + 1) - (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha 21 | a1 = 2 * ((amp - 1) - (amp + 1) * torch.cos(w0)) 22 | a2 = (amp + 1) - (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha 23 | 24 | b = torch.cat((b0, b1, b2), dim=-1) 25 | a = torch.cat((a0, a1, a2), -1) 26 | coef = torch.cat((b, a), 0) 27 | return coef 28 | 29 | 30 | def high_pass(center_freq: Tensor, q_factor: Tensor, sr: Tensor): 31 | 32 | w0 = Tensor([2. * np.pi * center_freq / sr]) 33 | alpha = torch.sin(w0) / 2. / q_factor 34 | 35 | b0 = (1 + torch.cos(w0)) / 2. 36 | b1 = -(1 + torch.cos(w0)) 37 | b2 = b0 38 | 39 | a0 = 1 + alpha 40 | a1 = -2 * torch.cos(w0) 41 | a2 = 1 - alpha 42 | 43 | b = torch.cat((b0, b1, b2)) 44 | a = torch.cat((a0, a1, a2)) 45 | coef = torch.stack((b, a), 0) 46 | return coef 47 | 48 | 49 | def low_shelf(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 50 | sr: float): 51 | w0 = Tensor([2. * np.pi * center_freq / sr]) 52 | amp = torch.pow(10, gain_db / 40.) 53 | alpha = torch.sin(w0) / 2. / q_factor 54 | 55 | b0 = amp * ((amp + 1) - 56 | (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha) 57 | b1 = 2 * amp * ((amp - 1) - (amp + 1) * torch.cos(w0)) 58 | b2 = amp * ((amp + 1) - 59 | (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha) 60 | a0 = (amp + 1) + (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha 61 | a1 = -2 * ((amp - 1) + (amp + 1) * torch.cos(w0)) 62 | a2 = (amp + 1) + (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha 63 | 64 | b = torch.cat((b0, b1, b2), -1) 65 | a = torch.cat((a0, a1, a2), -1) 66 | coef = torch.cat((b, a), 0) 67 | return coef 68 | 69 | 70 | def low_pass(center_freq: Tensor, q_factor: Tensor, sr: float): 71 | w0 = Tensor([2. * np.pi * center_freq / sr]) 72 | alpha = torch.sin(w0) / 2. / q_factor 73 | 74 | b0 = (1 - torch.cos(w0)) / 2 75 | b1 = 1 - torch.cos(w0) 76 | b2 = b0 77 | 78 | a0 = 1 + alpha 79 | a1 = -2 * torch.cos(w0) 80 | a2 = 1 - alpha 81 | 82 | b = torch.cat((b0, b1, b2)) 83 | a = torch.cat((a0, a1, a2)) 84 | 85 | coef = torch.stack((b, a), 0) 86 | return coef 87 | 88 | 89 | def peaking_eq(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 90 | sr: float): 91 | 92 | w0 = Tensor([2. * np.pi * center_freq / sr]) 93 | amp = torch.pow(10, gain_db / 40.) 94 | alpha = torch.sin(w0) / 2. / q_factor 95 | 96 | b0 = 1 + alpha * amp 97 | b1 = -2 * torch.cos(w0) 98 | b2 = 1 - alpha * amp 99 | 100 | a0 = 1 + alpha / amp 101 | a1 = -2 * torch.cos(w0) 102 | a2 = 1 - alpha / amp 103 | 104 | b = torch.cat((b0, b1, b2)) 105 | a = torch.cat((a0, a1, a2)) 106 | coef = torch.stack((b, a), 0) 107 | return coef 108 | 109 | 110 | def notch(center_freq: Tensor, q_factor: Tensor, sr: float): 111 | w0 = Tensor([2. * np.pi * center_freq / sr]) 112 | alpha = torch.sin(w0) / 2. / q_factor 113 | 114 | b0 = Tensor(1.) 115 | b1 = -2. * torch.cos(w0) 116 | b2 = b0 117 | 118 | a0 = 1. + alpha 119 | a1 = -2 * torch.cos(w0) 120 | a2 = 1. - alpha 121 | 122 | b = torch.cat((b0, b1, b2)) 123 | a = torch.cat((a0, a1, a2)) 124 | 125 | coef = torch.stack((b, a), 0) 126 | return coef 127 | 128 | 129 | def randFilt(thr: float = 0.375): 130 | r = torch.FloatTensor(4).uniform_(-thr, thr) 131 | 132 | b = torch.ones([1, 3]) 133 | a = torch.ones_like(b) 134 | b[0, 1:] = r[:2] 135 | a[0, 1:] = r[2:] 136 | coef = torch.cat((b, a), 0) 137 | return coef 138 | 139 | 140 | def randClipping(db_range=None, c_range=(0.01, 0.25), eps=1e-10, eps_c=0.001): 141 | pass 142 | 143 | 144 | def suppress_late(rir: np.ndarray, sr: float, rt60: float, offset: int): 145 | len = rir.shape[-1] 146 | decay = torch.ones(1, len) 147 | dt = 1. / sr 148 | rt60_level = np.power(10., -60 / 20) 149 | tau = -rt60 / np.log10(rt60_level) 150 | if offset >= len: 151 | return rir 152 | for v in range(0, len - offset): 153 | decay[v] = np.exp(-v * dt / tau) 154 | 155 | rir = rir * decay 156 | return rir 157 | 158 | 159 | def trim(rir: np.ndarray, ref_idx: int): 160 | min_db = -80 161 | len = rir.shape[-1] 162 | rir_mono = rir 163 | ref_level = rir_mono[ref_idx] 164 | min_level = np.power(10, (min_db + np.log10(ref_level) * 20.) / 20.) 165 | idx = len 166 | pass 167 | 168 | 169 | def as_windowed(x: torch.Tensor, win_len, hop_len=1, dim=1): 170 | """ 171 | input: B, T 172 | output: B, T//win_len, win_len 173 | """ 174 | shape: List[int] = list(x.shape) 175 | stride: List[int] = list(x.stride()) 176 | shape[dim] = int((shape[dim] - win_len + hop_len) // hop_len) 177 | shape.insert(dim + 1, win_len) 178 | stride.insert(dim + 1, stride[dim]) 179 | stride[dim] = stride[dim] * hop_len 180 | y = x.as_strided(shape, stride) 181 | return y 182 | 183 | 184 | def airAbsorption(sig, sr=16000): 185 | center_freqs = [125, 250, 500, 1000, 2000, 4000, 8000, 16000, 24000] 186 | air_absorption = [0.1, 0.2, 0.5, 1.1, 2.7, 9.4, 29.0, 91.5, 289.0] 187 | air_absorption_table = [x * 1e-3 for x in air_absorption] 188 | distance_low = 1.0 189 | distance_high = 20.0 190 | d = np.random.uniform(distance_low, distance_high, 1) 191 | 192 | atten_vals = np.exp(-d * air_absorption_table) 193 | atten_vals_db = 20 * np.log10(atten_vals) 194 | atten_interp_db = interp_atten(atten_interp_db, 161) 195 | atten_interp = 10**(atten_interp_db / 20.) 196 | sig_stft = torch.stft(sig, 197 | window=torch.hann_window(320), 198 | n_fft=320, 199 | win_length=320, 200 | hop_length=160, 201 | return_complex=True).squeeze() 202 | att_interp_tile = torch.tiel(atten_interp, 203 | (sig_stft.shape[-1], 1)).tranpose(1, 0) 204 | masked = sig_stft * att_interp_tile 205 | masked = masked.unsqeeze(0) 206 | rc = torch.istft(masked, 207 | window=torch.hann_window(512), 208 | n_fft=320, 209 | win_length=320, 210 | hop_length=160, 211 | length=sig.shape[-1]) 212 | torchaudio.save('ost_air.wav', rc, sr) 213 | 214 | 215 | def interp_atten(atten_vals, n_freqs, center_freqs, sr=16000): 216 | atten_vals1 = atten_vals[0] + atten_vals + atten_vals[-1] 217 | freqs = np.linspace(0., sr / 2., n_freqs) 218 | atten_vals_interp = np.zeros(n_freqs) 219 | 220 | center_freqs = [0] + center_freqs + [sr / 2] 221 | i = 0 222 | center_freqs_win = as_windowed(Tensor([center_freqs]), 2, 1).squeeze() 223 | atten_vals_win = as_windowed(Tensor([atten_vals1]), 2, 1).squeeze() 224 | 225 | for k, (c, a) in enumerate( 226 | zip(center_freqs_win.tolist(), atten_vals_win.tolist())): 227 | c0, c1 = c[0], c[1] 228 | a0, a1 = a[0], a[1] 229 | while i < n_freqs and freqs[i] <= c1: 230 | x = (freqs[i] - c1) / (c0 - c1) 231 | atten_vals_interp[i] = a0 * x + a1 * (1. - x) 232 | i += 1 233 | return atten_vals_interp 234 | 235 | 236 | if __name__ == "__main__": 237 | psthq = "ssf.wav" 238 | sig, sr = torchaudio.load(psthq) 239 | 240 | sr = 16000 241 | gain_db = torch.FloatTensor(1).uniform_(-15, 15) 242 | q_factor = torch.FloatTensor(1).uniform_(0.5, 1.5) 243 | # high_shelf 244 | center_freq1 = loguniform.rvs(1000, 6000, size=1) 245 | 246 | hs_coef = high_shelf(center_freq1, gain_db, q_factor, sr) 247 | out_hs = torchaudio.functional.lfilter(sig, hs_coef[1, :], hs_coef[0, :]) 248 | torchaudio.save('ost.wav', out_hs, sr) 249 | # hig_pass 250 | center_freq1 = loguniform.rvs(40, 400, size=1) 251 | # low_shelf 252 | center_freq1 = loguniform.rvs(40, 1000, size=1) 253 | # low_pass 254 | center_freq1 = loguniform.rvs(3000, 8000, size=1) 255 | # peaking_eq 256 | center_freq1 = loguniform.rvs(40, 4000, size=1) 257 | # notch 258 | center_freq1 = loguniform.rvs(40, 4000, size=1) -------------------------------------------------------------------------------- /test/test_erb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | import colorful 5 | 6 | colortool = colorful 7 | colortool.use_style("solarized") 8 | 9 | 10 | def freq2erb(freq_hz): 11 | return 9.265 * torch.log1p(freq_hz / (24.7 * 9.265)) 12 | 13 | 14 | def erb2freq(n_erb): 15 | return 24.7 * 9.265 * (torch.exp(n_erb / 9.265) - 1.) 16 | 17 | 18 | def erb_fb(sr, fft_size, nb_bands, min_nb_freqs): 19 | """ 20 | slice frequency band to erb, which is not overlap 21 | """ 22 | nyq_freq = sr / 2 23 | freq_width = sr / fft_size 24 | erb_low = freq2erb(torch.Tensor([0.])) 25 | erb_high = freq2erb(torch.Tensor([nyq_freq])) 26 | erb = torch.zeros([nb_bands], dtype=torch.int16) 27 | step = (erb_high - erb_low) / nb_bands 28 | prev_freq = 0 29 | freq_over = 0 30 | for i in range(nb_bands): 31 | f = erb2freq(erb_low + (i + 1) * step) 32 | fb = int(torch.round(f / freq_width)) 33 | nb_freqs = fb - prev_freq - freq_over 34 | if nb_freqs < min_nb_freqs: 35 | freq_over = min_nb_freqs - nb_freqs 36 | nb_freqs = min_nb_freqs 37 | else: 38 | freq_over = 0 39 | erb[i] = nb_freqs 40 | prev_freq = fb 41 | erb[nb_bands - 1] += 1 42 | too_large = torch.sum(erb) - (fft_size / 2 + 1) 43 | if too_large > 0: 44 | erb[nb_bands - 1] -= too_large 45 | assert torch.sum(erb) == (fft_size / 2 + 1) 46 | 47 | return erb 48 | 49 | 50 | def erb_fb_use(width: np.ndarray, 51 | sr: int, 52 | normalized: bool = True, 53 | inverse: bool = False) -> Tensor: 54 | """ 55 | construct freq2erb transform matrix 56 | """ 57 | n_freqs = int(np.sum(width)) 58 | all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1] 59 | b_pts = np.cumsum([0] + width.tolist()).astype(int)[:-1] 60 | fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) 61 | for i, (b, w) in enumerate(zip(b_pts.tolist(), width.tolist())): 62 | fb[b:b + w, i] = 1 63 | if inverse: 64 | fb = fb.t() 65 | if not normalized: 66 | fb /= fb.sum(dim=1, keepdim=True) 67 | else: 68 | if normalized: 69 | fb /= fb.sum(dim=0) 70 | return fb 71 | 72 | 73 | def compute_band_corr(out, x, p, erb_fb): 74 | bcsum = 0 75 | for i, (band_size, out_b) in enumerate(zip(erb_fb, out)): 76 | k = 1. / band_size 77 | for j in range(band_size): 78 | idx = bcsum + j 79 | out_b += (x[idx].real * p[idx].real + 80 | x[idx].imag * p[idx].imag) * k 81 | bcsum += band_size 82 | 83 | 84 | def band_mean_norm_freq(xs, xout, state, alpha): 85 | """ 86 | xs: complex32 87 | others: f32 88 | """ 89 | out_state = torch.zeros_like(state) 90 | out_xout = torch.zeros_like(xout) 91 | for i, (x, s, xo) in enumerate(zip(xs, state, xout)): 92 | xabs = torch.norm(x) 93 | out_state[i] = xabs * (1. - alpha) + s * alpha 94 | out_xout[i] = xabs - out_state[i] 95 | 96 | 97 | def band_mean_norm_erb(xs, state, alpha): 98 | """ 99 | all: f32 100 | """ 101 | out_xs = torch.zeros_like(xs) 102 | out_state = torch.zeros_like(state) 103 | for i, (x, s) in enumerate(zip(xs, state)): 104 | out_state[i] = x * (1. - alpha) + alpha * s 105 | out_xs[i] -= out_state[i] 106 | out_xs[i] /= 40. 107 | 108 | 109 | def band_unit_norm(xs, state, alpha): 110 | """ 111 | xs: complex32 112 | """ 113 | out_xs = torch.zeros_like(xs) 114 | out_state = torch.zeros_like(state) 115 | for i, (x, s) in enumerate(zip(xs, state)): 116 | out_state[i] = torch.norm(x) * (1. - alpha + s * alpha) 117 | out_xs[i] /= torch.sqrt(out_state[i]) 118 | 119 | 120 | def apply_interp_band_gain(out, band_e, erb_fb): 121 | bcsum = 0 122 | out_out = torch.zeros_like(out) 123 | for i, (band_size, b) in enumerate(zip(erb_fb, band_e)): 124 | for j in range(band_size): 125 | idx = bcsum + j 126 | out_out[idx] = out[idx] * b 127 | bcsum += band_size 128 | return out_out 129 | 130 | 131 | def interp_band_gain(out, band_e, erb_fb): 132 | bcsum = 0 133 | for i, (band_size, b) in enumerate(zip(erb_fb, band_e)): 134 | for j in range(band_size): 135 | idx = bcsum + j 136 | out[idx] = b 137 | bcsum += band_size 138 | 139 | 140 | def apply_band_gain(out, band_e, erb_fb): 141 | bcsum = 0 142 | out_out = torch.zeros_like(out) 143 | for i, (band_size, b) in enumerate(zip(erb_fb, band_e)): 144 | for j in range(0, band_size): 145 | idx = bcsum + j 146 | out_out[idx] = out[idx] * b # NOTE 147 | bcsum += band_size 148 | return out_out 149 | 150 | 151 | def post_filter(gain): 152 | beta = 0.02 153 | eps = 1e-12 * torch.ones_like(gain) 154 | pi = torch.pi 155 | g_sin = torch.zeros_like(gain) 156 | out_gain = torch.zeros_like(gain) 157 | g_sin = torch.maximum(gain * torch.sin(pi / 2. * gain), eps) 158 | out_gain = (1.0 + beta) * gain / (1.0 + beta * torch.pow(gain / g_sin, 2)) 159 | return out_gain 160 | 161 | 162 | def test_erb_fb(): 163 | 164 | sr = 48000 165 | fft_size = 960 166 | nb_bands = 34 167 | min_nb_freqs = 3 # 2 168 | 169 | erb = erb_fb(sr=sr, 170 | fft_size=fft_size, 171 | nb_bands=nb_bands, 172 | min_nb_freqs=min_nb_freqs) 173 | print(colortool.red(f"erb.shape:{erb.shape}")) 174 | print(colortool.yellow("sc")) 175 | 176 | 177 | def test_erb_fb_use(): 178 | sr = 48000 179 | fft_size = 960 180 | nb_bands = 32 181 | min_nb_freqs = 2 182 | erb = erb_fb(sr, fft_size, nb_bands, min_nb_freqs) 183 | erb = erb.numpy().astype(int) 184 | fb = erb_fb_use(erb, sr) 185 | fb_inverse = erb_fb_use(erb, sr, inverse=True) 186 | print(colortool.red(f"fb:{fb.shape} {fb}")) 187 | print(colortool.yellow(f"fb_inverse:{fb_inverse.shape}, {fb_inverse}")) 188 | print('sc') 189 | 190 | 191 | def test_apply_band_gain(): 192 | """ 193 | test erb band gain for data 194 | """ 195 | # import random 196 | sr = 24000 197 | fft_size = 192 198 | n_freqs = fft_size // 2 + 1 199 | nb_bands = 24 200 | min_nb_freqs = 1 201 | erb = erb_fb(sr, fft_size, nb_bands, min_nb_freqs) 202 | # band_e = torch.randint(0, 10, erb.shape) 203 | mask = torch.ones(nb_bands) 204 | mask[3] = 0.3 205 | mask[nb_bands - 1] = 0.5 206 | input_real = torch.rand(n_freqs) 207 | input_imag = torch.rand(n_freqs) 208 | input = torch.complex(input_real, input_imag) 209 | out_out = apply_band_gain(input, mask, erb) 210 | 211 | out_out1 = torch.zeros_like(out_out) 212 | cumsum = 0 213 | for erb_idx, erb_w in enumerate(erb): 214 | for i in range(cumsum, cumsum + erb_w): 215 | out_out1[i] = input[i] * mask[erb_idx] 216 | cumsum += erb_w 217 | 218 | print('sc') 219 | 220 | 221 | if __name__ == "__main__": 222 | test_apply_band_gain() 223 | # test_erb_fb() 224 | # test_erb_fb_use() 225 | a = [1, 4, 5] 226 | b = ['sf', 'fs', 'e'] 227 | for i, (b, w) in enumerate(zip(a, b)): 228 | print(b, w) 229 | print('sc') 230 | -------------------------------------------------------------------------------- /test/test_loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: your name 3 | Date: 2022-03-02 22:19:22 4 | LastEditTime: 2022-03-03 22:58:14 5 | LastEditors: Please set LastEditors 6 | Description: refer to "weighted speech distortion lossed for neural network based real-time speech enhancement" 7 | FilePath: /CRUSE/test/test_loss.py 8 | ''' 9 | import os 10 | 11 | from matplotlib import pyplot as plt 12 | import numpy as np 13 | import librosa as lib 14 | import soundfile as sf 15 | import torch 16 | import torch.nn.functional as F 17 | from typing import Dict, Final, Iterable, List 18 | from torch import Tensor, nn 19 | from torch.autograd import Function 20 | 21 | 22 | def plot_mesh(img, title="", save_home=""): 23 | img = img 24 | fig, ax = plt.subplots() 25 | plt.title(title) 26 | fig.colorbar(plt.pcolormesh(range(img.shape[1]), range(img.shape[0]), img)) 27 | if save_home != "": 28 | print(os.path.join(save_home, "%s.jpg" % title)) 29 | plt.savefig(os.path.join(save_home, "%s.jpg" % title)) 30 | return 31 | 32 | 33 | def snr_weight_loss_adjust(): 34 | global_snr = np.linspace(-10, 30, 410) 35 | beta = np.array([1, 5, 10, 20]) 36 | # global_snr = np.expand_dims(global_snr, axis=0) 37 | # beta = np.expand_dims(beta, axis=-1) 38 | alpha = np.zeros((len(beta), len(global_snr))) 39 | for i in range(len(beta)): 40 | tmp = 10**(global_snr / 10) 41 | beta_tmp = 10**(beta[i] / 10) 42 | alpha[i] = tmp / (tmp + beta_tmp) 43 | plt.xlim(xmin=-10, xmax=30) 44 | plt.ylim(ymin=0., ymax=1.0) 45 | plt.plot(global_snr, alpha.T) 46 | plt.plot(global_snr, np.ones_like(global_snr) * 0.5, 'k:') 47 | plt.xlabel('global snr of the noisy utterence in training(dB)') 48 | plt.ylabel('speech distortion weight(alpha)') 49 | plt.legend(['b=1dB', 'b=5dB', 'b=10dB', 'b=20dB']) 50 | plt.show() 51 | print('sc') 52 | 53 | 54 | def speech_distortion_test(): 55 | main_path = '/Users/audio_source/GaGNet/First_DNS_no_reverb' 56 | file_name = 'clnsp1_train_69005_1_snr15_tl-21_fileid_158.wav' 57 | clean_name = os.path.join(main_path, 'no_reverb_clean', 58 | 'clean_fileid_158.wav') 59 | noise_name = os.path.join(main_path, 'no_reverb_mix', file_name) 60 | clean, _ = lib.load(clean_name, sr=16000) 61 | noisy, _ = lib.load(noise_name, sr=16000) 62 | noise = noisy - clean 63 | 64 | win_len = 512 65 | hop_len = 128 66 | n_fft = 512 67 | stft_clean = lib.stft(clean, 68 | win_length=win_len, 69 | hop_length=hop_len, 70 | n_fft=n_fft) 71 | stft_noise = lib.stft(noise, 72 | win_length=win_len, 73 | hop_length=hop_len, 74 | n_fft=n_fft) 75 | stft_noisy = lib.stft(noisy, 76 | win_length=win_len, 77 | hop_length=hop_len, 78 | n_fft=n_fft) 79 | clean_mag, _ = lib.magphase(stft_clean) 80 | noise_mag, _ = lib.magphase(stft_noise) 81 | noisy_mag, noisy_phase = lib.magphase(stft_noisy) 82 | # plot_mesh(np.log(clean_mag), 'clean_mag') 83 | # plot_mesh(np.log(noise_mag), 'noise_mag') 84 | # plt.show() 85 | 86 | snr = clean_mag / (noise_mag + clean_mag) 87 | # snr = clean_mag / noise_mag 88 | # snr = clean_mag / noisy_mag 89 | enhance_noisy = noisy_mag * snr 90 | enhance_noise = noise_mag * snr 91 | enhance_clean = clean_mag * snr 92 | # plot_mesh(np.log(enhance_noisy), 'enhance_noisy') 93 | # plot_mesh(np.log(enhance_noise), 'enhance_noise') 94 | # plot_mesh(np.log(enhance_clean), 'enhance_clean') 95 | # plt.show() 96 | 97 | enhance_noisy_tmp = enhance_noisy * noisy_phase 98 | enhance_noise_tmp = enhance_noise * noisy_phase 99 | enhance_clean_tmp = enhance_clean * noisy_phase 100 | len_noisy = len(noisy) 101 | enhance_noisy_t = lib.istft(enhance_noisy_tmp, 102 | win_length=win_len, 103 | hop_length=hop_len, 104 | length=len_noisy) 105 | enhance_clean_t = lib.istft(enhance_clean_tmp, 106 | win_length=win_len, 107 | hop_length=hop_len, 108 | length=len_noisy) 109 | enhance_noise_t = lib.istft(enhance_noise_tmp, 110 | win_length=win_len, 111 | hop_length=hop_len, 112 | length=len_noisy) 113 | sf.write("./enhance_noisy_t.wav", enhance_noisy_t, samplerate=16000) 114 | sf.write("./enhance_clean_t.wav", enhance_clean_t, samplerate=16000) 115 | sf.write('./enhance_noise_t.wav', enhance_noise_t, samplerate=16000) 116 | 117 | print('sc') 118 | 119 | 120 | def wg(S: Tensor, X: Tensor, eps: float = 1e-10): 121 | N = X - S 122 | SS = S.abs().square() 123 | NN = N.abs().square() 124 | return (SS / (SS + NN + eps)).clamp(0, 1) 125 | 126 | 127 | def irm(S: Tensor, X: Tensor, eps: float = 1e-10): 128 | N = X - S 129 | SS_mag = S.abs() 130 | NN_mag = N.abs() 131 | return (SS_mag / (SS_mag + NN_mag + eps)).clamp(0, 1) 132 | 133 | 134 | def iam(S: Tensor, X: Tensor, eps: float = 1e-10): 135 | SS_mag = S.abs() 136 | XX_mag = X.abs() 137 | return (SS_mag / (XX_mag + eps)).clamp(0, 1) 138 | 139 | 140 | class Stft(nn.Module): 141 | def __init__(self, 142 | n_fft: int, 143 | hop: int = None, 144 | window: Tensor = None) -> None: 145 | super().__init__() 146 | self.n_fft = n_fft 147 | self.hop = hop or n_fft // 4 148 | if window is not None: 149 | assert window.shape[0] == n_fft 150 | else: 151 | window = torch.hann_window(self.n_fft) 152 | self.w: torch.Tensor 153 | self.register_buffer("w", window) 154 | 155 | def forward(self, input: Tensor): 156 | t = input.shape[-1] 157 | sh = input.shape[:-1] 158 | out = torch.stft(input.reshape(-1, t), 159 | n_fft=self.n_fft, 160 | hop_length=self.hop, 161 | window=self.w, 162 | normalized=True, 163 | return_complex=True) 164 | out = out.view(*sh, *out.shape[-2:]) 165 | return out 166 | 167 | 168 | class Istft(nn.Module): 169 | def __init__(self, n_fft_inv: int, hop_inv: int, window_inv: Tensor): 170 | super().__init__() 171 | self.n_fft_inv = n_fft_inv 172 | self.hop_inv = hop_inv 173 | # self.window_inv = window_inv 174 | self.w_inv: torch.Tensor 175 | self.register_buffer("w_inv", window_inv) 176 | 177 | def forward(self, input: Tensor): 178 | t, f = input.shape[-2:] 179 | sh = input.shape[:-2] 180 | 181 | out = torch.istft(F.pad( 182 | input.reshape(-1, t, f).transpose(1, 2), (0, 1)), 183 | n_fft=self.n_fft_inv, 184 | hop_length=self.hop_inv, 185 | window=self.w_inv, 186 | normalized=True) 187 | if input.ndim > 2: 188 | out = out.view(*sh, out.shape[-1]) 189 | 190 | return out 191 | 192 | 193 | class MultResSpecLoss(nn.Module): 194 | gamma: Final[float] 195 | f: Final[float] 196 | f_complex: Final[List[float]] 197 | 198 | def __init__(self, n_ffts, gamma, factor, f_complex=None): 199 | super().__init__() 200 | self.gamma = gamma 201 | self.f = factor 202 | self.stfts = nn.ModuleDict( 203 | {str(n_fft): Stft(n_fft) 204 | for n_fft in n_ffts}) 205 | if f_complex is None or f_complex == 0: 206 | self.f_complex = None 207 | elif isinstance(f_complex, Iterable): 208 | self.f_complex = list(f_complex) 209 | else: 210 | self.f_complex = [f_complex] * len(self.stfts) 211 | 212 | def forward(self, input: Tensor, target: Tensor): 213 | loss = torch.zeros() 214 | for i, stft in enumerate(self.stfts.values()): 215 | Y = stft(input) 216 | S = stft(target) 217 | Y_abs = Y.abs() 218 | S_abs = S.abs() 219 | if self.gamma != 1: 220 | Y_abs = Y_abs.clamp_min(1e-12).pow(self.gamma) 221 | S_abs = S_abs.clamp_min(1e-12).pow(self.gamma) 222 | loss += F.mse_loss(Y_abs, S_abs) * self.f 223 | if self.f_complex is not None: 224 | if self.gamma != 1: 225 | Y = Y_abs * torch.exp(1j * angle.apply(Y)) 226 | S = S_abs * torch.exp(1j * angle.apply(S)) 227 | loss += F.mse_loss(torch.view_as_real(Y), 228 | torch.view_as_real(S)) * self.f_complex[i] 229 | return loss 230 | 231 | 232 | class angle(Function): 233 | @staticmethod 234 | def forward(ctx, x: Tensor): 235 | ctx.save_for_backend(x) 236 | return torch.atan2(x.imag, x.real) 237 | 238 | @staticmethod 239 | def backward(ctx, grad: Tensor): 240 | (x, ) = ctx.saved_tensors 241 | grad_inv = grad / (x.real.square() + x.imag.square()).clamp_min_(1e-10) 242 | return torch.view_as_complex( 243 | torch.stack((-x.imag * grad_inv, x.real * grad_inv), dim=-1)) 244 | 245 | 246 | if __name__ == "__main__": 247 | # snr_weight_loss_adjust() 248 | speech_distortion_test() 249 | -------------------------------------------------------------------------------- /test/test_model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torchaudio 5 | from torch import nn, Tensor 6 | from typing import Callable, Iterable, List, Optional, Tuple, Union 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | from torch.nn.parameter import Parameter 10 | from typing_extensions import Final 11 | import math 12 | class GroupedGRULayer(nn.Module): 13 | input_size: Final[int] 14 | hidden_size: Final[int] 15 | out_size: Final[int] 16 | bidirectional: Final[bool] 17 | num_directions: Final[int] 18 | groups: Final[int] 19 | batch_first: Final[bool] 20 | 21 | def __init__( 22 | self, 23 | input_size: int, 24 | hidden_size: int, 25 | groups: int, 26 | batch_first: bool = True, 27 | bias: bool = True, 28 | dropout: float = 0, 29 | bidirectional: bool = False, 30 | ): 31 | super().__init__() 32 | assert input_size % groups == 0 33 | assert hidden_size % groups == 0 34 | kwargs = { 35 | "bias": bias, 36 | "batch_first": batch_first, 37 | "dropout": dropout, 38 | "bidirectional": bidirectional, 39 | } 40 | self.input_size = input_size // groups 41 | self.hidden_size = hidden_size // groups 42 | self.out_size = hidden_size 43 | self.bidirectional = bidirectional 44 | self.num_directions = 2 if bidirectional else 1 45 | self.groups = groups 46 | self.batch_first = batch_first 47 | assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups" 48 | self.layers = nn.ModuleList( 49 | (nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups)) 50 | ) 51 | 52 | def flatten_parameters(self): 53 | for layer in self.layers: 54 | layer.flatten_parameters() 55 | 56 | def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")): 57 | return torch.zeros( 58 | self.groups * self.num_directions, 59 | batch_size, 60 | self.hidden_size, 61 | device=device, 62 | ) 63 | 64 | def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 65 | # input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size 66 | # state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size 67 | 68 | if h0 is None: 69 | dim0, dim1 = input.shape[:2] 70 | bs = dim0 if self.batch_first else dim1 71 | h0 = self.get_h0(bs, device=input.device) 72 | outputs: List[Tensor] = [] 73 | outstates: List[Tensor] = [] 74 | for i, layer in enumerate(self.layers): 75 | o, s = layer( 76 | input[..., i * self.input_size : (i + 1) * self.input_size], 77 | h0[i * self.num_directions : (i + 1) * self.num_directions].detach(), 78 | ) 79 | outputs.append(o) 80 | outstates.append(s) 81 | output = torch.cat(outputs, dim=-1) 82 | h = torch.cat(outstates, dim=0) 83 | return output, h 84 | 85 | 86 | class GroupedGRU(nn.Module): 87 | groups: Final[int] 88 | num_layers: Final[int] 89 | batch_first: Final[bool] 90 | hidden_size: Final[int] 91 | bidirectional: Final[bool] 92 | num_directions: Final[int] 93 | shuffle: Final[bool] 94 | add_outputs: Final[bool] 95 | 96 | def __init__( 97 | self, 98 | input_size: int, 99 | hidden_size: int, 100 | num_layers: int = 1, 101 | groups: int = 4, 102 | bias: bool = True, 103 | batch_first: bool = True, 104 | dropout: float = 0, 105 | bidirectional: bool = False, 106 | shuffle: bool = True, 107 | add_outputs: bool = False, 108 | ): 109 | super().__init__() 110 | kwargs = { 111 | "groups": groups, 112 | "bias": bias, 113 | "batch_first": batch_first, 114 | "dropout": dropout, 115 | "bidirectional": bidirectional, 116 | } 117 | assert input_size % groups == 0 118 | assert hidden_size % groups == 0 119 | assert num_layers > 0 120 | self.input_size = input_size 121 | self.groups = groups 122 | self.num_layers = num_layers 123 | self.batch_first = batch_first 124 | self.hidden_size = hidden_size // groups 125 | self.bidirectional = bidirectional 126 | self.num_directions = 2 if bidirectional else 1 127 | if groups == 1: 128 | shuffle = False # Fully connected, no need to shuffle 129 | self.shuffle = shuffle 130 | self.add_outputs = add_outputs 131 | self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore 132 | self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs)) 133 | for _ in range(1, num_layers): 134 | self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs)) 135 | self.flatten_parameters() 136 | 137 | def flatten_parameters(self): 138 | for gru in self.grus: 139 | gru.flatten_parameters() 140 | 141 | def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor: 142 | return torch.zeros( 143 | (self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size), 144 | device=device, 145 | ) 146 | 147 | def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: 148 | dim0, dim1, _ = input.shape 149 | b = dim0 if self.batch_first else dim1 150 | if state is None: 151 | state = self.get_h0(b, input.device) 152 | output = torch.zeros( 153 | dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device 154 | ) 155 | outstates = [] 156 | h = self.groups * self.num_directions 157 | for i, gru in enumerate(self.grus): 158 | input, s = gru(input, state[i * h : (i + 1) * h]) 159 | outstates.append(s) 160 | if self.shuffle and i < self.num_layers - 1: 161 | input = ( 162 | input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1) 163 | ) 164 | if self.add_outputs: 165 | output += input 166 | else: 167 | output = input 168 | outstate = torch.cat(outstates, dim=0) 169 | return output, outstate 170 | 171 | 172 | class SqueezedGRU(nn.Module): 173 | input_size: Final[int] 174 | hidden_size: Final[int] 175 | 176 | def __init__( 177 | self, 178 | input_size: int, 179 | hidden_size: int, 180 | output_size: Optional[int] = None, 181 | num_layers: int = 1, 182 | linear_groups: int = 8, 183 | batch_first: bool = True, 184 | gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, 185 | linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, 186 | ): 187 | super().__init__() 188 | self.input_size = input_size 189 | self.hidden_size = hidden_size 190 | self.linear_in = nn.Sequential( 191 | GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer() 192 | ) 193 | self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first) 194 | self.gru_skip = gru_skip_op() if gru_skip_op is not None else None 195 | if output_size is not None: 196 | self.linear_out = nn.Sequential( 197 | GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer() 198 | ) 199 | else: 200 | self.linear_out = nn.Identity() 201 | 202 | def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: 203 | input = self.linear_in(input) 204 | x, h = self.gru(input, h) 205 | if self.gru_skip is not None: 206 | x = x + self.gru_skip(input) 207 | x = self.linear_out(x) 208 | return x, h 209 | 210 | 211 | class SqueezedGRU_S(nn.Module): 212 | input_size: Final[int] 213 | hidden_size: Final[int] 214 | 215 | def __init__( 216 | self, 217 | input_size: int, 218 | hidden_size: int, 219 | output_size: Optional[int] = None, 220 | num_layers: int = 1, 221 | linear_groups: int = 8, 222 | batch_first: bool = True, 223 | gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, 224 | linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, 225 | ): 226 | super().__init__() 227 | self.input_size = input_size 228 | self.hidden_size = hidden_size 229 | self.linear_in = nn.Sequential( 230 | GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer() 231 | ) 232 | self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first) 233 | self.gru_skip = gru_skip_op() if gru_skip_op is not None else None 234 | if output_size is not None: 235 | self.linear_out = nn.Sequential( 236 | GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer() 237 | ) 238 | else: 239 | self.linear_out = nn.Identity() 240 | 241 | def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: 242 | x = self.linear_in(input) 243 | x, h = self.gru(x, h) 244 | x = self.linear_out(x) 245 | if self.gru_skip is not None: 246 | x = x + self.gru_skip(input) 247 | return x, h 248 | 249 | 250 | class GroupedLinearEinsum(nn.Module): 251 | input_size: Final[int] 252 | hidden_size: Final[int] 253 | groups: Final[int] 254 | 255 | def __init__(self, input_size: int, hidden_size: int, groups: int = 1): 256 | super().__init__() 257 | # self.weight: Tensor 258 | self.input_size = input_size 259 | self.hidden_size = hidden_size 260 | self.groups = groups 261 | assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" 262 | assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" 263 | self.ws = input_size // groups 264 | self.register_parameter( 265 | "weight", 266 | Parameter( 267 | torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True 268 | ), 269 | ) 270 | self.reset_parameters() 271 | 272 | def reset_parameters(self): 273 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore 274 | 275 | def forward(self, x: Tensor) -> Tensor: 276 | # x: [..., I] 277 | b, t, _ = x.shape 278 | # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] 279 | new_shape = (b, t, self.groups, self.ws) 280 | x = x.view(new_shape) 281 | # The better way, but not supported by torchscript 282 | # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] 283 | x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] 284 | x = x.flatten(2, 3) # [B, T, H] 285 | return x 286 | 287 | def __repr__(self): 288 | cls = self.__class__.__name__ 289 | return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" 290 | 291 | 292 | class GroupedLinear(nn.Module): 293 | input_size: Final[int] 294 | hidden_size: Final[int] 295 | groups: Final[int] 296 | shuffle: Final[bool] 297 | 298 | def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True): 299 | super().__init__() 300 | assert input_size % groups == 0 301 | assert hidden_size % groups == 0 302 | self.groups = groups 303 | self.input_size = input_size // groups 304 | self.hidden_size = hidden_size // groups 305 | if groups == 1: 306 | shuffle = False 307 | self.shuffle = shuffle 308 | self.layers = nn.ModuleList( 309 | nn.Linear(self.input_size, self.hidden_size) for _ in range(groups) 310 | ) 311 | 312 | def forward(self, x: Tensor) -> Tensor: 313 | outputs: List[Tensor] = [] 314 | for i, layer in enumerate(self.layers): 315 | outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size])) 316 | output = torch.cat(outputs, dim=-1) 317 | if self.shuffle: 318 | orig_shape = output.shape 319 | output = ( 320 | output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape) 321 | ) 322 | return output 323 | 324 | 325 | 326 | def test_grouped_gru(): 327 | from icecream import ic 328 | 329 | g = 2 # groups 330 | h = 4 # hidden_size 331 | i = 2 # input_size 332 | b = 1 # batch_size 333 | t = 5 # time_steps 334 | m = GroupedGRULayer(i, h, g, batch_first=True) 335 | m1 = SqueezedGRU(i,h,i,1,g) 336 | m2 = GroupedGRU(i,h,2,g) 337 | ic(m) 338 | ic(m1) 339 | ic(m2) 340 | input = torch.randn((b, t, i)) 341 | h0 = m.get_h0(b) 342 | assert list(h0.shape) == [g, b, h // g] 343 | out, hout = m(input, h0) 344 | h0_1 = m2.get_h0(b) 345 | h1_1 = m 346 | 347 | # Should be exportable as raw nn.Module 348 | torch.onnx.export( 349 | m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 350 | ) 351 | # Should be exportable as traced 352 | m = torch.jit.trace(m, (input, h0)) 353 | torch.onnx.export( 354 | m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 355 | ) 356 | # and as scripted module 357 | m = torch.jit.script(m) 358 | torch.onnx.export( 359 | m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 360 | ) 361 | 362 | # now grouped gru 363 | num = 2 364 | m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True) 365 | ic(m) 366 | h0 = m.get_h0(b) 367 | assert list(h0.shape) == [num * g, b, h // g] 368 | out, hout = m(input, h0) 369 | 370 | # Should be exportable as traced 371 | m = torch.jit.trace(m, (input, h0)) 372 | torch.onnx.export( 373 | m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 374 | ) 375 | # and scripted module 376 | m = torch.jit.script(m) 377 | torch.onnx.export( 378 | m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 379 | ) 380 | 381 | 382 | if __name__ == "__main__": 383 | test_grouped_gru() 384 | -------------------------------------------------------------------------------- /test/test_norm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | # import numpy as np 4 | # import sys, os 5 | from torch import nn, Tensor 6 | # from torch.nn import functional as F 7 | from typing_extensions import Final 8 | from typing import List 9 | import math 10 | 11 | 12 | def get_norm_alpha(sr: int = 48000, 13 | hop_size: int = 480, 14 | tau: float = 1, 15 | log: bool = True): 16 | 17 | a_ = _calculate_norm_alpha(sr=sr, hop_size=hop_size, tau=tau) 18 | precision = 3 19 | a = 1.0 20 | while a >= 1.0: 21 | a = round(a_, precision) 22 | precision += 1 23 | if log: 24 | print(f"Running with normalization window alpha = '{a}'") 25 | return a 26 | 27 | 28 | def _calculate_norm_alpha(sr: int, hop_size: int, tau: float): 29 | dt = hop_size / sr 30 | return math.exp(-dt / tau) 31 | 32 | 33 | class ExponentialUnitNorm(nn.Module): 34 | 35 | alpha: Final[float] 36 | eps: Final[float] 37 | 38 | def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14): 39 | super().__init__() 40 | self.alpha = alpha 41 | self.eps = eps 42 | self.init_state: Tensor 43 | self.UNIT_NORN_INIT = [0.001, 0.0001] 44 | self.unit_norm_init = torch.linspace(self.UNIT_NORN_INIT[0], 45 | self.UNIT_NORN_INIT[1], 46 | num_freq_bins).unsqueeze(0) 47 | s = self.unit_norm_init 48 | s = s.view(1, 1, num_freq_bins, 1) 49 | self.register_buffer("init_state", s) 50 | # s = torch.from_numpy() 51 | 52 | def forward(self, x: Tensor): 53 | b, c, t, f, _ = x.shape 54 | x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min( 55 | self.eps).sqrt() # square sum 56 | state = self.init_state.clone().expand(b, c, f, 1) 57 | out_state: List[Tensor] = [] 58 | for t in range(t): 59 | state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha 60 | out_state.append(state) 61 | return x / torch.stack(out_state, 2).sqrt() 62 | 63 | 64 | if __name__ == "__main__": 65 | """ 66 | test norm method 67 | """ 68 | F1 = 96 69 | sr = 48000 70 | hop_size = 480 71 | tau = 1. 72 | alpha = get_norm_alpha(log=False, sr=sr, hop_size=hop_size, tau=tau) 73 | tmp = ExponentialUnitNorm(0.8, 96) 74 | spec = torch.randn(2, 1, 100, F1, 2) 75 | x = tmp(spec) 76 | norm_torch = torch.view_as_complex(x.squeeze(1)) 77 | print(f"norm_torch :{norm_torch}") 78 | print('sc') 79 | -------------------------------------------------------------------------------- /test/test_pqmf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy.signal import kaiser 5 | 6 | 7 | def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0): 8 | """ 9 | Design prototype filter for PQMF. 10 | This method is based on `A Kaiser window approach for the design of prototype 11 | filters of cosine modulated filterbanks`_. 12 | """ 13 | assert taps % 2 == 0, "The number of taps mush be even number" 14 | assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.' 15 | omega_c = np.pi * cutoff_ratio 16 | with np.errstate(invalid='ignore'): 17 | h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / ( 18 | np.pi * (np.arange(taps + 1) - 0.5 * taps)) 19 | h_i[taps // 2] = np.cos(0) * cutoff_ratio 20 | w = kaiser(taps + 1, beta) 21 | h = h_i * w 22 | return h 23 | 24 | 25 | class PQMF(torch.nn.Module): 26 | def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): 27 | super(PQMF, self).__init__() 28 | h_proto = design_prototype_filter(taps, cutoff_ratio, beta) 29 | h_analysis = np.zeros((subbands, len(h_proto))) 30 | h_synthesis = np.zeros((subbands, len(h_proto))) 31 | for k in range(subbands): 32 | h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * 33 | (np.pi / (2 * subbands)) * 34 | (np.arange(taps + 1) - 35 | ((taps - 1) / 2)) + 36 | (-1)**k * np.pi / 4) 37 | h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * 38 | (np.pi / (2 * subbands)) * 39 | (np.arange(taps + 1) - 40 | ((taps - 1) / 2)) - 41 | (-1)**k * np.pi / 4) 42 | 43 | # convert to tensor 44 | analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) 45 | synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) 46 | 47 | # register coefficients as beffer 48 | self.register_buffer("analysis_filter", analysis_filter) 49 | self.register_buffer("synthesis_filter", synthesis_filter) 50 | 51 | # filter for downsampling & upsampling 52 | updown_filter = torch.zeros((subbands, subbands, subbands)).float() 53 | for k in range(subbands): 54 | updown_filter[k, k, 0] = 1.0 55 | self.register_buffer("updown_filter", updown_filter) 56 | self.subbands = subbands 57 | 58 | # keep padding info 59 | self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) 60 | 61 | def analysis(self, x): 62 | """Analysis with PQMF. 63 | Args: 64 | x (Tensor): Input tensor (B, 1, T). 65 | Returns: 66 | Tensor: Output tensor (B, subbands, T // subbands). 67 | """ 68 | x = F.conv1d(self.pad_fn(x), self.analysis_filter) 69 | return F.conv1d(x, self.updown_filter, stride=self.subbands) 70 | 71 | def synthesis(self, x): 72 | """Synthesis with PQMF. 73 | Args: 74 | x (Tensor): Input tensor (B, subbands, T // subbands). 75 | Returns: 76 | Tensor: Output tensor (B, 1, T). 77 | """ 78 | # NOTE(): Power will be dreased so here multipy by # subbands. 79 | # Not sure this is the correct way, it is better to check again. 80 | # TODO(): Understand the reconstruction procedure 81 | x = F.conv_transpose1d(x, 82 | self.updown_filter * self.subbands, 83 | stride=self.subbands) 84 | return F.conv1d(self.pad_fn(x), self.synthesis_filter) 85 | 86 | 87 | def test_pqmf(): 88 | import colorful 89 | import librosa as lib 90 | import soundfile as sf 91 | colortool = colorful 92 | colortool.use_style("solarized") 93 | 94 | sr = 16000 95 | t_len = sr * 2 96 | B = 1 97 | sig = torch.randn(B, 1, t_len) 98 | sig_path = "/Users/okrio/codes/nearend_sparse.wav" 99 | sig, ss = lib.load(sig_path, sr=16000, mono=False) 100 | sig = torch.Tensor(sig) 101 | sig = sig.unsqueeze(0).unsqueeze(0) 102 | 103 | pqmf = PQMF() 104 | ananly_out = pqmf.analysis(sig) 105 | 106 | out = ananly_out.squeeze(0) 107 | out = out.transpose(1, 0) 108 | out = out.numpy() 109 | sf.write("out_pqmf.wav", out, samplerate=16000) 110 | print(colortool.red(f"analy_out:{ananly_out.shape}")) 111 | print('sc') 112 | 113 | 114 | if __name__ == "__main__": 115 | test_pqmf() -------------------------------------------------------------------------------- /tools/train_stand.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import sys 5 | from socket import socket 6 | 7 | import numpy as np 8 | import toml 9 | import torch 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from torch.utils.data import DataLoader, DistributedSampler 13 | import train_base.loss as loss 14 | from train_base.utils import initialize_module 15 | from utils.logger import init 16 | 17 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 18 | 19 | # get free gpu automatically 20 | # import GPUtil 21 | 22 | 23 | def entry(rank, world_size, config, resume, only_validation): 24 | torch.manual_seed(config["meta"]["seed"]) # For both CPU and GPU 25 | np.random.seed(config["meta"]["seed"]) 26 | random.seed(config["meta"]["seed"]) 27 | 28 | os.environ["MASTER_ADDR"] = "localhost" 29 | s = socket() 30 | s.bind(("", 0)) 31 | os.environ["MASTER_PORT"] = "1111" # A random local port 32 | 33 | # Initialize the process group 34 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 35 | 36 | # init log file 37 | if rank == 0: 38 | os.makedirs(os.path.join(config["meta"]["save_dir"]), exist_ok=True) 39 | init(os.path.join(config["meta"]["save_dir"], "train.log"), 40 | "train", 41 | slack_url=None) 42 | 43 | # The DistributedSampler will split the dataset into the several cross-process parts. 44 | # On the contrary, "Sampler=None, shuffle=True", each GPU will get all data in the whole dataset. 45 | 46 | train_dataset = initialize_module(config["train_dataset"]["path"], 47 | args=config["train_dataset"]["args"]) 48 | sampler = DistributedSampler(dataset=train_dataset, 49 | num_replicas=world_size, 50 | rank=rank, 51 | shuffle=True) 52 | train_dataloader = DataLoader( 53 | dataset=train_dataset, 54 | sampler=sampler, 55 | shuffle=False, 56 | **config["train_dataset"]["dataloader"], 57 | ) 58 | 59 | valid_dataloader = DataLoader(dataset=initialize_module( 60 | config["validation_dataset"]["path"], 61 | args=config["validation_dataset"]["args"]), 62 | num_workers=0, 63 | batch_size=1) 64 | 65 | model = initialize_module(config["model"]["path"], 66 | args=config["model"]["args"]) 67 | 68 | optimizer = torch.optim.Adam(params=model.parameters(), 69 | lr=config["optimizer"]["lr"], 70 | betas=(config["optimizer"]["beta1"], 71 | config["optimizer"]["beta2"])) 72 | 73 | loss_function = getattr( 74 | loss, 75 | config["loss_function"]["name"])(**config["loss_function"]["args"]) 76 | trainer_class = initialize_module(config["trainer"]["path"], 77 | initialize=False) 78 | 79 | trainer = trainer_class(dist=dist, 80 | rank=rank, 81 | config=config, 82 | resume=resume, 83 | only_validation=only_validation, 84 | model=model, 85 | loss_function=loss_function, 86 | optimizer=optimizer, 87 | train_dataloader=train_dataloader, 88 | validation_dataloader=valid_dataloader) 89 | 90 | trainer.train() 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser(description="FullSubNet") 95 | parser.add_argument("-C", 96 | "--configuration", 97 | required=True, 98 | type=str, 99 | help="Configuration (*.toml).") 100 | parser.add_argument("-R", 101 | "--resume", 102 | action="store_true", 103 | help="Resume the experiment from latest checkpoint.") 104 | parser.add_argument( 105 | "-V", 106 | "--only_validation", 107 | action="store_true", 108 | help="Only run validation. It is used for debugging validation.") 109 | parser.add_argument("-N", 110 | "--num_gpus", 111 | type=int, 112 | default=0, 113 | help="The number of GPUs you are using for training.") 114 | parser.add_argument("-P", 115 | "--preloaded_model_path", 116 | type=str, 117 | help="Path of the *.pth file of a model.") 118 | args = parser.parse_args() 119 | 120 | # set the gpu auto 121 | if args.num_gpus == 0: 122 | device_ids = GPUtil.getAvailable(order='first', 123 | limit=8, 124 | maxLoad=0.5, 125 | maxMemory=0.5, 126 | includeNan=False, 127 | excludeID=[], 128 | excludeUUID=[]) 129 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( 130 | [str(device_id) for device_id in device_ids]) 131 | args.num_gpus = len(device_ids) 132 | print(f"gpus: {os.environ['CUDA_VISIBLE_DEVICES']}") 133 | 134 | if args.preloaded_model_path: 135 | assert not args.resume, "The 'resume' conflicts with the 'preloaded_model_path'." 136 | 137 | configuration = toml.load(args.configuration) 138 | 139 | configuration["meta"]["experiment_name"], _ = os.path.splitext( 140 | os.path.basename(args.configuration)) 141 | configuration["meta"]["config_path"] = args.configuration 142 | configuration["meta"]["preloaded_model_path"] = args.preloaded_model_path 143 | 144 | # Expand python search path to "recipes" 145 | # sys.path.append(os.path.join(os.getcwd(), "..")) 146 | 147 | # One training job is corresponding to one group (world). 148 | # The world size is the number of processes for training, which is usually the number of GPUs you are using for distributed training. 149 | # the rank is the unique ID given to a process. 150 | # Find more information about DistributedDataParallel (DDP) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html. 151 | mp.spawn(entry, 152 | args=(args.num_gpus, configuration, args.resume, 153 | args.only_validation), 154 | nprocs=args.num_gpus, 155 | join=True) 156 | -------------------------------------------------------------------------------- /train/trainer_casual.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train/trainer_casual.py -------------------------------------------------------------------------------- /train_base/acoustics/audioAug.py: -------------------------------------------------------------------------------- 1 | from itertools import filterfalse 2 | import torch 3 | import torchaudio 4 | # import os 5 | import numpy as np 6 | import scipy.stats as ss 7 | # from scipy.stat import loguniform 8 | from typing import List 9 | from torch import Tensor 10 | import random 11 | 12 | 13 | def high_shelf(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 14 | sr: float): 15 | w0 = Tensor([2. * np.pi * center_freq / sr]) 16 | amp = torch.pow(10, gain_db / 40.) 17 | alpha = torch.sin(w0) / 2. / q_factor 18 | b0 = amp * ((amp + 1) + 19 | (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha) 20 | b1 = -2 * amp * ((amp - 1) + (amp + 1) * torch.cos(w0)) 21 | b2 = amp * ((amp + 1) + 22 | (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha) 23 | a0 = (amp + 1) - (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha 24 | a1 = 2 * ((amp - 1) - (amp + 1) * torch.cos(w0)) 25 | a2 = (amp + 1) - (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha 26 | 27 | b = torch.cat((b0, b1, b2), dim=-1) 28 | a = torch.cat((a0, a1, a2), -1) 29 | coef = torch.cat((b, a), 0) 30 | return coef 31 | 32 | 33 | def high_pass(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, sr: Tensor): 34 | 35 | w0 = Tensor([2. * np.pi * center_freq / sr]) 36 | alpha = torch.sin(w0) / 2. / q_factor 37 | 38 | b0 = (1 + torch.cos(w0)) / 2. 39 | b1 = -(1 + torch.cos(w0)) 40 | b2 = b0 41 | 42 | a0 = 1 + alpha 43 | a1 = -2 * torch.cos(w0) 44 | a2 = 1 - alpha 45 | 46 | b = torch.cat((b0, b1, b2)) 47 | a = torch.cat((a0, a1, a2)) 48 | coef = torch.stack((b, a), 0) 49 | return coef 50 | 51 | 52 | def low_shelf(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 53 | sr: float): 54 | w0 = Tensor([2. * np.pi * center_freq / sr]) 55 | amp = torch.pow(10, gain_db / 40.) 56 | alpha = torch.sin(w0) / 2. / q_factor 57 | 58 | b0 = amp * ((amp + 1) - 59 | (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha) 60 | b1 = 2 * amp * ((amp - 1) - (amp + 1) * torch.cos(w0)) 61 | b2 = amp * ((amp + 1) - 62 | (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha) 63 | a0 = (amp + 1) + (amp - 1) * torch.cos(w0) + 2 * torch.sqrt(amp) * alpha 64 | a1 = -2 * ((amp - 1) + (amp + 1) * torch.cos(w0)) 65 | a2 = (amp + 1) + (amp - 1) * torch.cos(w0) - 2 * torch.sqrt(amp) * alpha 66 | 67 | b = torch.cat((b0, b1, b2), -1) 68 | a = torch.cat((a0, a1, a2), -1) 69 | coef = torch.cat((b, a), 0) 70 | return coef 71 | 72 | 73 | def low_pass(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, sr: float): 74 | w0 = Tensor([2. * np.pi * center_freq / sr]) 75 | alpha = torch.sin(w0) / 2. / q_factor 76 | 77 | b0 = (1 - torch.cos(w0)) / 2 78 | b1 = 1 - torch.cos(w0) 79 | b2 = b0 80 | 81 | a0 = 1 + alpha 82 | a1 = -2 * torch.cos(w0) 83 | a2 = 1 - alpha 84 | 85 | b = torch.cat((b0, b1, b2)) 86 | a = torch.cat((a0, a1, a2)) 87 | 88 | coef = torch.stack((b, a), 0) 89 | return coef 90 | 91 | 92 | def peaking_eq(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, 93 | sr: float): 94 | 95 | w0 = Tensor([2. * np.pi * center_freq / sr]) 96 | amp = torch.pow(10, gain_db / 40.) 97 | alpha = torch.sin(w0) / 2. / q_factor 98 | 99 | b0 = 1 + alpha * amp 100 | b1 = -2 * torch.cos(w0) 101 | b2 = 1 - alpha * amp 102 | 103 | a0 = 1 + alpha / amp 104 | a1 = -2 * torch.cos(w0) 105 | a2 = 1 - alpha / amp 106 | 107 | b = torch.cat((b0, b1, b2)) 108 | a = torch.cat((a0, a1, a2)) 109 | coef = torch.stack((b, a), 0) 110 | return coef 111 | 112 | 113 | def notch(center_freq: Tensor, gain_db: Tensor, q_factor: Tensor, sr: float): 114 | w0 = Tensor([2. * np.pi * center_freq / sr]) 115 | alpha = torch.sin(w0) / 2. / q_factor 116 | 117 | b0 = Tensor(1.) 118 | b1 = -2. * torch.cos(w0) 119 | b2 = b0 120 | 121 | a0 = 1. + alpha 122 | a1 = -2 * torch.cos(w0) 123 | a2 = 1. - alpha 124 | 125 | b = torch.cat((b0, b1, b2)) 126 | a = torch.cat((a0, a1, a2)) 127 | 128 | coef = torch.stack((b, a), 0) 129 | return coef 130 | 131 | 132 | REGISTERED_SecFilter = { 133 | "high_shelf": high_shelf, 134 | "high_pass": high_pass, 135 | "low_shelf": low_shelf, 136 | "low_pass": low_pass, 137 | "peaking_eq": peaking_eq, 138 | "notch": notch 139 | } 140 | REGISTERED_SecFilter_freq = { 141 | "high_shelf": [1000,4000], 142 | "high_pass":[40, 400], 143 | "low_shelf":[40, 1000], 144 | "low_pass":[3000, 8000], 145 | "peaking_eq": [40, 4000], 146 | "notch": [40, 4000] 147 | } 148 | 149 | def compositeSecFilt(indata,filter_num=3, sr=16000): 150 | filter_list = [ 151 | "high_shelf", "high_pass", "low_shelf", "low_pass", "peaking_eq", 152 | "notch" 153 | ] 154 | assert filter_num < len(filter_list), " filter_num is error" 155 | filter_idx = list(np.linspace(0, 5, 6, dtype=np.int16)) 156 | sel_filter = random.sample(filter_idx, filter_num) 157 | indata_tmp = indata 158 | for i in range(0,filter_num): 159 | filter_type = filter_list[sel_filter[i]] 160 | center_freq = ss.loguniform.rvs(REGISTERED_SecFilter_freq[filter_type][0], REGISTERED_SecFilter_freq[filter_type][1], 1) 161 | gain_db = np.random.uniform(-15,15, 1) 162 | q_factor = np.random.uniform(0.5, 1.5, 1) 163 | selFilt_coef = REGISTERED_SecFilter[filter_type](center_freq,gain_db,q_factor,sr) 164 | indata_tmp = torchaudio.functional.lfilter(indata_tmp, selFilt_coef[1,:], selFilt_coef[0,:]) 165 | return indata_tmp 166 | 167 | 168 | def hp_filter(indata,filte_num=1, sr=16000): 169 | """ 170 | fixed frequency high pass filter 171 | """ 172 | center_freq = 150. 173 | q_factor = np.random.uniform(0.5,1.5,1) 174 | filt_coef = high_pass(center_freq,0,q_factor,sr) 175 | out = indata 176 | for i in range(0, filte_num): 177 | out = torchaudio.functional.lfilter(out, filt_coef[1,:],filt_coef[0,:]) 178 | return out 179 | 180 | def airAbsorption(sig, sr=16000): 181 | center_freq = [125,250,500,1000,2000,4000,8000, 16000,24000] 182 | air_absorption = [0.1,0.2,0.5,1.1,2.7,9.4,29.0,91.5,289.0] 183 | air_absorption_table = Tensor([x * 1e-3 for x in air_absorption]) 184 | distance_low = 1.0 185 | distance_high = 20.0 186 | d = torch.FloatTensor(1).uniform_(distance_low,distance_high) 187 | atten_val = torch.exp(-d * air_absorption_table) 188 | atten_val_db = 20 * torch.log10(atten_val) 189 | att_interp_db = interp_atten(att_interp_db, 161) 190 | att_interp = 10 ** (att_interp_db / 20) 191 | sig_stft = torch.stft(sig, window=torch.hann_window(320), n_fft=320, win_length=320, hop_length=160, return_complex=True).squeeze() 192 | att_interp_tile = torch.tile(att_interp, (sig_stft.shape[-1], 1)).transpose(1,0) 193 | masked = sig_stft * att_interp_tile 194 | masked = masked.unsqueeze() 195 | rc = torch.istft(masked, window = torch.hann_window(320), n_fft=320, win_length=320,hop_length=320, length=sig.shape[-1]) 196 | return rc 197 | 198 | def interp_atten(atten_vals=None, n_freq=None, center_freq=None, sr=16000): 199 | center_freq = [125,250,500, 1000, 2000, 4000, 8000, 16000, 24000] 200 | sr=16000 201 | atten_vals1 = [atten_vals[0].tolist()] + atten_vals.tolist() + [atten_vals[-1].tolist()] 202 | freqs = torch.linspace(0, sr/2, n_freq) 203 | atten_vals_interp = torch.zeros(n_freq) 204 | center_freq = [0] + center_freq + [sr/2] 205 | i = 0 206 | center_freq_win = as_windowed(Tensor([center_freq]), 2,1).squeeze() 207 | atten_vals_win = as_windowed(Tensor([atten_vals1]), 2,1).squeeze() 208 | gf = center_freq_win.tolist() 209 | for k,(c,a) in enumerate(zip(center_freq_win.tolist(), atten_vals_win.tolist())): 210 | c0, c1 = c[0], c[1] 211 | a0, a1 = a[0], a[1] 212 | while i None: 13 | super(STFT, self).__init__() 14 | 15 | self.win_size = win_size 16 | self.hop_size = hop_size 17 | self.n_overlap = self.win_size // self.hop_size 18 | self.requires_grad = requires_grad 19 | 20 | win = torch.from_numpy(scipy.hamming(self.win_size).astype(np.float32)) 21 | win = F.relu(win) 22 | 23 | win = nn.parameter(data=win, requires_grad=self.requires_grad) 24 | self.register_parameter('win', win) 25 | 26 | fourier_basis = np.fft.fft(np.eye(self.win_size)) 27 | fourier_basis_r = np.real(fourier_basis).astype(np.float32) 28 | fourier_basis_i = np.imag(fourier_basis).astype(np.float32) 29 | 30 | self.register_buffer('fourier_basis_r', 31 | torch.from_numpy(fourier_basis_r)) 32 | self.register_buffer('fourier_basis_i', 33 | torch.from_numpy(fourier_basis_i)) 34 | 35 | idx = torch.tensor(range(self.win_size // 2 - 1, 0, -1), 36 | dtype=torch.long) 37 | self.register_buffer('idx', idx) 38 | 39 | self.eps = torch.finfo(torch.float32).eps 40 | 41 | def kernel_fw(self): 42 | fourier_basis_r = torch.matmul(self.fourier_basis_r, 43 | torch.diag(self.win)) 44 | fourier_basis_i = torch.matmul(self.fourier_basis_i, 45 | torch.diag(self.win)) 46 | 47 | fourier_basis = torch.stack([fourier_basis_r, fourier_basis_i], dim=-1) 48 | forward_basis = fourier_basis.unsqueeze(dim=1) 49 | return forward_basis 50 | 51 | def kernel_bw(self): 52 | inv_fourier_basis_r = self.fourier_basis_r / self.win_size 53 | inv_fourier_basis_i = self.fourier_basis_i / self.win_size 54 | 55 | inv_fourier_basis = torch.stack( 56 | [inv_fourier_basis_r, inv_fourier_basis_i], dim=-1) 57 | backward_basis = inv_fourier_basis.unsqueeze(dim=1) 58 | return backward_basis 59 | 60 | def window(self, n_frames): 61 | assert n_frames >= 2 62 | seg = sum([ 63 | self.win[i * self.hop_size:(i + 1) * self.hop_size] 64 | for i in range(self.n_overlap) 65 | ]) 66 | seg = seg.unsqueeze(dim=-1).expand( 67 | (self.hop_size, n_frames - self.n_overlap + 1)) 68 | window = seg.contiguous().view(-1).contiguous() 69 | 70 | return window 71 | 72 | def stft(self, sig): 73 | batch_size = sig.shape[0] 74 | n_samples = sig.shape[1] 75 | 76 | cutoff = self.win_size // 2 + 1 77 | sig = sig.view(batch_size, 1, n_samples) 78 | kernel = self.kernel_fw() 79 | 80 | kernel_r = kernel[..., 0] 81 | kernel_i = kernel[..., 1] 82 | 83 | spec_r = F.conv1d(sig, 84 | kernel_r[:cutoff], 85 | stride=self.hop_size, 86 | padding=self.win_size - self.hop_size) 87 | spec_i = F.conv1d(sig, 88 | kernel_i[:cutoff], 89 | stride=self.hop_size, 90 | padding=self.win_size - self.hop_size) 91 | 92 | spec_r = spec_r.transpose(-1, -2).contiguous() 93 | spec_i = spec_i.transpose(-1, -2).contiguous() 94 | 95 | mag = torch.sqrt(spec_r**2 + spec_i**2) 96 | pha = torch.atan2(spec_i.data, spec_r.data) 97 | 98 | return spec_r, spec_i, mag, pha 99 | 100 | def istft(self, x): 101 | spec_r = x[:, 0, :, :] 102 | spec_i = x[:, 0, :, :] 103 | 104 | n_frames = spec_r.shape[1] 105 | spec_r = torch.cat( 106 | [spec_r, spec_r.index_select(dim=-1, index=self.idx)], dim=-1) 107 | spec_i = torch.cat( 108 | [spec_r, -spec_i.index_select(dim=-1, index=self.idx)], dim=-1) 109 | 110 | spec_r = spec_r.transpose(-1, -2).contiguous() 111 | spec_i = spec_i.transpose(-1, -2).contiguous() 112 | 113 | kernel = self.kernel_bw() 114 | kernel_r = kernel[..., 0].transpose(0, -1) 115 | kernel_i = kernel[..., 1].transpose(0, -1) 116 | sig = F.conv_transpose1d( 117 | spec_r, 118 | kernel_r, 119 | stride=self.hop_size, 120 | padding=self.win_size - self.hop_size) - F.conv_transpose1d( 121 | spec_i, 122 | kernel_i, 123 | stride=self.hop_size, 124 | padding=self.win_size - self.hop_size) 125 | 126 | sig = sig.squeeze(dim=1) 127 | window = self.window(n_frames) 128 | sig = sig / (window + self.eps) 129 | return sig 130 | -------------------------------------------------------------------------------- /train_base/acoustics/feature.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train_base/acoustics/feature.py -------------------------------------------------------------------------------- /train_base/acoustics/mask.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | from train_base.constant import EPSILON 6 | 7 | 8 | def build_ideal_ratio_mask(noisy_mag, clean_mag) -> torch.Tensor: 9 | """ 10 | Args: 11 | noisy_mag: [B, F, T], noisy magnitude 12 | clean_mag: [B, F, T], clean magnitude 13 | Returns: 14 | [B, F, T, 1] 15 | """ 16 | # noisy_mag_finetune = torch.sqrt(torch.square(noisy_mag) + EPSILON) 17 | # ratio_mask = clean_mag / noisy_mag_finetune 18 | ratio_mask = clean_mag / (noisy_mag + EPSILON) 19 | ratio_mask = ratio_mask[..., None] 20 | return compress_cIRM(ratio_mask, K=10, C=0.1) 21 | 22 | 23 | def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor: 24 | """ 25 | Args: 26 | noisy: [B, F, T], noisy complex-valued stft coefficients 27 | clean: [B, F, T], clean complex-valued stft coefficients 28 | Returns: 29 | [B, F, T, 2] 30 | """ 31 | denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON 32 | 33 | mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator 34 | mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator 35 | 36 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1) 37 | 38 | return compress_cIRM(complex_ratio_mask, K=10, C=0.1) 39 | 40 | 41 | def compress_cIRM(mask, K=10, C=0.1): 42 | """ 43 | Compress from (-inf, +inf) to [-K ~ K] 44 | """ 45 | if torch.is_tensor(mask): 46 | mask = -100 * (mask <= -100) + mask * (mask > -100) 47 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask)) 48 | else: 49 | mask = -100 * (mask <= -100) + mask * (mask > -100) 50 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask)) 51 | return mask 52 | 53 | 54 | def decompress_cIRM(mask, K=10, limit=9.9): 55 | mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit) 56 | mask = -K * torch.log((K - mask) / (K + mask)) 57 | return mask 58 | 59 | 60 | def complex_mul(noisy_r, noisy_i, mask_r, mask_i): 61 | r = noisy_r * mask_r - noisy_i * mask_i 62 | i = noisy_r * mask_i + noisy_i * mask_r 63 | return r, i -------------------------------------------------------------------------------- /train_base/constant.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | NEG_INF = torch.finfo(torch.float32).min 6 | PI = math.pi 7 | SOUND_SPEED = 343 # m/s 8 | EPSILON = np.finfo(np.float32).eps 9 | MAX_INT16 = np.iinfo(np.int16).max 10 | -------------------------------------------------------------------------------- /train_base/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | 3 | 4 | class BaseDataset(data.Dataset): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | @staticmethod 9 | def _offset_and_limit(dataset_list, offset, limit): 10 | dataset_list = dataset_list[offset:] 11 | if limit: 12 | dataset_list = dataset_list[:limit] 13 | return dataset_list 14 | 15 | @staticmethod 16 | def _parse_snr_range(snr_range): 17 | assert len(snr_range) == 2, f"The range of SNR should be [low, high], not {snr_range}." 18 | assert snr_range[0] <= snr_range[-1], "The low SNR should not larger than high SNR." 19 | 20 | low, high = snr_range 21 | snr_list = [] 22 | for i in range(low, high + 1, 1): 23 | snr_list.append(i) 24 | 25 | return snr_list 26 | -------------------------------------------------------------------------------- /train_base/inferencer/base_inferencer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train_base/inferencer/base_inferencer.py -------------------------------------------------------------------------------- /train_base/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | l1_loss = torch.nn.L1Loss 4 | mse_loss = torch.nn.MSELoss 5 | 6 | 7 | def si_snr_loss(): 8 | def si_snr(x, s, eps=1e-8): 9 | def l2norm(mat, keep_dim=False): 10 | return torch.norm(mat, dim=-1, keepdim=keep_dim) 11 | 12 | if x.shape != s.shape: 13 | raise RuntimeError( 14 | f"Dimension mismatch when calculate si_snr, {x.shape} vs {s.shape}" 15 | ) 16 | 17 | x_zm = x - torch.mean(x, dim=-1, keepdim=True) 18 | s_zm = s - torch.mean(s, dim=-1, keepdim=True) 19 | t = torch.sum(x_zm * s_zm, dim=-1, keepdim=True) * s_zm / ( 20 | l2norm(s_zm, keep_dim=True)**2 + eps) 21 | 22 | return -torch.mean(20 * torch.log10(eps + l2norm(t) / 23 | (l2norm(x_zm - t) + eps))) 24 | 25 | return si_snr 26 | 27 | 28 | # class BaseLoss: 29 | 30 | 31 | def ccmse_loss(noisy, clean, mask, eps=1e-8): 32 | pass 33 | -------------------------------------------------------------------------------- /train_base/metrics.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train_base/metrics.py -------------------------------------------------------------------------------- /train_base/model/base_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train_base/model/base_model.py -------------------------------------------------------------------------------- /train_base/trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | from pathlib import Path 4 | 5 | import colorful 6 | import librosa 7 | import librosa.display 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import toml 11 | import torch 12 | from joblib import Parallel, delayed 13 | from torch.cuda.amp import GradScaler 14 | from torch.nn.parallel import DistributedDataParallel 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import train_base.metrics as metrics 18 | from train_base.acoustics.feature import stft, istft 19 | from train_base.utils import prepare_empty_dir, ExecutionTime 20 | from utils.logger import log 21 | 22 | plt.switch_backend('agg') 23 | 24 | 25 | class BaseTrainer: 26 | def __init__(self, dist, rank, config, resume, only_validation, model, 27 | loss_function, optimizer): 28 | self.color_tool = colorful 29 | self.color_tool.use_style("solarized") 30 | 31 | model = DistributedDataParallel(model.to(rank), device_ids=[rank]) 32 | self.model = model 33 | self.optimizer = optimizer 34 | self.loss_function = loss_function 35 | 36 | # DistributedDataParallel (DDP) 37 | self.rank = rank 38 | self.dist = dist 39 | 40 | # Automatic mixed precision (AMP) 41 | self.use_amp = config["meta"]["use_amp"] 42 | self.scaler = GradScaler(enabled=self.use_amp) 43 | 44 | # Acoustics 45 | self.acoustic_config = config["acoustics"] 46 | 47 | # Supported STFT 48 | n_fft = self.acoustic_config["n_fft"] 49 | hop_length = self.acoustic_config["hop_length"] 50 | win_length = self.acoustic_config["win_length"] 51 | 52 | self.torch_stft = partial(stft, 53 | n_fft=n_fft, 54 | hop_length=hop_length, 55 | win_length=win_length) 56 | self.torch_istft = partial(istft, 57 | n_fft=n_fft, 58 | hop_length=hop_length, 59 | win_length=win_length) 60 | self.librosa_stft = partial(librosa.stft, 61 | n_fft=n_fft, 62 | hop_length=hop_length, 63 | win_length=win_length) 64 | self.librosa_istft = partial(librosa.istft, 65 | hop_length=hop_length, 66 | win_length=win_length) 67 | 68 | # Trainer.train in the config 69 | self.train_config = config["trainer"]["train"] 70 | if self.train_config["alpha"]: 71 | self.alpha = self.train_config["alpha"] 72 | self.epochs = self.train_config["epochs"] 73 | self.save_checkpoint_interval = self.train_config[ 74 | "save_checkpoint_interval"] 75 | self.clip_grad_norm_value = self.train_config["clip_grad_norm_value"] 76 | assert self.save_checkpoint_interval >= 1, "Check the 'save_checkpoint_interval' parameter in the config. It should be large than one." 77 | 78 | # Trainer.validation in the config 79 | self.validation_config = config["trainer"]["validation"] 80 | self.validation_interval = self.validation_config[ 81 | "validation_interval"] 82 | self.save_max_metric_score = self.validation_config[ 83 | "save_max_metric_score"] 84 | assert self.validation_interval >= 1, "Check the 'validation_interval' parameter in the config. It should be large than one." 85 | 86 | # Trainer.visualization in the config 87 | self.visualization_config = config["trainer"]["visualization"] 88 | 89 | # In the 'train.py' file, if the 'resume' item is 'True', we will update the following args: 90 | self.start_epoch = 1 91 | self.best_score = -np.inf if self.save_max_metric_score else np.inf 92 | self.save_dir = Path(config["meta"]["save_dir"]).expanduser().absolute( 93 | ) / config["meta"]["experiment_name"] 94 | self.checkpoints_dir = self.save_dir / "checkpoints" 95 | self.logs_dir = self.save_dir / "logs" 96 | 97 | if resume: 98 | self._resume_checkpoint() 99 | 100 | # Debug validation, which skips training 101 | self.only_validation = only_validation 102 | 103 | if config["meta"]["preloaded_model_path"]: 104 | self._preload_model(Path(config["preloaded_model_path"])) 105 | 106 | if self.rank == 0: 107 | prepare_empty_dir([self.checkpoints_dir, self.logs_dir], 108 | resume=resume) 109 | 110 | self.writer = SummaryWriter(self.logs_dir.as_posix(), 111 | max_queue=5, 112 | flush_secs=30) 113 | self.writer.add_text( 114 | tag="Configuration", 115 | text_string=f"
  \n{toml.dumps(config)}  \n
", 116 | global_step=1) 117 | 118 | print(self.color_tool.cyan("The configurations are as follows: ")) 119 | print(self.color_tool.cyan("=" * 40)) 120 | print(self.color_tool.cyan(toml.dumps(config)[:-1])) # except "\n" 121 | print(self.color_tool.cyan("=" * 40)) 122 | 123 | with open( 124 | (self.save_dir / 125 | f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), 126 | "w") as handle: 127 | toml.dump(config, handle) 128 | 129 | self._print_networks([self.model]) 130 | 131 | def _preload_model(self, model_path): 132 | """ 133 | Preload model parameters (in "*.tar" format) at the start of experiment. 134 | Args: 135 | model_path (Path): The file path of the *.tar file 136 | """ 137 | model_path = model_path.expanduser().absolute() 138 | assert model_path.exists( 139 | ), f"The file {model_path.as_posix()} is not exist. please check path." 140 | 141 | model_checkpoint = torch.load(model_path.as_posix(), 142 | map_location="cpu") 143 | self.model.load_state_dict(model_checkpoint["model"], strict=False) 144 | self.model.to(self.rank) 145 | 146 | if self.rank == 0: 147 | print( 148 | f"Model preloaded successfully from {model_path.as_posix()}.") 149 | 150 | def _resume_checkpoint(self): 151 | """ 152 | Resume the experiment from the latest checkpoint. 153 | """ 154 | latest_model_path = self.checkpoints_dir.expanduser().absolute( 155 | ) / "latest_model.tar" 156 | assert latest_model_path.exists( 157 | ), f"{latest_model_path} does not exist, can not load latest checkpoint." 158 | 159 | # Make sure all processes (GPUs) do not start loading before the saving is finished. 160 | # see https://stackoverflow.com/questions/59760328/how-does-torch-distributed-barrier-work 161 | self.dist.barrier() 162 | 163 | # Load it on the CPU and later use .to(device) on the model 164 | # Maybe slightly slow than use map_location="cuda:<...>" 165 | # https://stackoverflow.com/questions/61642619/pytorch-distributed-data-parallel-confusion 166 | checkpoint = torch.load(latest_model_path.as_posix(), 167 | map_location="cpu") 168 | 169 | self.start_epoch = checkpoint["epoch"] + 1 170 | self.best_score = checkpoint["best_score"] 171 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 172 | self.scaler.load_state_dict(checkpoint["scaler"]) 173 | 174 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 175 | self.model.module.load_state_dict(checkpoint["model"]) 176 | else: 177 | self.model.load_state_dict(checkpoint["model"]) 178 | 179 | # self.model.to(self.rank) 180 | 181 | if self.rank == 0: 182 | print( 183 | f"Model checkpoint loaded. Training will begin at {self.start_epoch} epoch." 184 | ) 185 | 186 | def _save_checkpoint(self, epoch, is_best_epoch=False): 187 | """ 188 | Save checkpoint to "//checkpoints" directory, which consists of: 189 | - epoch 190 | - best metric score in historical epochs 191 | - optimizer parameters 192 | - model parameters 193 | Args: 194 | is_best_epoch (bool): In the current epoch, if the model get a best metric score (is_best_epoch=True), 195 | the checkpoint of model will be saved as "/checkpoints/best_model.tar". 196 | """ 197 | print(f"\t Saving {epoch} epoch model checkpoint...") 198 | 199 | state_dict = { 200 | "epoch": epoch, 201 | "best_score": self.best_score, 202 | "optimizer": self.optimizer.state_dict(), 203 | "scaler": self.scaler.state_dict() 204 | } 205 | 206 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 207 | state_dict["model"] = self.model.module.state_dict() 208 | else: 209 | state_dict["model"] = self.model.state_dict() 210 | 211 | # Saved in "latest_model.tar" 212 | # Contains all checkpoint information, including the optimizer parameters, the model parameters, etc. 213 | # New checkpoint will overwrite the older one. 214 | torch.save(state_dict, 215 | (self.checkpoints_dir / "latest_model.tar").as_posix()) 216 | 217 | # "model_{epoch_number}.pth" 218 | # Contains only model. 219 | torch.save(state_dict["model"], 220 | (self.checkpoints_dir / 221 | f"model_{str(epoch).zfill(4)}.pth").as_posix()) 222 | 223 | # If the model get a best metric score (means "is_best_epoch=True") in the current epoch, 224 | # the model checkpoint will be saved as "best_model.tar" 225 | # The newer best-scored checkpoint will overwrite the older one. 226 | if is_best_epoch: 227 | print( 228 | self.color_tool.red( 229 | f"\t Found a best score in the {epoch} epoch, saving...")) 230 | log(f"\t Found a best score in the {epoch} epoch, saving...") 231 | torch.save(state_dict, 232 | (self.checkpoints_dir / "best_model.tar").as_posix()) 233 | 234 | def _is_best_epoch(self, score, save_max_metric_score=True): 235 | """ 236 | Check if the current model got the best metric score 237 | """ 238 | if save_max_metric_score and score >= self.best_score: 239 | self.best_score = score 240 | return True 241 | elif not save_max_metric_score and score <= self.best_score: 242 | self.best_score = score 243 | return True 244 | else: 245 | return False 246 | 247 | @staticmethod 248 | def _print_networks(models: list): 249 | print( 250 | f"This project contains {len(models)} models, the number of the parameters is: " 251 | ) 252 | 253 | params_of_all_networks = 0 254 | for idx, model in enumerate(models, start=1): 255 | params_of_network = 0 256 | for param in model.parameters(): 257 | params_of_network += param.numel() 258 | 259 | print(f"\tNetwork {idx}: {params_of_network / 1e6} million.") 260 | params_of_all_networks += params_of_network 261 | 262 | print( 263 | f"The amount of parameters in the project is {params_of_all_networks / 1e6} million." 264 | ) 265 | 266 | def _set_models_to_train_mode(self): 267 | self.model.train() 268 | 269 | def _set_models_to_eval_mode(self): 270 | self.model.eval() 271 | 272 | def spec_audio_visualization(self, 273 | noisy, 274 | enhanced, 275 | clean, 276 | name, 277 | epoch, 278 | mark=""): 279 | # Supported STFT 280 | n_fft = self.acoustic_config["n_fft"] 281 | hop_length = self.acoustic_config["hop_length"] 282 | win_length = self.acoustic_config["win_length"] 283 | sr = self.acoustic_config["sr"] 284 | 285 | self.writer.add_audio(f"{mark}_Speech/{name}_Noisy", 286 | noisy, 287 | epoch, 288 | sample_rate=sr) 289 | self.writer.add_audio(f"{mark}_Speech/{name}_Enhanced", 290 | enhanced, 291 | epoch, 292 | sample_rate=sr) 293 | self.writer.add_audio(f"{mark}_Speech/{name}_Clean", 294 | clean, 295 | epoch, 296 | sample_rate=sr) 297 | 298 | # Visualize the spectrogram of noisy speech, clean speech, and enhanced speech 299 | noisy_mag, _ = librosa.magphase( 300 | self.librosa_stft(noisy, 301 | n_fft=n_fft, 302 | hop_length=hop_length, 303 | win_length=win_length)) 304 | enhanced_mag, _ = librosa.magphase( 305 | self.librosa_stft(enhanced, 306 | n_fft=n_fft, 307 | hop_length=hop_length, 308 | win_length=win_length)) 309 | clean_mag, _ = librosa.magphase( 310 | self.librosa_stft(clean, 311 | n_fft=n_fft, 312 | hop_length=hop_length, 313 | win_length=win_length)) 314 | fig, axes = plt.subplots(3, 1, figsize=(6, 6)) 315 | for k, mag in enumerate([noisy_mag, enhanced_mag, clean_mag]): 316 | axes[k].set_title(f"mean: {np.mean(mag):.3f}, " 317 | f"std: {np.std(mag):.3f}, " 318 | f"max: {np.max(mag):.3f}, " 319 | f"min: {np.min(mag):.3f}") 320 | librosa.display.specshow(librosa.amplitude_to_db(mag), 321 | cmap="magma", 322 | y_axis="linear", 323 | ax=axes[k], 324 | sr=sr) 325 | plt.tight_layout() 326 | self.writer.add_figure(f"{mark}_Spectrogram/{name}", fig, epoch) 327 | 328 | def metrics_visualization(self, 329 | noisy_list, 330 | clean_list, 331 | enhanced_list, 332 | metrics_list, 333 | epoch, 334 | num_workers=10, 335 | mark=""): 336 | """ 337 | Get metrics on validation dataset by paralleling. 338 | Notes: 339 | 1. You can register other metrics, but STOI and WB_PESQ metrics must be existence. These two metrics are 340 | used for checking if the current epoch is a "best epoch." 341 | 2. If you want to use a new metric, you must register it in "util.metrics" file. 342 | """ 343 | assert "STOI" in metrics_list and "WB_PESQ" in metrics_list, "'STOI' and 'WB_PESQ' must be existence." 344 | 345 | # Check if the metric is registered in "util.metrics" file. 346 | for i in metrics_list: 347 | assert i in metrics.REGISTERED_METRICS.keys( 348 | ), f"{i} is not registered, please check 'util.metrics' file." 349 | 350 | stoi_mean = 0.0 351 | wb_pesq_mean = 0.0 352 | for metric_name in metrics_list: 353 | score_on_noisy = Parallel(n_jobs=num_workers)( 354 | delayed(metrics.REGISTERED_METRICS[metric_name])( 355 | ref, est, sr=self.acoustic_config["sr"]) 356 | for ref, est in zip(clean_list, noisy_list)) 357 | score_on_enhanced = Parallel(n_jobs=num_workers)( 358 | delayed(metrics.REGISTERED_METRICS[metric_name])( 359 | ref, est, sr=self.acoustic_config["sr"]) 360 | for ref, est in zip(clean_list, enhanced_list)) 361 | 362 | # Add the mean value of the metric to tensorboard 363 | mean_score_on_noisy = np.mean(score_on_noisy) 364 | mean_score_on_enhanced = np.mean(score_on_enhanced) 365 | self.writer.add_scalars(f"{mark}_Validation/{metric_name}", { 366 | "Noisy": mean_score_on_noisy, 367 | "Enhanced": mean_score_on_enhanced 368 | }, epoch) 369 | 370 | if metric_name == "STOI": 371 | stoi_mean = mean_score_on_enhanced 372 | 373 | if metric_name == "WB_PESQ": 374 | wb_pesq_mean = (mean_score_on_enhanced + 0.5) / 5 375 | 376 | return (stoi_mean + wb_pesq_mean) / 2 377 | 378 | def train(self): 379 | for epoch in range(self.start_epoch, self.epochs + 1): 380 | if self.rank == 0: 381 | print( 382 | self.color_tool.yellow( 383 | f"{'=' * 15} {epoch} epoch {'=' * 15}")) 384 | print("[0 seconds] Begin training...") 385 | 386 | # [debug validation] Only run validation (only use the first GPU (process)) 387 | # inference + calculating metrics + saving checkpoints 388 | if self.only_validation and self.rank == 0: 389 | self._set_models_to_eval_mode() 390 | metric_score = self._validation_epoch(epoch) 391 | 392 | if self._is_best_epoch( 393 | metric_score, 394 | save_max_metric_score=self.save_max_metric_score): 395 | self._save_checkpoint(epoch, is_best_epoch=True) 396 | 397 | # Skip the following regular training, saving checkpoints, and validation 398 | continue 399 | 400 | # Regular training 401 | timer = ExecutionTime() 402 | self._set_models_to_train_mode() 403 | self._train_epoch(epoch) 404 | 405 | # Regular save checkpoints 406 | if self.rank == 0 and self.save_checkpoint_interval != 0 and ( 407 | epoch % self.save_checkpoint_interval == 0): 408 | self._save_checkpoint(epoch) 409 | 410 | # Regular validation 411 | if self.rank == 0 and (epoch % self.validation_interval == 0): 412 | print( 413 | f"[{timer.duration()} seconds] Training has finished, validation is in progress..." 414 | ) 415 | 416 | self._set_models_to_eval_mode() 417 | metric_score = self._validation_epoch(epoch) 418 | 419 | if self._is_best_epoch( 420 | metric_score, 421 | save_max_metric_score=self.save_max_metric_score): 422 | self._save_checkpoint(epoch, is_best_epoch=True) 423 | 424 | print(f"[{timer.duration()} seconds] This epoch is finished.") 425 | 426 | def _train_epoch(self, epoch): 427 | raise NotImplementedError 428 | 429 | def _validation_epoch(self, epoch): 430 | raise NotImplementedError 431 | -------------------------------------------------------------------------------- /train_base/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/train_base/utils.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | from datetime import datetime 3 | import json 4 | from threading import Thread 5 | from urllib.request import Request, urlopen 6 | import os 7 | 8 | _format = '%Y-%m-%d %H:%M:%S.%f' 9 | _file = None 10 | _run_name = None 11 | _slack_url = None 12 | 13 | 14 | def _close_logfile(): 15 | global _file 16 | if _file is not None: 17 | _file.close() 18 | _file = None 19 | 20 | 21 | def _send_slack(msg): 22 | req = Request(_slack_url) 23 | req.add_header('Content-Type', 'application/json') 24 | urlopen(req, json.dumps({ 25 | 'username': 'tacotron', 26 | 'icon_emoji': 'taco', 27 | 'text': '*%s*: %s' % (_run_name, msg) 28 | }).encode()) 29 | 30 | 31 | def init(filename, run_name, slack_url=None): 32 | os.makedirs(os.path.dirname(filename), exist_ok=True) 33 | global _file, _run_name, _slack_url 34 | _close_logfile() 35 | _file.write('\n---------------\n') 36 | _file.write('Starting new training run \n') 37 | _file.write('-----------------\n') 38 | _file.flush() 39 | _run_name = run_name 40 | _slack_url = slack_url 41 | 42 | 43 | def log(msg, slack=False): 44 | cur_time = datetime.now().strftime(_format)[:-3] 45 | print('[%s] %s' % (cur_time, msg), end='\n', flush=True) 46 | if _file is not None: 47 | _file.write('[%s] %s\n' % (cur_time, msg)) 48 | _file.flush() 49 | if slack and _slack_url is not None: 50 | Thread(target=_send_slack, args=(msg,)).start() 51 | 52 | 53 | atexit.register(_close_logfile) 54 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Okrio/CRUSE/bcd5607953e0f76bac82a5cd322af4e0d3d23342/utils/plot.py -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: your name 3 | Date: 2022-02-12 16:16:12 4 | LastEditTime: 2022-03-28 22:53:14 5 | LastEditors: Please set LastEditors 6 | Description: In User Settings Edit 7 | FilePath: /CRUSE/utils/utils.py 8 | ''' 9 | 10 | # from ntpath import join 11 | 12 | from pathlib import Path 13 | import numpy as np 14 | # from sklearn.linear_model import LogisticRegressionCV 15 | # from scipy.fft import fft 16 | import torch 17 | # import matplotlib.pyplot as plt 18 | import os 19 | import csv 20 | import glob 21 | import librosa as lib 22 | import statistics as stats 23 | import soundfile as sf 24 | import scipy.signal as scs 25 | import sys 26 | from typing import List 27 | import colorful 28 | 29 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "test"))) 30 | print(os.path.abspath(os.path.join(os.getcwd(), "test"))) 31 | 32 | from test_loss import plot_mesh 33 | 34 | EPS = np.finfo(float).eps 35 | 36 | 37 | def normalize(audio, target_level=-25): 38 | rms = (audio**2).mean()**0.5 39 | scalar = 10**(target_level / 20) / (rms + 1e-8) 40 | audio = audio * scalar 41 | return audio 42 | 43 | 44 | def as_windowed(x: torch.Tensor, win_len, hop_len=1, dim=1): 45 | """ 46 | input: B, T 47 | output: B, T//win_len, win_len 48 | """ 49 | shape: List[int] = list(x.shape) 50 | stride: List[int] = list(x.stride()) 51 | shape[dim] = int((shape[dim] - win_len + hop_len) // hop_len) 52 | shape.insert(dim + 1, win_len) 53 | stride.insert(dim + 1, stride[dim]) 54 | stride[dim] = stride[dim] * hop_len 55 | y = x.as_strided(shape, stride) 56 | return y 57 | 58 | 59 | def torch_active_rms(audio: torch.Tensor, 60 | sr=16000, 61 | thr=-120, 62 | frame_length=100): 63 | window_samples = int((sr * frame_length) // 1000) 64 | EPS = torch.finfo(torch.float32).eps 65 | 66 | y = as_windowed(audio, window_samples, window_samples) # stride input data 67 | audio_seg_rms = 20 * torch.log10( 68 | torch.mean(y**2, dim=-1, keepdim=True) + EPS) 69 | thr_mat = torch.zeros_like(y) 70 | y1 = torch.where(audio_seg_rms > thr, y, thr_mat) 71 | y1_flatten = torch.flatten(y1, 72 | start_dim=1) # flatten data in batch dimenstion 73 | y1_zeros_count = (y1_flatten == 0.).sum( 74 | dim=-1) # statistic zero numbers for each batch 75 | y1_flatten_sum = torch.sum(y1_flatten**2, dim=-1) # calcuate square sum 76 | y1_flatten_mean = (y1_flatten_sum / (y1_flatten.shape[-1] - y1_zeros_count) 77 | )**0.5 # mean data without zero-val part 78 | out = y1_flatten_mean.unsqueeze(dim=1) 79 | return out 80 | 81 | 82 | def active_rms(audio, sr=16000, energy_thresh=-120): 83 | """ 84 | signal active rms 85 | """ 86 | window_size = 100 87 | window_sample = int(sr * window_size / 1000) 88 | sample_start = 0 89 | audio_active_segs = [] 90 | EPS = np.finfo(float).eps 91 | 92 | while sample_start < len(audio): 93 | sample_end = min(sample_start + window_sample, len(audio)) 94 | audio_win = audio[sample_start:sample_end] 95 | audio_seg_rms = 10 * np.log10((audio_win**2).mean() + EPS) 96 | if audio_seg_rms > energy_thresh: 97 | audio_active_segs = np.append(audio_active_segs, audio_win) 98 | sample_start += window_sample 99 | if len(audio_active_segs) != 0: 100 | audio_rms = (audio_active_segs**2).mean()**0.5 101 | else: 102 | audio_rms = EPS 103 | return audio_rms 104 | 105 | 106 | def vad_simplify(audio, win_len=256, hop_len=160, fs=16000, target_level=-25): 107 | """ 108 | refer to "weighted speech distortion losses for neural network-based real-time speech enhancement" 109 | """ 110 | audio = normalize(audio, target_level) 111 | # audio_len = len(audio) 112 | # n_frames = (audio_len - win_len + hop_len) // hop_len 113 | # audio_clips = torch.tensor(audio).unfold(0, win_len, hop_len) 114 | # print(f"audio_clips:{audio_clips.shape}") 115 | freq_res = fs * 1. / win_len 116 | f300Hz_point = int(np.floor(300 / freq_res)) 117 | f5000Hz_point = int(np.ceil(5000 / freq_res)) 118 | 119 | stft_audio = lib.stft(audio, 120 | n_fft=win_len, 121 | hop_length=hop_len, 122 | win_length=win_len, 123 | center=True) 124 | stft_mag, _ = lib.magphase(stft_audio) # F * T 125 | stft_mag_log = 10 * np.log10(stft_mag**2 + 1e-12) 126 | plot_mesh(stft_mag_log, 'stft_mag_log') 127 | # plt.show() 128 | stft_300_5000_sum = np.sum(stft_mag_log[f300Hz_point:f5000Hz_point, :], 129 | axis=0) 130 | stft_300_5000_sum_smooth = scs.lfilter([0.1], [1, -0.5, -0.2, -0.2], 131 | stft_300_5000_sum) 132 | plt.figure() 133 | plt.plot(stft_300_5000_sum) 134 | plt.plot(stft_300_5000_sum_smooth) 135 | # ax = plt.axes() 136 | # ax.set_alpha(0.4) 137 | plt.legend(['stft_sum', 'stft_sum_smooth']) 138 | plt.show() 139 | 140 | return 141 | 142 | 143 | def activitydetector(audio, fs=16000, energy_thresh=0.13, target_level=-25): 144 | audio = normalize(audio, target_level) 145 | window_size = 50 146 | window_samples = int(fs * window_size / 1000) 147 | sample_start = 0 148 | cnt = 0 149 | prev_energy_prob = 0 150 | active_frames = 0 151 | a = -1 152 | b = 0.2 153 | alpha_rel = 0.05 154 | alpha_att = 0.8 155 | vad_val = np.zeros_like(audio) 156 | vad_frame_val = [] 157 | 158 | while sample_start < len(audio): 159 | sample_end = min(sample_start + window_samples, len(audio)) 160 | audio_win = audio[sample_start:sample_end] 161 | frame_rms = 10 * np.log10(sum(audio_win**2) + EPS) 162 | frame_energy_prob = 1. / (1 + np.exp(-(a + b + frame_rms))) 163 | 164 | if frame_energy_prob > prev_energy_prob: 165 | smooth_energy_prob = frame_energy_prob * alpha_att + prev_energy_prob * ( 166 | 1 - alpha_att) 167 | else: 168 | smooth_energy_prob = frame_energy_prob * alpha_rel + prev_energy_prob * ( 169 | 1 - alpha_rel) 170 | 171 | if smooth_energy_prob > energy_thresh: 172 | vad_val[sample_start:sample_end] = 1 173 | vad_frame_val.append(1) 174 | active_frames += 1 175 | else: 176 | vad_frame_val.append(0) 177 | 178 | prev_energy_prob = frame_energy_prob 179 | sample_start += window_samples 180 | cnt += 1 181 | 182 | prec_active = active_frames / cnt 183 | return prec_active, vad_val, np.array(vad_frame_val) 184 | 185 | 186 | def activity_detector_amp(audio, fs=16000, thre=9): 187 | """ 188 | rnnoise vad method 189 | 190 | """ 191 | window_size = 20 # ms 192 | window_sample = int(fs * window_size / 1000) 193 | start = 0 194 | data_len = len(audio) 195 | audio = audio * 32768 196 | frameshift = window_sample // 2 197 | nframe = (data_len - window_sample + frameshift) // frameshift 198 | Energy = np.zeros(data_len) 199 | vad_seq = np.zeros_like(Energy) 200 | vad1 = 0 201 | vad_cnt = 0 202 | for i in range(0, nframe): 203 | tmp = np.array(audio[start:i + window_sample]) 204 | E_val = np.sum(tmp * tmp) 205 | if E_val > 1e9: 206 | vad_cnt = 0 207 | elif E_val > 1e8: 208 | vad_cnt = vad_cnt - 5 209 | elif E_val > 1e7: 210 | vad_cnt = vad_cnt + 1 211 | else: 212 | vad_cnt += 2 213 | # todo ... 214 | pass 215 | 216 | 217 | def activity_detector_tf_frame(audio, fs=16000, thr=9): 218 | 219 | pass 220 | 221 | 222 | def write_log_file(log_dir, log_filename, data): 223 | data = zip(*data) 224 | with open(os.path.join(log_dir, log_filename), mode="w", 225 | newline='') as csvfile: 226 | csvwriter = csv.writer(csvfile, 227 | delimiter=' ', 228 | quotechar='|', 229 | quoting=csv.QUOTE_MINIMAL) 230 | for row in data: 231 | csvwriter.writerow([row]) 232 | 233 | 234 | def get_dir(cfg, param_name, new_dir_name): 235 | if param_name in cfg: 236 | dir_name = cfg[param_name] 237 | 238 | else: 239 | dir_name = os.path.join(os.path.dirname(__file__), new_dir_name) 240 | 241 | if not os.path.exists(dir_name): 242 | os.makedirs(dir_name) 243 | return dir_name 244 | 245 | 246 | def statist_vad(data_path): 247 | noise_filename = glob.glob(os.path.join(data_path), '*.wav') 248 | noise_filename_list = [] 249 | vad_results_list = [] 250 | total_clips = len(noise_filename) 251 | for noisepath in noise_filename: 252 | noise_filename_list.append(os.path.basename(noisepath)) 253 | noise_signal, sr_noise = lib.load(noisepath, sr=16000) 254 | per_act, _ = activitydetector(noise_signal) 255 | vad_results_list.append(per_act) 256 | 257 | pc_vad_passed = round( 258 | vad_results_list.count('True') / total_clips * 100, 1) 259 | print('% noise clips that passed vad tests: ', pc_vad_passed) 260 | dir_name = os.path.join(os.path.dirname(__file__), 'Unit_tests_logs') 261 | if not os.path.exists(dir_name): 262 | os.makedirs(dir_name) 263 | if not os.path.exists(dir_name): 264 | dir_name = os.path.join(os.path.dirname(__file__), "Unit_tests_logs") 265 | os.makedirs(dir_name) 266 | write_log_file(dir_name, 'unit_test_results.csv', 267 | [noise_filename_list, vad_results_list]) 268 | 269 | 270 | def cal_rt60(y): 271 | freq_third = [ 272 | 400, 500, 630, 800, 1000, 1250, 1600, 2000, 2500, 3150, 4000, 5000, 273 | 6300, 8000, 10000 274 | ] 275 | freqbands = [ 276 | 355, 447, 562, 708, 891, 1122, 1413, 1778, 2239, 2818, 3548, 4467, 277 | 5623, 7079, 8913, 11220 278 | ] 279 | maxlev = 2**15 - 1 280 | dbscale = 20 281 | # medavgtime = 0.3 282 | ratiofmax = 0.7 283 | convolven = 2500 284 | rt60raw = [0.0] * len(freq_third) 285 | sig, sr = lib.load(y) 286 | sig = sig[0, :] if sig.ndim > 1 else sig # todo 287 | da = sig 288 | for k in range(len(freq_third)): 289 | daf = np.fft.rfft(da) 290 | lofreq = round((freqbands[k + 0] / (sr / 2)) * (len(daf) - 1)) 291 | hifreq = round((freqbands[k + 1] / (sr / 2)) * (len(daf) - 1)) 292 | daf[:lofreq] = 0 293 | daf[:hifreq] = 0 294 | nda = np.fft.ifft(daf, len(da)) 295 | nda = abs(nda) 296 | ndalog = [0.0] * len(nda) 297 | ndapre = [0.0] * len(nda) 298 | for i in range(len(nda)): 299 | if nda[i] != 0: 300 | ndalog[i] = dbscale * np.log10(nda[i] / maxlev) 301 | else: 302 | ndalog[i] = dbscale * np.log10(1 / maxlev) 303 | ndapre[i] = ndalog[i] 304 | 305 | ndalog = np.convolve(ndalog, 306 | np.ones((convolven, )) / convolven, 307 | mode='valid') 308 | ndalog_min, ndalog_max = min(ndalog), max(ndalog) 309 | ndalog_cut_apx = ndalog_max - (ndalog_max - ndalog_min) * ratiofmax 310 | ndalog_cut_ind = (np.abs(ndalog - ndalog_cut_apx)).argmin() 311 | ndalog = ndalog[0:ndalog_cut_ind] 312 | 313 | temp_index = np.arange(0, len(ndalog)) 314 | slope, intercept, r_value, p_value, std_err = stats.Linregress( 315 | temp_index, ndalog) 316 | # dBlossline = slope * temp_index + intercept 317 | rt60 = -60.0 / (slope * sr) 318 | rt60raw[k] = rt60 319 | # print('rt60_median:{},{}'.format(np.mean(rt60raw), np.median(rt60raw))) 320 | return rt60raw 321 | 322 | 323 | def statist_rt60(data_path, savefile_path): 324 | rir_filename = lib.util.find_files(data_path, ext=["wav"]) 325 | rir_filename_list = [] 326 | rt60_filename_list = [] 327 | total_clips = len(rir_filename) 328 | 329 | for rirpath in rir_filename: 330 | rir_filename_list.append(rirpath) 331 | rt60 = cal_rt60(rirpath) 332 | rt60_filename_list.append(np.median(rt60)) 333 | 334 | target_file_name = 'Unit_test_{}_logs'.format(os.path.basename(data_path)) 335 | dir_name = os.path.join(savefile_path, target_file_name) 336 | if not os.path.exists(dir_name): 337 | os.makedirs(dir_name) 338 | if not os.path.exists(dir_name): 339 | dir_name = os.path.join(savefile_path, target_file_name) 340 | os.makedirs(dir_name) 341 | write_log_file(dir_name, 'Unit_test_rt60_results.csv', 342 | [rir_filename_list, rt60_filename_list]) 343 | 344 | 345 | def postfiltering(mask, indata=None, tao=0.02): 346 | iam_sin = mask * np.sin(np.pi * mask / 2) 347 | iam_pf = (1 + tao) * mask / (1 + tao * mask**2 / (iam_sin**2)) 348 | 349 | return iam_pf 350 | 351 | 352 | def envelope_postfiltering(unproc, mask, tao=0.02): 353 | """ 354 | perceptually-motivated 355 | Note: only for irm, iam cannot work 356 | """ 357 | g_hat_b_w = mask * np.sin(np.pi * 0.5 * mask) 358 | e0 = mask * unproc 359 | e1 = g_hat_b_w * unproc 360 | tmp = e0 / (e1 + np.finfo(float).eps) 361 | g = np.sqrt((1 + tao) * tmp / (1 + tao * tmp**2)) 362 | return g * g_hat_b_w 363 | 364 | 365 | class PreProcess: 366 | def __init__(self, 367 | win_len, 368 | win_inc, 369 | fft_len, 370 | win_type, 371 | post_process_mode, 372 | loss_mode, 373 | use_cuda=False): 374 | self.win_len = win_len 375 | self.win_inc = win_inc 376 | self.fft_len = fft_len 377 | self.win_type = win_type 378 | self.post_process_mode = post_process_mode 379 | self.loss_mode = loss_mode 380 | self.use_cuda = use_cuda 381 | 382 | if win_type == "hanning": 383 | self.window = torch.hann_window(self.fft_len) 384 | else: 385 | raise ValueError("ERROR window type") 386 | if use_cuda: 387 | self.window = self.window.cuda() 388 | 389 | def pre_stft(self, inputs): 390 | stft_inputs = torch.stft(inputs, 391 | n_fft=self.fft_len, 392 | hop_length=self.win_inc, 393 | win_length=self.win_len, 394 | window=self.window, 395 | center=True, 396 | pad_mode="constant") 397 | stft_inputs = stft_inputs.transpose(1, 3).contiguous() 398 | real = stft_inputs[:, 0, :, :] 399 | imag = stft_inputs[:, 1, :, :] 400 | spec_mags = torch.sqrt(real**2 + imag**2 + 1e-8) 401 | spec_phase = torch.atan2(imag, real) 402 | 403 | real = torch.unsqueeze(real, dim=1) 404 | imag = torch.unsqueeze(imag, dim=1) 405 | spec_mags = torch.unsqueeze(spec_mags, dim=1) 406 | spec_phase = torch.unsqueeze(spec_phase, dim=1) 407 | 408 | self.real = real 409 | self.imag = imag 410 | self.spec_mags = spec_mags 411 | self.spec_phase = spec_phase 412 | return stft_inputs, real, imag, spec_mags, spec_phase 413 | 414 | def log_transform(self): 415 | self.spec_mags = torch.log(self.spec_mags) 416 | 417 | def masking(self, mask_real, mask_imag=None): 418 | if self.post_process_mode == "mag_mapping": 419 | out_real = mask_real * self.real 420 | out_imag = mask_real * self.imag 421 | elif self.post_process_mode == "complex_mapping": 422 | out_real = mask_real * self.real 423 | out_imag = mask_imag * self.imag 424 | elif self.post_process_mode == "mapping": 425 | out_real = mask_real 426 | out_imag = mask_imag 427 | else: 428 | NotImplementedError 429 | 430 | out_real = out_real.squeeze(1) 431 | out_imag = out_imag.squeeze(1) 432 | out_spec = torch.stack([out_real, out_imag], dim=-1).contiguous() 433 | return out_spec 434 | 435 | def refsig_process(self, indatas): 436 | if self.loss_mode == "freq": 437 | out, _, _, _, _ = self.pre_stft(indatas) 438 | 439 | elif self.loss_mode == "time": 440 | out = indatas 441 | return out 442 | 443 | def reconstruction(self, stft_outputs, sig_len=None): 444 | if not isinstance(stft_outputs, torch.Tensor): 445 | stft_outputs = torch.from_numpy(stft_outputs).type( 446 | torch.FloatTensor) 447 | 448 | estimated_audio = torch.istft(stft_outputs, 449 | n_fft=self.fft_len, 450 | hop_length=self.win_inc, 451 | win_length=self.win_len, 452 | window=self.window, 453 | center=True, 454 | length=sig_len) 455 | return estimated_audio 456 | 457 | 458 | def test_torch_activate_rms(): 459 | x = torch.randn(3, 64) 460 | y = torch_active_rms(audio=x, frame_length=1, thr=0) 461 | colortool = colorful 462 | colortool.use_style("solarized") 463 | print(colortool.red(f'sc: {y.shape}')) 464 | 465 | 466 | if __name__ == "__main__": 467 | test_torch_activate_rms() 468 | import matplotlib.pyplot as plt 469 | import librosa.display 470 | 471 | wav_path_dir = "/Users/audio_source/GaGNet/First_DNS_no_reverb/" 472 | clean_wav_folder_name = "no_reverb_clean" 473 | mix_wav_folder_name = "no_reverb_mix" 474 | 475 | mix_wav_path_name = os.path.join(wav_path_dir, mix_wav_folder_name) 476 | mix_dataset_path = Path(mix_wav_path_name).expanduser().absolute() 477 | mix_all_lists = lib.util.find_files(mix_dataset_path.as_posix(), 478 | ext=['wav'], 479 | limit=1) 480 | # getting clean signle location 481 | mix_name = os.path.basename(mix_all_lists[0]) 482 | mix_name_split = mix_name.split('_') 483 | clean_name = 'clean_' + mix_name_split[-2] + '_' + mix_name_split[-1] 484 | clean_wav_path_name = os.path.join(wav_path_dir, clean_wav_folder_name, 485 | clean_name) 486 | 487 | mix_sig, _ = lib.load(mix_all_lists[0], sr=16000) 488 | clean_sig, _ = lib.load(clean_wav_path_name, sr=16000) 489 | noise_sig = mix_sig - clean_sig 490 | 491 | # test vad_simplify 492 | vad_simplify(mix_sig, win_len=512, hop_len=128) 493 | 494 | mix_sig_fd = lib.stft(mix_sig, n_fft=256, hop_length=160, win_length=256) 495 | mix_sig_mag_fd, mix_sig_phase_fd = lib.magphase(mix_sig_fd) 496 | clean_sig_fd = lib.stft(clean_sig, 497 | n_fft=256, 498 | hop_length=160, 499 | win_length=256) 500 | clean_sig_mag_fd, clean_sig_phase_fd = lib.magphase(clean_sig_fd) 501 | 502 | noise_sig_fd = lib.stft(noise_sig, 503 | n_fft=256, 504 | hop_length=160, 505 | win_length=256) 506 | noise_sig_mag_fd, noise_sig_phase_fd = lib.magphase(noise_sig_fd) 507 | 508 | iam = clean_sig_mag_fd / mix_sig_mag_fd 509 | irm = clean_sig_mag_fd / (clean_sig_mag_fd + noise_sig_mag_fd) 510 | iam_filter_sig = iam * mix_sig_fd 511 | iam_filter_sig = lib.istft(iam_filter_sig, 512 | hop_length=160, 513 | win_length=256, 514 | length=len(mix_sig)) 515 | ks_filter_sig = postfiltering(mask=iam) * mix_sig_fd 516 | ks_filter_sig = lib.istft(ks_filter_sig, 517 | hop_length=160, 518 | win_length=256, 519 | length=len(mix_sig)) 520 | g = envelope_postfiltering(unproc=mix_sig_mag_fd, mask=irm) 521 | amazon_filter_sig = g * mix_sig_fd 522 | amazon_filter_sig = lib.istft(amazon_filter_sig, 523 | hop_length=160, 524 | win_length=256, 525 | length=len(mix_sig)) 526 | 527 | sf.write('/Users/audio_source/result/iam_out.wav', 528 | iam_filter_sig, 529 | samplerate=16000) 530 | sf.write('/Users/audio_source/result/ks_filter_sig_out.wav', 531 | ks_filter_sig, 532 | samplerate=16000) 533 | sf.write('/Users/audio_source/result/amazon_filter_sig.wav', 534 | amazon_filter_sig, 535 | samplerate=16000) 536 | 537 | plt.figure(1) 538 | 539 | librosa.display.specshow(lib.amplitude_to_db(clean_sig_fd, ref=np.max), 540 | fmax=8000, 541 | y_axis='linear', 542 | x_axis='time') 543 | plt.title('clean signal') 544 | plt.colorbar(format='%+2.0f dB') 545 | plt.tight_layout() 546 | # plt.show() 547 | 548 | plt.figure(2) 549 | 550 | librosa.display.specshow(lib.amplitude_to_db(mix_sig_fd, ref=np.max), 551 | fmax=8000, 552 | y_axis='linear', 553 | x_axis='time') 554 | plt.title('mix signal') 555 | plt.colorbar(format='%+2.0f dB') 556 | plt.tight_layout() 557 | plt.show() 558 | 559 | indata = np.linspace(0, 1, 1001) 560 | beta = 0.02 561 | y = np.sqrt((1 + beta) * indata / (1 + beta * indata**2)) 562 | y = indata * np.sin(np.pi * 0.5 * indata) 563 | plt.figure() 564 | plt.plot(indata, y, color='r') 565 | plt.show() 566 | print('sc') 567 | -------------------------------------------------------------------------------- /utils/utils_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from .logger import log 9 | 10 | 11 | def touch_dir(d): 12 | os.makedirs(d, exist_ok=True) 13 | 14 | 15 | def is_file_exists(f): 16 | return os.path.exists(f) 17 | 18 | 19 | def check_file_exists(f): 20 | if not os.path.exists(f): 21 | log(f"not found file: {f}") 22 | assert False, f"not found file: {f}" 23 | 24 | 25 | def read_lines(data_path): 26 | lines = [] 27 | with open(data_path, encoding="utf-8") as fr: 28 | for line in fr.readlines(): 29 | if len(line.strip().replace(" ", "")): 30 | lines.append(line.strip()) 31 | # log("read {} lines from {}".format(len(lines), data_path)) 32 | # log("example(last) {}\n".format(lines[-1])) 33 | return lines 34 | 35 | 36 | def write_lines(data_path, lines): 37 | with open(data_path, "w", encoding="utf-8") as fw: 38 | for line in lines: 39 | fw.write("{}\n".format(line)) 40 | # log("write {} lines to {}".format(len(lines), data_path)) 41 | # log("example(last line): {}\n".format(lines[-1])) 42 | return 43 | 44 | 45 | def get_name_from_path(abs_path): 46 | return ".".join(os.path.basename(abs_path).split(".")[:-1]) 47 | 48 | 49 | class AttrDict(dict): 50 | def __init__(self, *args, **kwargs): 51 | super(AttrDict, self).__init__(*args, **kwargs) 52 | self.__dict__ = self 53 | return 54 | 55 | 56 | def load_hparams(yaml_path): 57 | with open(yaml_path, encoding="utf-8") as yaml_file: 58 | hparams = yaml.safe_load(yaml_file) 59 | return AttrDict(hparams) 60 | 61 | 62 | def dump_hparams(yaml_path, hparams): 63 | touch_dir(os.path.dirname(yaml_path)) 64 | with open(yaml_path, "w") as fw: 65 | yaml.dump(hparams, fw) 66 | log("save hparams to {}".format(yaml_path)) 67 | return 68 | 69 | 70 | def get_all_wav_path(file_dir): 71 | wav_list = [] 72 | for path, dir_list, file_list in os.walk(file_dir): 73 | for file_name in file_list: 74 | if file_name.endswith(".wav") or file_name.endswith(".WAV"): 75 | wav_path = os.path.join(path, file_name) 76 | wav_list.append(wav_path) 77 | return sorted(wav_list) 78 | 79 | 80 | def clean_and_new_dir(data_dir): 81 | if os.path.exists(data_dir): 82 | shutil.rmtree(data_dir) 83 | os.makedirs(data_dir) 84 | return 85 | 86 | 87 | def generate_dir_tree(synth_dir, dir_name_list, del_old=False): 88 | os.makedirs(synth_dir, exist_ok=True) 89 | dir_path_list = [] 90 | if del_old: 91 | shutil.rmtree(synth_dir, ignore_errors=True) 92 | for name in dir_name_list: 93 | dir_path = os.path.join(synth_dir, name) 94 | dir_path_list.append(dir_path) 95 | os.makedirs(dir_path, exist_ok=True) 96 | return dir_path_list 97 | 98 | 99 | def str2bool(v): 100 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 101 | return True 102 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 103 | return False 104 | else: 105 | raise argparse.ArgumentTypeError('Boolean value expected.') 106 | 107 | 108 | def pad(input_ele, mel_max_length=None): 109 | if mel_max_length: 110 | max_len = mel_max_length 111 | else: 112 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 113 | 114 | out_list = list() 115 | for i, batch in enumerate(input_ele): 116 | if len(batch.shape) == 1: 117 | one_batch_padded = F.pad( 118 | batch, (0, max_len - batch.size(0)), "constant", 0.0 119 | ) 120 | elif len(batch.shape) == 2: 121 | one_batch_padded = F.pad( 122 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 123 | ) 124 | out_list.append(one_batch_padded) 125 | out_padded = torch.stack(out_list) 126 | return out_padded 127 | 128 | 129 | def pad_1D(inputs, PAD=0): 130 | def pad_data(x, length, PAD): 131 | x_padded = np.pad( 132 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 133 | ) 134 | return x_padded 135 | 136 | max_len = max((len(x) for x in inputs)) 137 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 138 | 139 | return padded 140 | 141 | 142 | def pad_2D(inputs, maxlen=None): 143 | def pad(x, max_len): 144 | PAD = 0 145 | if np.shape(x)[0] > max_len: 146 | raise ValueError("not max_len") 147 | 148 | s = np.shape(x)[1] 149 | x_padded = np.pad( 150 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 151 | ) 152 | return x_padded[:, :s] 153 | 154 | if maxlen: 155 | output = np.stack([pad(x, maxlen) for x in inputs]) 156 | else: 157 | max_len = max(np.shape(x)[0] for x in inputs) 158 | output = np.stack([pad(x, max_len) for x in inputs]) 159 | 160 | return output 161 | 162 | 163 | def get_mask_from_lengths(lengths, max_len=None): 164 | batch_size = lengths.shape[0] 165 | if max_len is None: 166 | max_len = torch.max(lengths).item() 167 | 168 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) 169 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 170 | 171 | return mask 172 | 173 | 174 | if __name__ == '__main__': 175 | pass --------------------------------------------------------------------------------