├── LICENSE ├── README.md ├── __pycache__ └── solver.cpython-37.pyc ├── configs └── train_config.toml ├── dataset └── data.py ├── main.py ├── nets └── g2net.py ├── solver.py └── utils ├── loss.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AndongLi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # G2Net 2 | The implementation of G2Net, which is the extension of GaGNet and is in submission to T-ASLP. 3 | -------------------------------------------------------------------------------- /__pycache__/solver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/G2Net/8c0de9c0d834ade59b984c0328bd8474125bb5c4/__pycache__/solver.cpython-37.pyc -------------------------------------------------------------------------------- /configs/train_config.toml: -------------------------------------------------------------------------------- 1 | # for loading and saving paths 2 | [path] 3 | data_type = "DNS-Challenge_3000h" 4 | is_checkpoint = true 5 | is_resume_reload = false 6 | checkpoint_load_path = "CheckpointPath" 7 | checkpoint_load_filename = "" 8 | loss_save_path = "Loss" 9 | model_best_path = "BestModel" 10 | logging_path = "Logger" 11 | 12 | 13 | [path.train] 14 | mix_file_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/train/mix" 15 | target_file_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/train/clean" 16 | 17 | [path.val] 18 | mix_file_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/dev/mix" 19 | target_file_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/dev/clean" 20 | 21 | 22 | [gpu] 23 | gpu_ids = [1] 24 | # signal settings before sending into the network 25 | [signal] 26 | sr = 16000 27 | is_chunk = true 28 | chunk_length = 8.0 29 | win_size = 0.02 30 | win_shift = 0.01 31 | fft_num = 320 32 | is_variance_norm = true 33 | is_compress = true 34 | 35 | 36 | # choosed loss function 37 | [loss_function] 38 | path = "utils.loss" 39 | prev_weight = 0.1 40 | curr_weight = 1.0 41 | alpha = 0.5 42 | l_type = "L2" 43 | [loss_function.stagewise] 44 | classname = "StagewiseComMagEuclideanLoss" 45 | 46 | 47 | # choosed optimizer 48 | [optimizer] 49 | name = "adam" 50 | lr = 2e-4 51 | beta1 = 0.9 52 | beta2 = 0.999 53 | l2 = 1e-7 54 | gradient_norm = 5.0 55 | epochs = 60 56 | halve_lr = true 57 | early_stop = true 58 | halve_freq = 2 59 | early_stop_freq = 3 60 | print_freq = 200 61 | metric_options = ["SISNR"] # only one metric is supported in the current version, chioces: [NB-PESQ, ESTOI] 62 | 63 | # reproducibility settings 64 | [reproducibility] 65 | seed = 1234 66 | 67 | # Dataset 68 | [dataset] 69 | [dataset.train] 70 | json_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/Json/train" 71 | batch_size = 24 72 | is_shuffle = true 73 | 74 | [dataset.val] 75 | json_path = "/media/liandong/CTS_TASLP_for_dns_3000h_dataset/Json/dev" 76 | batch_size = 24 77 | is_shuffle = true 78 | 79 | [dataloader] 80 | [dataloader.train] 81 | num_workers = 6 82 | pin_memory = true 83 | drop_last = false 84 | shuffle = false 85 | 86 | [dataloader.val] 87 | num_workers = 6 88 | pin_memory = true 89 | drop_last = false 90 | shuffle = false 91 | 92 | # network configs 93 | [net] 94 | choice="G2Net" 95 | path = "nets.g2Net" 96 | classname = "G2Net" 97 | 98 | [net.G2Net.args] 99 | k1 = [2,3] 100 | k2 = [1,3] 101 | c = 64 102 | intra_connect = "cat" 103 | d_feat = 256 104 | kd1 = 3 105 | cd1 = 64 106 | tcn_num = 2 107 | dilas = [1,2,5,9] 108 | fft_num = 320 109 | is_causal = true 110 | acti_type = "sigmoid" 111 | crm_type = "crm1" 112 | stage_num = 3 113 | u_type = "u2" 114 | head_type = "RI+MAG" 115 | norm_type = "IN" 116 | 117 | 118 | [save] 119 | loss_filename = "DNS-Challenge_3000h_g2net_IN_causal_loss.mat" 120 | best_model_filename = "DNS-Challenge_3000h_g2net_IN_causal_model.pth" 121 | checkpoint_filename = "DNS-Challenge_3000h_g2net_IN_causal_model.pth.tar" 122 | logger_filename = "DNS-Challenge_3000h_gagnet_IN_causal.txt" 123 | #tensorboard_filename = "librispeech_taylorbeamformer_mic_linear_mid_target_timvdr_order0_param_nonshared_bf_embedding64_hidnode_64_u2_risqueezed_norm2d_BN_norm1d_BN_causal" 124 | 125 | -------------------------------------------------------------------------------- /dataset/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from random import shuffle 5 | from torch.utils.data import Dataset, DataLoader 6 | import random 7 | import soundfile as sf 8 | import librosa as lib 9 | from utils.utils import BatchInfo, pad_to_longest, logger_print, ToTensor 10 | 11 | class InstanceDataset(Dataset): 12 | def __init__(self, 13 | mix_file_path, 14 | target_file_path, 15 | mix_json_path, 16 | batch_size, 17 | is_shuffle, 18 | is_variance_norm, 19 | is_chunk, 20 | chunk_length, 21 | sr, 22 | ): 23 | super(InstanceDataset, self).__init__() 24 | self.mix_file_path = mix_file_path 25 | self.target_file_path = target_file_path 26 | self.mix_json_path = mix_json_path 27 | self.batch_size = batch_size 28 | self.is_shuffle = is_shuffle 29 | self.is_variance_norm = is_variance_norm 30 | self.is_chunk = is_chunk 31 | self.chunk_length = chunk_length 32 | self.sr = sr 33 | 34 | with open(mix_json_path, "r") as f: 35 | mix_json_list = json.load(f) 36 | # sort 37 | mix_json_list.sort() 38 | if is_shuffle: 39 | random.seed(1234) # fixed for reproducibility 40 | shuffle(mix_json_list) 41 | # the first type 42 | # mix_json_list, target_json_list = zip(*zipped_list) # mix_json_list and target_json_list are tuple type 43 | 44 | mix_minibatch = [] 45 | start = 0 46 | while True: 47 | end = min(len(mix_json_list), start+batch_size) 48 | mix_minibatch.append(mix_json_list[start:end]) 49 | start = end 50 | if end == len(mix_json_list): 51 | break 52 | self.mix_minibatch = mix_minibatch 53 | self.length = len(mix_minibatch) 54 | 55 | def __len__(self): 56 | return self.length 57 | 58 | def __getitem__(self, index): 59 | mix_minibatch_list = self.mix_minibatch[index] 60 | mix_wav_list, target_wav_list, wav_len_list = [], [], [] 61 | to_tensor = ToTensor() 62 | for id in range(len(mix_minibatch_list)): 63 | mix_filename = mix_minibatch_list[id] 64 | file_number = mix_filename.split("_")[-1] 65 | target_filename = f"clean_fileid_{file_number}" 66 | # mix_filename = mix_minibatch_list[id] 67 | # target_filename = mix_filename.split("_")[0] 68 | # read speech 69 | mix_wav, mix_sr = sf.read(os.path.join(self.mix_file_path, f"{mix_filename}.wav")) # (L,) 70 | target_wav, tar_sr = sf.read(os.path.join(self.target_file_path, f"{target_filename}.wav")) # (L,) 71 | if mix_sr != self.sr or tar_sr != self.sr: 72 | mix_wav, target_wav = lib.resample(mix_wav, mix_sr, self.sr), \ 73 | lib.resample(target_wav, tar_sr, self.sr) 74 | if self.is_variance_norm: 75 | c = np.sqrt(len(mix_wav) / np.sum(mix_wav ** 2.0)) 76 | mix_wav, target_wav = mix_wav*c, target_wav*c 77 | if self.is_chunk and (len(mix_wav) > int(self.sr*self.chunk_length)): 78 | wav_start = random.randint(0, len(mix_wav)-int(self.sr*self.chunk_length)) 79 | mix_wav = mix_wav[wav_start:wav_start+int(self.sr*self.chunk_length)] 80 | target_wav = target_wav[wav_start:wav_start+int(self.sr*self.chunk_length)] 81 | mix_wav_list.append(to_tensor(mix_wav)) 82 | target_wav_list.append(to_tensor(target_wav)) 83 | wav_len_list.append(len(mix_wav)) 84 | return mix_wav_list, target_wav_list, wav_len_list 85 | 86 | @staticmethod 87 | def check_align(mix_list, target_list): 88 | logger_print("checking.................") 89 | is_ok = 1 90 | mix_error_list, target_error_list = [], [] 91 | for i in range(len(mix_list)): 92 | extracted_filename_from_mix = "_".join(mix_list[i].split("_")[:-1]) 93 | extracted_filename_from_target = "_".join(target_list[i].split("_")[:-1]) 94 | if extracted_filename_from_mix != extracted_filename_from_target: 95 | is_ok = 0 96 | mix_error_list.append(extracted_filename_from_mix) 97 | target_error_list.append(extracted_filename_from_target) 98 | if is_ok == 0: 99 | for i in range(min(len(mix_error_list), len(target_error_list))): 100 | print("mix_file_name:{}, target_file_name:{}".format(mix_error_list[i], 101 | target_error_list[i])) 102 | raise Exception("Datasets between mix and target are not aligned!") 103 | else: 104 | logger_print("checking finished..............") 105 | 106 | 107 | class InstanceDataloader(object): 108 | def __init__(self, 109 | data_set, 110 | num_workers, 111 | pin_memory, 112 | drop_last, 113 | shuffle, 114 | ): 115 | self.data_set = data_set 116 | self.num_workers = num_workers 117 | self.pin_memory = pin_memory 118 | self.drop_last = drop_last 119 | self.shuffle = shuffle 120 | 121 | self.data_loader = DataLoader(dataset=data_set, 122 | num_workers=num_workers, 123 | pin_memory=pin_memory, 124 | drop_last=drop_last, 125 | shuffle=shuffle, 126 | collate_fn=self.collate_fn, 127 | batch_size=1 128 | ) 129 | @staticmethod 130 | def collate_fn(batch): 131 | feats, labels, frame_mask_list = pad_to_longest(batch) 132 | return BatchInfo(feats, labels, frame_mask_list) 133 | 134 | def get_data_loader(self): 135 | return self.data_loader 136 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import importlib 4 | import logging 5 | import numpy as np 6 | import random 7 | import time 8 | import argparse 9 | from dataset.data import InstanceDataset, InstanceDataloader 10 | from solver import Solver 11 | import warnings 12 | import toml 13 | from utils.utils import json_extraction, numParams, logger_print 14 | warnings.filterwarnings('ignore') 15 | 16 | # fix random seed 17 | def setup_seed(seed): 18 | """ 19 | set up random seed 20 | :param seed: 21 | :return: 22 | """ 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | #torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def main(config): 31 | # define seeds 32 | setup_seed(config["reproducibility"]["seed"]) 33 | 34 | # set logger 35 | if not os.path.exists(config["path"]["logging_path"]): 36 | os.makedirs(config["path"]["logging_path"]) 37 | logging.basicConfig(filename=config["path"]["logging_path"] + "/" + config["save"]["logger_filename"], 38 | filemode='w', 39 | level=logging.INFO, 40 | format="%(message)s") 41 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 42 | logger_print(f"start logging time:\t{start_time}") 43 | 44 | # set gpus 45 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu_id) for gpu_id in config["gpu"]["gpu_ids"]]) 46 | logger_print(f"gpus: {os.environ['CUDA_VISIBLE_DEVICES']}") 47 | 48 | # set network 49 | net_choice = config["net"]["choice"] 50 | module = importlib.import_module(config["net"]["path"]) 51 | net_args = config["net"][net_choice]["args"] 52 | net = getattr(module, config["net"]["classname"])(**net_args) 53 | logger_print(f"The number of trainable parameters: {numParams(net)}") 54 | 55 | # paths generation 56 | if not os.path.exists(config["path"]["checkpoint_load_path"]): 57 | os.makedirs(config["path"]["checkpoint_load_path"]) 58 | if not os.path.exists(config["path"]["loss_save_path"]): 59 | os.makedirs(config["path"]["loss_save_path"]) 60 | if not os.path.exists(config["path"]["model_best_path"]): 61 | os.makedirs(config["path"]["model_best_path"]) 62 | 63 | # save filename 64 | save_name_dict = {} 65 | save_name_dict["loss_filename"] = config["save"]["loss_filename"] 66 | save_name_dict["best_model_filename"] = config["save"]["best_model_filename"] 67 | save_name_dict["checkpoint_filename"] = config["save"]["checkpoint_filename"] 68 | 69 | # determine file json 70 | train_mix_json = json_extraction(config["path"]["train"]["mix_file_path"], config["dataset"]["train"]["json_path"], "mix") 71 | val_mix_json = json_extraction(config["path"]["val"]["mix_file_path"], config["dataset"]["val"]["json_path"], "mix") 72 | 73 | # define train/validation 74 | train_dataset = InstanceDataset(mix_file_path=config["path"]["train"]["mix_file_path"], 75 | target_file_path=config["path"]["train"]["target_file_path"], 76 | mix_json_path=train_mix_json, 77 | is_variance_norm=config["signal"]["is_variance_norm"], 78 | is_chunk=config["signal"]["is_chunk"], 79 | chunk_length=config["signal"]["chunk_length"], 80 | sr=config["signal"]["sr"], 81 | batch_size=config["dataset"]["train"]["batch_size"], 82 | is_shuffle=config["dataset"]["train"]["is_shuffle"]) 83 | val_dataset = InstanceDataset(mix_file_path=config["path"]["val"]["mix_file_path"], 84 | target_file_path=config["path"]["val"]["target_file_path"], 85 | mix_json_path=val_mix_json, 86 | is_variance_norm=config["signal"]["is_variance_norm"], 87 | is_chunk=config["signal"]["is_chunk"], 88 | chunk_length=config["signal"]["chunk_length"], 89 | sr=config["signal"]["sr"], 90 | batch_size=config["dataset"]["val"]["batch_size"], 91 | is_shuffle=config["dataset"]["val"]["is_shuffle"]) 92 | train_dataloader = InstanceDataloader(train_dataset, 93 | **config["dataloader"]["train"]) 94 | val_dataloader = InstanceDataloader(val_dataset, 95 | **config["dataloader"]["val"]) 96 | 97 | # define optimizer 98 | if config["optimizer"]["name"] == "adam": 99 | optimizer = torch.optim.Adam( 100 | net.parameters(), 101 | lr=config["optimizer"]["lr"], 102 | betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]), 103 | weight_decay=config["optimizer"]["l2"]) 104 | 105 | data = {'train_loader': train_dataloader, 'val_loader': val_dataloader} 106 | solver = Solver(data, net, optimizer, save_name_dict, config) 107 | solver.train() 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("--C", "--config", type=str, required=False, default="configs/train_config.toml", 112 | help="toml format") 113 | args = parser.parse_args() 114 | config = toml.load(args.C) 115 | print(config) 116 | main(config) 117 | -------------------------------------------------------------------------------- /nets/g2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.autograd import Variable 5 | from utils.utils import NormSwitch 6 | 7 | class G2Net(nn.Module): 8 | def __init__(self, 9 | k1: list = [2, 3], 10 | k2: list = [1, 3], 11 | c: int = 64, 12 | intra_connect: str = "cat", 13 | d_feat: int = 256, 14 | kd1: int = 3, 15 | cd1: int = 64, 16 | tcn_num: int = 2, 17 | dilas: list = [1, 2, 5, 9], 18 | fft_num: int = 320, 19 | is_causal: bool = True, 20 | acti_type: str = "sigmoid", 21 | crm_type: str = "crm1", 22 | stage_num: int = 3, 23 | u_type: str = "u2", 24 | head_type: str = "RI+MAG", 25 | norm_type: str = "IN", # switch to cLN leads to mild performance degradation but is still Ok. BN is the worst among the listed norm options. 26 | ): 27 | super(G2Net, self).__init__() 28 | self.k1 = tuple(k1) 29 | self.k2 = tuple(k2) 30 | self.c = c 31 | self.intra_connect = intra_connect 32 | self.d_feat = d_feat 33 | self.kd1 = kd1 34 | self.cd1 = cd1 35 | self.tcn_num = tcn_num 36 | self.dilas = dilas 37 | self.fft_num = fft_num 38 | self.is_causal = is_causal 39 | self.acti_type = acti_type 40 | self.crm_type = crm_type 41 | self.stage_num = stage_num 42 | self.u_type = u_type 43 | self.head_type = head_type 44 | self.norm_type = norm_type 45 | stride = (1, 2) 46 | # components 47 | if u_type == "u2": 48 | if head_type == "RI": 49 | self.ri_en = U2Net_Encoder(2, self.k1, self.k2, stride, c, intra_connect, norm_type) 50 | elif head_type == "MAG": 51 | self.mag_en = U2Net_Encoder(1, self.k1, self.k2, stride, c, intra_connect, norm_type) 52 | elif head_type == "RI+MAG": 53 | self.ri_en = U2Net_Encoder(2, self.k1, self.k2, stride, c, intra_connect, norm_type) 54 | self.mag_en = U2Net_Encoder(1, self.k1, self.k2, stride, c, intra_connect, norm_type) 55 | elif head_type == "PHASE+MAG": 56 | self.phase_en = U2Net_Encoder(1, self.k1, self.k2, stride, c, intra_connect, norm_type) 57 | self.mag_en = U2Net_Encoder(1, self.k1, self.k2, stride, c, intra_connect, norm_type) 58 | elif u_type == "u": 59 | if head_type == "RI": 60 | self.ri_en = UNet_Encoder(2, self.k1, stride, c, norm_type) 61 | elif head_type == "MAG": 62 | self.mag_en = UNet_Encoder(1, self.k1, stride, c, norm_type) 63 | elif head_type == "RI+MAG": 64 | self.ri_en = UNet_Encoder(2, self.k1, stride, c, norm_type) 65 | self.mag_en = UNet_Encoder(1, self.k1, stride, c, norm_type) 66 | elif head_type == "PHASE+MAG": 67 | self.phase_en = UNet_Encoder(1, self.k1, stride, c, norm_type) 68 | self.mag_rn = UNet_Encoder(1, self.k1, stride, c, norm_type) 69 | 70 | ggm_block_list = [] 71 | for i in range(stage_num): 72 | ggm_block_list.append(GGModule(d_feat, 73 | kd1, 74 | cd1, 75 | tcn_num, 76 | dilas, 77 | fft_num, 78 | is_causal, 79 | acti_type, 80 | crm_type, 81 | head_type, 82 | norm_type, 83 | )) 84 | self.ggms = nn.ModuleList(ggm_block_list) 85 | 86 | def forward(self, inpt): 87 | """ 88 | :param inpt: (B,2,T,F) 89 | :return: list, [est1, est2, ..., estQ], estq: (B,2,F,T) 90 | """ 91 | inpt_mag = torch.norm(inpt, dim=1, keepdim=True) 92 | inpt_phase = torch.atan2(inpt[:,-1,...], inpt[:,0,...]).unsqueeze(dim=1) 93 | if self.head_type == "MAG": 94 | x = self.mag_en(inpt_mag) 95 | b_size, c, seq_len, _ = x.shape 96 | x = x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 97 | elif self.head_type == "RI": 98 | x = self.ri_en(inpt) 99 | b_size, c, seq_len, _ = x.shape 100 | x = x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 101 | elif self.head_type == "RI+MAG": 102 | mag_x = self.mag_en(inpt_mag) 103 | ri_x = self.ri_en(inpt) 104 | b_size, c, seq_len, _ = mag_x.shape 105 | mag_x = mag_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 106 | ri_x = ri_x.transpose(-2,-1).contiguous().view(b_size, -1, seq_len) 107 | x = torch.cat((ri_x, mag_x), dim=1) 108 | elif self.head_type == "PHASE+MAG": 109 | phase_x = self.phase_en(inpt_phase) 110 | mag_x = self.mag_en(inpt_mag) 111 | b_size, c, seq_len, _ = mag_x.shape 112 | phase_x = phase_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 113 | mag_x = mag_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 114 | x = torch.cat((phase_x, mag_x), dim=1) 115 | out_list = [] 116 | pre_x = inpt.transpose(-2, -1).contiguous() 117 | for i in range(self.stage_num): 118 | tmp = self.ggms[i](x, pre_x) 119 | pre_x = tmp 120 | out_list.append(pre_x) 121 | return out_list 122 | 123 | 124 | class GGModule(nn.Module): 125 | def __init__(self, 126 | d_feat: int, 127 | kd1: int, 128 | cd1: int, 129 | tcn_num: int, 130 | dilas: list, 131 | fft_num: int, 132 | is_causal: bool, 133 | acti_type: str, 134 | crm_type: str, 135 | head_type: str, 136 | norm_type: str, 137 | ): 138 | super(GGModule, self).__init__() 139 | self.d_feat = d_feat 140 | self.kd1 = kd1 141 | self.cd1 = cd1 142 | self.tcn_num = tcn_num 143 | self.dilas = dilas 144 | self.fft_num = fft_num 145 | self.is_causal = is_causal 146 | self.acti_type = acti_type 147 | self.crm_type = crm_type 148 | self.head_type = head_type 149 | self.norm_type = norm_type 150 | 151 | # Components 152 | self.glance_branch = GlanceBranch(d_feat,kd1,cd1,tcn_num,dilas,fft_num,is_causal,acti_type,head_type,norm_type) 153 | self.gaze_branch = GazeBranch(d_feat,kd1,cd1,tcn_num,dilas,fft_num,is_causal,head_type,norm_type) 154 | 155 | def forward(self, x, pre_x): 156 | """ 157 | :param x: (B,C1,T) 158 | :param pre_x: (B,2,C,T) 159 | :return: (B,2,C,T) 160 | """ 161 | # pre_x: (B, 2, C, T) 162 | batch_num, _, c, seq_len = pre_x.size() 163 | pre_mag, pre_phase = torch.norm(pre_x, dim=1), torch.atan2(pre_x[:,-1,...], pre_x[:,0,...]) 164 | pre_com = pre_x.view(batch_num, -1, seq_len) 165 | 166 | gain_filter = self.glance_branch(x, pre_mag) # (B, C, T) 167 | com_resi = self.gaze_branch(x, pre_com) # (B, 2, C, T) 168 | x_mag = pre_mag * gain_filter 169 | if self.crm_type == "crm1": # crm1 yields better performance 170 | x_r, x_i = x_mag*torch.cos(pre_phase), x_mag*torch.sin(pre_phase) 171 | x = torch.stack((x_r, x_i), 1) + com_resi 172 | elif self.crm_type == "crm2": 173 | resi_phase = torch.atan2(com_resi[:,-1,...], com_resi[:,0,...]) 174 | resi_mag = torch.norm(com_resi, dim=1) 175 | x_mag = x_mag + resi_mag 176 | x_phase = pre_phase + resi_phase 177 | x_r, x_i = x_mag * torch.cos(x_phase), x_mag * torch.sin(x_phase) 178 | x = torch.stack((x_r, x_i), 1) 179 | return x 180 | 181 | 182 | class GlanceBranch(nn.Module): 183 | def __init__(self, 184 | d_feat: int, 185 | kd1: int, 186 | cd1: int, 187 | tcn_num: int, 188 | dilas: list, 189 | fft_num: int, 190 | is_causal: bool, 191 | acti_type: str, 192 | head_type: str, 193 | norm_type: str, 194 | ): 195 | super(GlanceBranch, self).__init__() 196 | self.d_feat = d_feat 197 | self.kd1 = kd1 198 | self.cd1 = cd1 199 | self.tcn_num = tcn_num 200 | self.dilas = dilas 201 | self.fft_num = fft_num 202 | self.is_causal = is_causal 203 | self.acti_type = acti_type 204 | self.head_type = head_type 205 | self.norm_type = norm_type 206 | 207 | # Components 208 | if head_type == "RI" or head_type == "MAG": 209 | cin = (fft_num//2+1)+d_feat 210 | elif head_type == "RI+MAG" or head_type == "PHASE+MAG": 211 | cin = (fft_num//2+1)+d_feat*2 212 | else: 213 | raise Exception("Only RI, MAG, RI+MAG and PHASE+MAG are supported at present") 214 | self.in_conv = nn.Conv1d(cin, d_feat, 1) 215 | tcn_list = [] 216 | for _ in range(tcn_num): 217 | tcn_list.append(SqueezedTCNList(d_feat, kd1, cd1, norm_type, dilas, is_causal)) 218 | 219 | self.tcn_list = nn.ModuleList(tcn_list) 220 | self.linear_mag = nn.Conv1d(d_feat, fft_num//2+1, 1) 221 | if acti_type == "relu": 222 | self.acti = nn.ReLU() 223 | elif acti_type == "sigmoid": 224 | self.acti = nn.Sigmoid() 225 | elif acti_type == "tanh": 226 | self.acti = nn.Tanh() 227 | 228 | def forward(self, x, mag_x): 229 | """ 230 | :param x: (B, C1, T) 231 | :param mag_x: (B, C2, T) 232 | :return: (B, C2, T) 233 | """ 234 | x = torch.cat((x, mag_x), dim=1) 235 | x = self.in_conv(x) 236 | acc_x = torch.Tensor(torch.zeros(x.shape, requires_grad=True)).to(x.device) 237 | for i in range(len(self.tcn_list)): 238 | x = self.tcn_list[i](x) 239 | acc_x = acc_x + x 240 | x = self.acti(self.linear_mag(acc_x)) 241 | return x 242 | 243 | class GazeBranch(nn.Module): 244 | def __init__(self, 245 | d_feat: int, 246 | kd1: int, 247 | cd1: int, 248 | tcn_num: int, 249 | dilas: list, 250 | fft_num: int, 251 | is_causal: bool, 252 | head_type: str, 253 | norm_type: str, 254 | ): 255 | super(GazeBranch, self).__init__() 256 | self.d_feat = d_feat 257 | self.kd1 = kd1 258 | self.cd1 = cd1 259 | self.tcn_num = tcn_num 260 | self.dilas = dilas 261 | self.fft_num = fft_num 262 | self.is_causal = is_causal 263 | self.head_type = head_type 264 | self.norm_type = norm_type 265 | 266 | # Components 267 | if head_type == "RI" or head_type == "MAG": 268 | cin = (fft_num//2+1)*2+d_feat 269 | elif head_type == "RI+MAG" or head_type == "PHASE+MAG": 270 | cin = (fft_num//2+1)*2+d_feat*2 271 | else: 272 | raise Exception("Only RI, MAG, RI+MAG and PHASE+MAG are supported at present") 273 | self.in_conv_r = nn.Conv1d(cin, d_feat, 1) 274 | self.in_conv_i = nn.Conv1d(cin, d_feat, 1) 275 | tcn_list_r, tcn_list_i = [], [] 276 | for _ in range(tcn_num): 277 | tcn_list_r.append(SqueezedTCNList(d_feat, kd1, cd1, norm_type, dilas, is_causal)) 278 | tcn_list_i.append(SqueezedTCNList(d_feat, kd1, cd1, norm_type, dilas, is_causal)) 279 | 280 | self.tcn_r = nn.ModuleList(tcn_list_r) 281 | self.tcn_i = nn.ModuleList(tcn_list_i) 282 | self.linear_r, self.linear_i = nn.Linear(d_feat, fft_num//2+1), nn.Linear(d_feat, fft_num//2+1) 283 | 284 | def forward(self, x, com_x): 285 | """ 286 | x: the abstract feature from the branches, C1 = 256*2 287 | com_x: the flatten feature from the previous stage 288 | :param x: (B, C1, T) 289 | :param com_x: (B, C2, T) 290 | :return: (B,2,C,T) 291 | """ 292 | x = torch.cat((x, com_x), dim=1) 293 | x_r, x_i = self.in_conv_r(x), self.in_conv_i(x) 294 | acc_r, acc_i = torch.Tensor(torch.zeros(x_r.shape, requires_grad=True)).to(x_r.device),\ 295 | torch.Tensor(torch.zeros(x_i.shape, requires_grad=True)).to(x_i.device) 296 | for i in range(len(self.tcn_r)): 297 | x_r, x_i = self.tcn_r[i](x_r), self.tcn_i[i](x_i) 298 | acc_r = acc_r + x_r 299 | acc_i = acc_i + x_i 300 | x = torch.stack((acc_r, acc_i), dim=1).transpose(-2,-1) # (B,2,T,F) 301 | x_r, x_i = x[:,0,...], x[:,-1,...] 302 | x_r, x_i = self.linear_r(x_r).transpose(-2,-1), self.linear_i(x_i).transpose(-2,-1) 303 | return torch.stack((x_r, x_i), dim=1).contiguous() 304 | 305 | 306 | class SqueezedTCNList(nn.Module): 307 | def __init__(self, 308 | d_feat: int, 309 | kd1: int, 310 | cd1: int, 311 | norm_type: str, 312 | dilas: list = [1,2,5,9], 313 | is_causal: bool = True): 314 | super(SqueezedTCNList, self).__init__() 315 | self.d_feat = d_feat 316 | self.kd1 = kd1 317 | self.cd1 = cd1 318 | self.norm_type = norm_type 319 | self.dilas = dilas 320 | self.is_causal = is_causal 321 | self.tcm_list = nn.ModuleList([SqueezedTCM(d_feat, kd1, cd1, dilas[i], is_causal, norm_type) for i in range(len(dilas))]) 322 | 323 | def forward(self, x): 324 | for i in range(len(self.tcm_list)): 325 | x = self.tcm_list[i](x) 326 | return x 327 | 328 | class SqueezedTCM(nn.Module): 329 | def __init__(self, 330 | d_feat: int, 331 | kd1: int, 332 | cd1: int, 333 | dilation: int, 334 | is_causal: bool, 335 | norm_type: str, 336 | ): 337 | super(SqueezedTCM, self).__init__() 338 | self.d_feat = d_feat 339 | self.kd1 = kd1 340 | self.cd1 = cd1 341 | self.dilation = dilation 342 | self.is_causal = is_causal 343 | self.norm_type = norm_type 344 | if is_causal: 345 | pad = nn.ConstantPad1d(((kd1-1)*dilation, 0), value=0.) 346 | else: 347 | pad = nn.ConstantPad1d(((kd1-1)*dilation//2, (kd1-1)*dilation//2), value=0.) 348 | self.in_conv = nn.Conv1d(d_feat, cd1, kernel_size=1, bias=False) 349 | self.dd_conv_main = nn.Sequential( 350 | nn.PReLU(cd1), 351 | NormSwitch(norm_type, "1D", cd1), 352 | pad, 353 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False)) 354 | self.dd_conv_gate = nn.Sequential( 355 | nn.PReLU(cd1), 356 | NormSwitch(norm_type, "1D", cd1), 357 | pad, 358 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False), 359 | nn.Sigmoid() 360 | ) 361 | self.out_conv = nn.Sequential( 362 | nn.PReLU(cd1), 363 | NormSwitch(norm_type, "1D", cd1), 364 | nn.Conv1d(cd1, d_feat, kernel_size=1, bias=False) 365 | ) 366 | 367 | def forward(self, x): 368 | resi = x 369 | x = self.in_conv(x) 370 | x = self.dd_conv_main(x) * self.dd_conv_gate(x) 371 | x = self.out_conv(x) 372 | x = x + resi 373 | return x 374 | 375 | 376 | class U2Net_Encoder(nn.Module): 377 | def __init__(self, 378 | cin: int, 379 | k1: tuple, 380 | k2: tuple, 381 | stride: tuple, 382 | c: int, 383 | intra_connect: str, 384 | norm_type: str, 385 | ): 386 | super(U2Net_Encoder, self).__init__() 387 | self.cin = cin 388 | self.k1 = k1 389 | self.k2 = k2 390 | self.stride = stride 391 | self.c = c 392 | self.intra_connect = intra_connect 393 | self.norm_type = norm_type 394 | k_begin = (2, 5) 395 | c_end = 64 396 | 397 | meta_unet = [] 398 | meta_unet.append( 399 | En_unet_module(cin, k_begin, k2, stride, c, intra_connect, norm_type, scale=4, de_flag=False, is_first=True)) 400 | meta_unet.append( 401 | En_unet_module(cin, k1, k2, stride, c, intra_connect, norm_type, scale=3, de_flag=False)) 402 | meta_unet.append( 403 | En_unet_module(cin, k1, k2, stride, c, intra_connect, norm_type, scale=2, de_flag=False)) 404 | meta_unet.append( 405 | En_unet_module(cin, k1, k2, stride, c, intra_connect, norm_type, scale=1, de_flag=False)) 406 | self.meta_unet_list = nn.ModuleList(meta_unet) 407 | self.last_conv = nn.Sequential( 408 | Gate2dconv(c, c_end, k1, stride, de_flag=False, pad=(0,0,k1[0]-1,0)), 409 | NormSwitch(norm_type, "2D", c_end), 410 | nn.PReLU(c_end) 411 | ) 412 | 413 | def forward(self, x): 414 | for i in range(len(self.meta_unet_list)): 415 | x = self.meta_unet_list[i](x) 416 | x = self.last_conv(x) 417 | return x 418 | 419 | 420 | class UNet_Encoder(nn.Module): 421 | def __init__(self, 422 | cin: int, 423 | k1: tuple, 424 | stride: tuple, 425 | c: int, 426 | norm_type: str, 427 | ): 428 | super(UNet_Encoder, self).__init__() 429 | self.cin = cin 430 | self.k1, self.c = k1, c 431 | self.stride = stride 432 | self.norm_type = norm_type 433 | k_begin = (2, 5) 434 | c_end = 64 # 64 by default 435 | unet = [] 436 | unet.append(nn.Sequential( 437 | Gate2dconv(cin, c, k_begin, stride, de_flag=False, pad=(0,0,k_begin[0]-1,0)), 438 | NormSwitch(norm_type, "2D", c), 439 | nn.PReLU(c))) 440 | unet.append(nn.Sequential( 441 | Gate2dconv(c, c, k1, stride, de_flag=False, pad=(0,0,k1[0]-1,0)), 442 | NormSwitch(norm_type, "2D", c), 443 | nn.PReLU(c))) 444 | unet.append(nn.Sequential( 445 | Gate2dconv(c, c, k1, stride, de_flag=False, pad=(0,0,k1[0]-1,0)), 446 | NormSwitch(norm_type, "2D", c), 447 | nn.PReLU(c))) 448 | unet.append(nn.Sequential( 449 | Gate2dconv(c, c, k1, stride, de_flag=False, pad=(0,0,k1[0]-1,0)), 450 | NormSwitch(norm_type, "2D", c), 451 | nn.PReLU(c))) 452 | unet.append(nn.Sequential( 453 | Gate2dconv(c, c_end, k1, stride, de_flag=False, pad=(0,0,k1[0]-1,0)), 454 | NormSwitch(norm_type, "2D", c_end), 455 | nn.PReLU(c_end))) 456 | self.unet_list = nn.ModuleList(unet) 457 | 458 | def forward(self, x): 459 | for i in range(len(self.unet_list)): 460 | x = self.unet_list[i](x) 461 | return x 462 | 463 | 464 | class En_unet_module(nn.Module): 465 | def __init__(self, 466 | cin: int, 467 | k1: tuple, 468 | k2: tuple, 469 | stride: tuple, 470 | c: int, 471 | intra_connect: str, 472 | norm_type: str, 473 | scale: int, 474 | de_flag: bool = False, 475 | is_first: bool = False, 476 | ): 477 | super(En_unet_module, self).__init__() 478 | self.cin, self.k1, self.k2 = cin, k1, k2 479 | self.stride = stride 480 | self.c = c 481 | self.intra_connect = intra_connect 482 | self.norm_type = norm_type 483 | self.scale = scale 484 | self.de_flag = de_flag 485 | self.is_first = is_first 486 | 487 | in_conv_list = [] 488 | if self.is_first: 489 | in_conv_list.append(Gate2dconv(cin, c, k1, stride, de_flag, pad=(0, 0, k1[0]-1, 0))) 490 | else: 491 | in_conv_list.append(Gate2dconv(c, c, k1, stride, de_flag, pad=(0, 0, k1[0]-1, 0))) 492 | in_conv_list.append(NormSwitch(norm_type, "2D", c)) 493 | in_conv_list.append(nn.PReLU(c)) 494 | self.in_conv = nn.Sequential(*in_conv_list) 495 | 496 | enco_list, deco_list = [], [] 497 | for _ in range(scale): 498 | enco_list.append(Conv2dunit(k2, stride, c, norm_type)) 499 | for i in range(scale): 500 | if i == 0: 501 | deco_list.append(Deconv2dunit(k2, stride, c, "add", norm_type)) 502 | else: 503 | deco_list.append(Deconv2dunit(k2, stride, c, intra_connect, norm_type)) 504 | self.enco = nn.ModuleList(enco_list) 505 | self.deco = nn.ModuleList(deco_list) 506 | self.skip_connect = Skip_connect(intra_connect) 507 | 508 | def forward(self, x): 509 | x_resi = self.in_conv(x) 510 | x = x_resi 511 | x_list = [] 512 | for i in range(len(self.enco)): 513 | x = self.enco[i](x) 514 | x_list.append(x) 515 | 516 | for i in range(len(self.deco)): 517 | if i == 0: 518 | x = self.deco[i](x) 519 | else: 520 | x_con = self.skip_connect(x, x_list[-(i+1)]) 521 | x = self.deco[i](x_con) 522 | x_resi = x_resi + x 523 | del x_list 524 | return x_resi 525 | 526 | 527 | class Conv2dunit(nn.Module): 528 | def __init__(self, 529 | k: tuple, 530 | stride: tuple, 531 | c: int, 532 | norm_type: str, 533 | ): 534 | super(Conv2dunit, self).__init__() 535 | self.k, self.c = k, c 536 | self.stride = stride 537 | self.norm_type = norm_type 538 | self.conv = nn.Sequential( 539 | nn.Conv2d(c, c, k, stride), 540 | NormSwitch(norm_type, "2D", c), 541 | nn.PReLU(c) 542 | ) 543 | 544 | def forward(self, x): 545 | return self.conv(x) 546 | 547 | 548 | class Deconv2dunit(nn.Module): 549 | def __init__(self, 550 | k: tuple, 551 | stride: tuple, 552 | c: int, 553 | intra_connect: str, 554 | norm_type: str, 555 | ): 556 | super(Deconv2dunit, self).__init__() 557 | self.k, self.c = k, c 558 | self.stride = stride 559 | self.intra_connect = intra_connect 560 | self.norm_type = norm_type 561 | deconv_list = [] 562 | if self.intra_connect == "add": 563 | deconv_list.append(nn.ConvTranspose2d(c, c, k, stride)) 564 | elif self.intra_connect == "cat": 565 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, stride)) 566 | deconv_list.append(NormSwitch(norm_type, "2D", c)) 567 | deconv_list.append(nn.PReLU(c)) 568 | self.deconv = nn.Sequential(*deconv_list) 569 | 570 | def forward(self, x): 571 | return self.deconv(x) 572 | 573 | class Gate2dconv(nn.Module): 574 | def __init__(self, 575 | in_channels: int, 576 | out_channels: int, 577 | kernel_size: tuple, 578 | stride: tuple, 579 | de_flag: bool, 580 | pad: tuple = (0,0,0,0), 581 | chomp=1, 582 | ): 583 | super(Gate2dconv, self).__init__() 584 | if not de_flag: 585 | self.conv = nn.Sequential( 586 | nn.ConstantPad2d(pad, value=0.), 587 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)) 588 | self.gate_conv = nn.Sequential( 589 | nn.ConstantPad2d(pad, value=0.), 590 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride), 591 | nn.Sigmoid()) 592 | else: 593 | self.conv = nn.Sequential( 594 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride), 595 | Chomp_T(chomp)) 596 | self.gate_conv = nn.Sequential( 597 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride), 598 | Chomp_T(chomp), 599 | nn.Sigmoid()) 600 | 601 | def forward(self, x): 602 | return self.conv(x) * self.gate_conv(x) 603 | 604 | class SelfAttention(nn.Module): 605 | def __init__(self, in_feat, d_feat, n_head=4, is_causal=True): 606 | super(SelfAttention, self).__init__() 607 | self.in_feat = in_feat 608 | self.d_feat = d_feat 609 | self.n_head = n_head 610 | self.is_causal = is_causal 611 | self.scale_factor = np.sqrt(d_feat//n_head) 612 | self.softmax = nn.Softmax(dim=-1) 613 | 614 | self.norm = nn.LayerNorm([in_feat]) 615 | self.q_linear = nn.Linear(in_feat, d_feat) 616 | self.k_linear = nn.Linear(in_feat, d_feat) 617 | self.v_linear = nn.Linear(in_feat, d_feat) 618 | self.out_linear = nn.Linear(d_feat, in_feat) 619 | 620 | def Sequence_masl(self, seq): 621 | b_size, n_heads, seq_len, sub_d = seq.size() 622 | mask = torch.triu(torch.ones((b_size, n_heads, seq_len, seq_len), device=seq.device), diagonal=1) 623 | return mask 624 | 625 | def forward(self, x): 626 | """ 627 | :param x: (B,Cin,T) 628 | :return: (B,F,T) 629 | """ 630 | resi = x 631 | x = x.transpose(-2, -1).contiguous() 632 | x = self.norm(x) 633 | x_q = self.q_linear(x) 634 | x_k = self.k_linear(x) 635 | x_v = self.v_linear(x) 636 | 637 | b_size, seq_len, d_feat = x_q.shape 638 | x_q = x_q.view(b_size,seq_len,self.n_head,-1).transpose(1,2).contiguous() 639 | x_k = x_k.view(b_size,seq_len,self.n_head,-1).transpose(1,2).contiguous() 640 | x_v = x_v.view(b_size,seq_len,self.n_head,-1).transpose(1,2).contiguous() 641 | scores = torch.matmul(x_q, x_k.transpose(-2,-1)) / self.scale_factor 642 | if self.is_causal is True: 643 | scores = scores + (-1e9 * self.Sequence_masl(x_q)) 644 | attn = self.softmax(scores) 645 | 646 | context = torch.matmul(attn, x_v) # (B,N,T,D) 647 | context = context.permute(0,2,1,3).contiguous().view(b_size,seq_len,-1) 648 | context = self.out_linear(context).transpose(-2,-1).contiguous() 649 | return resi+context 650 | 651 | class Conv1dunit(nn.Module): 652 | def __init__(self, 653 | ci: int, 654 | co: int, 655 | k: int, 656 | dila: int, 657 | is_causal: bool, 658 | norm_type: str, 659 | ): 660 | super(Conv1dunit, self).__init__() 661 | self.ci, self.co, self.k, self.dila = ci, co, k, dila 662 | self.is_causal = is_causal 663 | if self.is_causal: 664 | pad = nn.ConstantPad1d(((k-1)*dila, 0), value=0.) 665 | else: 666 | pad = nn.ConstantPad1d(((k-1)*dila//2, (k-1)*dila//2), value=0.) 667 | 668 | self.unit = nn.Sequential( 669 | pad, 670 | nn.Conv1d(ci, co, k, dilation=dila), 671 | NormSwitch(norm_type, "1D", co), 672 | nn.PReLU(co) 673 | ) 674 | def forward(self, x): 675 | x = self.unit(x) 676 | return x 677 | 678 | class Skip_connect(nn.Module): 679 | def __init__(self, connect): 680 | super(Skip_connect, self).__init__() 681 | self.connect = connect 682 | 683 | def forward(self, x_main, x_aux): 684 | if self.connect == "add": 685 | x = x_main + x_aux 686 | elif self.connect == "cat": 687 | x = torch.cat((x_main, x_aux), dim=1) 688 | return x 689 | 690 | 691 | class Chomp_T(nn.Module): 692 | def __init__(self, 693 | t: int): 694 | super(Chomp_T, self).__init__() 695 | self.t = t 696 | 697 | def forward(self, x): 698 | return x[:, :, :-self.t, :] 699 | 700 | 701 | if __name__ == '__main__': 702 | net = G2Net(k1=[2,3], 703 | k2=[1,3], 704 | c=64, 705 | intra_connect="cat", 706 | d_feat=256, 707 | kd1=3, 708 | cd1=64, 709 | tcn_num=2, 710 | dilas=[1,2,5,9], 711 | fft_num=320, 712 | is_causal=True, 713 | acti_type="sigmoid", 714 | crm_type="crm1", 715 | stage_num=3, 716 | u_type="u2", 717 | head_type="RI+MAG", 718 | norm_type="IN", 719 | ).cuda() 720 | from utils.utils import numParams 721 | print(f"The number of parameters of the model is:{numParams(net)}") 722 | x = torch.rand([4,2,101,161]).cuda() 723 | y = net(x) 724 | print(f"{x.shape}->{y[-1].shape}") 725 | from ptflops.flops_counter import get_model_complexity_info 726 | macs, params = get_model_complexity_info(net, (2, 101, 161), as_strings=True, 727 | print_per_layer_stat=True, verbose=True) 728 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | import importlib 7 | from utils.utils import logger_print 8 | import hdf5storage 9 | from utils.utils import cal_pesq 10 | from utils.utils import cal_stoi 11 | from utils.utils import cal_sisnr 12 | train_epoch, val_epoch, val_metric_epoch = [], [], [] # for loss, loss and metric score 13 | # from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | class Solver(object): 17 | def __init__(self, 18 | data, 19 | net, 20 | optimizer, 21 | save_name_dict, 22 | args, 23 | ): 24 | self.train_dataloader = data["train_loader"] 25 | self.val_dataloader = data["val_loader"] 26 | self.net = net 27 | # optimizer part 28 | self.optimizer = optimizer 29 | self.lr = args["optimizer"]["lr"] 30 | self.gradient_norm = args["optimizer"]["gradient_norm"] 31 | self.epochs = args["optimizer"]["epochs"] 32 | self.halve_lr = args["optimizer"]["halve_lr"] 33 | self.early_stop = args["optimizer"]["early_stop"] 34 | self.halve_freq = args["optimizer"]["halve_freq"] 35 | self.early_stop_freq = args["optimizer"]["early_stop_freq"] 36 | self.print_freq = args["optimizer"]["print_freq"] 37 | self.metric_options = args["optimizer"]["metric_options"] 38 | # loss part 39 | self.loss_path = args["loss_function"]["path"] 40 | self.stagewise_loss = args["loss_function"]["stagewise"]["classname"] 41 | self.prev_weight = args["loss_function"]["prev_weight"] 42 | self.curr_weight = args["loss_function"]["curr_weight"] 43 | self.alpha = args["loss_function"]["alpha"] 44 | self.l_type = args["loss_function"]["l_type"] 45 | # signal part 46 | self.sr = args["signal"]["sr"] 47 | self.win_size = args["signal"]["win_size"] 48 | self.win_shift = args["signal"]["win_shift"] 49 | self.fft_num = args["signal"]["fft_num"] 50 | self.is_compress = args["signal"]["is_compress"] 51 | # path part 52 | self.is_checkpoint = args["path"]["is_checkpoint"] 53 | self.is_resume_reload = args["path"]["is_resume_reload"] 54 | self.checkpoint_load_path = args["path"]["checkpoint_load_path"] 55 | self.checkpoint_load_filename = args["path"]["checkpoint_load_filename"] 56 | self.loss_save_path = args["path"]["loss_save_path"] 57 | self.model_best_path = args["path"]["model_best_path"] 58 | # sava name 59 | self.loss_save_filename = save_name_dict["loss_filename"] 60 | self.best_model_save_filename = save_name_dict["best_model_filename"] 61 | self.checkpoint_save_filename = save_name_dict["checkpoint_filename"] 62 | 63 | self.train_loss = torch.Tensor(self.epochs) 64 | self.val_loss = torch.Tensor(self.epochs) 65 | # set loss funcs 66 | loss_module = importlib.import_module(self.loss_path) 67 | self.stagewise_loss = getattr(loss_module, self.stagewise_loss)(self.prev_weight, self.curr_weight, self.alpha, 68 | self.l_type) 69 | self._reset() 70 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 | # summarywriter 72 | # self.tensorboard_path = "./" + args["path"]["logging_path"] + "/" + args["save"]["tensorboard_filename"] 73 | # if not os.path.exists(self.tensorboard_path): 74 | # os.makedirs(self.tensorboard_path) 75 | #self.writer = SummaryWriter(self.tensorboard_path, max_queue=5, flush_secs=30) 76 | 77 | def _reset(self): 78 | # Reset 79 | if self.is_resume_reload: 80 | checkpoint = torch.load(os.path.join(self.checkpoint_load_path, self.checkpoint_load_filename)) 81 | self.net.load_state_dict(checkpoint["model_state_dict"]) 82 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 83 | self.start_epoch = checkpoint["start_epoch"] 84 | self.prev_val_loss = checkpoint["val_loss"] # val loss 85 | self.prev_val_metric = checkpoint["val_metric"] 86 | self.best_val_metric = checkpoint["best_val_metric"] 87 | self.val_no_impv = checkpoint["val_no_impv"] 88 | self.halving = checkpoint["halving"] 89 | else: 90 | self.start_epoch = 0 91 | self.prev_val_loss = float("inf") 92 | self.prev_val_metric = -float("inf") 93 | self.best_val_metric = -float("inf") 94 | self.val_no_impv = 0 95 | self.halving = False 96 | 97 | def train(self): 98 | logger_print("Begin to train....") 99 | self.net.to(self.device) 100 | for epoch in range(self.start_epoch, self.epochs): 101 | begin_time = time.time() 102 | # training phase 103 | logger_print("-" * 90) 104 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 105 | logger_print(f"Epoch id:{int(epoch + 1)}, Training phase, Start time:{start_time}") 106 | self.net.train() 107 | train_avg_loss = self._run_one_epoch(epoch, val_opt=False) 108 | # self.writer.add_scalar(f"Loss/Training_Loss", train_avg_loss, epoch) 109 | end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 110 | logger_print(f"Epoch if:{int(epoch + 1)}, Training phase, End time:{end_time}, " 111 | f"Training loss:{train_avg_loss}") 112 | 113 | # Cross val 114 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 115 | logger_print(f"Epoch id:{int(epoch + 1)}, Validation phase, Start time:{start_time}") 116 | self.net.eval() # norm and dropout is off 117 | val_avg_loss, val_avg_metric = self._run_one_epoch(epoch, val_opt=True) 118 | # self.writer.add_scalar(f"Loss/Validation_Loss", val_avg_loss, epoch) 119 | # self.writer.add_scalar(f"Loss/Validation_Metric", val_avg_metric, epoch) 120 | end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 121 | logger_print(f"Epoch if:{int(epoch + 1)}, Validation phase, End time:{end_time}, " 122 | f"Validation loss:{val_avg_loss}, Validation metric score:{val_avg_metric}") 123 | end_time = time.time() 124 | print(f"{end_time-begin_time}s in {epoch+1}th epoch") 125 | logger_print("-" * 90) 126 | 127 | # whether to save checkpoint at current epoch 128 | if self.is_checkpoint: 129 | cpk_dic = {} 130 | cpk_dic["model_state_dict"] = self.net.state_dict() 131 | cpk_dic["optimizer_state_dict"] = self.optimizer.state_dict() 132 | cpk_dic["train_loss"] = train_avg_loss 133 | cpk_dic["val_loss"] = val_avg_loss 134 | cpk_dic["val_metric"] = val_avg_metric 135 | cpk_dic["best_val_metric"] = self.best_val_metric 136 | cpk_dic["start_epoch"] = epoch+1 137 | cpk_dic["val_no_impv"] = self.val_no_impv 138 | cpk_dic["halving"] = self.halving 139 | torch.save(cpk_dic, os.path.join(self.checkpoint_load_path, "Epoch_{}_{}_{}".format(epoch+1, 140 | self.net.__class__.__name__, self.checkpoint_save_filename))) 141 | # record loss 142 | # self.train_loss[epoch] = train_avg_loss 143 | # self.val_loss[epoch] = val_avg_loss 144 | 145 | train_epoch.append(train_avg_loss) 146 | val_epoch.append(val_avg_loss) 147 | val_metric_epoch.append(val_avg_metric) 148 | 149 | # save loss 150 | loss = {} 151 | loss["train_loss"] = train_epoch 152 | loss["val_loss"] = val_epoch 153 | loss["val_metric"] = val_metric_epoch 154 | 155 | if not self.is_resume_reload: 156 | hdf5storage.savemat(os.path.join(self.loss_save_path, self.loss_save_filename), loss) 157 | else: 158 | hdf5storage.savemat(os.path.join(self.loss_save_path, "resume_cpk_{}".format(self.loss_save_filename)), 159 | loss) 160 | 161 | # lr halve and Early stop 162 | if self.halve_lr: 163 | if val_avg_metric <= self.prev_val_metric: 164 | self.val_no_impv += 1 165 | if self.val_no_impv == self.halve_freq: 166 | self.halving = True 167 | if (self.val_no_impv >= self.early_stop_freq) and self.early_stop: 168 | logger_print("No improvements and apply early-stopping") 169 | break 170 | else: 171 | self.val_no_impv = 0 172 | 173 | if self.halving: 174 | optim_state = self.optimizer.state_dict() 175 | optim_state["param_groups"][0]["lr"] = optim_state["param_groups"][0]["lr"] / 2.0 176 | self.optimizer.load_state_dict(optim_state) 177 | logger_print("Learning rate is adjusted to %5f" % (optim_state["param_groups"][0]["lr"])) 178 | self.halving = False 179 | self.prev_val_metric = val_avg_metric 180 | 181 | if val_avg_metric > self.best_val_metric: 182 | self.best_val_metric = val_avg_metric 183 | torch.save(self.net.state_dict(), os.path.join(self.model_best_path, self.best_model_save_filename)) 184 | logger_print(f"Find better model, saving to {self.best_model_save_filename}") 185 | else: 186 | logger_print("Did not find better model") 187 | 188 | @torch.no_grad() 189 | def _val_batch(self, batch_info): 190 | batch_mix_wav = batch_info.feats.to(self.device) # (B,L) 191 | batch_target_wav = batch_info.labels.to(self.device) # (B,L) 192 | batch_wav_len_list = batch_info.frame_mask_list 193 | real_len = batch_mix_wav.shape[-1] 194 | 195 | # stft 196 | b_size, wav_len = batch_mix_wav.shape 197 | win_size, win_shift = int(self.sr*self.win_size), int(self.sr*self.win_shift) 198 | batch_mix_stft = torch.stft( 199 | batch_mix_wav, 200 | n_fft=self.fft_num, 201 | hop_length=win_shift, 202 | win_length=win_size, 203 | window=torch.hann_window(win_size).to(self.device)) # (B,F,T,2) 204 | batch_target_stft = torch.stft( 205 | batch_target_wav, 206 | n_fft=self.fft_num, 207 | hop_length=win_shift, 208 | win_length=win_size, 209 | window=torch.hann_window(win_size).to(self.device)) # (B,F,T,2) 210 | 211 | batch_frame_list = [] 212 | for i in range(len(batch_wav_len_list)): 213 | curr_frame_num = (batch_wav_len_list[i]-win_size+win_size)//win_shift+1 # center case 214 | batch_frame_list.append(curr_frame_num) 215 | _, freq_num, seq_len, _ = batch_mix_stft.shape 216 | if self.is_compress: 217 | # mix 218 | batch_mix_mag, batch_mix_phase = torch.norm(batch_mix_stft, dim=-1)**0.5, \ 219 | torch.atan2(batch_mix_stft[..., -1], batch_mix_stft[..., 0]) 220 | batch_mix_stft = torch.stack((batch_mix_mag*torch.cos(batch_mix_phase), 221 | batch_mix_mag*torch.sin(batch_mix_phase)), dim=-1) 222 | # target 223 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1)**0.5, \ 224 | torch.atan2(batch_target_stft[...,-1], 225 | batch_target_stft[...,0]) 226 | batch_target_stft = torch.stack((batch_target_mag*torch.cos(batch_target_phase), 227 | batch_target_mag*torch.sin(batch_target_phase)), dim=-1) 228 | 229 | # convert, mix: (B,2,T,F), target: (B,2,F,T) 230 | batch_mix_stft = batch_mix_stft.permute(0,3,2,1) # (B,2,T,F) 231 | batch_target_stft = batch_target_stft.permute(0,3,1,2) # (B,2,F,T) 232 | 233 | # net predict 234 | batch_est_list = self.net(batch_mix_stft) # (B,2,F,T) 235 | # cal stagewise loss 236 | batch_stagewise_loss = self.stagewise_loss(batch_est_list, batch_target_stft, batch_frame_list) 237 | # cal metric loss, (B,F,T,2) 238 | batch_mix_stft = batch_mix_stft.permute(0,3,2,1) # (B,F,T,2) 239 | batch_est_stft = batch_est_list[-1].permute(0,2,3,1) # (B,F,T,2) 240 | batch_target_stft = batch_target_stft.permute(0,2,3,1) # (B,F,T,2) 241 | if self.is_compress: 242 | # mix 243 | batch_mix_mag, batch_mix_phase = torch.norm(batch_mix_stft, dim=-1)**2.0, \ 244 | torch.atan2(batch_mix_stft[..., -1], batch_mix_stft[...,0]) 245 | batch_mix_stft = torch.stack((batch_mix_mag*torch.cos(batch_mix_phase), 246 | batch_mix_mag*torch.sin(batch_mix_phase)), dim=-1) 247 | # est 248 | batch_spec_mag, batch_spec_phase = torch.norm(batch_est_stft, dim=-1)**2.0,\ 249 | torch.atan2(batch_est_stft[..., -1], batch_est_stft[...,0]) 250 | batch_est_stft = torch.stack((batch_spec_mag*torch.cos(batch_spec_phase), 251 | batch_spec_mag*torch.sin(batch_spec_phase)), dim=-1) 252 | # target 253 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1)**2.0, \ 254 | torch.atan2(batch_target_stft[...,-1], batch_target_stft[...,0]) 255 | batch_target_stft = torch.stack((batch_target_mag*torch.cos(batch_target_phase), 256 | batch_target_mag*torch.sin(batch_target_phase)), dim=-1) 257 | batch_mix_wav = torch.istft(batch_mix_stft, 258 | n_fft=self.fft_num, 259 | hop_length=win_shift, 260 | win_length=win_size, 261 | window=torch.hann_window(win_size).to(self.device), 262 | length=real_len 263 | ) # (B,L) 264 | batch_est_wav = torch.istft(batch_est_stft, 265 | n_fft=self.fft_num, 266 | hop_length=win_shift, 267 | win_length=win_size, 268 | window=torch.hann_window(win_size).to(self.device), 269 | length=real_len 270 | ) # (B,L) 271 | batch_target_wav = torch.istft(batch_target_stft, 272 | n_fft=self.fft_num, 273 | hop_length=win_shift, 274 | win_length=win_size, 275 | window=torch.hann_window(win_size).to(self.device), 276 | length=real_len 277 | ) # (B,L) 278 | loss_dict = {} 279 | loss_dict["mse_loss"] = batch_stagewise_loss.item() 280 | # create mask 281 | mask_list = [] 282 | for id in range(b_size): 283 | mask_list.append(torch.ones((batch_wav_len_list[id]))) 284 | wav_mask = torch.nn.utils.rnn.pad_sequence(mask_list, batch_first=True).to(batch_mix_stft.device) # (B,L) 285 | batch_mix_wav, batch_target_wav, batch_est_wav = (batch_mix_wav*wav_mask).cpu().numpy(), \ 286 | (batch_target_wav*wav_mask).cpu().numpy(), \ 287 | (batch_est_wav*wav_mask).cpu().numpy() 288 | if "SISNR" in self.metric_options: 289 | unpro_score_list, pro_score_list = [], [] 290 | for id in range(batch_mix_wav.shape[0]): 291 | unpro_score_list.append(cal_sisnr(id, batch_mix_wav, batch_target_wav, self.sr)) 292 | pro_score_list.append(cal_sisnr(id, batch_est_wav, batch_target_wav, self.sr)) 293 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), np.asarray(pro_score_list) 294 | unpro_sisnr_mean_score, pro_sisnr_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 295 | loss_dict["unpro_metric"] = unpro_sisnr_mean_score 296 | loss_dict["pro_metric"] = pro_sisnr_mean_score 297 | if "NB-PESQ" in self.metric_options: 298 | unpro_score_list, pro_score_list = [], [] 299 | for id in range(batch_mix_wav.shape[0]): 300 | unpro_score_list.append(cal_pesq(id, batch_mix_wav, batch_target_wav, self.sr)) 301 | pro_score_list.append(cal_pesq(id, batch_est_wav, batch_target_wav, self.sr)) 302 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), \ 303 | np.asarray(pro_score_list) 304 | unpro_pesq_mean_score, pro_pesq_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 305 | loss_dict["unpro_metric"] = unpro_pesq_mean_score 306 | loss_dict["pro_metric"] = pro_pesq_mean_score 307 | if "ESTOI" in self.metric_options: 308 | unpro_score_list, pro_score_list = [], [] 309 | for id in range(batch_mix_wav.shape[0]): 310 | unpro_score_list.append(cal_stoi(id, batch_mix_wav, batch_target_wav, self.sr)) 311 | pro_score_list.append(cal_stoi(id, batch_mix_wav, batch_target_wav, self.sr)) 312 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), \ 313 | np.asarray(pro_score_list) 314 | unpro_estoi_mean_score, pro_estoi_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 315 | loss_dict["unpro_metric"] = unpro_estoi_mean_score 316 | loss_dict["pro_metric"] = pro_estoi_mean_score 317 | return loss_dict 318 | 319 | def _train_batch(self, batch_info): 320 | batch_mix_wav = batch_info.feats.to(self.device) # (B,L) 321 | batch_target_wav = batch_info.labels.to(self.device) # (B,L) 322 | batch_wav_len_list = batch_info.frame_mask_list 323 | 324 | # stft 325 | b_size, wav_len = batch_mix_wav.shape 326 | win_size, win_shift = int(self.sr*self.win_size), int(self.sr*self.win_shift) 327 | batch_mix_stft = torch.stft( 328 | batch_mix_wav, 329 | n_fft=self.fft_num, 330 | hop_length=win_shift, 331 | win_length=win_size, 332 | window=torch.hann_window(win_size).to(batch_mix_wav.device)) # (B,F,T,2) 333 | batch_target_stft = torch.stft( 334 | batch_target_wav, 335 | n_fft=self.fft_num, 336 | hop_length=win_shift, 337 | win_length=win_size, 338 | window=torch.hann_window(win_size).to(batch_target_wav.device)) # (B,F,T,2) 339 | batch_frame_list = [] 340 | for i in range(len(batch_wav_len_list)): 341 | curr_frame_num = (batch_wav_len_list[i] - win_size + win_size) // win_shift + 1 342 | batch_frame_list.append(curr_frame_num) 343 | 344 | if self.is_compress: # here only apply to target and bf as feat-compression has been applied within the network 345 | # mix 346 | batch_mix_mag, batch_mix_phase = torch.norm(batch_mix_stft, dim=-1)**0.5, \ 347 | torch.atan2(batch_mix_stft[..., -1], batch_mix_stft[..., 0]) 348 | batch_mix_stft = torch.stack((batch_mix_mag*torch.cos(batch_mix_phase), 349 | batch_mix_mag*torch.sin(batch_mix_phase)), dim=-1) 350 | # target 351 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1)**0.5, \ 352 | torch.atan2(batch_target_stft[..., -1], 353 | batch_target_stft[..., 0]) 354 | batch_target_stft = torch.stack((batch_target_mag*torch.cos(batch_target_phase), 355 | batch_target_mag*torch.sin(batch_target_phase)), dim=-1) 356 | 357 | # convert, (B,2,T,F) 358 | batch_mix_stft = batch_mix_stft.permute(0,3,2,1) # (B,2,T,F) 359 | batch_target_stft = batch_target_stft.permute(0,3,1,2) # (B,2,F,T) 360 | 361 | with torch.enable_grad(): 362 | batch_est_list = self.net(batch_mix_stft) # (B,2,F,T) 363 | 364 | # stagewise loss 365 | batch_stagewise_loss = self.stagewise_loss(batch_est_list, batch_target_stft, batch_frame_list) 366 | # params update 367 | self.update_params(batch_stagewise_loss) 368 | loss_dict = {} 369 | loss_dict["stagewise_loss"] = batch_stagewise_loss.item() 370 | return loss_dict 371 | 372 | def _run_one_epoch(self, epoch, val_opt=False): 373 | # training phase 374 | if not val_opt: 375 | data_loader = self.train_dataloader 376 | total_stagewise_loss = 0. 377 | start_time = time.time() 378 | for batch_id, batch_info in enumerate(data_loader.get_data_loader()): 379 | loss_dict = self._train_batch(batch_info) 380 | total_stagewise_loss += loss_dict["stagewise_loss"] 381 | if batch_id % self.print_freq == 0: 382 | logger_print( 383 | "Epoch:{:d}, Iter:{:d}, Average loss:{:.4f}, Time: {:d}ms/batch". 384 | format(epoch+1, int(batch_id), total_stagewise_loss/(batch_id+1), 385 | int(1000*(time.time()-start_time)/(batch_id+1)))) 386 | return total_stagewise_loss / (batch_id+1) 387 | 388 | else: # validation phase 389 | data_loder = self.val_dataloader 390 | total_sp_loss, total_pro_metric_loss, total_unpro_metric_loss = 0., 0., 0. 391 | start_time = time.time() 392 | for batch_id, batch_info in enumerate(data_loder.get_data_loader()): 393 | loss_dict = self._val_batch(batch_info) 394 | assert len(self.metric_options) == 1, "only one metric is supported to output in the val phase" 395 | total_sp_loss += loss_dict["mse_loss"] 396 | total_unpro_metric_loss += loss_dict["unpro_metric"] 397 | total_pro_metric_loss += loss_dict["pro_metric"] 398 | if batch_id % self.print_freq == 0: 399 | logger_print( 400 | "Epoch:{:d}, Iter:{:d}, Average spectral loss:{:.4f}, Average unpro metric score:{:.4f}, " 401 | "Average pro metric score:{:.4f}, Time: {:d}ms/batch". 402 | format(epoch+1, int(batch_id), total_sp_loss/(batch_id+1), total_unpro_metric_loss/(batch_id+1), 403 | total_pro_metric_loss/(batch_id+1), int(1000*(time.time()-start_time)/(batch_id+1)))) 404 | return total_sp_loss / (batch_id+1), total_pro_metric_loss / (batch_id) 405 | 406 | def update_params(self, loss): 407 | self.optimizer.zero_grad() 408 | loss.backward() 409 | if self.gradient_norm >= 0.0: 410 | nn.utils.clip_grad_norm_(self.net.parameters(), self.gradient_norm) 411 | has_nan_inf = 0 412 | for params in self.net.parameters(): 413 | if params.requires_grad: 414 | has_nan_inf += torch.sum(torch.isnan(params.grad)) 415 | has_nan_inf += torch.sum(torch.isinf(params.grad)) 416 | if has_nan_inf == 0: 417 | self.optimizer.step() 418 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MagnitudeLoss(object): 5 | def __init__(self, l_type): 6 | self.l_type = l_type 7 | 8 | def __call__(self, esti, label, frame_list): 9 | """ 10 | esti: (B,T,F) 11 | label: (B,T,F) 12 | frame_list: list 13 | """ 14 | b_size, seq_len, freq_num = esti.shape 15 | mask_for_loss = [] 16 | with torch.no_grad(): 17 | for i in range(b_size): 18 | tmp_mask = torch.ones((frame_list[i], freq_num), dtype=esti.dtype) 19 | mask_for_loss.append(tmp_mask) 20 | mask_for_loss = torch.nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(esti.device) 21 | 22 | if self.l_type == "L1" or self.l_type == "l1": 23 | loss_mag = (torch.abs(esti - label) * mask_for_loss).sum() / mask_for_loss.sum() 24 | elif self.l_type == "L2" or self.l_type == "l2": 25 | loss_mag = (torch.square(esti - label) * mask_for_loss).sum() / mask_for_loss.sum() 26 | else: 27 | raise RuntimeError("only L1 and L2 are supported") 28 | return loss_mag 29 | 30 | 31 | class ComMagEuclideanLoss(object): 32 | def __init__(self, alpha, l_type): 33 | self.alpha = alpha 34 | self.l_type = l_type 35 | 36 | def __call__(self, est, label, frame_list): 37 | """ 38 | est: (B,2,T,F) 39 | label: (B,2,T,F) 40 | frame_list: list 41 | alpha: scalar 42 | l_type: str, L1 or L2 43 | """ 44 | b_size, _, seq_len, freq_num = est.shape 45 | mask_for_loss = [] 46 | with torch.no_grad(): 47 | for i in range(b_size): 48 | tmp_mask = torch.ones((frame_list[i], freq_num, 2), dtype=est.dtype) 49 | mask_for_loss.append(tmp_mask) 50 | mask_for_loss = torch.nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(est.device) 51 | mask_for_loss = mask_for_loss.permute(0,3,1,2) # (B,2,T,F) 52 | mag_mask_for_loss = mask_for_loss[:,0,...] 53 | est_mag, label_mag = torch.norm(est, dim=1), torch.norm(label, dim=1) 54 | 55 | if self.l_type == "L1" or self.l_type == "l1": 56 | loss_com = (torch.abs(est - label) * mask_for_loss).sum() / mask_for_loss.sum() 57 | loss_mag = (torch.abs(est_mag - label_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 58 | elif self.l_type == "L2" or self.l_type == "l2": 59 | loss_com = (torch.square(est - label) * mask_for_loss).sum() / mask_for_loss.sum() 60 | loss_mag = (torch.square(est_mag - label_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 61 | else: 62 | raise RuntimeError("only L1 and L2 are supported!") 63 | return self.alpha*loss_com + (1 - self.alpha)*loss_mag 64 | 65 | 66 | class StagewiseComMagEuclideanLoss(object): 67 | def __init__(self, 68 | prev_weight, 69 | curr_weight, 70 | alpha, 71 | l_type, 72 | ): 73 | self.prev_weight = prev_weight 74 | self.curr_weight = curr_weight 75 | self.alpha = alpha 76 | self.l_type = l_type 77 | 78 | def __call__(self, est_list, label, frame_list): 79 | alpha_list = [self.prev_weight for _ in range(len(est_list)-1)] 80 | alpha_list.append(self.curr_weight) 81 | mask_for_loss = [] 82 | utt_num = label.size()[0] 83 | with torch.no_grad(): 84 | for i in range(utt_num): 85 | tmp_mask = torch.ones((frame_list[i], label.size()[-2]), dtype=label.dtype) 86 | mask_for_loss.append(tmp_mask) 87 | mask_for_loss = torch.nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(label.device) 88 | mask_for_loss = mask_for_loss.transpose(-2, -1).contiguous() 89 | com_mask_for_loss = torch.stack((mask_for_loss, mask_for_loss), dim=1) 90 | loss1, loss2 = 0., 0. 91 | mag_label = torch.norm(label, dim=1) 92 | for i in range(len(est_list)): 93 | curr_esti = est_list[i] 94 | mag_esti = torch.norm(curr_esti, dim=1) 95 | if self.l_type == "L1" or self.l_type == "l1": 96 | loss1 = loss1 + alpha_list[i] * ( 97 | (torch.abs(curr_esti - label) * com_mask_for_loss).sum() / com_mask_for_loss.sum()) 98 | loss2 = loss2 + alpha_list[i] * ( 99 | (torch.abs(mag_esti - mag_label) * mask_for_loss).sum() / mask_for_loss.sum()) 100 | elif self.l_type == "L2" or self.l_type == "l2": 101 | loss1 = loss1 + alpha_list[i] * ( 102 | (torch.square(curr_esti - label) * com_mask_for_loss).sum() / com_mask_for_loss.sum()) 103 | loss2 = loss2 + alpha_list[i] * ( 104 | (torch.square(mag_esti - mag_label) * mask_for_loss).sum() / mask_for_loss.sum()) 105 | return self.alpha*loss1 + (1-self.alpha)*loss2 106 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | EPSILON = np.finfo(np.float32).eps 9 | 10 | def logger_print(log): 11 | logging.info(log) 12 | print(log) 13 | 14 | def numParams(net): 15 | num = 0 16 | for param in net.parameters(): 17 | if param.requires_grad: 18 | num += int(np.prod(param.size())) 19 | return num 20 | 21 | class ToTensor(object): 22 | def __call__(self, 23 | x, 24 | type="float"): 25 | if type == "float": 26 | return torch.FloatTensor(x) 27 | elif type == "int": 28 | return torch.IntTensor(x) 29 | 30 | 31 | def pad_to_longest(batch_data): 32 | """ 33 | pad the waves with the longest length among one batch chunk 34 | :param batch_data: 35 | :return: 36 | """ 37 | mix_wav_batch_list, target_wav_batch_list, wav_len_list = batch_data[0] 38 | mix_tensor, target_tensor = nn.utils.rnn.pad_sequence(mix_wav_batch_list, batch_first=True), \ 39 | nn.utils.rnn.pad_sequence(target_wav_batch_list, batch_first=True) # (B,L,M) 40 | return mix_tensor, target_tensor, wav_len_list 41 | 42 | 43 | class BatchInfo(object): 44 | def __init__(self, feats, labels, frame_mask_list): 45 | self.feats = feats 46 | self.labels = labels 47 | self.frame_mask_list = frame_mask_list 48 | 49 | 50 | def json_extraction(file_path, json_path, data_type): 51 | if not os.path.exists(json_path): 52 | os.makedirs(json_path) 53 | file_list = os.listdir(file_path) 54 | file_num = len(file_list) 55 | json_list = [] 56 | 57 | for i in range(file_num): 58 | file_name = file_list[i] 59 | file_name = os.path.splitext(file_name)[0] 60 | json_list.append(file_name) 61 | 62 | with open(os.path.join(json_path, "{}_files.json".format(data_type)), "w") as f: 63 | json.dump(json_list, f, indent=4) 64 | return os.path.join(json_path, "{}_files.json".format(data_type)) 65 | 66 | 67 | def complex_mul(inpt1, inpt2): 68 | """ 69 | inpt1: (B,2,...) or (...,2) 70 | inpt2: (B,2,...) or (...,2) 71 | """ 72 | if inpt1.shape[1] == 2: 73 | out_r = inpt1[:,0,...]*inpt2[:,0,...] - inpt1[:,-1,...]*inpt2[:,-1,...] 74 | out_i = inpt1[:,0,...]*inpt2[:,-1,...] + inpt1[:,-1,...]*inpt2[:,0,...] 75 | return torch.stack((out_r, out_i), dim=1) 76 | elif inpt1.shape[-1] == 2: 77 | out_r = inpt1[...,0]*inpt2[...,0] - inpt1[...,-1]*inpt2[...,-1] 78 | out_i = inpt1[...,0]*inpt2[...,-1] + inpt1[...,-1]*inpt2[...,0] 79 | return torch.stack((out_r, out_i), dim=-1) 80 | else: 81 | raise RuntimeError("Only supports two tensor formats") 82 | 83 | def complex_conj(inpt): 84 | """ 85 | inpt: (B,2,...) or (...,2) 86 | """ 87 | if inpt.shape[1] == 2: 88 | inpt_r, inpt_i = inpt[:,0,...], inpt[:,-1,...] 89 | return torch.stack((inpt_r, -inpt_i), dim=1) 90 | elif inpt.shape[-1] == 2: 91 | inpt_r, inpt_i = inpt[...,0], inpt[...,-1] 92 | return torch.stack((inpt_r, -inpt_i), dim=-1) 93 | 94 | def complex_div(inpt1, inpt2): 95 | """ 96 | inpt1: (B,2,...) or (...,2) 97 | inpt2: (B,2,...) or (...,2) 98 | """ 99 | if inpt1.shape[1] == 2: 100 | inpt1_r, inpt1_i = inpt1[:,0,...], inpt1[:,-1,...] 101 | inpt2_r, inpt2_i = inpt2[:,0,...], inpt2[:,-1,...] 102 | denom = torch.norm(inpt2, dim=1)**2.0 + EPSILON 103 | out_r = inpt1_r * inpt2_r + inpt1_i * inpt2_i 104 | out_i = inpt1_i * inpt2_r - inpt1_r * inpt2_i 105 | return torch.stack((out_r/denom, out_i/denom), dim=1) 106 | elif inpt1.shape[-1] == 2: 107 | inpt1_r, inpt1_i = inpt1[...,0], inpt1[...,-1] 108 | inpt2_r, inpt2_i = inpt2[...,0], inpt2[...,-1] 109 | denom = torch.norm(inpt2, dim=-1)**2.0 + EPSILON 110 | out_r = inpt1_r * inpt2_r + inpt1_i * inpt2_i 111 | out_i = inpt1_i * inpt2_r - inpt1_r * inpt2_i 112 | return torch.stack((out_r/denom, out_i/denom), dim=-1) 113 | 114 | 115 | class NormSwitch(nn.Module): 116 | def __init__(self, 117 | norm_type: str, 118 | format: str, 119 | num_features: int, 120 | affine: bool = True, 121 | ): 122 | super(NormSwitch, self).__init__() 123 | self.norm_type = norm_type 124 | self.format = format 125 | self.num_features = num_features 126 | self.affine = affine 127 | 128 | if norm_type == "BN": 129 | if format == "1D": 130 | self.norm = nn.BatchNorm1d(num_features, affine=True) 131 | else: 132 | self.norm = nn.BatchNorm2d(num_features, affine=True) 133 | elif norm_type == "IN": 134 | if format == "1D": 135 | self.norm = nn.InstanceNorm1d(num_features, affine) 136 | else: 137 | self.norm = nn.InstanceNorm2d(num_features, affine) 138 | elif norm_type == "cLN": 139 | if format == "1D": 140 | self.norm = CumulativeLayerNorm1d(num_features, affine) 141 | else: 142 | self.norm = CumulativeLayerNorm2d(num_features, affine) 143 | elif norm_type == "cIN": 144 | if format == "2D": 145 | self.norm = CumulativeLayerNorm2d(num_features, affine) 146 | elif norm_type == "iLN": 147 | if format == "1D": 148 | self.norm = InstantLayerNorm1d(num_features, affine) 149 | else: 150 | self.norm = InstantLayerNorm2d(num_features, affine) 151 | 152 | def forward(self, inpt): 153 | return self.norm(inpt) 154 | 155 | 156 | class CumulativeLayerNorm2d(nn.Module): 157 | def __init__(self, 158 | num_features, 159 | affine=True, 160 | eps=1e-5, 161 | ): 162 | super(CumulativeLayerNorm2d, self).__init__() 163 | self.num_features = num_features 164 | self.eps = eps 165 | self.affine = affine 166 | 167 | if affine: 168 | self.gain = nn.Parameter(torch.ones(1,num_features,1,1)) 169 | self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 170 | else: 171 | self.gain = Variable(torch.ones(1,num_features,1,1), requires_grad=False) 172 | self.bias = Variable(torch.zeros(1,num_features,1,1), requires_grad=False) 173 | 174 | def forward(self, inpt): 175 | """ 176 | :param inpt: (B,C,T,F) 177 | :return: 178 | """ 179 | b_size, channel, seq_len, freq_num = inpt.shape 180 | step_sum = inpt.sum([1,3], keepdim=True) # (B,1,T,1) 181 | step_pow_sum = inpt.pow(2).sum([1,3], keepdim=True) # (B,1,T,1) 182 | cum_sum = torch.cumsum(step_sum, dim=-2) # (B,1,T,1) 183 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=-2) # (B,1,T,1) 184 | 185 | entry_cnt = np.arange(channel*freq_num, channel*freq_num*(seq_len+1), channel*freq_num) 186 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 187 | entry_cnt = entry_cnt.view(1,1,seq_len,1).expand_as(cum_sum) 188 | 189 | cum_mean = cum_sum / entry_cnt 190 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 191 | cum_std = (cum_var + self.eps).sqrt() 192 | 193 | x = (inpt - cum_mean) / cum_std 194 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 195 | 196 | class CumulativeLayerNorm1d(nn.Module): 197 | def __init__(self, 198 | num_features, 199 | affine=True, 200 | eps=1e-5, 201 | ): 202 | super(CumulativeLayerNorm1d, self).__init__() 203 | self.num_features = num_features 204 | self.affine = affine 205 | self.eps = eps 206 | 207 | if affine: 208 | self.gain = nn.Parameter(torch.ones(1,num_features,1), requires_grad=True) 209 | self.bias = nn.Parameter(torch.zeros(1,num_features,1), requires_grad=True) 210 | else: 211 | self.gain = Variable(torch.ones(1, num_features, 1), requires_grad=False) 212 | self.bias = Variable(torch.zeros(1, num_features, 1), requires_gra=False) 213 | 214 | def forward(self, inpt): 215 | # inpt: (B,C,T) 216 | b_size, channel, seq_len = inpt.shape 217 | cum_sum = torch.cumsum(inpt.sum(1), dim=1) # (B,T) 218 | cum_power_sum = torch.cumsum(inpt.pow(2).sum(1), dim=1) # (B,T) 219 | 220 | entry_cnt = np.arange(channel, channel*(seq_len+1), channel) 221 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 222 | entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) # (B,T) 223 | 224 | cum_mean = cum_sum / entry_cnt # (B,T) 225 | cum_var = (cum_power_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 226 | cum_std = (cum_var + self.eps).sqrt() 227 | 228 | x = (inpt - cum_mean.unsqueeze(dim=1).expand_as(inpt)) / cum_std.unsqueeze(dim=1).expand_as(inpt) 229 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 230 | 231 | 232 | class CumulativeInstanceNorm2d(nn.Module): 233 | def __init__(self, 234 | num_features, 235 | affine=True, 236 | eps=1e-5, 237 | ): 238 | super(CumulativeInstanceNorm2d, self).__init__() 239 | self.num_features = num_features 240 | self.eps = eps 241 | self.affine = affine 242 | 243 | if affine: 244 | self.gain = nn.Parameter(torch.ones(1,num_features,1,1)) 245 | self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 246 | else: 247 | self.gain = Variable(torch.ones(1,num_features,1,1), requires_grad=False) 248 | self.bias = Variable(torch.zeros(1,num_features,1,1), requires_grad=False) 249 | 250 | 251 | def forward(self, inpt): 252 | """ 253 | :param inpt: (B,C,T,F) 254 | :return: 255 | """ 256 | b_size, channel, seq_len, freq_num = inpt.shape 257 | step_sum = inpt.sum([3], keepdim=True) # (B,C,T,1) 258 | step_pow_sum = inpt.pow(2).sum([3], keepdim=True) # (B,C,T,1) 259 | cum_sum = torch.cumsum(step_sum, dim=-2) # (B,C,T,1) 260 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=-2) # (B,C,T,1) 261 | 262 | entry_cnt = np.arange(freq_num, freq_num*(seq_len+1), freq_num) 263 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 264 | entry_cnt = entry_cnt.view(1,1,seq_len,1).expand_as(cum_sum) 265 | 266 | cum_mean = cum_sum / entry_cnt 267 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 268 | cum_std = (cum_var + self.eps).sqrt() 269 | 270 | x = (inpt - cum_mean) / cum_std 271 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 272 | 273 | 274 | class InstantLayerNorm1d(nn.Module): 275 | def __init__(self, 276 | num_features, 277 | affine=True, 278 | eps=1e-5, 279 | ): 280 | super(InstantLayerNorm1d, self).__init__() 281 | self.num_features = num_features 282 | self.affine = affine 283 | self.eps = eps 284 | 285 | if affine: 286 | self.gain = nn.Parameter(torch.ones(1,num_features,1), requires_grad=True) 287 | self.bias = nn.Parameter(torch.zeros(1,num_features,1), requires_grad=True) 288 | else: 289 | self.gain = Variable(torch.ones(1, num_features, 1), requires_grad=False) 290 | self.bias = Variable(torch.zeros(1, num_features, 1), requires_gra=False) 291 | 292 | def forward(self, inpt): 293 | # inpt: (B,C,T) 294 | b_size, channel, seq_len = inpt.shape 295 | ins_mean = torch.mean(inpt, dim=1, keepdim=True) # (B,1,T) 296 | ins_std = (torch.var(inpt, dim=1, keepdim=True) + self.eps).pow(0.5) # (B,1,T) 297 | x = (inpt - ins_mean) / ins_std 298 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 299 | 300 | 301 | class InstantLayerNorm2d(nn.Module): 302 | def __init__(self, 303 | num_features, 304 | affine=True, 305 | eps=1e-5, 306 | ): 307 | super(InstantLayerNorm2d, self).__init__() 308 | self.num_features = num_features 309 | self.affine = affine 310 | self.eps = eps 311 | if affine: 312 | self.gain = nn.Parameter(torch.ones(1, num_features, 1, 1), requires_grad=True) 313 | self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1), requires_grad=True) 314 | else: 315 | self.gain = Variable(torch.ones(1, num_features, 1, 1), requires_grad=False) 316 | self.bias = Variable(torch.zeros(1, num_features, 1, 1), requires_grad=False) 317 | 318 | def forward(self, inpt): 319 | # inpt: (B,C,T,F) 320 | ins_mean = torch.mean(inpt, dim=[1,3], keepdim=True) # (B,C,T,1) 321 | ins_std = (torch.std(inpt, dim=[1,3], keepdim=True) + self.eps).pow(0.5) # (B,C,T,1) 322 | x = (inpt - ins_mean) / ins_std 323 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 324 | 325 | 326 | 327 | def sisnr(est, label): 328 | label_power = np.sum(label**2.0) + 1e-8 329 | scale = np.sum(est*label) / label_power 330 | est_true = scale * label 331 | est_res = est - est_true 332 | true_power = np.sum(est_true**2.0, axis=0) + 1e-8 333 | res_power = np.sum(est_res**2.0, axis=0) + 1e-8 334 | sdr = 10*np.log10(true_power) - 10*np.log10(res_power) 335 | return sdr 336 | 337 | def cal_pesq(id, esti_utts, clean_utts, fs): 338 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 339 | from pypesq import pesq 340 | pesq_score = pesq(clean_utt, esti_utt, fs=fs) 341 | return pesq_score 342 | 343 | def cal_stoi(id, esti_utts, clean_utts, fs): 344 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 345 | from pystoi import stoi 346 | stoi_score = stoi(clean_utt, esti_utt, fs, extended=True) 347 | return 100*stoi_score 348 | 349 | def cal_sisnr(id, esti_utts, clean_utts, fs): 350 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 351 | sisnr_score = sisnr(esti_utt, clean_utt) 352 | return sisnr_score 353 | --------------------------------------------------------------------------------