├── LICENSE ├── README.md ├── __pycache__ └── solver.cpython-37.pyc ├── configs └── train_config.toml ├── dataset └── data.py ├── main.py ├── nets ├── EaBNet.py ├── GeneralizedWF.py ├── TaylorBeamformer.py ├── TaylorBeamformer_ori.py └── __pycache__ │ ├── TaylorBeamformer.cpython-37.pyc │ ├── gcrn.cpython-37.pyc │ ├── lstm.cpython-37.pyc │ └── tcn_sa.cpython-37.pyc ├── solver.py ├── torch_complex ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── functional.cpython-37.pyc │ ├── tensor.cpython-37.pyc │ └── utils.cpython-37.pyc ├── functional.py ├── tensor.py └── utils.py └── utils ├── loss.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AndongLi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the implementation of the paper ''TaylorBeamformer: Learning All-Neural Beamformer for Multi-Channel Speech Enhancement from Taylor's Approximation Theory"[[arXiv]](https://arxiv.org/abs/2203.07195v2), which is in submission to Interspeech2022. 2 | -------------------------------------------------------------------------------- /__pycache__/solver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/__pycache__/solver.cpython-37.pyc -------------------------------------------------------------------------------- /configs/train_config.toml: -------------------------------------------------------------------------------- 1 | # for loading and saving paths 2 | [path] 3 | data_type = "LibriSpeech" 4 | is_checkpoint = true 5 | is_resume_reload = false 6 | checkpoint_load_path = "CheckpointPath" 7 | checkpoint_load_filename = "" 8 | loss_save_path = "Loss" 9 | model_best_path = "BestModel" 10 | logging_path = "Logger" 11 | 12 | 13 | [path.train] 14 | mix_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_40000/train/mix" 15 | bf_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_40000/train/timvdr/filter_and_sum" 16 | target_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_40000/train/spk" 17 | 18 | [path.val] 19 | mix_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_4000/val/mix" 20 | bf_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_4000/val/timvdr/filter_and_sum" 21 | target_file_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_4000/val/spk" 22 | 23 | [gpu] 24 | gpu_ids = [0] 25 | # signal settings before sending into the network 26 | [signal] 27 | sr = 16000 28 | is_chunk = true 29 | chunk_length = 6.0 30 | win_size = 0.02 31 | win_shift = 0.01 32 | fft_num = 320 33 | is_variance_norm = true 34 | is_compress = true 35 | ref_mic = 0 36 | 37 | # choosed loss function 38 | [loss_function] 39 | path = "utils.loss" 40 | spatial_weight = 1.0 41 | spectral_weight = 1.0 42 | alpha = 0.5 43 | l_type = "L2" 44 | [loss_function.spatial] 45 | classname = "SpatialFilterLoss" 46 | [loss_function.spectral] 47 | classname = "ComMagEuclideanLoss" 48 | 49 | # choosed optimizer 50 | [optimizer] 51 | name = "adam" 52 | lr = 5e-4 53 | beta1 = 0.9 54 | beta2 = 0.999 55 | l2 = 1e-7 56 | gradient_norm = 5.0 57 | epochs = 60 58 | halve_lr = true 59 | early_stop = true 60 | halve_freq = 2 61 | early_stop_freq = 3 62 | print_freq = 10 63 | metric_options = ["ESTOI"] # chioces: [NB-PESQ, ESTOI, SISNR] 64 | 65 | # reproducibility settings 66 | [reproducibility] 67 | seed = 1234 68 | 69 | # Dataset 70 | [dataset] 71 | [dataset.train] 72 | json_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_40000/train/json" 73 | batch_size = 4 74 | is_check = true 75 | is_shuffle = true 76 | 77 | [dataset.val] 78 | json_path = "/media/liandong/mc_dataset_create/mc_dataset/data/spk+noise_linear_6mics/pairs_4000/val/json" 79 | batch_size = 4 80 | is_check = true 81 | is_shuffle = true 82 | 83 | [dataloader] 84 | [dataloader.train] 85 | num_workers = 6 86 | pin_memory = true 87 | drop_last = false 88 | shuffle = false 89 | 90 | [dataloader.val] 91 | num_workers = 6 92 | pin_memory = true 93 | drop_last = false 94 | shuffle = false 95 | 96 | # network configs 97 | [net] 98 | choice="TaylorBeamformer" 99 | path = "nets.TaylorBeamformer" 100 | classname = "TaylorBeamformer" 101 | 102 | [net.TaylorBeamformer.args] 103 | k1 = [1, 3] 104 | k2 = [2, 3] 105 | c = 64 106 | embed_dim = 64 # ablation study 107 | fft_num = 320 108 | order_num = 3 # ablation study 109 | kd1 = 5 110 | cd1 = 64 111 | d_feat = 256 112 | dilations = [1,2,5,9] 113 | group_num = 2 114 | hid_node = 64 115 | M = 6 116 | rnn_type = "LSTM" 117 | intra_connect = "cat" 118 | inter_connect = "cat" 119 | out_type = "mapping" # ablation study 120 | bf_type = "embedding" # ablation study 121 | norm2d_type = "BN" # ablation study 122 | norm1d_type = "BN" 123 | is_total_separate = false # ablation study 124 | is_u2 = true # ablation study 125 | is_1dgate = true 126 | is_squeezed = false # ablation study 127 | is_causal = true 128 | is_param_share = false # ablation study 129 | 130 | [net.EaBNet.args] 131 | k1 = [2, 3] 132 | k2 = [1, 3] 133 | c = 64 134 | M = 6 135 | embed_dim = 64 136 | kd1 = 5 137 | cd1 = 64 138 | d_feat = 256 139 | p = 6 140 | q = 3 141 | is_causal = true 142 | is_u2 = true 143 | bf_type = "lstm" 144 | topo_type = "mimo" 145 | intra_connect = "cat" 146 | norm2d_type = "BN" 147 | norm1d_type = "BN" 148 | 149 | [save] 150 | loss_filename = "librispeech_taylorbeamformer_mic_{linear}_mid_target_{timvdr}_order{0}_param_{nonshared}_bf_{embedding64}_hidnode_{64}_{u2}_{risqueezed}_norm2d_{BN}_norm1d_{BN}_causal_loss.mat" 151 | best_model_filename = "librispeech_taylorbeamformer_mic_{linear}_mid_target_{timvdr}_order{0}_param_{nonshared}_bf_{embedding64}_hidnode_{64}_{u2}_{risqueezed}_norm2d_{BN}_norm1d_{BN}_causal_model.pth" 152 | checkpoint_filename = "librispeech_taylorbeamformer_mic_{linear}_mid_target_{timvdr}_order{0}_param_{nonshared}_bf_{embedding64}_hidnode_{64}_{u2}_{risqueezed}_norm2d_{BN}_norm1d_{BN}_causal.pth.tar" 153 | logger_filename = "librispeech_taylorbeamformer_mic_{linear}_mid_target_{timvdr}_order{0}_param_{nonshared}_bf_{embedding64}_{hidnode}_{64}_{u2}_{risqueezed}_norm2d_{BN}_norm1d_{BN}_causal.txt" 154 | #tensorboard_filename = "librispeech_taylorbeamformer_mic_linear_mid_target_timvdr_order0_param_nonshared_bf_embedding64_hidnode_64_u2_risqueezed_norm2d_BN_norm1d_BN_causal" 155 | 156 | 157 | #loss_filename = "librispeech_baseline_EaBNet_BN_causal_loss.mat" 158 | #best_model_filename = "librispeech_baseline_EaBNet_BN_causal_model.pth" 159 | #checkpoint_filename = "librispeech_baseline_EaBNet_BN_causal_model.pth.tar" 160 | -------------------------------------------------------------------------------- /dataset/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | from random import shuffle 5 | from torch.utils.data import Dataset, DataLoader 6 | import random 7 | import soundfile as sf 8 | import librosa as lib 9 | from utils.utils import BatchInfo, pad_to_longest, logger_print 10 | 11 | class InstanceDataset(Dataset): 12 | def __init__(self, 13 | mix_file_path, 14 | bf_file_path, 15 | target_file_path, 16 | mix_json_path, 17 | bf_json_path, 18 | target_json_path, 19 | batch_size, 20 | is_check, 21 | is_shuffle, 22 | is_variance_norm, 23 | is_chunk, 24 | chunk_length, 25 | sr, 26 | ): 27 | super(InstanceDataset, self).__init__() 28 | self.mix_file_path = mix_file_path 29 | self.bf_file_path = bf_file_path 30 | self.target_file_path = target_file_path 31 | self.mix_json_path = mix_json_path 32 | self.bf_json_path = bf_json_path 33 | self.target_json_path = target_json_path 34 | self.batch_size = batch_size 35 | self.is_check = is_check 36 | self.is_shuffle = is_shuffle 37 | self.is_variance_norm = is_variance_norm 38 | self.is_chunk = is_chunk 39 | self.chunk_length = chunk_length 40 | self.sr = sr 41 | 42 | with open(mix_json_path, "r") as f: 43 | mix_json_list = json.load(f) 44 | with open(bf_json_path, "r") as f: 45 | bf_json_list = json.load(f) 46 | with open(target_json_path, "r") as f: 47 | target_json_list = json.load(f) 48 | 49 | # sort 50 | mix_json_list.sort() 51 | target_json_list.sort() 52 | bf_json_list.sort() 53 | 54 | if is_check: 55 | self.check_align(mix_json_list, target_json_list) 56 | self.check_align(mix_json_list, bf_json_list) 57 | 58 | if is_shuffle: 59 | random.seed(1234) # fixed for reproducibility 60 | zipped_list = list(zip(mix_json_list, bf_json_list, target_json_list)) 61 | shuffle(zipped_list) 62 | # the first type 63 | # mix_json_list, target_json_list = zip(*zipped_list) # mix_json_list and target_json_list are tuple type 64 | # the second type 65 | json_list = list(map(list, zip(*zipped_list))) # json_list (list:2) 66 | [mix_json_list, bf_json_list, target_json_list] = json_list 67 | 68 | mix_minibatch, bf_minibatch, target_minibatch = [], [], [] 69 | start = 0 70 | while True: 71 | end = min(len(mix_json_list), start+batch_size) 72 | mix_minibatch.append(mix_json_list[start:end]) 73 | bf_minibatch.append(bf_json_list[start:end]) 74 | target_minibatch.append(target_json_list[start:end]) 75 | start = end 76 | if end == len(mix_json_list): 77 | break 78 | self.mix_minibatch, self.bf_minibatch, self.target_minibatch = mix_minibatch, bf_minibatch, target_minibatch 79 | self.length = len(mix_minibatch) 80 | 81 | def __len__(self): 82 | return self.length 83 | 84 | def __getitem__(self, index): 85 | mix_minibatch_list = self.mix_minibatch[index] 86 | bf_minibatch_list = self.bf_minibatch[index] 87 | target_minibatch_list = self.target_minibatch[index] 88 | mix_wav_list, bf_wav_list, target_wav_list, wav_len_list = [], [], [], [] 89 | for id in range(len(mix_minibatch_list)): 90 | mix_filename = mix_minibatch_list[id] 91 | bf_filename = bf_minibatch_list[id] 92 | target_filename = target_minibatch_list[id] 93 | extracted_filename_from_mix = "_".join(mix_filename.split("_")[:-1]) 94 | extracted_filename_from_target = "_".join(target_filename.split("_")[:-1]) 95 | assert extracted_filename_from_mix == extracted_filename_from_target 96 | # read speech 97 | mix_wav, mix_sr = sf.read(os.path.join(self.mix_file_path, "{}.wav".format(mix_filename))) # (L,M) 98 | bf_wav, bf_sr = sf.read(os.path.join(self.bf_file_path, "{}.wav".format(bf_filename))) 99 | target_wav, tar_sr = sf.read(os.path.join(self.target_file_path, "{}.wav".format(target_filename))) 100 | if mix_sr != self.sr or bf_sr != self.sr or tar_sr != self.sr: 101 | mix_wav, bf_wav, target_wav = lib.resample(mix_wav, mix_sr, self.sr), \ 102 | lib.resample(bf_wav, bf_sr, self.sr),\ 103 | lib.resample(target_wav, tar_sr, self.sr) 104 | if self.is_variance_norm: 105 | ref_mic = np.mean(mix_wav, axis=-1) # mean mic is selected as the ref for normalization 106 | c = np.sqrt(len(ref_mic) / np.sum(ref_mic ** 2.0)) 107 | mix_wav, bf_wav, target_wav = mix_wav * c, bf_wav * c, target_wav * c 108 | if self.is_chunk and (len(ref_mic) > int(self.sr*self.chunk_length)): 109 | wav_start = random.randint(0, len(ref_mic)-int(self.sr*self.chunk_length)) 110 | mix_wav = mix_wav[wav_start:wav_start+int(self.sr*self.chunk_length), :] 111 | bf_wav = bf_wav[wav_start:wav_start+int(self.sr*self.chunk_length)] 112 | target_wav = target_wav[wav_start:wav_start+int(self.sr*self.chunk_length), :] 113 | mix_wav_list.append(mix_wav) 114 | bf_wav_list.append(bf_wav) 115 | target_wav_list.append(target_wav) 116 | wav_len_list.append(mix_wav.shape[0]) 117 | return mix_wav_list, bf_wav_list, target_wav_list, wav_len_list 118 | 119 | @staticmethod 120 | def check_align(mix_list, target_list): 121 | logger_print("checking.................") 122 | is_ok = 1 123 | mix_error_list, target_error_list = [], [] 124 | for i in range(len(mix_list)): 125 | extracted_filename_from_mix = "_".join(mix_list[i].split("_")[:-1]) 126 | extracted_filename_from_target = "_".join(target_list[i].split("_")[:-1]) 127 | if extracted_filename_from_mix != extracted_filename_from_target: 128 | is_ok = 0 129 | mix_error_list.append(extracted_filename_from_mix) 130 | target_error_list.append(extracted_filename_from_target) 131 | if is_ok == 0: 132 | for i in range(min(len(mix_error_list), len(target_error_list))): 133 | print("mix_file_name:{}, target_file_name:{}".format(mix_error_list[i], 134 | target_error_list[i])) 135 | raise Exception("Datasets between mix and target are not aligned!") 136 | else: 137 | logger_print("checking finished..............") 138 | 139 | class InstanceDataloader(object): 140 | def __init__(self, 141 | data_set, 142 | num_workers, 143 | pin_memory, 144 | drop_last, 145 | shuffle, 146 | ): 147 | self.data_set = data_set 148 | self.num_workers = num_workers 149 | self.pin_memory = pin_memory 150 | self.drop_last = drop_last 151 | self.shuffle = shuffle 152 | 153 | self.data_loader = DataLoader(dataset=data_set, 154 | num_workers=num_workers, 155 | pin_memory=pin_memory, 156 | drop_last=drop_last, 157 | shuffle=shuffle, 158 | collate_fn=self.collate_fn, 159 | batch_size=1 160 | ) 161 | @staticmethod 162 | def collate_fn(batch): 163 | feats, bfs, labels, frame_mask_list = pad_to_longest(batch) 164 | return BatchInfo(feats, bfs, labels, frame_mask_list) 165 | 166 | def get_data_loader(self): 167 | return self.data_loader 168 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import importlib 4 | import logging 5 | import numpy as np 6 | import random 7 | import time 8 | import argparse 9 | from dataset.data import InstanceDataset, InstanceDataloader 10 | from solver import Solver 11 | import warnings 12 | import toml 13 | from utils.utils import json_extraction, numParams, logger_print 14 | warnings.filterwarnings('ignore') 15 | 16 | # fix random seed 17 | def setup_seed(seed): 18 | """ 19 | set up random seed 20 | :param seed: 21 | :return: 22 | """ 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | 29 | 30 | def main(config): 31 | # define seeds 32 | setup_seed(config["reproducibility"]["seed"]) 33 | 34 | # set logger 35 | if not os.path.exists(config["path"]["logging_path"]): 36 | os.makedirs(config["path"]["logging_path"]) 37 | logging.basicConfig(filename=config["path"]["logging_path"] + "/" + config["save"]["logger_filename"], 38 | filemode='w', 39 | level=logging.INFO, 40 | format="%(message)s") 41 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 42 | logger_print(f"start logging time:\t{start_time}") 43 | 44 | # set gpus 45 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu_id) for gpu_id in config["gpu"]["gpu_ids"]]) 46 | logger_print(f"gpus: {os.environ['CUDA_VISIBLE_DEVICES']}") 47 | 48 | # set network 49 | net_choice = config["net"]["choice"] 50 | module = importlib.import_module(config["net"]["path"]) 51 | net_args = config["net"][net_choice]["args"] 52 | net_args.update({"is_compress": config["signal"]["is_compress"]}) 53 | net_args.update({"ref_mic": config["signal"]["ref_mic"]}) 54 | net = getattr(module, config["net"]["classname"])(**net_args) 55 | logger_print(f"The number of trainable parameters: {numParams(net)}") 56 | 57 | # paths generation 58 | if not os.path.exists(config["path"]["checkpoint_load_path"]): 59 | os.makedirs(config["path"]["checkpoint_load_path"]) 60 | if not os.path.exists(config["path"]["loss_save_path"]): 61 | os.makedirs(config["path"]["loss_save_path"]) 62 | if not os.path.exists(config["path"]["model_best_path"]): 63 | os.makedirs(config["path"]["model_best_path"]) 64 | 65 | # save filename 66 | save_name_dict = {} 67 | save_name_dict["loss_filename"] = config["save"]["loss_filename"] 68 | save_name_dict["best_model_filename"] = config["save"]["best_model_filename"] 69 | save_name_dict["checkpoint_filename"] = config["save"]["checkpoint_filename"] 70 | 71 | # determine file json 72 | train_mix_json = json_extraction(config["path"]["train"]["mix_file_path"], config["dataset"]["train"]["json_path"], "mix") 73 | train_bf_json = json_extraction(config["path"]["train"]["bf_file_path"], config["dataset"]["train"]["json_path"], "bf") 74 | train_target_json = json_extraction(config["path"]["train"]["target_file_path"], config["dataset"]["train"]["json_path"], "spk") 75 | val_mix_json = json_extraction(config["path"]["val"]["mix_file_path"], config["dataset"]["val"]["json_path"], "mix") 76 | val_bf_json = json_extraction(config["path"]["val"]["bf_file_path"], config["dataset"]["val"]["json_path"], "bf") 77 | val_target_json = json_extraction(config["path"]["val"]["target_file_path"], config["dataset"]["val"]["json_path"], "spk") 78 | 79 | # define train/validation 80 | train_dataset = InstanceDataset(mix_file_path=config["path"]["train"]["mix_file_path"], 81 | bf_file_path=config["path"]["train"]["bf_file_path"], 82 | target_file_path=config["path"]["train"]["target_file_path"], 83 | mix_json_path=train_mix_json, 84 | bf_json_path=train_bf_json, 85 | target_json_path=train_target_json, 86 | is_variance_norm=config["signal"]["is_variance_norm"], 87 | is_chunk=config["signal"]["is_chunk"], 88 | chunk_length=config["signal"]["chunk_length"], 89 | sr=config["signal"]["sr"], 90 | batch_size=config["dataset"]["train"]["batch_size"], 91 | is_check=config["dataset"]["train"]["is_check"], 92 | is_shuffle=config["dataset"]["train"]["is_shuffle"]) 93 | val_dataset = InstanceDataset(mix_file_path=config["path"]["val"]["mix_file_path"], 94 | bf_file_path=config["path"]["val"]["bf_file_path"], 95 | target_file_path=config["path"]["val"]["target_file_path"], 96 | mix_json_path=val_mix_json, 97 | bf_json_path=val_bf_json, 98 | target_json_path=val_target_json, 99 | is_variance_norm=config["signal"]["is_variance_norm"], 100 | is_chunk=config["signal"]["is_chunk"], 101 | chunk_length=config["signal"]["chunk_length"], 102 | sr=config["signal"]["sr"], 103 | batch_size=config["dataset"]["val"]["batch_size"], 104 | is_check=config["dataset"]["val"]["is_check"], 105 | is_shuffle=config["dataset"]["val"]["is_shuffle"]) 106 | train_dataloader = InstanceDataloader(train_dataset, 107 | **config["dataloader"]["train"]) 108 | val_dataloader = InstanceDataloader(val_dataset, 109 | **config["dataloader"]["val"]) 110 | 111 | # define optimizer 112 | if config["optimizer"]["name"] == "adam": 113 | optimizer = torch.optim.Adam( 114 | net.parameters(), 115 | lr=config["optimizer"]["lr"], 116 | betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]), 117 | weight_decay=config["optimizer"]["l2"]) 118 | 119 | data = {'train_loader': train_dataloader, 'val_loader': val_dataloader} 120 | solver = Solver(data, net, optimizer, save_name_dict, config) 121 | solver.train() 122 | 123 | if __name__ == "__main__": 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument("--C", "--config", type=str, required=False, default="configs/train_config.toml", 126 | help="toml format") 127 | args = parser.parse_args() 128 | config = toml.load(args.C) 129 | print(config) 130 | main(config) 131 | -------------------------------------------------------------------------------- /nets/EaBNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch import Tensor 5 | 6 | 7 | class EaBNet(nn.Module): 8 | def __init__(self, 9 | k1: int = [2, 3], 10 | k2: int = [1, 3], 11 | c: int = 64, 12 | M: int = 9, 13 | embed_dim: int = 64, 14 | kd1: int = 5, 15 | cd1: int = 64, 16 | d_feat: int = 256, 17 | p: int = 6, 18 | q: int = 3, 19 | is_causal: bool = True, 20 | is_u2: bool = True, 21 | bf_type: str = "lstm", 22 | topo_type: str = 'mimo', 23 | intra_connect: str = 'cat', 24 | norm_type: str = "BN", 25 | ): 26 | """ 27 | :param k1: kernel size in the 2-D GLU, (2, 3) by default 28 | :param k2: kernel size in the UNet-blok, (1, 3) by defauly 29 | :param c: channel number in the 2-D Convs, 64 by default 30 | :param M: mic number, 9 by default 31 | :param embed_dim: embedded dimension, 64 by default 32 | :param kd1: kernel size in the Squeezed-TCM (dilation-part), 5 by default 33 | :param cd1: channel number in the Squeezed-TCM (dilation-part), 64 by default 34 | :param d_feat: channel number in the Squeezed-TCM(pointwise-part), 256 by default 35 | :param p: the number of Squeezed-TCMs within a group, 6 by default 36 | :param q: group numbers, 3 by default 37 | :param is_causal: causal flag, True by default 38 | :param is_u2: whether U^{2} is set, True by default 39 | :param bf_type: beamformer type, "lstm" by default 40 | :param topo_type: topology type, "mimo" and "miso", "mimo" by default 41 | :param intra_connect: intra connection type, "cat" by default 42 | """ 43 | super(EaBNet, self).__init__() 44 | self.k1 = tuple(k1) 45 | self.k2 = tuple(k2) 46 | self.c = c 47 | self.M = M 48 | self.embed_dim = embed_dim 49 | self.kd1 = kd1 50 | self.cd1 = cd1 51 | self.d_feat = d_feat 52 | self.p = p 53 | self.q = q 54 | self.is_causal = is_causal 55 | self.is_u2 = is_u2 56 | self.bf_type = bf_type 57 | self.intra_connect = intra_connect 58 | self.topo_type = topo_type 59 | self.norm_type = norm_type 60 | 61 | if is_u2: 62 | self.en = U2Net_Encoder(M*2, tuple(k1), tuple(k2), c, intra_connect, norm_type) 63 | self.de = U2Net_Decoder(embed_dim, c, tuple(k1), tuple(k2), intra_connect, norm_type) 64 | else: 65 | self.en = UNet_Encoder(M*2, tuple(k1), c, norm_type) 66 | self.de = UNet_Decoder(embed_dim, tuple(k1), c, norm_type) 67 | 68 | if topo_type == "mimo": 69 | if bf_type == "lstm": 70 | self.bf_map = LSTM_BF(embed_dim, M) 71 | elif bf_type == "cnn": 72 | self.bf_map = nn.Conv2d(embed_dim, M*2, (1,1), (1,1)) # pointwise 73 | elif topo_type == "miso": 74 | self.bf_map = nn.Conv2d(embed_dim, 2, (1,1), (1,1)) # pointwise 75 | 76 | stcn_list = [] 77 | for _ in range(q): 78 | stcn_list.append(SqueezedTCNGroup(kd1, cd1, d_feat, p, is_causal, norm_type)) 79 | self.stcns = nn.ModuleList(stcn_list) 80 | 81 | def forward(self, inpt: Tensor) -> Tensor: 82 | """ 83 | :param inpt: (B, T, F, M, 2) -> (batchsize, seqlen, freqsize, mics, 2) 84 | :return: beamformed estimation: (B,T,F,2) 85 | """ 86 | if inpt.ndim == 4: 87 | inpt = inpt.unsqueeze(dim=-2) 88 | b_size, seq_len, freq_len, M, _ = inpt.shape 89 | x = inpt.transpose(-2, -1).contiguous() 90 | x = x.view(b_size, seq_len, freq_len, -1).permute(0,3,1,2) 91 | x, en_list = self.en(x) 92 | c = x.shape[1] 93 | x = x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 94 | x_acc = Variable(torch.zeros(x.size()), requires_grad=True).to(x.device) 95 | for i in range(len(self.stcns)): 96 | x = self.stcns[i](x) 97 | x_acc = x_acc + x 98 | x = x_acc 99 | x = x.view(b_size, c, -1, seq_len).transpose(-2, -1).contiguous() 100 | x = self.de(x, en_list) 101 | if self.topo_type == "mimo": 102 | if self.bf_type == "lstm": 103 | bf_w = self.bf_map(x) # (B, T, F, M, 2) 104 | elif self.bf_type == "cnn": 105 | bf_w = self.bf_map(x) 106 | bf_w = bf_w.view(b_size, M, -1, seq_len, freq_len).permute(0,3,4,1,2) # (B,T,F,M,2) 107 | bf_w_r, bf_w_i = bf_w[...,0], -bf_w[...,-1] # conj 108 | esti_x_r, esti_x_i = (bf_w_r*inpt[...,0]-bf_w_i*inpt[...,-1]).sum(dim=-1), \ 109 | (bf_w_r*inpt[...,-1]+bf_w_i*inpt[...,0]).sum(dim=-1) 110 | return torch.stack((esti_x_r, esti_x_i), dim=-1) 111 | elif self.topo_type == "miso": 112 | bf_w = self.bf_map(x) # (B,2,T,F) 113 | bf_w = bf_w.permute(0,2,3,1) # (B,T,F,2) 114 | bf_w_r, bf_w_i = bf_w[...,0], -bf_w[...,-1] 115 | # mic-0 is selected as the target mic herein 116 | esti_x_r, esti_x_i = (bf_w_r*inpt[...,0,0]-bf_w_i*inpt[...,0,-1]).sum(dim=-1), \ 117 | (bf_w_r*inpt[...,0,-1]+bf_w_i*inpt[...,0,0]).sum(dim=-1) 118 | return torch.stack((esti_x_r, esti_x_i), dim=-1) 119 | 120 | class NormSwitch(nn.Module): 121 | def __init__(self, 122 | norm_type, 123 | format_type, 124 | feat_dim, 125 | ): 126 | super(NormSwitch, self).__init__() 127 | self.norm_type = norm_type 128 | self.format_type = format_type 129 | self.feat_dim = feat_dim 130 | 131 | if norm_type == "BN": 132 | if format_type == "1D": 133 | self.norm = nn.BatchNorm1d(feat_dim) 134 | elif format_type == "2D": 135 | self.norm = nn.BatchNorm2d(feat_dim) 136 | elif norm_type == "IN": 137 | if format_type == "1D": 138 | self.norm = nn.InstanceNorm1d(feat_dim, affine=True) 139 | elif format_type == "2D": 140 | self.norm = nn.InstanceNorm2d(feat_dim, affine=True) 141 | 142 | def forward(self, x): 143 | return self.norm(x) 144 | 145 | 146 | class U2Net_Encoder(nn.Module): 147 | def __init__(self, 148 | cin: int, 149 | k1: tuple, 150 | k2: tuple, 151 | c: int, 152 | intra_connect: str, 153 | norm_type: str, 154 | ): 155 | super(U2Net_Encoder, self).__init__() 156 | self.cin = cin 157 | self.k1 = k1 158 | self.k2 = k2 159 | self.c = c 160 | self.intra_connect = intra_connect 161 | self.norm_type = norm_type 162 | k_beg = (2, 5) 163 | c_end = 64 164 | meta_unet = [] 165 | meta_unet.append( 166 | En_unet_module(cin, c, k_beg, k2, intra_connect, norm_type, scale=4, is_deconv=False)) 167 | meta_unet.append( 168 | En_unet_module(c, c, k1, k2, intra_connect, norm_type, scale=3, is_deconv=False)) 169 | meta_unet.append( 170 | En_unet_module(c, c, k1, k2, intra_connect, norm_type, scale=2, is_deconv=False)) 171 | meta_unet.append( 172 | En_unet_module(c, c, k1, k2, intra_connect, norm_type, scale=1, is_deconv=False)) 173 | self.meta_unet_list = nn.ModuleList(meta_unet) 174 | self.last_conv = nn.Sequential( 175 | GateConv2d(c, c_end, k1, (1,2)), 176 | NormSwitch(norm_type, "2D", c_end), 177 | nn.PReLU(c_end) 178 | ) 179 | def forward(self, x: Tensor): 180 | en_list = [] 181 | for i in range(len(self.meta_unet_list)): 182 | x = self.meta_unet_list[i](x) 183 | en_list.append(x) 184 | x = self.last_conv(x) 185 | en_list.append(x) 186 | return x, en_list 187 | 188 | class UNet_Encoder(nn.Module): 189 | def __init__(self, 190 | cin: int, 191 | k1: tuple, 192 | c: int, 193 | norm_type: str,): 194 | super(UNet_Encoder, self).__init__() 195 | self.cin = cin 196 | self.k1 = k1 197 | self.c = c 198 | self.norm_type = norm_type 199 | k_beg = (2, 5) 200 | c_end = 64 201 | unet = [] 202 | unet.append(nn.Sequential( 203 | GateConv2d(cin, c, k_beg, (1,2)), 204 | NormSwitch(norm_type, "2D", c), 205 | nn.PReLU(c))) 206 | unet.append(nn.Sequential( 207 | GateConv2d(c, c, k1, (1,2)), 208 | NormSwitch(norm_type, "2D", c), 209 | nn.PReLU(c))) 210 | unet.append(nn.Sequential( 211 | GateConv2d(c, c, k1, (1,2)), 212 | NormSwitch(norm_type, "2D", c), 213 | nn.PReLU(c))) 214 | unet.append(nn.Sequential( 215 | GateConv2d(c, c, k1, (1,2)), 216 | NormSwitch(norm_type, "2D", c), 217 | nn.PReLU(c))) 218 | unet.append(nn.Sequential( 219 | GateConv2d(c, c_end, k1, (1,2)), 220 | NormSwitch(norm_type, "2D", c_end), 221 | nn.PReLU(64))) 222 | self.unet_list = nn.ModuleList(unet) 223 | 224 | def forward(self, x: Tensor): 225 | en_list = [] 226 | for i in range(len(self.unet_list)): 227 | x = self.unet_list[i](x) 228 | en_list.append(x) 229 | return x, en_list 230 | 231 | class U2Net_Decoder(nn.Module): 232 | def __init__(self, embed_dim, c, k1, k2, intra_connect, norm_type): 233 | super(U2Net_Decoder, self).__init__() 234 | self.embed_dim = embed_dim 235 | self.k1 = k1 236 | self.k2 = k2 237 | self.c = c 238 | self.intra_connect = intra_connect 239 | self.norm_type = norm_type 240 | c_beg = 64 241 | k_end = (2, 5) 242 | 243 | meta_unet = [] 244 | meta_unet.append( 245 | En_unet_module(c_beg*2, c, k1, k2, intra_connect, norm_type, scale=1, is_deconv=True) 246 | ) 247 | meta_unet.append( 248 | En_unet_module(c*2, c, k1, k2, intra_connect, norm_type, scale=2, is_deconv=True) 249 | ) 250 | meta_unet.append( 251 | En_unet_module(c*2, c, k1, k2, intra_connect, norm_type, scale=3, is_deconv=True) 252 | ) 253 | meta_unet.append( 254 | En_unet_module(c*2, c, k1, k2, intra_connect, norm_type, scale=4, is_deconv=True) 255 | ) 256 | self.meta_unet_list = nn.ModuleList(meta_unet) 257 | self.last_conv = nn.Sequential( 258 | GateConvTranspose2d(c*2, embed_dim, k_end, (1,2)), 259 | NormSwitch(norm_type, "2D", embed_dim), 260 | nn.PReLU(embed_dim) 261 | ) 262 | 263 | def forward(self, x: Tensor, en_list: list) -> Tensor: 264 | for i in range(len(self.meta_unet_list)): 265 | tmp = torch.cat((x, en_list[-(i+1)]), dim=1) 266 | x = self.meta_unet_list[i](tmp) 267 | x = torch.cat((x, en_list[0]), dim=1) 268 | x = self.last_conv(x) 269 | return x 270 | 271 | 272 | class UNet_Decoder(nn.Module): 273 | def __init__(self, 274 | embed_dim: int, 275 | k1: tuple, 276 | c: int, 277 | norm_type: str, 278 | ): 279 | super(UNet_Decoder, self).__init__() 280 | self.embed_dim = embed_dim 281 | self.k1 = k1 282 | self.c = c 283 | self.norm_type = norm_type 284 | c_beg = 64 # the channels of the last encoder and the first decoder are fixed at 64 by default 285 | k_end = (2, 5) 286 | unet = [] 287 | unet.append(nn.Sequential( 288 | GateConvTranspose2d(c_beg*2, c, k1, (1,2)), 289 | NormSwitch(norm_type, "2D", c), 290 | nn.PReLU(c) 291 | )) 292 | unet.append(nn.Sequential( 293 | GateConvTranspose2d(c*2, c, k1, (1,2)), 294 | NormSwitch(norm_type, "2D", c), 295 | nn.PReLU(c) 296 | )) 297 | unet.append(nn.Sequential( 298 | GateConvTranspose2d(c*2, c, k1, (1,2)), 299 | NormSwitch(norm_type, "2D", c), 300 | nn.PReLU(c) 301 | )) 302 | unet.append(nn.Sequential( 303 | GateConvTranspose2d(c*2, c, k1, (1,2)), 304 | NormSwitch(norm_type, "2D", c), 305 | nn.PReLU(c) 306 | )) 307 | unet.append(nn.Sequential( 308 | GateConvTranspose2d(c*2, embed_dim, k_end, (1,2)), 309 | NormSwitch(norm_type, "2D", embed_dim), 310 | nn.PReLU(embed_dim) 311 | )) 312 | self.unet_list = nn.ModuleList(unet) 313 | 314 | def forward(self, x: Tensor, en_list: list) -> Tensor: 315 | for i in range(len(self.unet_list)): 316 | tmp = torch.cat((x, en_list[-(i+1)]), dim=1) # skip connections 317 | x = self.unet_list[i](tmp) 318 | return x 319 | 320 | 321 | class En_unet_module(nn.Module): 322 | def __init__(self, 323 | cin: int, 324 | cout: int, 325 | k1: tuple, 326 | k2: tuple, 327 | intra_connect: str, 328 | norm_type: str, 329 | scale: int, 330 | is_deconv: bool, 331 | ): 332 | super(En_unet_module, self).__init__() 333 | self.k1 = k1 334 | self.k2 = k2 335 | self.cin = cin 336 | self.cout = cout 337 | self.intra_connect = intra_connect 338 | self.norm_type = norm_type 339 | self.scale = scale 340 | self.is_deconv = is_deconv 341 | 342 | in_conv_list = [] 343 | if not is_deconv: 344 | in_conv_list.append(GateConv2d(cin, cout, k1, (1,2))) 345 | else: 346 | in_conv_list.append(GateConvTranspose2d(cin, cout, k1, (1,2))) 347 | in_conv_list.append(NormSwitch(norm_type, "2D", cout)) 348 | in_conv_list.append(nn.PReLU(cout)) 349 | self.in_conv = nn.Sequential(*in_conv_list) 350 | 351 | enco_list, deco_list = [], [] 352 | for _ in range(scale): 353 | enco_list.append(Conv2dunit(k2, cout, norm_type)) 354 | for i in range(scale): 355 | if i == 0: 356 | deco_list.append(Deconv2dunit(k2, cout, "add", norm_type)) 357 | else: 358 | deco_list.append(Deconv2dunit(k2, cout, intra_connect, norm_type)) 359 | self.enco = nn.ModuleList(enco_list) 360 | self.deco = nn.ModuleList(deco_list) 361 | self.skip_connect = Skip_connect(intra_connect) 362 | 363 | def forward(self, x): 364 | x_resi = self.in_conv(x) 365 | x = x_resi 366 | x_list = [] 367 | for i in range(len(self.enco)): 368 | x = self.enco[i](x) 369 | x_list.append(x) 370 | 371 | for i in range(len(self.deco)): 372 | if i == 0: 373 | x = self.deco[i](x) 374 | else: 375 | x_con = self.skip_connect(x, x_list[-(i+1)]) 376 | x = self.deco[i](x_con) 377 | x_resi = x_resi + x 378 | del x_list 379 | return x_resi 380 | 381 | 382 | class Conv2dunit(nn.Module): 383 | def __init__(self, 384 | k: tuple, 385 | c: int, 386 | norm_type: str, 387 | ): 388 | super(Conv2dunit, self).__init__() 389 | self.k = k 390 | self.c = c 391 | self.norm_type = norm_type 392 | self.conv = nn.Sequential( 393 | nn.Conv2d(c, c, k, (1, 2)), 394 | NormSwitch(norm_type, "2D", c), 395 | nn.PReLU(c) 396 | ) 397 | def forward(self, x): 398 | return self.conv(x) 399 | 400 | 401 | class Deconv2dunit(nn.Module): 402 | def __init__(self, 403 | k: tuple, 404 | c: int, 405 | intra_connect: str, 406 | norm_type: str, 407 | ): 408 | super(Deconv2dunit, self).__init__() 409 | self.k, self.c = k, c 410 | self.intra_connect = intra_connect 411 | self.norm_type = norm_type 412 | deconv_list = [] 413 | if self.intra_connect == "add": 414 | deconv_list.append(nn.ConvTranspose2d(c, c, k, (1, 2))) 415 | elif self.intra_connect == "cat": 416 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, (1, 2))) 417 | deconv_list.append(NormSwitch(norm_type, "2D", c)) 418 | deconv_list.append(nn.PReLU(c)) 419 | self.deconv = nn.Sequential(*deconv_list) 420 | 421 | def forward(self, x: Tensor) -> Tensor: 422 | return self.deconv(x) 423 | 424 | 425 | class GateConv2d(nn.Module): 426 | def __init__(self, 427 | in_channels: int, 428 | out_channels: int, 429 | kernel_size: tuple, 430 | stride: tuple, 431 | ): 432 | super(GateConv2d, self).__init__() 433 | self.in_channels = in_channels 434 | self.out_channels = out_channels 435 | self.kernel_size = kernel_size 436 | self.stride = stride 437 | k_t = kernel_size[0] 438 | if k_t > 1: 439 | self.conv = nn.Sequential( 440 | nn.ConstantPad2d((0, 0, k_t-1, 0), value=0.), # for causal-setting 441 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, stride=stride)) 442 | else: 443 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 444 | stride=stride) 445 | 446 | def forward(self, inputs: Tensor) -> Tensor: 447 | if inputs.ndim == 3: 448 | inputs = inputs.unsqueeze(dim=1) 449 | x = self.conv(inputs) 450 | outputs, gate = x.chunk(2, dim=1) 451 | return outputs * gate.sigmoid() 452 | 453 | 454 | class GateConvTranspose2d(nn.Module): 455 | def __init__(self, 456 | in_channels: int, 457 | out_channels: int, 458 | kernel_size: tuple, 459 | stride: tuple,): 460 | super(GateConvTranspose2d, self).__init__() 461 | self.in_channels = in_channels 462 | self.out_channels = out_channels 463 | self.kernel_size = kernel_size 464 | self.stride = stride 465 | 466 | k_t = kernel_size[0] 467 | if k_t > 1: 468 | self.conv = nn.Sequential( 469 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 470 | stride=stride), 471 | Chomp_T(k_t-1)) 472 | else: 473 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 474 | stride=stride) 475 | 476 | def forward(self, inputs: Tensor) -> Tensor: 477 | if inputs.ndim == 3: 478 | inputs = inputs.unsqueeze(dim=1) 479 | x = self.conv(inputs) 480 | outputs, gate = x.chunk(2, dim=1) 481 | return outputs * gate.sigmoid() 482 | 483 | 484 | class Skip_connect(nn.Module): 485 | def __init__(self, connect): 486 | super(Skip_connect, self).__init__() 487 | self.connect = connect 488 | 489 | def forward(self, x_main, x_aux): 490 | if self.connect == "add": 491 | x = x_main + x_aux 492 | elif self.connect == "cat": 493 | x = torch.cat((x_main, x_aux), dim=1) 494 | return x 495 | 496 | 497 | class SqueezedTCNGroup(nn.Module): 498 | def __init__(self, 499 | kd1: int, 500 | cd1: int, 501 | d_feat: int, 502 | p: int, 503 | is_causal: bool, 504 | norm_type: str, 505 | ): 506 | super(SqueezedTCNGroup, self).__init__() 507 | self.kd1 = kd1 508 | self.cd1 = cd1 509 | self.d_feat = d_feat 510 | self.p = p 511 | self.is_causal = is_causal 512 | self.norm_type = norm_type 513 | 514 | # Components 515 | self.tcm_list = nn.ModuleList([SqueezedTCM(kd1, cd1, 2**i, d_feat, is_causal, norm_type) for i in range(p)]) 516 | 517 | def forward(self, x): 518 | for i in range(self.p): 519 | x = self.tcm_list[i](x) 520 | return x 521 | 522 | 523 | class SqueezedTCM(nn.Module): 524 | def __init__(self, 525 | kd1: int, 526 | cd1: int, 527 | dilation: int, 528 | d_feat: int, 529 | is_causal: bool, 530 | norm_type: str, 531 | ): 532 | super(SqueezedTCM, self).__init__() 533 | self.kd1 = kd1 534 | self.cd1 = cd1 535 | self.dilation = dilation 536 | self.d_feat = d_feat 537 | self.is_causal = is_causal 538 | self.norm_type = norm_type 539 | 540 | self.in_conv = nn.Conv1d(d_feat, cd1, 1, bias=False) 541 | if is_causal: 542 | pad = ((kd1-1)*dilation, 0) 543 | else: 544 | pad = ((kd1-1)*dilation//2, (kd1-1)*dilation//2) 545 | self.left_conv = nn.Sequential( 546 | nn.PReLU(cd1), 547 | NormSwitch(norm_type, "1D", cd1), 548 | nn.ConstantPad1d(pad, value=0.), 549 | nn.Conv1d(cd1, cd1, kd1, dilation=dilation, bias=False) 550 | ) 551 | self.right_conv = nn.Sequential( 552 | nn.PReLU(cd1), 553 | NormSwitch(norm_type, "1D", cd1), 554 | nn.ConstantPad1d(pad, value=0.), 555 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False), 556 | nn.Sigmoid() 557 | ) 558 | self.out_conv = nn.Sequential( 559 | nn.PReLU(cd1), 560 | NormSwitch(norm_type, "1D", cd1), 561 | nn.Conv1d(cd1, d_feat, kernel_size=1, bias=False) 562 | ) 563 | def forward(self, x): 564 | resi = x 565 | x = self.in_conv(x) 566 | x = self.left_conv(x) * self.right_conv(x) 567 | x = self.out_conv(x) 568 | x = x + resi 569 | return x 570 | 571 | 572 | class LSTM_BF(nn.Module): 573 | def __init__(self, 574 | embed_dim: int, 575 | M: int, 576 | hid_node: int = 64): 577 | super(LSTM_BF, self).__init__() 578 | self.embed_dim = embed_dim 579 | self.M = M 580 | self.hid_node = hid_node 581 | # Components 582 | self.rnn1 = nn.LSTM(input_size=embed_dim, hidden_size=hid_node, batch_first=True) 583 | self.rnn2 = nn.LSTM(input_size=hid_node, hidden_size=hid_node, batch_first=True) 584 | self.w_dnn = nn.Sequential( 585 | nn.Linear(hid_node, hid_node), 586 | nn.ReLU(True), 587 | nn.Linear(hid_node, 2*M) 588 | ) 589 | self.norm = nn.LayerNorm([embed_dim]) 590 | 591 | def forward(self, embed_x: Tensor) -> Tensor: 592 | """ 593 | formulate the bf operation 594 | :param embed_x: (B, C, T, F) 595 | :return: (B, T, F, M, 2) 596 | """ 597 | # norm 598 | B, _, T, F = embed_x.shape 599 | x = self.norm(embed_x.permute(0,3,2,1).contiguous()) 600 | x = x.view(B*F, T, -1) 601 | x, _ = self.rnn1(x) 602 | x, _ = self.rnn2(x) 603 | x = x.view(B, F, T, -1).transpose(1, 2).contiguous() 604 | bf_w = self.w_dnn(x).view(B, T, F, self.M, 2) 605 | return bf_w 606 | 607 | 608 | class Chomp_T(nn.Module): 609 | def __init__(self, 610 | t): 611 | super(Chomp_T, self).__init__() 612 | self.t = t 613 | 614 | def forward(self, x): 615 | return x[:, :, :-self.t, :] 616 | 617 | 618 | def com_mag_mse_loss(esti, label, frame_list): 619 | mask_for_loss = [] 620 | utt_num = esti.size()[0] 621 | with torch.no_grad(): 622 | for i in range(utt_num): 623 | tmp_mask = torch.ones((frame_list[i], esti.size()[-1]), dtype=esti.dtype) 624 | mask_for_loss.append(tmp_mask) 625 | mask_for_loss = nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(esti.device) 626 | com_mask_for_loss = torch.stack((mask_for_loss, mask_for_loss), dim=1) 627 | mag_esti, mag_label = torch.norm(esti, dim=1), torch.norm(label, dim=1) 628 | loss1 = (((mag_esti - mag_label) ** 2.0) * mask_for_loss).sum() / mask_for_loss.sum() 629 | loss2 = (((esti - label)**2.0)*com_mask_for_loss).sum() / com_mask_for_loss.sum() 630 | return 0.5*(loss1 + loss2) 631 | 632 | def numParams(net): 633 | import numpy as np 634 | num = 0 635 | for param in net.parameters(): 636 | if param.requires_grad: 637 | num += int(np.prod(param.size())) 638 | return num 639 | 640 | 641 | 642 | if __name__ == '__main__': 643 | net = EaBNet(k1=[2,3], 644 | k2=[1,3], 645 | c=64, 646 | M=6, 647 | embed_dim=64, 648 | kd1=5, 649 | cd1=64, 650 | d_feat=256, 651 | p=6, 652 | q=3, 653 | is_causal=True, 654 | is_u2=True, 655 | bf_type="lstm", 656 | topo_type="mimo", 657 | intra_connect="cat", 658 | norm_type="BN" 659 | ).cuda() 660 | net.eval() 661 | print("The number of trainable parameters is:{}".format(numParams(net))) 662 | x = torch.rand([2,101,161,6,2]).cuda() 663 | y = net(x) 664 | print(f"{x.shape}-{y.shape}") 665 | from ptflops.flops_counter import get_model_complexity_info 666 | get_model_complexity_info(net, (101, 161, 6, 2)) 667 | 668 | -------------------------------------------------------------------------------- /nets/GeneralizedWF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch.autograd import Variable 5 | from torch_complex.tensor import ComplexTensor 6 | import torch_complex.functional as F 7 | from utils.utils import complex_mul, complex_conj, NormSwitch 8 | 9 | class GeneralizedMultichannelWienerFiter(nn.Module): 10 | def __init__(self, 11 | k1: list, 12 | k2: list, 13 | c: int, 14 | M: int, 15 | fft_num: int, 16 | hid_node: int, 17 | kd1: int, 18 | cd1: int, 19 | d_feat: int, 20 | group_num: int, 21 | is_gate: bool, 22 | dilations: list, 23 | is_causal: bool, 24 | is_u2: bool, 25 | rnn_type: str, 26 | norm1d_type: str, 27 | norm2d_type: str, 28 | intra_connect: str, 29 | inter_connect: str, 30 | out_type: str, 31 | ): 32 | super(GeneralizedMultichannelWienerFiter, self).__init__() 33 | self.k1 = tuple(k1) 34 | self.k2 = tuple(k2) 35 | self.c = c 36 | self.M = M 37 | self.fft_num = fft_num 38 | self.hid_node = hid_node 39 | self.kd1 = kd1 40 | self.cd1 = cd1 41 | self.d_feat = d_feat 42 | self.group_num = group_num 43 | self.is_gate = is_gate 44 | self.dilations = dilations 45 | self.is_causal = is_causal 46 | self.is_u2 = is_u2 47 | self.rnn_type = rnn_type 48 | self.norm1d_type = norm1d_type 49 | self.norm2d_type = norm2d_type 50 | self.intra_connect = intra_connect 51 | self.inter_connect = inter_connect 52 | self.out_type = out_type 53 | 54 | # Components 55 | # inv module 56 | self.inv_module = NeuralInvModule(M, hid_node, out_type, rnn_type) 57 | if is_u2: 58 | self.en = U2Net_Encoder(2*M, self.k1, self.k2, c, intra_connect, norm2d_type) 59 | self.de = U2Net_Decoder(c, self.k1, self.k2, fft_num, intra_connect, inter_connect, norm2d_type, 60 | out_type) 61 | else: 62 | self.en = UNet_Encoder(2*M, self.k1, c, norm2d_type) 63 | self.de = UNet_Decoder(c, self.k1, fft_num, inter_connect, norm2d_type, out_type) 64 | tcn_list = [] 65 | for i in range(group_num): 66 | tcn_list.append(TCMGroup(kd1, cd1, d_feat, is_gate, dilations, is_causal, norm1d_type)) 67 | self.tcns = nn.ModuleList(tcn_list) 68 | 69 | def forward(self, inpt): 70 | """ 71 | inpt: (B,T,F,M,2) 72 | """ 73 | inv_Phi_yy = self.inv_module(inpt) # (B,T,F,M,M,2) 74 | b_size, seq_len, freq_num, M, _ = inpt.shape 75 | inpt1 = inpt.view(b_size, seq_len, freq_num, -1).permute(0,3,1,2).contiguous() 76 | en_x, en_list = self.en(inpt1) 77 | en_x = en_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 78 | acc_x = Variable(torch.zeros_like(en_x), requires_grad=True).to(en_x.device) 79 | x = en_x 80 | for i in range(len(self.tcns)): 81 | x = self.tcns[i](x) 82 | acc_x = acc_x + x 83 | x = acc_x 84 | x = x.view(b_size, 64, 4, seq_len).transpose(-2, -1).contiguous() 85 | Vec_Ys = self.de(inpt, x, en_list) # (B,T,F,M,2) 86 | 87 | # derive wiener filter 88 | inpt_complex = ComplexTensor(inpt[...,0], inpt[...,-1]) # (B,T,F,M) 89 | inv_Phi_yy_complex = ComplexTensor(inv_Phi_yy[...,0], inv_Phi_yy[...,-1]) 90 | Vec_Ys_complex = ComplexTensor(Vec_Ys[...,0], Vec_Ys[...,-1]) 91 | mcwf_bf_complex = F.einsum("...mn,...p->...m", [inv_Phi_yy_complex, Vec_Ys_complex]) # (B,T,F,M) 92 | bf_x_complex = F.einsum("...m,...n->...", [mcwf_bf_complex.conj(), inpt_complex]) 93 | bf_x = torch.stack((bf_x_complex.real, bf_x_complex.imag), dim=-1) # (B,T,F,2) 94 | return bf_x 95 | 96 | 97 | class NeuralInvModule(nn.Module): 98 | def __init__(self, 99 | M: int, 100 | hid_node: int, 101 | out_type: str, 102 | rnn_type: str, 103 | ): 104 | super(NeuralInvModule, self).__init__() 105 | self.M = M 106 | self.hid_node = hid_node 107 | self.out_type = out_type 108 | self.rnn_type = rnn_type 109 | 110 | # Components 111 | inpt_dim = 2*M*M 112 | self.norm = nn.LayerNorm([inpt_dim]) 113 | self.rnn = getattr(nn, rnn_type)(input_size=inpt_dim, hidden_size=hid_node, num_layers=2) 114 | self.w_dnn = nn.Sequential( 115 | nn.Linear(hid_node, hid_node), 116 | nn.ReLU(True), 117 | nn.Linear(hid_node, inpt_dim)) 118 | 119 | def forward(self, inpt): 120 | """ 121 | inpt: (B,T,F,M,2) 122 | return: (B,T,F,M,M,2) 123 | """ 124 | b_size, seq_len, freq_num, M, _ = inpt.shape 125 | inpt_complex = ComplexTensor(inpt[...,0], inpt[...,-1]) # (B,T,F,M) 126 | inpt_cov = F.einsum("...m,...n->...mn", [inpt_complex.conj(), inpt_complex]) # (B,T,F,M,M) 127 | inpt_cov = inpt_cov.view(b_size, seq_len, freq_num, -1) 128 | inpt_cov = torch.cat((inpt_cov.real, inpt_cov.imag), dim=-1) # (B,T,F,2MM) 129 | inpt_cov = self.norm(inpt_cov) 130 | inpt_cov = inpt_cov.transpose(1,2).contiguous().view(b_size*freq_num, seq_len, -1) 131 | h, _ = self.rnn(inpt_cov) 132 | inv_cov = self.w_dnn(h) # (BF,T,2MM) 133 | inv_cov = inv_cov.view(b_size, freq_num, seq_len, M, M, 2) 134 | return inv_cov.transpose(1, 2).contiguous() 135 | 136 | 137 | class UNet_Encoder(nn.Module): 138 | def __init__(self, 139 | cin: int, 140 | k1: tuple, 141 | c: int, 142 | norm2d_type: str, 143 | ): 144 | super(UNet_Encoder, self).__init__() 145 | self.cin = cin 146 | self.k1 = k1 147 | self.c = c 148 | self.norm2d_type = norm2d_type 149 | kernel_begin = (k1[0], 5) 150 | stride = (1, 2) 151 | c_final = 64 152 | unet = [] 153 | unet.append(nn.Sequential( 154 | GateConv2d(cin, c, kernel_begin, stride, padding=(0, 0, k1[0]-1, 0)), 155 | NormSwitch(norm2d_type, "2D", c), 156 | nn.PReLU(c))) 157 | unet.append(nn.Sequential( 158 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 159 | NormSwitch(norm2d_type, "2D", c), 160 | nn.PReLU(c))) 161 | unet.append(nn.Sequential( 162 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 163 | NormSwitch(norm2d_type, "2D", c), 164 | nn.PReLU(c))) 165 | unet.append(nn.Sequential( 166 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 167 | NormSwitch(norm2d_type, "2D", c), 168 | nn.PReLU(c))) 169 | unet.append(nn.Sequential( 170 | GateConv2d(c, c_final, k1, (1,2), padding=(0, 0, k1[0]-1, 0)), 171 | NormSwitch(norm2d_type, "2D", c_final), 172 | nn.PReLU(c_final))) 173 | self.unet_list = nn.ModuleList(unet) 174 | 175 | def forward(self, x: Tensor) -> tuple: 176 | en_list = [] 177 | for i in range(len(self.unet_list)): 178 | x = self.unet_list[i](x) 179 | en_list.append(x) 180 | return x, en_list 181 | 182 | class UNet_Decoder(nn.Module): 183 | def __init__(self, 184 | c: int, 185 | k1: tuple, 186 | fft_num: int, 187 | inter_connect: str, 188 | norm2d_type: str, 189 | out_type: str, 190 | ): 191 | super(UNet_Decoder, self).__init__() 192 | self.k1 = k1 193 | self.c = c 194 | self.fft_num = fft_num 195 | self.inter_connect = inter_connect 196 | self.norm2d_type = norm2d_type 197 | self.out_type = out_type 198 | 199 | kernel_end = (k1[0], 5) 200 | stride = (1, 2) 201 | unet = [] 202 | if inter_connect == "add": 203 | inter_c = c 204 | c_begin = 64 205 | elif inter_connect == "cat": 206 | inter_c = c*2 207 | c_begin = 64*2 208 | else: 209 | raise RuntimeError("Skip connections only support add or concatenate operation") 210 | 211 | unet.append(nn.Sequential( 212 | GateConvTranspose2d(c_begin, c, k1, stride), 213 | NormSwitch(norm2d_type, "2D", c), 214 | nn.PReLU(c))) 215 | unet.append(nn.Sequential( 216 | GateConvTranspose2d(inter_c, c, k1, stride), 217 | NormSwitch(norm2d_type, "2D", c), 218 | nn.PReLU(c))) 219 | unet.append(nn.Sequential( 220 | GateConvTranspose2d(inter_c, c, k1, stride), 221 | NormSwitch(norm2d_type, "2D", c), 222 | nn.PReLU(c))) 223 | unet.append(nn.Sequential( 224 | GateConvTranspose2d(inter_c, c, k1, stride), 225 | NormSwitch(norm2d_type, "2D", c), 226 | nn.PReLU(c))) 227 | unet.append(nn.Sequential( 228 | GateConvTranspose2d(inter_c, c, kernel_end, stride), 229 | NormSwitch(norm2d_type, "2D", c), 230 | nn.PReLU(c))) 231 | self.unet_list = nn.ModuleList(unet) 232 | self.out_r = nn.Sequential( 233 | nn.Conv2d(c, 1, (1,1), (1,1)), 234 | nn.Linear(fft_num//2+1, fft_num//2+1)) 235 | self.out_i = nn.Sequential( 236 | nn.Conv2d(c, 1, (1,1), (1,1)), 237 | nn.Linear(fft_num//2+1, fft_num//2+1)) 238 | 239 | def forward(self, inpt: Tensor, x: Tensor, en_list: list): 240 | """ 241 | inpt: (B,T,F,M,2) 242 | return: (B,T,F,M,2) 243 | """ 244 | b_size, seq_len, freq_num, _, _ = inpt.shape 245 | if self.inter_connect == "add": 246 | for i in range(len(self.unet_list)): 247 | tmp = x + en_list[-(i + 1)] 248 | x = self.unet_list[i](tmp) 249 | elif self.inter_connect == "cat": 250 | for i in range(len(self.unet_list)): 251 | tmp = torch.cat((x, en_list[-(i + 1)]), dim=1) 252 | x = self.unet_list[i](tmp) 253 | else: 254 | raise Exception("only add and cat are supported") 255 | # output 256 | if self.out_type == "mask": 257 | gain = torch.stack((self.out_r(x).squeeze(dim=1), self.out_i(x).squeeze(dim=1)), dim=-1) # (B,T,F,2) 258 | ref_inpt = inpt[...,0,:] # (B,T,F,2) 259 | Yy = complex_mul(inpt, complex_conj(ref_inpt[...,None,:])) # (B,T,F,M,2) 260 | out = complex_mul(complex_conj(gain[...,None,:]), Yy) # (B,T,F,M,2) 261 | 262 | elif self.out_type == "mapping": 263 | map = torch.stack((self.out_r(x).squeeze(dim=1), self.out_i(x).squeeze(dim=1)), dim=-1) # (B,T,F,2) 264 | out = complex_mul(inpt, complex_conj(map[...,None,:])) # (B,T,F,M,2) 265 | else: 266 | raise Exception("only mask and mapping are supported") 267 | return out 268 | 269 | class U2Net_Encoder(nn.Module): 270 | def __init__(self, 271 | cin: int, 272 | k1: tuple, 273 | k2: tuple, 274 | c: int, 275 | intra_connect: str, 276 | norm2d_type: str, 277 | ): 278 | super(U2Net_Encoder, self).__init__() 279 | self.cin = cin 280 | self.k1 = k1 281 | self.k2 = k2 282 | self.c = c 283 | self.intra_connect = intra_connect 284 | self.norm2d_type = norm2d_type 285 | 286 | c_last = 64 287 | kernel_begin = (k1[0], 5) 288 | stride = (1, 2) 289 | meta_unet = [] 290 | meta_unet.append( 291 | En_unet_module(cin, c, kernel_begin, k2, intra_connect, norm2d_type, scale=4, de_flag=False)) 292 | meta_unet.append( 293 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=3, de_flag=False)) 294 | meta_unet.append( 295 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=2, de_flag=False)) 296 | meta_unet.append( 297 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=1, de_flag=False)) 298 | self.meta_unet_list = nn.ModuleList(meta_unet) 299 | self.last_conv = nn.Sequential( 300 | GateConv2d(c, c_last, k1, stride, (0, 0, k1[0]-1, 0)), 301 | NormSwitch(norm2d_type, "2D", c_last), 302 | nn.PReLU(c_last) 303 | ) 304 | 305 | def forward(self, x: Tensor) -> tuple: 306 | en_list = [] 307 | for i in range(len(self.meta_unet_list)): 308 | x = self.meta_unet_list[i](x) 309 | en_list.append(x) 310 | x = self.last_conv(x) 311 | en_list.append(x) 312 | return x, en_list 313 | 314 | class U2Net_Decoder(nn.Module): 315 | def __init__(self, 316 | c: int, 317 | k1: tuple, 318 | k2: tuple, 319 | fft_num: int, 320 | intra_connect: str, 321 | inter_connect: str, 322 | norm2d_type: str, 323 | out_type: str, 324 | ): 325 | super(U2Net_Decoder, self).__init__() 326 | self.c = c 327 | self.k1 = k1 328 | self.k2 = k2 329 | self.fft_num = fft_num 330 | self.intra_connect = intra_connect 331 | self.inter_connect = inter_connect 332 | self.norm2d_type = norm2d_type 333 | self.out_type = out_type 334 | 335 | kernel_end = (k1[0], 5) 336 | stride = (1, 2) 337 | meta_unet = [] 338 | if inter_connect == "add": 339 | inter_c = c 340 | c_begin = 64 341 | elif inter_connect == "cat": 342 | inter_c = c*2 343 | c_begin = 64*2 344 | else: 345 | raise Exception("Skip connections only support add or concatenate operation") 346 | meta_unet.append( 347 | En_unet_module(c_begin, c, k1, k2, intra_connect, norm2d_type, scale=1, de_flag=True)) 348 | meta_unet.append( 349 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=2, de_flag=True)) 350 | meta_unet.append( 351 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=3, de_flag=True)) 352 | meta_unet.append( 353 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=4, de_flag=True)) 354 | meta_unet.append(nn.Sequential( 355 | GateConvTranspose2d(inter_c, c, kernel_end, stride), 356 | NormSwitch(norm2d_type, "2D", c), 357 | nn.PReLU(c))) 358 | self.meta_unet_list = nn.ModuleList(meta_unet) 359 | self.out_r = nn.Sequential( 360 | nn.Conv2d(c, 1, (1, 1), (1, 1)), 361 | nn.Linear(fft_num//2+1, fft_num//2+1)) 362 | self.out_i = nn.Sequential( 363 | nn.Conv2d(c, 1, (1, 1), (1, 1)), 364 | nn.Linear(fft_num//2+1, fft_num//2+1)) 365 | 366 | def forward(self, inpt: Tensor, x: Tensor, en_list: list): 367 | """ 368 | inpt: (B,T,F,M,2) 369 | return: (B,T,F,M,2) 370 | """ 371 | b_size, seq_len, freq_num, M, _ = inpt.shape 372 | if self.inter_connect == "add": 373 | for i in range(len(self.meta_unet_list)): 374 | tmp = x + en_list[-(i+1)] 375 | x = self.meta_unet_list[i](tmp) 376 | elif self.inter_connect == "cat": 377 | for i in range(len(self.meta_unet_list)): 378 | tmp = torch.cat((x, en_list[-(i+1)]), dim=1) 379 | x = self.meta_unet_list[i](tmp) 380 | else: 381 | raise Exception("only add and cat are supported") 382 | # output 383 | if self.out_type == "mask": 384 | gain = torch.stack((self.out_r(x).squeeze(dim=1), self.out_i(x).squeeze(dim=1)), dim=-1) # (B,T,F,2) 385 | ref_inpt = inpt[..., 0, :] # (B,T,F,2) 386 | Yy = complex_mul(inpt, complex_conj(ref_inpt[..., None, :])) # (B,T,F,M,2) 387 | out = complex_mul(complex_conj(gain[..., None, :]), Yy) # (B,T,F,M,2) 388 | 389 | elif self.out_type == "mapping": 390 | map = torch.stack((self.out_r(x).squeeze(dim=1), self.out_i(x).squeeze(dim=1)), dim=-1) # (B,T,F,2) 391 | out = complex_mul(inpt, complex_conj(map[..., None, :])) # (B,T,F,M,2) 392 | else: 393 | raise Exception("only mask and mapping are supported") 394 | return out 395 | 396 | class En_unet_module(nn.Module): 397 | def __init__(self, 398 | cin: int, 399 | cout: int, 400 | k1: tuple, 401 | k2: tuple, 402 | intra_connect: str, 403 | norm2d_type: str, 404 | scale: int, 405 | de_flag: bool = False, 406 | ): 407 | super(En_unet_module, self).__init__() 408 | self.cin = cin 409 | self.cout = cout 410 | self.k1 = k1 411 | self.k2 = k2 412 | self.intra_connect = intra_connect 413 | self.norm2d_type = norm2d_type 414 | self.scale = scale 415 | self.de_flag = de_flag 416 | 417 | in_conv_list = [] 418 | if de_flag is False: 419 | in_conv_list.append(GateConv2d(cin, cout, k1, (1, 2), (0, 0, k1[0]-1, 0))) 420 | else: 421 | in_conv_list.append(GateConvTranspose2d(cin, cout, k1, (1, 2))) 422 | in_conv_list.append(NormSwitch(norm2d_type, "2D", cout)) 423 | in_conv_list.append(nn.PReLU(cout)) 424 | self.in_conv = nn.Sequential(*in_conv_list) 425 | 426 | enco_list, deco_list = [], [] 427 | for _ in range(scale): 428 | enco_list.append(Conv2dunit(k2, cout, norm2d_type)) 429 | for i in range(scale): 430 | if i == 0: 431 | deco_list.append(Deconv2dunit(k2, cout, "add", norm2d_type)) 432 | else: 433 | deco_list.append(Deconv2dunit(k2, cout, intra_connect, norm2d_type)) 434 | self.enco = nn.ModuleList(enco_list) 435 | self.deco = nn.ModuleList(deco_list) 436 | self.skip_connect = Skip_connect(intra_connect) 437 | 438 | 439 | def forward(self, inputs: Tensor) -> Tensor: 440 | x_resi = self.in_conv(inputs) 441 | x = x_resi 442 | x_list = [] 443 | for i in range(len(self.enco)): 444 | x = self.enco[i](x) 445 | x_list.append(x) 446 | 447 | for i in range(len(self.deco)): 448 | if i == 0: 449 | x = self.deco[i](x) 450 | else: 451 | x_con = self.skip_connect(x, x_list[-(i+1)]) 452 | x = self.deco[i](x_con) 453 | x_resi = x_resi + x 454 | del x_list 455 | return x_resi 456 | 457 | class Conv2dunit(nn.Module): 458 | def __init__(self, 459 | k: tuple, 460 | c: int, 461 | norm2d_type: str, 462 | ): 463 | super(Conv2dunit, self).__init__() 464 | self.k, self.c = k, c 465 | self.norm2d_type = norm2d_type 466 | k_t = k[0] 467 | stride = (1, 2) 468 | if k_t > 1: 469 | self.conv = nn.Sequential( 470 | nn.ConstantPad2d((0, 0, k_t-1, 0), value=0.), 471 | nn.Conv2d(c, c, k, stride), 472 | NormSwitch(norm2d_type, "2D", c), 473 | nn.PReLU(c) 474 | ) 475 | else: 476 | self.conv = nn.Sequential( 477 | nn.Conv2d(c, c, k, stride), 478 | NormSwitch(norm2d_type, "2D", c), 479 | nn.PReLU(c) 480 | ) 481 | 482 | def forward(self, inputs: Tensor) -> Tensor: 483 | return self.conv(inputs) 484 | 485 | class Deconv2dunit(nn.Module): 486 | def __init__(self, 487 | k: tuple, 488 | c: int, 489 | intra_connect: str, 490 | norm2d_type: str, 491 | ): 492 | super(Deconv2dunit, self).__init__() 493 | self.k, self.c = k, c 494 | self.intra_connect = intra_connect 495 | self.norm2d_type = norm2d_type 496 | k_t = k[0] 497 | stride = (1, 2) 498 | deconv_list = [] 499 | if self.intra_connect == "add": 500 | if k_t > 1: 501 | deconv_list.append(nn.ConvTranspose2d(c, c, k, stride)), 502 | deconv_list.append(Chomp_T(k_t-1)) 503 | else: 504 | deconv_list.append(nn.ConvTranspose2d(c, c, k, stride)) 505 | elif self.intra_connect == "cat": 506 | if k_t > 1: 507 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, stride)) 508 | deconv_list.append(Chomp_T(k_t-1)) 509 | else: 510 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, stride)) 511 | deconv_list.append(NormSwitch(norm2d_type, "2D", c)) 512 | deconv_list.append(nn.PReLU(c)) 513 | self.deconv = nn.Sequential(*deconv_list) 514 | 515 | def forward(self, inputs: Tensor) -> Tensor: 516 | assert inputs.dim() == 4 517 | return self.deconv(inputs) 518 | 519 | class GateConv2d(nn.Module): 520 | def __init__(self, 521 | in_channels: int, 522 | out_channels: int, 523 | kernel_size: tuple, 524 | stride: tuple, 525 | padding: tuple, 526 | ): 527 | super(GateConv2d, self).__init__() 528 | self.in_channels = in_channels 529 | self.out_channels = out_channels 530 | self.kernel_size = kernel_size 531 | self.stride = stride 532 | self.padding = padding 533 | k_t = kernel_size[0] 534 | if k_t > 1: 535 | self.conv = nn.Sequential( 536 | nn.ConstantPad2d(padding, value=0.), 537 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, stride=stride)) 538 | else: 539 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 540 | stride=stride) 541 | def forward(self, inputs: Tensor) -> Tensor: 542 | if inputs.dim() == 3: 543 | inputs = inputs.unsqueeze(dim=1) 544 | x = self.conv(inputs) 545 | outputs, gate = x.chunk(2, dim=1) 546 | return outputs * gate.sigmoid() 547 | 548 | class GateConvTranspose2d(nn.Module): 549 | def __init__(self, 550 | in_channels: int, 551 | out_channels: int, 552 | kernel_size: tuple, 553 | stride: tuple, 554 | ): 555 | super(GateConvTranspose2d, self).__init__() 556 | self.in_channels = in_channels 557 | self.out_channels = out_channels 558 | self.kernel_size = kernel_size 559 | self.stride = stride 560 | 561 | k_t = kernel_size[0] 562 | if k_t > 1: 563 | self.conv = nn.Sequential( 564 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 565 | stride=stride), 566 | Chomp_T(k_t-1)) 567 | else: 568 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 569 | stride=stride) 570 | 571 | def forward(self, inputs: Tensor) -> Tensor: 572 | assert inputs.dim() == 4 573 | x = self.conv(inputs) 574 | outputs, gate = x.chunk(2, dim=1) 575 | return outputs * gate.sigmoid() 576 | 577 | class Skip_connect(nn.Module): 578 | def __init__(self, 579 | connect): 580 | super(Skip_connect, self).__init__() 581 | self.connect = connect 582 | 583 | def forward(self, x_main, x_aux): 584 | if self.connect == "add": 585 | x = x_main + x_aux 586 | elif self.connect == "cat": 587 | x = torch.cat((x_main, x_aux), dim=1) 588 | return x 589 | 590 | class TCMGroup(nn.Module): 591 | def __init__(self, 592 | kd1: int, 593 | cd1: int, 594 | d_feat: int, 595 | is_gate: bool, 596 | dilations: list, 597 | is_causal: bool, 598 | norm1d_type: str, 599 | ): 600 | super(TCMGroup, self).__init__() 601 | self.kd1 = kd1 602 | self.cd1 = cd1 603 | self.d_feat = d_feat 604 | self.is_gate = is_gate 605 | self.dilations = dilations 606 | self.is_causal = is_causal 607 | self.norm1d_type = norm1d_type 608 | 609 | tcm_list = [] 610 | for i in range(len(dilations)): 611 | tcm_list.append(SqueezedTCM(kd1, cd1, dilation=dilations[i], d_feat=d_feat, is_gate=is_gate, 612 | is_causal=is_causal, norm1d_type=norm1d_type)) 613 | self.tcm_list = nn.ModuleList(tcm_list) 614 | 615 | def forward(self, inputs: Tensor) -> Tensor: 616 | x = inputs 617 | for i in range(len(self.dilations)): 618 | x = self.tcm_list[i](x) 619 | return x 620 | 621 | class SqueezedTCM(nn.Module): 622 | def __init__(self, 623 | kd1: int, 624 | cd1: int, 625 | dilation: int, 626 | d_feat: int, 627 | is_gate: bool, 628 | is_causal: bool, 629 | norm1d_type: str, 630 | ): 631 | super(SqueezedTCM, self).__init__() 632 | self.kd1 = kd1 633 | self.cd1 = cd1 634 | self.dilation = dilation 635 | self.d_feat = d_feat 636 | self.is_gate = is_gate 637 | self.is_causal = is_causal 638 | self.norm1d_type = norm1d_type 639 | 640 | self.in_conv = nn.Conv1d(d_feat, cd1, kernel_size=1, bias=False) 641 | if is_causal: 642 | pad = ((kd1-1)*dilation, 0) 643 | else: 644 | pad = ((kd1-1)*dilation//2, (kd1-1)*dilation//2) 645 | self.left_conv = nn.Sequential( 646 | nn.PReLU(cd1), 647 | NormSwitch(norm1d_type, "1D", cd1), 648 | nn.ConstantPad1d(pad, value=0.), 649 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False) 650 | ) 651 | if is_gate: 652 | self.right_conv = nn.Sequential( 653 | nn.PReLU(cd1), 654 | NormSwitch(norm1d_type, "1D", cd1), 655 | nn.ConstantPad1d(pad, value=0.), 656 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False), 657 | nn.Sigmoid() 658 | ) 659 | self.out_conv = nn.Sequential( 660 | nn.PReLU(cd1), 661 | NormSwitch(norm1d_type, "1D", cd1), 662 | nn.Conv1d(cd1, d_feat, kernel_size=1, bias=False) 663 | ) 664 | 665 | def forward(self, inputs: Tensor) -> Tensor: 666 | resi = inputs 667 | x = self.in_conv(inputs) 668 | if self.is_gate: 669 | x = self.left_conv(x) * self.right_conv(x) 670 | else: 671 | x = self.left_conv(x) 672 | x = self.out_conv(x) 673 | x = x + resi 674 | return x 675 | 676 | class Chomp_T(nn.Module): 677 | def __init__(self, 678 | t: int): 679 | super(Chomp_T, self).__init__() 680 | self.t = t 681 | 682 | def forward(self, x): 683 | return x[:, :, :-self.t, :] 684 | 685 | 686 | if __name__ == "__main__": 687 | net = GeneralizedMultichannelWienerFiter(k1=[2,3], 688 | k2=[1,3], 689 | c=64, 690 | M=6, 691 | fft_num=320, 692 | hid_node=64, 693 | kd1=5, 694 | cd1=64, 695 | d_feat=256, 696 | group_num=2, 697 | is_gate=True, 698 | dilations=[1,2,5,9], 699 | is_causal=True, 700 | is_u2=True, 701 | rnn_type="LSTM", 702 | norm1d_type="BN", 703 | norm2d_type="BN", 704 | intra_connect="cat", 705 | inter_connect="cat", 706 | out_type="mask", 707 | ).cuda() 708 | from utils.utils import numParams 709 | print(f"The number of trainable parameters:{numParams(net)}") 710 | import ptflops 711 | flops, macs = ptflops.get_model_complexity_info(net, (101,161,6,2)) 712 | x = torch.rand([2,51,161,6,2]).cuda() 713 | y = net(x) 714 | print(f"{x.shape}->{y.shape}") 715 | -------------------------------------------------------------------------------- /nets/TaylorBeamformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch.autograd import Variable 5 | from torch_complex.tensor import ComplexTensor 6 | import torch_complex.functional as F 7 | from utils.utils import complex_mul, complex_conj, NormSwitch 8 | import math 9 | 10 | 11 | class TaylorBeamformer(nn.Module): 12 | def __init__(self, 13 | k1: list, 14 | k2: list, 15 | ref_mic: int, 16 | c: int, 17 | embed_dim: int, 18 | fft_num: int, 19 | order_num: int, 20 | kd1: int, 21 | cd1: int, 22 | d_feat: int, 23 | dilations: list, 24 | group_num: int, 25 | hid_node: int, 26 | M: int, 27 | rnn_type: str, 28 | intra_connect: str, 29 | inter_connect: str, 30 | out_type: str, # ["mask", "mapping"] 31 | bf_type: str, # ["embedding", "generalized", "mvdr"] 32 | norm2d_type: str, # ["BN", "IN"] 33 | norm1d_type: str, 34 | is_compress: bool, 35 | is_total_separate: bool, # whether the encoder in the spectral domain contains no spatial info 36 | is_u2: bool, 37 | is_1dgate: bool, 38 | is_squeezed: bool, 39 | is_causal: bool, 40 | is_param_share: bool, 41 | ): 42 | super(TaylorBeamformer, self).__init__() 43 | self.k1 = tuple(k1) 44 | self.k2 = tuple(k2) 45 | self.ref_mic = ref_mic 46 | self.c = c 47 | self.embed_dim = embed_dim 48 | self.fft_num = fft_num 49 | self.order_num = order_num 50 | self.kd1 = kd1 51 | self.cd1 = cd1 52 | self.d_feat = d_feat 53 | self.dilations = dilations 54 | self.group_num = group_num 55 | self.hid_node = hid_node 56 | self.M = M 57 | self.rnn_type = rnn_type 58 | self.intra_connect = intra_connect 59 | self.inter_connect = inter_connect 60 | self.out_type = out_type 61 | self.bf_type = bf_type 62 | self.norm2d_type = norm2d_type 63 | self.norm1d_type = norm1d_type 64 | self.is_compress = is_compress 65 | self.is_total_separate = is_total_separate 66 | self.is_u2 = is_u2 67 | self.is_1dgate = is_1dgate 68 | self.is_squeezed = is_squeezed 69 | self.is_causal = is_causal 70 | self.is_param_share = is_param_share 71 | 72 | #assert (out_type, bf_type) in [("mask", "mvdr"), ("mask", "generalized"), ("mapping", "embedding")] 73 | # Components 74 | self.zeroorderblock = ZeroOrderBlock(self.k1, self.k2, c, embed_dim, fft_num, kd1, cd1, d_feat, dilations, 75 | group_num, hid_node, M, rnn_type, intra_connect, inter_connect, out_type, 76 | bf_type, norm2d_type, norm1d_type, is_u2, is_1dgate, is_causal) 77 | if order_num > 0: 78 | if not is_total_separate: 79 | if is_u2: 80 | self.highorderen = U2Net_Encoder(2*M, self.k1, self.k2, c, intra_connect, norm2d_type) 81 | else: 82 | self.highorderen = UNet_Encoder(2*M, self.k1, c, norm2d_type) 83 | else: 84 | if is_u2: 85 | self.highorderen = U2Net_Encoder(2, self.k1, self.k2, c, intra_connect, norm2d_type) 86 | else: 87 | self.highorderen = UNet_Encoder(2, self.k1, c, norm2d_type) 88 | 89 | highorderblock_list = [] 90 | if is_param_share: 91 | highorderblock_list.append(HighOrderBlock(kd1, cd1, d_feat, dilations, group_num, fft_num, is_1dgate, 92 | is_causal, is_squeezed, norm1d_type)) 93 | else: 94 | for i in range(order_num): 95 | highorderblock_list.append(HighOrderBlock(kd1, cd1, d_feat, dilations, group_num, fft_num, is_1dgate, 96 | is_causal, is_squeezed, norm1d_type)) 97 | self.highorderblock_list = nn.ModuleList(highorderblock_list) 98 | 99 | def forward(self, inpt): 100 | """ 101 | inpt: (B,T,F,M,2) 102 | return: spatial_x_wo_sum: (B,T,F,M,2) and out_term: (B,T,F,2) 103 | """ 104 | if inpt.ndim == 4: 105 | inpt = inpt.unsqueeze(dim=-2) 106 | b_size, seq_len, freq_num, _, _ = inpt.shape 107 | # zero order process 108 | spatial_x = self.zeroorderblock(inpt) # (B,T,F,2) 109 | # taylor unfolding process 110 | if self.is_compress: 111 | inpt_mag, inpt_phase = torch.norm(inpt, dim=-1)**0.5, torch.atan2(inpt[...,-1], inpt[...,0]) 112 | inpt = torch.stack((inpt_mag*torch.cos(inpt_phase), inpt_mag*torch.sin(inpt_phase)), dim=-1) 113 | spatial_mag, spatial_phase = (torch.norm(spatial_x, dim=-1)+1e-10)**0.5, \ 114 | torch.atan2(spatial_x[...,-1], spatial_x[...,0]) 115 | spatial_x = torch.stack((spatial_mag*torch.cos(spatial_phase), spatial_mag*torch.sin(spatial_phase)), dim=1) 116 | else: 117 | spatial_x = spatial_x.permute(0,3,1,2).contiguous() 118 | out_term, pre_term = spatial_x, spatial_x # (B,2,T,F) 119 | # high order encoding 120 | if self.order_num > 0: 121 | if not self.is_total_separate: 122 | inpt = inpt.view(b_size, seq_len, freq_num, -1).permute(0,3,1,2).contiguous() # (B,2M,T,F) 123 | else: 124 | inpt = inpt[...,self.ref_mic,:].permute(0,3,1,2).contiguous() # (B,2,T,F) 125 | en_x, _ = self.highorderen(inpt) 126 | en_x = en_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 127 | 128 | for order_id in range(self.order_num): 129 | if self.is_param_share: 130 | update_term = self.highorderblock_list[0](en_x, pre_term) + order_id * pre_term 131 | else: 132 | update_term = self.highorderblock_list[order_id](en_x, pre_term) + order_id * pre_term 133 | pre_term = update_term 134 | out_term = out_term + update_term / math.factorial(order_id+1) 135 | return spatial_x.permute(0,2,3,1), out_term.permute(0,2,3,1) 136 | 137 | 138 | class ZeroOrderBlock(nn.Module): 139 | def __init__(self, 140 | k1: tuple, 141 | k2: tuple, 142 | c: int, 143 | embed_dim: int, 144 | fft_num: int, 145 | kd1: int, 146 | cd1: int, 147 | d_feat: int, 148 | dilations: list, 149 | group_num: int, 150 | hid_node: int, 151 | M: int, 152 | rnn_type: str, 153 | intra_connect: str, 154 | inter_connect: str, 155 | out_type: str, 156 | bf_type: str, 157 | norm2d_type: str, 158 | norm1d_type: str, 159 | is_u2: bool, 160 | is_1dgate: bool, 161 | is_causal: bool, 162 | ): 163 | super(ZeroOrderBlock, self).__init__() 164 | self.k1 = k1 165 | self.k2 = k2 166 | self.c = c 167 | self.embed_dim = embed_dim 168 | self.fft_num = fft_num 169 | self.kd1 = kd1 170 | self.cd1 = cd1 171 | self.d_feat = d_feat 172 | self.dilations = dilations 173 | self.group_num = group_num 174 | self.hid_node = hid_node 175 | self.M = M 176 | self.rnn_type = rnn_type 177 | self.intra_connect = intra_connect 178 | self.inter_connect = inter_connect 179 | self.out_type = out_type 180 | self.bf_type = bf_type 181 | self.norm2d_type = norm2d_type 182 | self.norm1d_type = norm1d_type 183 | self.is_u2 = is_u2 184 | self.is_1dgate = is_1dgate 185 | self.is_causal = is_causal 186 | # Components 187 | if is_u2: 188 | self.en = U2Net_Encoder(2*M, k1, k2, c, intra_connect, norm2d_type) 189 | self.de = U2Net_Decoder(c, k1, k2, embed_dim, fft_num, intra_connect, inter_connect, out_type, norm2d_type) 190 | else: 191 | self.en = UNet_Encoder(2*M, k1, c, norm2d_type) 192 | self.de = UNet_Decoder(c, k1, embed_dim, fft_num, inter_connect, out_type, norm2d_type) 193 | tcns = [] 194 | for i in range(group_num): 195 | tcns.append(TCMGroup(kd1, cd1, d_feat, is_1dgate, dilations, is_causal, norm1d_type)) 196 | self.tcns = nn.ModuleList(tcns) 197 | self.bf_module = BeamformingModule(embed_dim, M, hid_node, out_type, bf_type, rnn_type) 198 | 199 | 200 | def forward(self, inpt): 201 | """ 202 | inpt: (B,T,F,M,2) 203 | return: (B,T,F,M,2) 204 | """ 205 | b_size, seq_len, freq_num, channel_num, _ = inpt.shape 206 | inpt_x = inpt.contiguous().view(b_size, seq_len, freq_num, -1).permute(0,3,1,2) 207 | en_x, en_list = self.en(inpt_x) 208 | x = en_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 209 | x_acc = Variable(torch.zeros(x.size()), requires_grad=True).to(x.device) 210 | for i in range(self.group_num): 211 | x = self.tcns[i](x) 212 | x_acc += x 213 | x = x_acc 214 | x = x.view(b_size, -1, 4, seq_len).transpose(-2, -1).contiguous() # 4 denotes the freq size of the last encoding layer 215 | 216 | if self.out_type == "mask": 217 | est_s, est_n = self.de(inpt, x, en_list) 218 | bf_weight = self.bf_module(est_s, est_n) 219 | else: 220 | embed_x = self.de(inpt, x, en_list) 221 | bf_weight = self.bf_module(embed_x) 222 | bf_x = torch.sum(complex_mul(complex_conj(bf_weight), inpt), dim=-2) 223 | return bf_x 224 | 225 | 226 | class HighOrderBlock(nn.Module): 227 | def __init__(self, 228 | kd1: int, 229 | cd1: int, 230 | d_feat: int, 231 | dilations: list, 232 | group_num: int, 233 | fft_num: int, 234 | is_1dgate: bool, 235 | is_causal: bool, 236 | is_squeezed: bool, 237 | norm1d_type: str, 238 | ): 239 | super(HighOrderBlock, self).__init__() 240 | self.kd1 = kd1 241 | self.cd1 = cd1 242 | self.d_feat = d_feat 243 | self.dilations = dilations 244 | self.group_num = group_num 245 | self.fft_num = fft_num 246 | self.is_1dgate = is_1dgate 247 | self.is_causal = is_causal 248 | self.is_squeezed = is_squeezed 249 | self.norm1d_type = norm1d_type 250 | 251 | in_feat = (fft_num//2+1)*2 + d_feat 252 | self.in_conv = nn.Conv1d(in_feat, d_feat, 1) 253 | if not is_squeezed: 254 | tcm_r_list, tcm_i_list = [], [] 255 | for i in range(group_num): 256 | tcm_r_list.append(TCMGroup(kd1, cd1, d_feat, is_1dgate, dilations, is_causal, norm1d_type)) 257 | tcm_i_list.append(TCMGroup(kd1, cd1, d_feat, is_1dgate, dilations, is_causal, norm1d_type)) 258 | self.tcms_r, self.tcms_i = nn.ModuleList(tcm_r_list), nn.ModuleList(tcm_i_list) 259 | else: 260 | tcm_list = [] 261 | for i in range(group_num): 262 | tcm_list.append(TCMGroup(kd1, cd1, d_feat, is_1dgate, dilations, is_causal, norm1d_type)) 263 | self.tcms = nn.ModuleList(tcm_list) 264 | self.real_resi, self.imag_resi = nn.Conv1d(d_feat, fft_num//2+1, 1), nn.Conv1d(d_feat, fft_num//2+1, 1) 265 | 266 | 267 | def forward(self, en_x: Tensor, pre_x: Tensor) -> Tensor: 268 | """ 269 | :param en_x: (B, C, T) 270 | :param pre_x: (B, 2, T, F) 271 | :return: (B, 2, T, F) 272 | """ 273 | assert en_x.ndim == 3 and pre_x.ndim == 4 274 | # fuse the features 275 | b_size, _, seq_len, freq_num = pre_x.shape 276 | x1 = pre_x.transpose(-2, -1).contiguous().view(b_size, -1, seq_len) 277 | x = torch.cat((en_x, x1), dim=1) 278 | # in conv 279 | x = self.in_conv(x) 280 | # STCMs 281 | if not self.is_squeezed: 282 | x_r, x_i = x, x 283 | for i in range(self.group_num): 284 | x_r, x_i = self.tcms_r[i](x_r), self.tcms_i[i](x_i) 285 | else: 286 | for i in range(self.group_num): 287 | x = self.tcms[i](x) 288 | x_r, x_i = x, x 289 | # generate real and imaginary parts 290 | x_r, x_i = self.real_resi(x_r).transpose(-2, -1), self.imag_resi(x_i).transpose(-2, -1) 291 | return torch.stack((x_r, x_i), dim=1).contiguous() 292 | 293 | 294 | class UNet_Encoder(nn.Module): 295 | def __init__(self, 296 | cin: int, 297 | k1: tuple, 298 | c: int, 299 | norm2d_type: str, 300 | ): 301 | super(UNet_Encoder, self).__init__() 302 | self.cin = cin 303 | self.k1 = k1 304 | self.c = c 305 | self.norm2d_type = norm2d_type 306 | kernel_begin = (k1[0], 5) 307 | stride = (1, 2) 308 | c_final = 64 309 | unet = [] 310 | unet.append(nn.Sequential( 311 | GateConv2d(cin, c, kernel_begin, stride, padding=(0, 0, k1[0]-1, 0)), 312 | NormSwitch(norm2d_type, "2D", c), 313 | nn.PReLU(c))) 314 | unet.append(nn.Sequential( 315 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 316 | NormSwitch(norm2d_type, "2D", c), 317 | nn.PReLU(c))) 318 | unet.append(nn.Sequential( 319 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 320 | NormSwitch(norm2d_type, "2D", c), 321 | nn.PReLU(c))) 322 | unet.append(nn.Sequential( 323 | GateConv2d(c, c, k1, stride, padding=(0, 0, k1[0]-1, 0)), 324 | NormSwitch(norm2d_type, "2D", c), 325 | nn.PReLU(c))) 326 | unet.append(nn.Sequential( 327 | GateConv2d(c, c_final, k1, (1,2), padding=(0, 0, k1[0]-1, 0)), 328 | NormSwitch(norm2d_type, "2D", c_final), 329 | nn.PReLU(c_final))) 330 | self.unet_list = nn.ModuleList(unet) 331 | 332 | def forward(self, x: Tensor) -> tuple: 333 | en_list = [] 334 | for i in range(len(self.unet_list)): 335 | x = self.unet_list[i](x) 336 | en_list.append(x) 337 | return x, en_list 338 | 339 | 340 | class UNet_Decoder(nn.Module): 341 | def __init__(self, 342 | c: int, 343 | k1: tuple, 344 | embed_dim: int, 345 | fft_num: int, 346 | inter_connect: str, 347 | out_type: str, 348 | norm2d_type: str, 349 | ): 350 | super(UNet_Decoder, self).__init__() 351 | self.k1 = k1 352 | self.c = c 353 | self.embed_dim = embed_dim 354 | self.fft_num = fft_num 355 | self.inter_connect = inter_connect 356 | self.out_type = out_type 357 | self.norm2d_type = norm2d_type 358 | 359 | kernel_end = (k1[0], 5) 360 | stride = (1, 2) 361 | unet = [] 362 | if inter_connect == "add": 363 | inter_c = c 364 | c_begin = 64 365 | elif inter_connect == "cat": 366 | inter_c = c * 2 367 | c_begin = 64 * 2 368 | else: 369 | raise RuntimeError("Skip connections only support add or concatenate operation") 370 | 371 | unet.append(nn.Sequential( 372 | GateConvTranspose2d(c_begin, c, k1, stride), 373 | NormSwitch(norm2d_type, "2D", c), 374 | nn.PReLU(c))) 375 | unet.append(nn.Sequential( 376 | GateConvTranspose2d(inter_c, c, k1, stride), 377 | NormSwitch(norm2d_type, "2D", c), 378 | nn.PReLU(c))) 379 | unet.append(nn.Sequential( 380 | GateConvTranspose2d(inter_c, c, k1, stride), 381 | NormSwitch(norm2d_type, "2D", c), 382 | nn.PReLU(c))) 383 | unet.append(nn.Sequential( 384 | GateConvTranspose2d(inter_c, c, k1, stride), 385 | NormSwitch(norm2d_type, "2D", c), 386 | nn.PReLU(c))) 387 | self.unet_list = nn.ModuleList(unet) 388 | if out_type == "mask": 389 | self.conv = nn.Sequential( 390 | GateConvTranspose2d(inter_c, c, kernel_end, stride), 391 | NormSwitch(norm2d_type, "2D", c), 392 | nn.PReLU(c) 393 | ) 394 | self.mask_s = nn.Sequential( 395 | nn.Conv2d(c, 2, (1, 1), (1, 1)), 396 | nn.Linear(fft_num//2+1, fft_num//2+1) 397 | ) 398 | self.mask_n = nn.Sequential( 399 | nn.Conv2d(c, 2, (1, 1), (1, 1)), 400 | nn.Linear(fft_num//2+1, fft_num//2+1) 401 | ) 402 | elif out_type == "mapping": 403 | self.embed = nn.Sequential( 404 | GateConvTranspose2d(inter_c, embed_dim, kernel_end, stride), 405 | nn.Linear(fft_num//2+1, fft_num//2+1) 406 | ) 407 | 408 | def forward(self, inpt: Tensor, x: Tensor, en_list: list): 409 | """ 410 | inpt: (B,T,F,M,2) 411 | return: (B,-1,T,F) 412 | """ 413 | b_size, seq_len, freq_num, _, _ = inpt.shape 414 | if self.inter_connect == "add": 415 | for i in range(len(self.unet_list)): 416 | tmp = x + en_list[-(i + 1)] 417 | x = self.unet_list[i](tmp) 418 | x = x + en_list[0] 419 | elif self.inter_connect == "cat": 420 | for i in range(len(self.unet_list)): 421 | tmp = torch.cat((x, en_list[-(i + 1)]), dim=1) 422 | x = self.unet_list[i](tmp) 423 | x = torch.cat((x, en_list[0]), dim=1) 424 | else: 425 | raise RuntimeError("only add and cat are supported") 426 | # output 427 | if self.out_type == "mask": 428 | x = self.conv(x) 429 | mask_s, mask_n = self.mask_s(x).permute(0,2,3,1).contiguous().unsqueeze(dim=-2), \ 430 | self.mask_n(x).permute(0,2,3,1).contiguous().unsqueeze(dim=-2) 431 | est_s, est_n = complex_mul(inpt, mask_s), complex_mul(inpt, mask_n) 432 | return est_s, est_n 433 | elif self.out_type == "mapping": 434 | out_x = self.embed(x).permute(0,2,3,1).contiguous() 435 | return out_x 436 | else: 437 | raise RuntimeError("only mask and mapping are supported") 438 | 439 | 440 | class U2Net_Encoder(nn.Module): 441 | def __init__(self, 442 | cin: int, 443 | k1: tuple, 444 | k2: tuple, 445 | c: int, 446 | intra_connect: str, 447 | norm2d_type: str, 448 | ): 449 | super(U2Net_Encoder, self).__init__() 450 | self.cin = cin 451 | self.k1 = k1 452 | self.k2 = k2 453 | self.c = c 454 | self.intra_connect = intra_connect 455 | self.norm2d_type = norm2d_type 456 | c_last = 64 457 | kernel_begin = (k1[0], 5) 458 | stride = (1, 2) 459 | meta_unet = [] 460 | meta_unet.append( 461 | En_unet_module(cin, c, kernel_begin, k2, intra_connect, norm2d_type, scale=4, de_flag=False)) 462 | meta_unet.append( 463 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=3, de_flag=False)) 464 | meta_unet.append( 465 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=2, de_flag=False)) 466 | meta_unet.append( 467 | En_unet_module(c, c, k1, k2, intra_connect, norm2d_type, scale=1, de_flag=False)) 468 | self.meta_unet_list = nn.ModuleList(meta_unet) 469 | self.last_conv = nn.Sequential( 470 | GateConv2d(c, c_last, k1, stride, (0, 0, k1[0]-1, 0)), 471 | NormSwitch(norm2d_type, "2D", c_last), 472 | nn.PReLU(c_last) 473 | ) 474 | 475 | def forward(self, x: Tensor) -> tuple: 476 | en_list = [] 477 | for i in range(len(self.meta_unet_list)): 478 | x = self.meta_unet_list[i](x) 479 | en_list.append(x) 480 | x = self.last_conv(x) 481 | en_list.append(x) 482 | return x, en_list 483 | 484 | 485 | class U2Net_Decoder(nn.Module): 486 | def __init__(self, 487 | c: int, 488 | k1: tuple, 489 | k2: tuple, 490 | embed_dim: int, 491 | fft_num: int, 492 | intra_connect: str, 493 | inter_connect: str, 494 | out_type: str, 495 | norm2d_type: str, 496 | ): 497 | super(U2Net_Decoder, self).__init__() 498 | self.c = c 499 | self.k1 = k1 500 | self.k2 = k2 501 | self.embed_dim = embed_dim 502 | self.fft_num = fft_num 503 | self.intra_connect = intra_connect 504 | self.inter_connect = inter_connect 505 | self.out_type = out_type 506 | self.norm2d_type = norm2d_type 507 | 508 | kernel_end = (k1[0], 5) 509 | stride = (1, 2) 510 | meta_unet = [] 511 | if inter_connect == "add": 512 | inter_c = c 513 | c_begin = 64 514 | elif inter_connect == "cat": 515 | inter_c = c*2 516 | c_begin = 64*2 517 | else: 518 | raise RuntimeError("Skip connections only support add or concatenate operation") 519 | meta_unet.append( 520 | En_unet_module(c_begin, c, k1, k2, intra_connect, norm2d_type, scale=1, de_flag=True)) 521 | meta_unet.append( 522 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=2, de_flag=True)) 523 | meta_unet.append( 524 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=3, de_flag=True)) 525 | meta_unet.append( 526 | En_unet_module(inter_c, c, k1, k2, intra_connect, norm2d_type, scale=4, de_flag=True)) 527 | self.meta_unet_list = nn.ModuleList(meta_unet) 528 | if out_type == "mask": 529 | self.conv = nn.Sequential( 530 | GateConvTranspose2d(inter_c, c, kernel_end, stride), 531 | NormSwitch(norm2d_type, "2D", c), 532 | nn.PReLU(c) 533 | ) 534 | self.mask_s = nn.Sequential( 535 | nn.Conv2d(c, 2, (1, 1), (1, 1)), 536 | nn.Linear(fft_num//2+1, fft_num//2+1) 537 | ) 538 | self.mask_n = nn.Sequential( 539 | nn.Conv2d(c, 2, (1, 1), (1, 1)), 540 | nn.Linear(fft_num//2+1, fft_num//2+1) 541 | ) 542 | elif out_type == "mapping": 543 | self.embed = nn.Sequential( 544 | GateConvTranspose2d(inter_c, embed_dim, kernel_end, stride), 545 | nn.Linear(fft_num//2+1, fft_num//2+1) 546 | ) 547 | 548 | def forward(self, inpt: Tensor, x: Tensor, en_list: list): 549 | """ 550 | inpt: (B,T,F,M,2) 551 | return: (B,T,F,M,2) or (B,T,F,K) 552 | """ 553 | b_size, seq_len, freq_num, _, _ = inpt.shape 554 | if self.inter_connect == "add": 555 | for i in range(len(self.meta_unet_list)): 556 | tmp = x + en_list[-(i+1)] 557 | x = self.meta_unet_list[i](tmp) 558 | x = x + en_list[0] 559 | elif self.inter_connect == "cat": 560 | for i in range(len(self.meta_unet_list)): 561 | tmp = torch.cat((x, en_list[-(i+1)]), dim=1) 562 | x = self.meta_unet_list[i](tmp) 563 | x = torch.cat((x, en_list[0]), dim=1) 564 | else: 565 | raise RuntimeError("only add and cat are supported") 566 | # output 567 | if self.out_type == "mask": 568 | x = self.conv(x) 569 | mask_s, mask_n = self.mask_s(x).permute(0, 2, 3, 1).contiguous().unsqueeze(dim=-2), \ 570 | self.mask_n(x).permute(0, 2, 3, 1).contiguous().unsqueeze(dim=-2) 571 | est_s, est_n = complex_mul(inpt, mask_s), complex_mul(inpt, mask_n) 572 | return est_s, est_n 573 | elif self.out_type == "mapping": 574 | out_x = self.embed(x).permute(0, 2, 3, 1).contiguous() 575 | return out_x 576 | else: 577 | raise RuntimeError("only mask and mapping are supported") 578 | 579 | 580 | class En_unet_module(nn.Module): 581 | def __init__(self, 582 | cin: int, 583 | cout: int, 584 | k1: tuple, 585 | k2: tuple, 586 | intra_connect: str, 587 | norm2d_type: str, 588 | scale: int, 589 | de_flag: bool = False, 590 | ): 591 | super(En_unet_module, self).__init__() 592 | self.cin = cin 593 | self.cout = cout 594 | self.k1 = k1 595 | self.k2 = k2 596 | self.intra_connect = intra_connect 597 | self.norm2d_type = norm2d_type 598 | self.scale = scale 599 | self.de_flag = de_flag 600 | 601 | in_conv_list = [] 602 | if de_flag is False: 603 | in_conv_list.append(GateConv2d(cin, cout, k1, (1, 2), (0, 0, k1[0]-1, 0))) 604 | else: 605 | in_conv_list.append(GateConvTranspose2d(cin, cout, k1, (1, 2))) 606 | in_conv_list.append(NormSwitch(norm2d_type, "2D", cout)) 607 | in_conv_list.append(nn.PReLU(cout)) 608 | self.in_conv = nn.Sequential(*in_conv_list) 609 | 610 | enco_list, deco_list = [], [] 611 | for _ in range(scale): 612 | enco_list.append(Conv2dunit(k2, cout, norm2d_type)) 613 | for i in range(scale): 614 | if i == 0: 615 | deco_list.append(Deconv2dunit(k2, cout, "add", norm2d_type)) 616 | else: 617 | deco_list.append(Deconv2dunit(k2, cout, intra_connect, norm2d_type)) 618 | self.enco = nn.ModuleList(enco_list) 619 | self.deco = nn.ModuleList(deco_list) 620 | self.skip_connect = Skip_connect(intra_connect) 621 | 622 | 623 | def forward(self, inputs: Tensor) -> Tensor: 624 | x_resi = self.in_conv(inputs) 625 | x = x_resi 626 | x_list = [] 627 | for i in range(len(self.enco)): 628 | x = self.enco[i](x) 629 | x_list.append(x) 630 | 631 | for i in range(len(self.deco)): 632 | if i == 0: 633 | x = self.deco[i](x) 634 | else: 635 | x_con = self.skip_connect(x, x_list[-(i+1)]) 636 | x = self.deco[i](x_con) 637 | x_resi = x_resi + x 638 | del x_list 639 | return x_resi 640 | 641 | 642 | class Conv2dunit(nn.Module): 643 | def __init__(self, 644 | k: tuple, 645 | c: int, 646 | norm2d_type: str, 647 | ): 648 | super(Conv2dunit, self).__init__() 649 | self.k, self.c = k, c 650 | self.norm2d_type = norm2d_type 651 | k_t = k[0] 652 | stride = (1, 2) 653 | if k_t > 1: 654 | self.conv = nn.Sequential( 655 | nn.ConstantPad2d((0, 0, k_t-1, 0), value=0.), 656 | nn.Conv2d(c, c, k, stride), 657 | NormSwitch(norm2d_type, "2D", c), 658 | nn.PReLU(c) 659 | ) 660 | else: 661 | self.conv = nn.Sequential( 662 | nn.Conv2d(c, c, k, stride), 663 | NormSwitch(norm2d_type, "2D", c), 664 | nn.PReLU(c) 665 | ) 666 | 667 | def forward(self, inputs: Tensor) -> Tensor: 668 | return self.conv(inputs) 669 | 670 | 671 | class Deconv2dunit(nn.Module): 672 | def __init__(self, 673 | k: tuple, 674 | c: int, 675 | intra_connect: str, 676 | norm2d_type: str, 677 | ): 678 | super(Deconv2dunit, self).__init__() 679 | self.k, self.c = k, c 680 | self.intra_connect = intra_connect 681 | self.norm2d_type = norm2d_type 682 | k_t = k[0] 683 | stride = (1, 2) 684 | deconv_list = [] 685 | if self.intra_connect == "add": 686 | if k_t > 1: 687 | deconv_list.append(nn.ConvTranspose2d(c, c, k, stride)), 688 | deconv_list.append(Chomp_T(k_t-1)) 689 | else: 690 | deconv_list.append(nn.ConvTranspose2d(c, c, k, stride)) 691 | elif self.intra_connect == "cat": 692 | if k_t > 1: 693 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, stride)) 694 | deconv_list.append(Chomp_T(k_t-1)) 695 | else: 696 | deconv_list.append(nn.ConvTranspose2d(2*c, c, k, stride)) 697 | deconv_list.append(NormSwitch(norm2d_type, "2D", c)) 698 | deconv_list.append(nn.PReLU(c)) 699 | self.deconv = nn.Sequential(*deconv_list) 700 | 701 | def forward(self, inputs: Tensor) -> Tensor: 702 | assert inputs.dim() == 4 703 | return self.deconv(inputs) 704 | 705 | 706 | class GateConv2d(nn.Module): 707 | def __init__(self, 708 | in_channels: int, 709 | out_channels: int, 710 | kernel_size: tuple, 711 | stride: tuple, 712 | padding: tuple, 713 | ): 714 | super(GateConv2d, self).__init__() 715 | self.in_channels = in_channels 716 | self.out_channels = out_channels 717 | self.kernel_size = kernel_size 718 | self.stride = stride 719 | self.padding = padding 720 | k_t = kernel_size[0] 721 | if k_t > 1: 722 | self.conv = nn.Sequential( 723 | nn.ConstantPad2d(padding, value=0.), 724 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, stride=stride)) 725 | else: 726 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 727 | stride=stride) 728 | def forward(self, inputs: Tensor) -> Tensor: 729 | if inputs.dim() == 3: 730 | inputs = inputs.unsqueeze(dim=1) 731 | x = self.conv(inputs) 732 | outputs, gate = x.chunk(2, dim=1) 733 | return outputs * gate.sigmoid() 734 | 735 | 736 | class GateConvTranspose2d(nn.Module): 737 | def __init__(self, 738 | in_channels: int, 739 | out_channels: int, 740 | kernel_size: tuple, 741 | stride: tuple, 742 | ): 743 | super(GateConvTranspose2d, self).__init__() 744 | self.in_channels = in_channels 745 | self.out_channels = out_channels 746 | self.kernel_size = kernel_size 747 | self.stride = stride 748 | 749 | k_t = kernel_size[0] 750 | if k_t > 1: 751 | self.conv = nn.Sequential( 752 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 753 | stride=stride), 754 | Chomp_T(k_t-1)) 755 | else: 756 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels*2, kernel_size=kernel_size, 757 | stride=stride) 758 | 759 | def forward(self, inputs: Tensor) -> Tensor: 760 | assert inputs.dim() == 4 761 | x = self.conv(inputs) 762 | outputs, gate = x.chunk(2, dim=1) 763 | return outputs * gate.sigmoid() 764 | 765 | 766 | class Skip_connect(nn.Module): 767 | def __init__(self, 768 | connect): 769 | super(Skip_connect, self).__init__() 770 | self.connect = connect 771 | 772 | def forward(self, x_main, x_aux): 773 | if self.connect == "add": 774 | x = x_main + x_aux 775 | elif self.connect == "cat": 776 | x = torch.cat((x_main, x_aux), dim=1) 777 | return x 778 | 779 | 780 | class TCMGroup(nn.Module): 781 | def __init__(self, 782 | kd1: int, 783 | cd1: int, 784 | d_feat: int, 785 | is_gate: bool, 786 | dilations: list, 787 | is_causal: bool, 788 | norm1d_type: str, 789 | ): 790 | super(TCMGroup, self).__init__() 791 | self.kd1 = kd1 792 | self.cd1 = cd1 793 | self.d_feat = d_feat 794 | self.is_gate = is_gate 795 | self.dilations = dilations 796 | self.is_causal = is_causal 797 | self.norm1d_type = norm1d_type 798 | 799 | tcm_list = [] 800 | for i in range(len(dilations)): 801 | tcm_list.append(SqueezedTCM(kd1, cd1, dilation=dilations[i], d_feat=d_feat, is_gate=is_gate, 802 | is_causal=is_causal, norm1d_type=norm1d_type)) 803 | self.tcm_list = nn.ModuleList(tcm_list) 804 | 805 | def forward(self, inputs: Tensor) -> Tensor: 806 | x = inputs 807 | for i in range(len(self.dilations)): 808 | x = self.tcm_list[i](x) 809 | return x 810 | 811 | 812 | class SqueezedTCM(nn.Module): 813 | def __init__(self, 814 | kd1: int, 815 | cd1: int, 816 | dilation: int, 817 | d_feat: int, 818 | is_gate: bool, 819 | is_causal: bool, 820 | norm1d_type: str, 821 | ): 822 | super(SqueezedTCM, self).__init__() 823 | self.kd1 = kd1 824 | self.cd1 = cd1 825 | self.dilation = dilation 826 | self.d_feat = d_feat 827 | self.is_gate = is_gate 828 | self.is_causal = is_causal 829 | self.norm1d_type = norm1d_type 830 | 831 | self.in_conv = nn.Conv1d(d_feat, cd1, kernel_size=1, bias=False) 832 | if is_causal: 833 | pad = ((kd1-1)*dilation, 0) 834 | else: 835 | pad = ((kd1-1)*dilation//2, (kd1-1)*dilation//2) 836 | self.left_conv = nn.Sequential( 837 | nn.PReLU(cd1), 838 | NormSwitch(norm1d_type, "1D", cd1), 839 | nn.ConstantPad1d(pad, value=0.), 840 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False) 841 | ) 842 | if is_gate: 843 | self.right_conv = nn.Sequential( 844 | nn.PReLU(cd1), 845 | NormSwitch(norm1d_type, "1D", cd1), 846 | nn.ConstantPad1d(pad, value=0.), 847 | nn.Conv1d(cd1, cd1, kernel_size=kd1, dilation=dilation, bias=False), 848 | nn.Sigmoid() 849 | ) 850 | self.out_conv = nn.Sequential( 851 | nn.PReLU(cd1), 852 | NormSwitch(norm1d_type, "1D", cd1), 853 | nn.Conv1d(cd1, d_feat, kernel_size=1, bias=False) 854 | ) 855 | 856 | def forward(self, inputs: Tensor) -> Tensor: 857 | resi = inputs 858 | x = self.in_conv(inputs) 859 | if self.is_gate: 860 | x = self.left_conv(x) * self.right_conv(x) 861 | else: 862 | x = self.left_conv(x) 863 | x = self.out_conv(x) 864 | x = x + resi 865 | return x 866 | 867 | class BeamformingModule(nn.Module): 868 | def __init__(self, 869 | embed_dim: int, 870 | M: int, 871 | hid_node: int, 872 | out_type: str, 873 | bf_type: str, 874 | rnn_type: str, 875 | ): 876 | super(BeamformingModule, self).__init__() 877 | self.embed_dim = embed_dim 878 | self.M = M 879 | self.hid_node = hid_node 880 | self.out_type = out_type 881 | self.bf_type = bf_type 882 | self.rnn_type = rnn_type 883 | assert out_type in ["mask", "mapping"] 884 | assert bf_type in ["embedding", "generalized", "mvdr"] 885 | if out_type == "mask": 886 | inpt_dim = 2*2*M*M 887 | elif out_type == "mapping": 888 | inpt_dim = embed_dim 889 | else: 890 | raise RuntimeError("only mask and mapping are supported") 891 | 892 | if bf_type in ["embedding", "generalized"]: 893 | self.norm = nn.LayerNorm([inpt_dim]) 894 | self.rnn = getattr(nn, rnn_type)(input_size=inpt_dim, hidden_size=hid_node, num_layers=2, batch_first=True) 895 | self.w_dnn = nn.Sequential( 896 | nn.Linear(hid_node, hid_node), 897 | nn.ReLU(True), 898 | nn.Linear(hid_node, 2*M) 899 | ) 900 | elif bf_type == "mvdr": 901 | self.norm1 = nn.LayerNorm([inpt_dim//2]) 902 | self.norm2 = nn.LayerNorm([inpt_dim//2]) 903 | self.rnn1 = getattr(nn, rnn_type)(input_size=inpt_dim//2, hidden_size=hid_node, num_layers=2, batch_first=True) 904 | self.rnn2 = getattr(nn, rnn_type)(input_size=inpt_dim//2, hidden_size=hid_node, num_layers=2, batch_first=True) 905 | self.pca_dnn = nn.Sequential( 906 | nn.Linear(hid_node, hid_node), 907 | nn.ReLU(True), 908 | nn.Linear(hid_node, 2*M) 909 | ) 910 | self.inverse_dnn = nn.Sequential( 911 | nn.Linear(hid_node, hid_node), 912 | nn.ReLU(True), 913 | nn.Linear(hid_node, 2*M*M) 914 | ) 915 | 916 | def forward(self, inpt1, inpt2=None): 917 | if self.out_type == "mask": 918 | est_s, est_n = inpt1, inpt2 919 | complex_s = ComplexTensor(est_s[...,0], est_s[...,-1]) # (B,T,F,M) 920 | complex_n = ComplexTensor(est_n[...,0], est_n[...,-1]) # (B,T,F,M) 921 | cov_s = F.einsum("...m,...n->...mn", [complex_s.conj(), complex_s]) # (B,T,F,M,M) 922 | cov_n = F.einsum("...m,...n->...mn", [complex_n.conj(), complex_n]) # (B,T,F,M,M) 923 | b_size, seq_len, freq_num, M, M = cov_s.shape 924 | cov_s, cov_n = cov_s.view(b_size, seq_len, freq_num, -1), cov_n.view(b_size, seq_len, freq_num, -1) 925 | cov_ss = torch.cat((cov_s.real, cov_s.imag), dim=-1).permute(0,3,1,2) # (B,2*M*M,T,F) 926 | cov_nn = torch.cat((cov_n.real, cov_n.imag), dim=-1).permute(0,3,1,2) # (B,2*M*M,T,F) 927 | else: 928 | embed_x = inpt1.permute(0,3,1,2) # (B,-1,T,F) 929 | b_size, _, seq_len, freq_num = embed_x.shape 930 | 931 | if self.bf_type == "mvdr": 932 | cov_ss, cov_nn = self.norm1(cov_ss.permute(0,3,2,1).contiguous()), \ 933 | self.norm2(cov_nn.permute(0,3,2,1).contiguous()) 934 | cov_ss, cov_nn = cov_ss.view(b_size*freq_num, seq_len, -1), \ 935 | cov_nn.view(b_size*freq_num, seq_len, -1) 936 | # steer vestor 937 | h1, _ = self.rnn1(cov_ss) 938 | steer_vec = self.pca_dnn(h1) 939 | steer_vec = steer_vec.view(b_size, freq_num, seq_len, self.M, 2).transpose(1,2) # (B,T,F,M,2) 940 | # inverse rnn 941 | h2, _ = self.rnn2(cov_nn) 942 | inverse_phi = self.inverse_dnn(h2) 943 | inverse_phi = inverse_phi.view(b_size, freq_num, seq_len, self.M, self.M, 2).transpose(1, 2) # (B,T,F,M,M,2) 944 | # mvdr 945 | complex_steer_vec = ComplexTensor(steer_vec[...,0], steer_vec[...,-1]) # (B,T,F,M) 946 | complex_inverse_phi = ComplexTensor(inverse_phi[...,0], inverse_phi[...,-1]) # (B,T,F,M,M) 947 | nomin = F.einsum("...mn,...n->...m", [complex_inverse_phi, complex_steer_vec]) # (B,T,F,M) 948 | denomin = F.einsum("...m,...m->...", [complex_steer_vec.conj(), nomin]) # (B,T,F) 949 | bf_weight = nomin / denomin.unsqueeze(dim=-1) 950 | bf_weight = torch.stack((bf_weight.real, bf_weight.imag), dim=-1) # (B,T,F,M,2) 951 | elif self.bf_type == "generalized": 952 | x = self.norm(torch.cat((cov_ss, cov_nn), dim=1).permute(0,3,2,1).contiguous()) 953 | x = x.view(b_size*freq_num, seq_len, -1) 954 | h, _ = self.rnn(x) 955 | bf_weight = self.w_dnn(h) 956 | bf_weight = bf_weight.view(b_size, freq_num, seq_len, self.M, 2).transpose(1, 2) 957 | elif self.bf_type == "embedding": 958 | x = self.norm(embed_x.permute(0,3,2,1).contiguous()) 959 | x = x.view(b_size*freq_num, seq_len, -1) 960 | h, _ = self.rnn(x) 961 | bf_weight = self.w_dnn(h) 962 | bf_weight = bf_weight.view(b_size, freq_num, seq_len, self.M, 2).transpose(1, 2) 963 | else: 964 | raise Exception("only mvdr, generalized, and embedding are supported") 965 | return bf_weight 966 | 967 | 968 | class Chomp_T(nn.Module): 969 | def __init__(self, 970 | t: int): 971 | super(Chomp_T, self).__init__() 972 | self.t = t 973 | 974 | def forward(self, x): 975 | return x[:, :, :-self.t, :] 976 | 977 | 978 | 979 | if __name__ == "__main__": 980 | net = TaylorBeamformer( 981 | k1=[1,3], 982 | k2=[2,3], 983 | ref_mic=0, 984 | c=64, 985 | embed_dim=64, 986 | fft_num=320, 987 | order_num=4, 988 | kd1=5, 989 | cd1=64, 990 | d_feat=256, 991 | dilations=[1,2,5,9], 992 | group_num=2, 993 | hid_node=64, 994 | M=6, 995 | rnn_type="LSTM", 996 | intra_connect="cat", 997 | inter_connect="cat", 998 | out_type="mapping", 999 | bf_type="embedding", 1000 | norm2d_type="BN", 1001 | norm1d_type="BN", 1002 | is_compress=True, 1003 | is_total_separate=False, 1004 | is_u2=True, 1005 | is_1dgate=True, 1006 | is_squeezed=True, 1007 | is_causal=True, 1008 | is_param_share=False 1009 | ).cuda() 1010 | x = torch.rand([3,31,161,6,2]).cuda() 1011 | from utils.utils import numParams 1012 | from ptflops.flops_counter import get_model_complexity_info 1013 | print("The number of parameters:{}".format(numParams(net))) 1014 | get_model_complexity_info(net, (101, 161, 6, 2)) 1015 | y1, y2 = net(x) 1016 | print("{}->{},{}".format(x.shape, y1.shape, y2.shape)) 1017 | -------------------------------------------------------------------------------- /nets/__pycache__/TaylorBeamformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/nets/__pycache__/TaylorBeamformer.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/gcrn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/nets/__pycache__/gcrn.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/lstm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/nets/__pycache__/lstm.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/tcn_sa.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/nets/__pycache__/tcn_sa.cpython-37.pyc -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | import importlib 7 | from utils.utils import logger_print 8 | import hdf5storage 9 | from utils.utils import cal_pesq 10 | from utils.utils import cal_stoi 11 | from utils.utils import cal_sisnr 12 | train_epoch, val_epoch, val_metric_epoch = [], [], [] # for loss, loss and metric score 13 | # from torch.utils.tensorboard import SummaryWriter 14 | 15 | 16 | class Solver(object): 17 | def __init__(self, 18 | data, 19 | net, 20 | optimizer, 21 | save_name_dict, 22 | args, 23 | ): 24 | self.train_dataloader = data["train_loader"] 25 | self.val_dataloader = data["val_loader"] 26 | self.net = net 27 | # optimizer part 28 | self.optimizer = optimizer 29 | self.lr = args["optimizer"]["lr"] 30 | self.gradient_norm = args["optimizer"]["gradient_norm"] 31 | self.epochs = args["optimizer"]["epochs"] 32 | self.halve_lr = args["optimizer"]["halve_lr"] 33 | self.early_stop = args["optimizer"]["early_stop"] 34 | self.halve_freq = args["optimizer"]["halve_freq"] 35 | self.early_stop_freq = args["optimizer"]["early_stop_freq"] 36 | self.print_freq = args["optimizer"]["print_freq"] 37 | self.metric_options = args["optimizer"]["metric_options"] 38 | # loss part 39 | self.loss_path = args["loss_function"]["path"] 40 | self.spectral_loss = args["loss_function"]["spectral"]["classname"] 41 | self.spatial_weight = args["loss_function"]["spatial_weight"] 42 | self.spectral_weight = args["loss_function"]["spectral_weight"] 43 | self.alpha = args["loss_function"]["alpha"] 44 | self.l_type = args["loss_function"]["l_type"] 45 | # signal part 46 | self.sr = args["signal"]["sr"] 47 | self.win_size = args["signal"]["win_size"] 48 | self.win_shift = args["signal"]["win_shift"] 49 | self.fft_num = args["signal"]["fft_num"] 50 | self.is_compress = args["signal"]["is_compress"] 51 | self.ref_mic = args["signal"]["ref_mic"] 52 | # path part 53 | self.is_checkpoint = args["path"]["is_checkpoint"] 54 | self.is_resume_reload = args["path"]["is_resume_reload"] 55 | self.checkpoint_load_path = args["path"]["checkpoint_load_path"] 56 | self.checkpoint_load_filename = args["path"]["checkpoint_load_filename"] 57 | self.loss_save_path = args["path"]["loss_save_path"] 58 | self.model_best_path = args["path"]["model_best_path"] 59 | # sava name 60 | self.loss_save_filename = save_name_dict["loss_filename"] 61 | self.best_model_save_filename = save_name_dict["best_model_filename"] 62 | self.checkpoint_save_filename = save_name_dict["checkpoint_filename"] 63 | 64 | self.train_loss = torch.Tensor(self.epochs) 65 | self.val_loss = torch.Tensor(self.epochs) 66 | # set loss funcs 67 | loss_module = importlib.import_module(self.loss_path) 68 | self.spectral_loss = getattr(loss_module, self.spectral_loss)(self.alpha, self.l_type) 69 | self._reset() 70 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 71 | # summarywriter 72 | # self.tensorboard_path = "./" + args["path"]["logging_path"] + "/" + args["save"]["tensorboard_filename"] 73 | # if not os.path.exists(self.tensorboard_path): 74 | # os.makedirs(self.tensorboard_path) 75 | #self.writer = SummaryWriter(self.tensorboard_path, max_queue=5, flush_secs=30) 76 | 77 | def _reset(self): 78 | # Reset 79 | if self.is_resume_reload: 80 | checkpoint = torch.load(os.path.join(self.checkpoint_load_path, self.checkpoint_load_filename)) 81 | self.net.load_state_dict(checkpoint["model_state_dict"]) 82 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 83 | self.start_epoch = checkpoint["start_epoch"] 84 | self.prev_val_loss = checkpoint["val_loss"] # val loss 85 | self.prev_val_metric = checkpoint["val_metric"] 86 | self.best_val_metric = checkpoint["best_val_metric"] 87 | self.val_no_impv = checkpoint["val_no_impv"] 88 | self.halving = checkpoint["halving"] 89 | else: 90 | self.start_epoch = 0 91 | self.prev_val_loss = float("inf") 92 | self.prev_val_metric = -float("inf") 93 | self.best_val_metric = -float("inf") 94 | self.val_no_impv = 0 95 | self.halving = False 96 | 97 | def train(self): 98 | logger_print("Begin to train....") 99 | self.net.to(self.device) 100 | for epoch in range(self.start_epoch, self.epochs): 101 | begin_time = time.time() 102 | # training phase 103 | logger_print("-" * 90) 104 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 105 | logger_print(f"Epoch id:{int(epoch + 1)}, Training phase, Start time:{start_time}") 106 | self.net.train() 107 | train_avg_loss = self._run_one_epoch(epoch, val_opt=False) 108 | # self.writer.add_scalar(f"Loss/Training_Loss", train_avg_loss, epoch) 109 | end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 110 | logger_print(f"Epoch if:{int(epoch + 1)}, Training phase, End time:{end_time}, " 111 | f"Training loss:{train_avg_loss}") 112 | 113 | # Cross val 114 | start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 115 | logger_print(f"Epoch id:{int(epoch + 1)}, Validation phase, Start time:{start_time}") 116 | self.net.eval() # norm and dropout is off 117 | val_avg_loss, val_avg_metric = self._run_one_epoch(epoch, val_opt=True) 118 | # self.writer.add_scalar(f"Loss/Validation_Loss", val_avg_loss, epoch) 119 | # self.writer.add_scalar(f"Loss/Validation_Metric", val_avg_metric, epoch) 120 | end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 121 | logger_print(f"Epoch if:{int(epoch + 1)}, Validation phase, End time:{end_time}, " 122 | f"Validation loss:{val_avg_loss}, Validation metric score:{val_avg_metric}") 123 | end_time = time.time() 124 | print(f"{end_time-begin_time}s in {epoch+1}th epoch") 125 | logger_print("-" * 90) 126 | 127 | # whether to save checkpoint at current epoch 128 | if self.is_checkpoint: 129 | cpk_dic = {} 130 | cpk_dic["model_state_dict"] = self.net.state_dict() 131 | cpk_dic["optimizer_state_dict"] = self.optimizer.state_dict() 132 | cpk_dic["train_loss"] = train_avg_loss 133 | cpk_dic["val_loss"] = val_avg_loss 134 | cpk_dic["val_metric"] = val_avg_metric 135 | cpk_dic["best_val_metric"] = self.best_val_metric 136 | cpk_dic["start_epoch"] = epoch+1 137 | cpk_dic["val_no_impv"] = self.val_no_impv 138 | cpk_dic["halving"] = self.halving 139 | torch.save(cpk_dic, os.path.join(self.checkpoint_load_path, "Epoch_{}_{}_{}".format(epoch+1, 140 | self.net.__class__.__name__, self.checkpoint_save_filename))) 141 | # record loss 142 | # self.train_loss[epoch] = train_avg_loss 143 | # self.val_loss[epoch] = val_avg_loss 144 | 145 | train_epoch.append(train_avg_loss) 146 | val_epoch.append(val_avg_loss) 147 | val_metric_epoch.append(val_avg_metric) 148 | 149 | # save loss 150 | loss = {} 151 | loss["train_loss"] = train_epoch 152 | loss["val_loss"] = val_epoch 153 | loss["val_metric"] = val_metric_epoch 154 | 155 | if not self.is_resume_reload: 156 | hdf5storage.savemat(os.path.join(self.loss_save_path, self.loss_save_filename), loss) 157 | else: 158 | hdf5storage.savemat(os.path.join(self.loss_save_path, "resume_cpk_{}".format(self.loss_save_filename)), 159 | loss) 160 | 161 | # lr halve and Early stop 162 | if self.halve_lr: 163 | if val_avg_metric <= self.prev_val_metric: 164 | self.val_no_impv += 1 165 | if self.val_no_impv == self.halve_freq: 166 | self.halving = True 167 | if (self.val_no_impv >= self.early_stop_freq) and self.early_stop: 168 | logger_print("No improvements and apply early-stopping") 169 | break 170 | else: 171 | self.val_no_impv = 0 172 | 173 | if self.halving: 174 | optim_state = self.optimizer.state_dict() 175 | optim_state["param_groups"][0]["lr"] = optim_state["param_groups"][0]["lr"] / 2.0 176 | self.optimizer.load_state_dict(optim_state) 177 | logger_print("Learning rate is adjusted to %5f" % (optim_state["param_groups"][0]["lr"])) 178 | self.halving = False 179 | self.prev_val_metric = val_avg_metric 180 | 181 | if val_avg_metric > self.best_val_metric: 182 | self.best_val_metric = val_avg_metric 183 | torch.save(self.net.state_dict(), os.path.join(self.model_best_path, self.best_model_save_filename)) 184 | logger_print(f"Find better model, saving to {self.best_model_save_filename}") 185 | else: 186 | logger_print("Did not find better model") 187 | 188 | 189 | @torch.no_grad() 190 | def _val_batch(self, batch_info): 191 | batch_mix_wav = batch_info.feats.to(self.device) # (B,L,M) 192 | batch_target_wav = batch_info.labels[...,self.ref_mic].to(self.device) # (B,L) 193 | batch_wav_len_list = batch_info.frame_mask_list 194 | 195 | # stft 196 | b_size, wav_len, channel_num = batch_mix_wav.shape 197 | batch_mix_wav = batch_mix_wav.transpose(-2, -1).contiguous().view(b_size*channel_num, wav_len) 198 | win_size, win_shift = int(self.sr*self.win_size), int(self.sr*self.win_shift) 199 | batch_mix_stft = torch.stft( 200 | batch_mix_wav, 201 | n_fft=self.fft_num, 202 | hop_length=win_shift, 203 | win_length=win_size, 204 | window=torch.hann_window(win_size).to(self.device)) # (BM,F,T,2) 205 | batch_target_stft = torch.stft( 206 | batch_target_wav, 207 | n_fft=self.fft_num, 208 | hop_length=win_shift, 209 | win_length=win_size, 210 | window=torch.hann_window(win_size).to(self.device)) # (B,F,T,2) 211 | batch_frame_list = [] 212 | for i in range(len(batch_wav_len_list)): 213 | curr_frame_num = (batch_wav_len_list[i]-win_size+win_size)//win_shift+1 # center case 214 | batch_frame_list.append(curr_frame_num) 215 | 216 | _, freq_num, seq_len, _ = batch_mix_stft.shape 217 | batch_mix_stft = batch_mix_stft.view(b_size, -1, freq_num, seq_len, 2) 218 | 219 | if self.is_compress: # here only apply to target and bf as feat-compression has been applied within the network 220 | # target 221 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1)**0.5, torch.atan2( 222 | batch_target_stft[..., -1], batch_target_stft[..., 0]) 223 | batch_target_stft = torch.stack((batch_target_mag * torch.cos(batch_target_phase), 224 | batch_target_mag * torch.sin(batch_target_phase)), dim=-1) 225 | 226 | # convert to formats: (B,T,F,M,2) for mix, (B,T,F,2) for target and bf 227 | batch_mix_stft = batch_mix_stft.permute(0, 3, 2, 1, 4) 228 | batch_target_stft = batch_target_stft.transpose(1, 2) 229 | # net predict 230 | _, batch_spec_est = self.net(batch_mix_stft) # (B,T,F,2), (B,T,F,2) 231 | 232 | # cal mse loss 233 | batch_mse_loss = self.spectral_loss(batch_spec_est, batch_target_stft, batch_frame_list) 234 | # cal metric loss 235 | batch_ref_mix_stft = batch_mix_stft[...,self.ref_mic,:].transpose(1,2) # (B,F,T,2) 236 | batch_spec_est = batch_spec_est.transpose(1,2) # (B,F,T,2) 237 | batch_target_stft = batch_target_stft.transpose(1,2) # (B,F,T,2) 238 | if self.is_compress: 239 | batch_spec_mag, batch_spec_phase = torch.norm(batch_spec_est, dim=-1)**2.0,\ 240 | torch.atan2(batch_spec_est[...,-1], batch_spec_est[...,0]) 241 | batch_spec_est = torch.stack((batch_spec_mag*torch.cos(batch_spec_phase), 242 | batch_spec_mag*torch.sin(batch_spec_phase)), dim=-1) 243 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1) ** 2.0, \ 244 | torch.atan2(batch_target_stft[...,-1], batch_target_stft[...,0]) 245 | batch_target_stft = torch.stack((batch_target_mag * torch.cos(batch_target_phase), 246 | batch_target_mag * torch.sin(batch_target_phase)), dim=-1) 247 | batch_mix_wav = torch.istft(batch_ref_mix_stft, 248 | n_fft=self.fft_num, 249 | hop_length=win_shift, 250 | win_length=win_size, 251 | window=torch.hann_window(win_size).to(self.device)).cpu().numpy() # (B,L) 252 | batch_est_wav = torch.istft(batch_spec_est, 253 | n_fft=self.fft_num, 254 | hop_length=win_shift, 255 | win_length=win_size, 256 | window=torch.hann_window(win_size).to(self.device)).cpu().numpy() # (B,L) 257 | batch_target_wav = torch.istft(batch_target_stft, 258 | n_fft=self.fft_num, 259 | hop_length=win_shift, 260 | win_length=win_size, 261 | window=torch.hann_window(win_size).to(self.device)).cpu().numpy() # (B,L) 262 | 263 | loss_dict = {} 264 | loss_dict["mse_loss"] = batch_mse_loss.item() 265 | # create mask 266 | mask_list = [] 267 | for id in range(b_size): 268 | mask_list.append(torch.ones((batch_wav_len_list[id]))) 269 | wav_mask = torch.nn.utils.rnn.pad_sequence(mask_list, batch_first=True).to(batch_mix_stft.device) # (B,L) 270 | batch_mix_wav, batch_target_wav, batch_est_wav = (batch_mix_wav * wav_mask).cpu().numpy(), \ 271 | (batch_target_wav * wav_mask).cpu().numpy(), \ 272 | (batch_est_wav * wav_mask).cpu().numpy() 273 | if "SISNR" in self.metric_options: 274 | unpro_score_list, pro_score_list = [], [] 275 | for id in range(batch_mix_wav.shape[0]): 276 | unpro_score_list.append(cal_sisnr(id, batch_mix_wav, batch_target_wav, self.sr)) 277 | pro_score_list.append(cal_sisnr(id, batch_est_wav, batch_target_wav, self.sr)) 278 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), np.asarray(pro_score_list) 279 | unpro_sisnr_mean_score, pro_sisnr_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 280 | loss_dict["unpro_metric"] = unpro_sisnr_mean_score 281 | loss_dict["pro_metric"] = pro_sisnr_mean_score 282 | if "NB-PESQ" in self.metric_options: 283 | unpro_score_list, pro_score_list = [], [] 284 | for id in range(batch_mix_wav.shape[0]): 285 | unpro_score_list.append(cal_pesq(id, batch_mix_wav, batch_target_wav, self.sr)) 286 | pro_score_list.append(cal_pesq(id, batch_est_wav, batch_target_wav, self.sr)) 287 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), \ 288 | np.asarray(pro_score_list) 289 | unpro_pesq_mean_score, pro_pesq_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 290 | loss_dict["unpro_metric"] = unpro_pesq_mean_score 291 | loss_dict["pro_metric"] = pro_pesq_mean_score 292 | if "ESTOI" in self.metric_options: 293 | unpro_score_list, pro_score_list = [], [] 294 | for id in range(batch_mix_wav.shape[0]): 295 | unpro_score_list.append(cal_stoi(id, batch_mix_wav, batch_target_wav, self.sr)) 296 | pro_score_list.append(cal_stoi(id, batch_mix_wav, batch_target_wav, self.sr)) 297 | unpro_score_list, pro_score_list = np.asarray(unpro_score_list), \ 298 | np.asarray(pro_score_list) 299 | unpro_estoi_mean_score, pro_estoi_mean_score = np.mean(unpro_score_list), np.mean(pro_score_list) 300 | loss_dict["unpro_metric"] = unpro_estoi_mean_score 301 | loss_dict["pro_metric"] = pro_estoi_mean_score 302 | return loss_dict 303 | 304 | 305 | def _train_batch(self, batch_info): 306 | batch_mix_wav = batch_info.feats.to(self.device) # (B,L,M) 307 | batch_bf_wav = batch_info.bfs.to(self.device) # (B,L) 308 | batch_target_wav = batch_info.labels[..., self.ref_mic].to(self.device) # (B,L), only ref-mic is selected 309 | batch_wav_len_list = batch_info.frame_mask_list 310 | 311 | # stft 312 | b_size, wav_len, channel_num = batch_mix_wav.shape 313 | batch_mix_wav = batch_mix_wav.transpose(-2, -1).contiguous().view(b_size * channel_num, wav_len) 314 | win_size, win_shift = int(self.sr * self.win_size), int(self.sr * self.win_shift) 315 | batch_mix_stft = torch.stft( 316 | batch_mix_wav, 317 | n_fft=self.fft_num, 318 | hop_length=win_shift, 319 | win_length=win_size, 320 | window=torch.hann_window(win_size).to(batch_mix_wav.device)) # (BM,F,T,2) 321 | batch_bf_stft = torch.stft( 322 | batch_bf_wav, 323 | n_fft=self.fft_num, 324 | hop_length=win_shift, 325 | win_length=win_size, 326 | window=torch.hann_window(win_size).to(batch_mix_wav.device)) # (B,F,T,2) 327 | batch_target_stft = torch.stft( 328 | batch_target_wav, 329 | n_fft=self.fft_num, 330 | hop_length=win_shift, 331 | win_length=win_size, 332 | window=torch.hann_window(win_size).to(batch_mix_wav.device)) # (B,F,T,2) 333 | batch_frame_list = [] 334 | for i in range(len(batch_wav_len_list)): 335 | curr_frame_num = (batch_wav_len_list[i] - win_size + win_size) // win_shift + 1 336 | batch_frame_list.append(curr_frame_num) 337 | 338 | _, freq_num, seq_len, _ = batch_mix_stft.shape 339 | batch_mix_stft = batch_mix_stft.view(b_size, -1, freq_num, seq_len, 2) 340 | if self.is_compress: # here only apply to target and bf as feat-compression has been applied within the network 341 | # target 342 | batch_target_mag, batch_target_phase = torch.norm(batch_target_stft, dim=-1)**0.5, torch.atan2( 343 | batch_target_stft[..., -1], batch_target_stft[..., 0]) 344 | batch_target_stft = torch.stack((batch_target_mag * torch.cos(batch_target_phase), 345 | batch_target_mag * torch.sin(batch_target_phase)), dim=-1) 346 | # bf 347 | batch_bf_mag, batch_bf_phase = torch.norm(batch_bf_stft, dim=-1)**0.5, torch.atan2( 348 | batch_bf_stft[..., -1], batch_bf_stft[..., 0]) 349 | batch_bf_stft = torch.stack((batch_bf_mag * torch.cos(batch_bf_phase), 350 | batch_bf_mag * torch.sin(batch_bf_phase)), dim=-1) 351 | 352 | # convert to formats: (B,T,F,M,2) for mix, (B,T,F,2) for target and bf 353 | batch_mix_stft = batch_mix_stft.permute(0, 3, 2, 1, 4) 354 | batch_bf_stft = batch_bf_stft.transpose(1, 2) 355 | batch_target_stft = batch_target_stft.transpose(1, 2) 356 | 357 | with torch.enable_grad(): 358 | batch_bf_est, batch_spec_est = self.net(batch_mix_stft) # (B,T,F,2), (B,T,F,2) 359 | 360 | # beamforming loss 361 | batch_spatial_loss = self.spectral_loss(batch_bf_est, batch_bf_stft, batch_frame_list) 362 | # reconstruction loss 363 | batch_spectral_loss = self.spectral_loss(batch_spec_est, batch_target_stft, batch_frame_list) 364 | batch_loss = self.spatial_weight*batch_spatial_loss + self.spectral_weight*batch_spectral_loss 365 | # params update 366 | self.update_params(batch_loss) 367 | loss_dict = {} 368 | loss_dict["spatial_bf_loss"] = batch_spatial_loss.item() 369 | loss_dict["spectral_loss"] = batch_spectral_loss.item() 370 | return loss_dict 371 | 372 | 373 | def _run_one_epoch(self, epoch, val_opt=False): 374 | # training phase 375 | if not val_opt: 376 | data_loader = self.train_dataloader 377 | total_bf_loss, total_sp_loss = 0., 0. 378 | start_time = time.time() 379 | for batch_id, batch_info in enumerate(data_loader.get_data_loader()): 380 | loss_dict = self._train_batch(batch_info) 381 | total_bf_loss += loss_dict["spatial_bf_loss"] 382 | total_sp_loss += loss_dict["spectral_loss"] 383 | if batch_id % self.print_freq == 0: 384 | logger_print( 385 | "Epoch:{:d}, Iter:{:d}, Average bf loss:{:.4f}, Average spectral loss:{:.4f}, Time: {:d}ms/batch". 386 | format(epoch+1, int(batch_id), total_bf_loss/(batch_id+1), total_sp_loss/(batch_id+1), 387 | int(1000*(time.time()-start_time)/(batch_id+1)))) 388 | return total_sp_loss / (batch_id+1) 389 | else: # validation phase 390 | data_loder = self.val_dataloader 391 | total_sp_loss, total_pro_metric_loss, total_unpro_metric_loss = 0., 0., 0. 392 | start_time = time.time() 393 | for batch_id, batch_info in enumerate(data_loder.get_data_loader()): 394 | loss_dict = self._val_batch(batch_info) 395 | assert len(self.metric_options) == 1, "only one metric is supported to output in the val phase" 396 | total_sp_loss += loss_dict["mse_loss"] 397 | total_unpro_metric_loss += loss_dict["unpro_metric"] 398 | total_pro_metric_loss += loss_dict["pro_metric"] 399 | if batch_id % self.print_freq == 0: 400 | logger_print( 401 | "Epoch:{:d}, Iter:{:d}, Average spectral loss:{:.4f}, Average unpro metric score:{:.4f}, " 402 | "Average pro metric score:{:.4f}, Time: {:d}ms/batch". 403 | format(epoch+1, int(batch_id), total_sp_loss/(batch_id+1), total_unpro_metric_loss/(batch_id+1), 404 | total_pro_metric_loss/(batch_id+1), int(1000*(time.time()-start_time)/(batch_id+1)))) 405 | return total_sp_loss / (batch_id+1), total_pro_metric_loss / (batch_id) 406 | 407 | def update_params(self, loss): 408 | self.optimizer.zero_grad() 409 | loss.backward() 410 | if self.gradient_norm >= 0.0: 411 | nn.utils.clip_grad_norm_(self.net.parameters(), self.gradient_norm) 412 | has_nan_inf = 0 413 | for params in self.net.parameters(): 414 | if params.requires_grad: 415 | has_nan_inf += torch.sum(torch.isnan(params.grad)) 416 | has_nan_inf += torch.sum(torch.isinf(params.grad)) 417 | if has_nan_inf == 0: 418 | self.optimizer.step() 419 | -------------------------------------------------------------------------------- /torch_complex/__init__.py: -------------------------------------------------------------------------------- 1 | from . import functional # noqa: F401 2 | from . import tensor # noqa: F401 3 | 4 | from .functional import * # noqa: F401, F403 5 | from .tensor import * # noqa: F401, F403 6 | -------------------------------------------------------------------------------- /torch_complex/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/torch_complex/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /torch_complex/__pycache__/functional.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/torch_complex/__pycache__/functional.cpython-37.pyc -------------------------------------------------------------------------------- /torch_complex/__pycache__/tensor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/torch_complex/__pycache__/tensor.cpython-37.pyc -------------------------------------------------------------------------------- /torch_complex/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Andong-Li-speech/TaylorBeamformer/fc3d0e79a19a0feb2f6faf37c538219d7ea78433/torch_complex/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /torch_complex/functional.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | import functools 3 | from typing import Sequence 4 | from typing import Union 5 | 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | from torch_complex.tensor import ComplexTensor 10 | from torch_complex.utils import complex_matrix2real_matrix 11 | from torch_complex.utils import complex_vector2real_vector 12 | from torch_complex.utils import real_matrix2complex_matrix 13 | from torch_complex.utils import real_vector2complex_vector 14 | 15 | 16 | __all__ = [ 17 | "einsum", 18 | "cat", 19 | "stack", 20 | "pad", 21 | "squeeze", 22 | "reverse", 23 | "trace", 24 | "allclose", 25 | "matmul", 26 | "solve", 27 | ] 28 | 29 | 30 | def _fcomplex(func, nthargs=0): 31 | @functools.wraps(func) 32 | def wrapper(*args, **kwargs) -> Union[ComplexTensor, torch.Tensor]: 33 | signal = args[nthargs] 34 | if isinstance(signal, ComplexTensor): 35 | real_args = args[:nthargs] + (signal.real,) + args[nthargs + 1 :] 36 | imag_args = args[:nthargs] + (signal.imag,) + args[nthargs + 1 :] 37 | real = func(*real_args, **kwargs) 38 | imag = func(*imag_args, **kwargs) 39 | return ComplexTensor(real, imag) 40 | else: 41 | return func(*args, **kwargs) 42 | 43 | return wrapper 44 | 45 | 46 | def einsum(equation, *operands): 47 | """Einsum 48 | 49 | >>> import numpy 50 | >>> def get(*shape): 51 | ... real = numpy.random.rand(*shape) 52 | ... imag = numpy.random.rand(*shape) 53 | ... return real + 1j * imag 54 | >>> x = get(3, 4, 5) 55 | >>> y = get(3, 5, 6) 56 | >>> z = get(3, 6, 7) 57 | >>> test = einsum('aij,ajk,akl->ail', 58 | ... [ComplexTensor(x), ComplexTensor(y), ComplexTensor(z)]) 59 | >>> valid = numpy.einsum('aij,ajk,akl->ail', x, y, z) 60 | >>> numpy.testing.assert_allclose(test.numpy(), valid) 61 | >>> _ = einsum('aij->ai', ComplexTensor(x)) 62 | >>> _ = einsum('aij->ai', [ComplexTensor(x)]) 63 | 64 | """ 65 | if len(operands) == 1 and isinstance(operands[0], (tuple, list)): 66 | operands = operands[0] 67 | 68 | x = operands[0] 69 | if isinstance(x, ComplexTensor): 70 | real_operands = [[x.real]] 71 | imag_operands = [[x.imag]] 72 | else: 73 | real_operands = [[x]] 74 | imag_operands = [] 75 | 76 | for x in operands[1:]: 77 | if isinstance(x, ComplexTensor): 78 | real_operands, imag_operands = ( 79 | [ops + [x.real] for ops in real_operands] 80 | + [ops + [-x.imag] for ops in imag_operands], 81 | [ops + [x.imag] for ops in real_operands] 82 | + [ops + [x.real] for ops in imag_operands], 83 | ) 84 | else: 85 | real_operands = [ops + [x] for ops in real_operands] 86 | imag_operands = [ops + [x] for ops in imag_operands] 87 | 88 | real = sum([torch.einsum(equation, ops) for ops in real_operands]) 89 | imag = sum([torch.einsum(equation, ops) for ops in imag_operands]) 90 | return ComplexTensor(real, imag) 91 | 92 | 93 | def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): 94 | """ 95 | cat(seq, dim=0, *, out=None) 96 | cat(seq, axis=0, *, out=None) 97 | """ 98 | reals = [v.real if isinstance(v, ComplexTensor) else v for v in seq] 99 | imags = [ 100 | v.imag if isinstance(v, ComplexTensor) else torch.zeros_like(v.real) 101 | for v in seq 102 | ] 103 | out = kwargs.pop("out", None) 104 | if out is not None: 105 | out = out 106 | out_real = out.real 107 | out_imag = out.imag 108 | else: 109 | out_real = out_imag = None 110 | return ComplexTensor( 111 | torch.cat(reals, *args, out=out_real, **kwargs), 112 | torch.cat(imags, *args, out=out_imag, **kwargs), 113 | ) 114 | 115 | 116 | def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs): 117 | """ 118 | stack(tensors, dim=0, * out=None) 119 | stack(tensors, axis=0, * out=None) 120 | 121 | """ 122 | reals = [v.real if isinstance(v, ComplexTensor) else v for v in seq] 123 | imags = [ 124 | v.imag if isinstance(v, ComplexTensor) else torch.zeros_like(v.real) 125 | for v in seq 126 | ] 127 | 128 | out = kwargs.pop("out", None) 129 | if out is not None: 130 | out_real = out.real 131 | out_imag = out.imag 132 | else: 133 | out_real = out_imag = None 134 | return ComplexTensor( 135 | torch.stack(reals, *args, out=out_real, **kwargs), 136 | torch.stack(imags, *args, out=out_imag, **kwargs), 137 | ) 138 | 139 | 140 | pad = _fcomplex(F.pad) 141 | squeeze = _fcomplex(torch.squeeze) 142 | 143 | 144 | @_fcomplex 145 | def reverse(tensor: torch.Tensor, dim=0) -> torch.Tensor: 146 | # https://discuss.pytorch.org/t/how-to-reverse-a-torch-tensor/382 147 | idx = [i for i in range(tensor.size(dim) - 1, -1, -1)] 148 | idx = torch.LongTensor(idx).to(tensor.device) 149 | inverted_tensor = tensor.index_select(dim, idx) 150 | return inverted_tensor 151 | 152 | 153 | @_fcomplex 154 | def signal_frame( 155 | signal: torch.Tensor, frame_length: int, frame_step: int, pad_value=0 156 | ) -> torch.Tensor: 157 | """Expands signal into frames of frame_length. 158 | 159 | Args: 160 | signal : (B * F, D, T) 161 | Returns: 162 | torch.Tensor: (B * F, D, T, W) 163 | """ 164 | signal = F.pad(signal, (0, frame_length - 1), "constant", pad_value) 165 | indices = sum( 166 | [ 167 | list(range(i, i + frame_length)) 168 | for i in range(0, signal.size(-1) - frame_length + 1, frame_step) 169 | ], 170 | [], 171 | ) 172 | 173 | signal = signal[..., indices].view(*signal.size()[:-1], -1, frame_length) 174 | return signal 175 | 176 | 177 | def trace(a: ComplexTensor) -> ComplexTensor: 178 | if LooseVersion(torch.__version__) >= LooseVersion("1.3"): 179 | datatype = torch.bool 180 | else: 181 | datatype = torch.uint8 182 | E = torch.eye(a.shape[-1], dtype=datatype).expand(*a.size()) 183 | if LooseVersion(torch.__version__) >= LooseVersion("1.1"): 184 | E = E.type(torch.bool) 185 | return a[E].view(*a.size()[:-1]).sum(-1) 186 | 187 | 188 | def allclose( 189 | a: Union[ComplexTensor, torch.Tensor], 190 | b: Union[ComplexTensor, torch.Tensor], 191 | rtol=1e-05, 192 | atol=1e-08, 193 | equal_nan=False, 194 | ) -> bool: 195 | if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): 196 | return torch.allclose( 197 | a.real, b.real, rtol=rtol, atol=atol, equal_nan=equal_nan 198 | ) and torch.allclose(a.imag, b.imag, rtol=rtol, atol=atol, equal_nan=equal_nan) 199 | elif not isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): 200 | return torch.allclose( 201 | a.real, b.real, rtol=rtol, atol=atol, equal_nan=equal_nan 202 | ) and torch.allclose( 203 | torch.zeros_like(b.imag), b.imag, rtol=rtol, atol=atol, equal_nan=equal_nan 204 | ) 205 | elif isinstance(a, ComplexTensor) and not isinstance(b, ComplexTensor): 206 | return torch.allclose( 207 | a.real, b, rtol=rtol, atol=atol, equal_nan=equal_nan 208 | ) and torch.allclose( 209 | a.imag, torch.zeros_like(a.imag), rtol=rtol, atol=atol, equal_nan=equal_nan 210 | ) 211 | else: 212 | return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) 213 | 214 | 215 | def matmul( 216 | a: Union[ComplexTensor, torch.Tensor], b: Union[ComplexTensor, torch.Tensor] 217 | ) -> ComplexTensor: 218 | if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): 219 | return a @ b 220 | elif not isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor): 221 | o_real = torch.matmul(a, b.real) 222 | o_imag = torch.matmul(a, b.imag) 223 | elif isinstance(a, ComplexTensor) and not isinstance(b, ComplexTensor): 224 | return a @ b 225 | else: 226 | o_real = torch.matmul(a.real, b.real) 227 | o_imag = torch.zeros_like(o_real) 228 | return ComplexTensor(o_real, o_imag) 229 | 230 | 231 | def solve(b: ComplexTensor, a: ComplexTensor) -> ComplexTensor: 232 | """Solve ax = b""" 233 | a = complex_matrix2real_matrix(a) 234 | b = complex_vector2real_vector(b) 235 | x, LU = torch.solve(b, a) 236 | return real_vector2complex_vector(x), real_matrix2complex_matrix(LU) 237 | -------------------------------------------------------------------------------- /torch_complex/tensor.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Union, List 3 | 4 | import numpy 5 | import torch 6 | EPSILON = torch.finfo(torch.float32).eps 7 | 8 | __all__ = ["ComplexTensor"] 9 | 10 | 11 | class ComplexTensor: 12 | def __init__( 13 | self, real: Union[torch.Tensor, numpy.ndarray], imag=None, device=None 14 | ): 15 | if imag is None: 16 | if isinstance(real, numpy.ndarray): 17 | if real.dtype.kind == "c": 18 | imag = real.imag 19 | real = real.real 20 | else: 21 | imag = numpy.zeros_like(real) 22 | elif isinstance(real, ComplexTensor): 23 | imag = real.imag 24 | real = real.real 25 | else: 26 | imag = torch.zeros_like(real, device=device) 27 | 28 | if isinstance(real, numpy.ndarray): 29 | real = torch.as_tensor(real, device=device) 30 | else: 31 | real = real.to(device) 32 | if isinstance(imag, numpy.ndarray): 33 | imag = torch.as_tensor(imag, device=device) 34 | else: 35 | imag = imag.to(device) 36 | 37 | if not torch.is_tensor(real): 38 | raise TypeError( 39 | f"The first arg must be torch.Tensor" f"but got {type(real)}" 40 | ) 41 | 42 | if not torch.is_tensor(imag): 43 | raise TypeError( 44 | f"The second arg must be torch.Tensor" f"but got {type(imag)}" 45 | ) 46 | if not real.size() == imag.size(): 47 | raise ValueError( 48 | f"The two inputs must have same sizes: " 49 | f"{real.size()} != {imag.size()}" 50 | ) 51 | 52 | self.real = real 53 | self.imag = imag 54 | 55 | def __getitem__(self, item) -> "ComplexTensor": 56 | return ComplexTensor(self.real[item], self.imag[item]) 57 | 58 | def __setitem__( 59 | self, item, value: Union["ComplexTensor", torch.Tensor, numbers.Number] 60 | ): 61 | if isinstance(value, (ComplexTensor, complex)): 62 | self.real[item] = value.real 63 | self.imag[item] = value.imag 64 | else: 65 | self.real[item] = value 66 | self.imag[item] = 0 67 | 68 | def __mul__( 69 | self, other: Union["ComplexTensor", torch.Tensor, numbers.Number] 70 | ) -> "ComplexTensor": 71 | if isinstance(other, (ComplexTensor, complex)): 72 | return ComplexTensor( 73 | self.real * other.real - self.imag * other.imag, 74 | self.real * other.imag + self.imag * other.real, 75 | ) 76 | else: 77 | return ComplexTensor(self.real * other, self.imag * other) 78 | 79 | def __rmul__( 80 | self, other: Union["ComplexTensor", torch.Tensor, numbers.Number] 81 | ) -> "ComplexTensor": 82 | if isinstance(other, (ComplexTensor, complex)): 83 | return ComplexTensor( 84 | other.real * self.real - other.imag * self.imag, 85 | other.imag * self.real + other.real * self.imag, 86 | ) 87 | else: 88 | return ComplexTensor(other * self.real, other * self.imag) 89 | 90 | def __imul__(self, other): 91 | if isinstance(other, (ComplexTensor, numbers.Complex)): 92 | t = self * other 93 | self.real = t.real 94 | self.imag = t.imag 95 | else: 96 | self.real *= other 97 | self.imag *= other 98 | return self 99 | 100 | def __truediv__(self, other) -> "ComplexTensor": 101 | if isinstance(other, (ComplexTensor, complex)): 102 | den = other.real ** 2 + other.imag ** 2 + EPSILON 103 | return ComplexTensor( 104 | (self.real * other.real + self.imag * other.imag) / den, 105 | (-self.real * other.imag + self.imag * other.real) / den, 106 | ) 107 | else: 108 | return ComplexTensor(self.real / other, self.imag / other) 109 | 110 | def __rtruediv__(self, other) -> "ComplexTensor": 111 | if isinstance(other, (ComplexTensor, complex)): 112 | den = self.real ** 2 + self.imag ** 2 113 | return ComplexTensor( 114 | (other.real * self.real + other.imag * self.imag) / den, 115 | (-other.real * self.imag + other.imag * self.real) / den, 116 | ) 117 | else: 118 | den = self.real ** 2 + self.imag ** 2 119 | return ComplexTensor(other * self.real / den, -other * self.imag / den) 120 | 121 | def __itruediv__(self, other) -> "ComplexTensor": 122 | if isinstance(other, (ComplexTensor, numbers.Complex)): 123 | t = self / other 124 | self.real = t.real 125 | self.imag = t.imag 126 | else: 127 | self.real /= other 128 | self.imag /= other 129 | return self 130 | 131 | def __add__(self, other) -> "ComplexTensor": 132 | if isinstance(other, (ComplexTensor, complex)): 133 | return ComplexTensor(self.real + other.real, self.imag + other.imag) 134 | else: 135 | return ComplexTensor(self.real + other, self.imag) 136 | 137 | def __radd__(self, other) -> "ComplexTensor": 138 | if isinstance(other, (ComplexTensor, complex)): 139 | return ComplexTensor(other.real + self.real, other.imag + self.imag) 140 | else: 141 | return ComplexTensor(other + self.real, self.imag) 142 | 143 | def __iadd__(self, other) -> "ComplexTensor": 144 | if isinstance(other, (ComplexTensor, complex)): 145 | self.real += other.real 146 | self.imag += other.imag 147 | else: 148 | self.real += other 149 | return self 150 | 151 | def __sub__(self, other) -> "ComplexTensor": 152 | if isinstance(other, (ComplexTensor, complex)): 153 | return ComplexTensor(self.real - other.real, self.imag - other.imag) 154 | else: 155 | return ComplexTensor(self.real - other, self.imag) 156 | 157 | def __rsub__(self, other) -> "ComplexTensor": 158 | if isinstance(other, (ComplexTensor, complex)): 159 | return ComplexTensor(other.real - self.real, other.imag - self.imag) 160 | else: 161 | return ComplexTensor(other - self.real, self.imag) 162 | 163 | def __isub__(self, other) -> "ComplexTensor": 164 | if isinstance(other, (ComplexTensor, complex)): 165 | self.real -= other.real 166 | self.imag -= other.imag 167 | else: 168 | self.real -= other 169 | return self 170 | 171 | def __matmul__(self, other) -> "ComplexTensor": 172 | if isinstance(other, ComplexTensor): 173 | o_real = torch.matmul(self.real, other.real) - torch.matmul( 174 | self.imag, other.imag 175 | ) 176 | o_imag = torch.matmul(self.real, other.imag) + torch.matmul( 177 | self.imag, other.real 178 | ) 179 | else: 180 | o_real = torch.matmul(self.real, other) 181 | o_imag = torch.matmul(self.imag, other) 182 | return ComplexTensor(o_real, o_imag) 183 | 184 | def __rmatmul__(self, other) -> "ComplexTensor": 185 | if isinstance(other, ComplexTensor): 186 | o_real = torch.matmul(other.real, self.real) - torch.matmul( 187 | other.imag, self.imag 188 | ) 189 | o_imag = torch.matmul(other.real, self.imag) + torch.matmul( 190 | other.imag, self.real 191 | ) 192 | else: 193 | o_real = torch.matmul(other, self.real) 194 | o_imag = torch.matmul(other, self.imag) 195 | return ComplexTensor(o_real, o_imag) 196 | 197 | def __imatmul__(self, other) -> "ComplexTensor": 198 | if isinstance(other, (ComplexTensor, numbers.Complex)): 199 | t = self @ other 200 | self.real = t.real 201 | self.imag = t.imag 202 | else: 203 | self.real @= other 204 | self.imag @= other 205 | return self 206 | 207 | def __neg__(self) -> "ComplexTensor": 208 | return ComplexTensor(-self.real, -self.imag) 209 | 210 | def __eq__(self, other) -> torch.Tensor: 211 | if isinstance(other, (ComplexTensor, complex)): 212 | return (self.real == other.real) ** (self.imag == other.imag) 213 | else: 214 | return (self.real == other) ** (self.imag == 0) 215 | 216 | def __len__(self) -> int: 217 | return len(self.real) 218 | 219 | def __repr__(self) -> str: 220 | import textwrap 221 | 222 | return ( 223 | "ComplexTensor(" 224 | + "\n real=" 225 | + textwrap.indent(repr(self.real), " " * len(" real=")).lstrip(" ") 226 | + ",\n imag=" 227 | + textwrap.indent(repr(self.imag), " " * len(" imag=")).lstrip(" ") 228 | + ",\n)" 229 | ) 230 | 231 | def __abs__(self) -> torch.Tensor: 232 | return (self.real * self.real + self.imag * self.imag).sqrt() 233 | 234 | def __pow__(self, exponent) -> "ComplexTensor": 235 | if exponent == -2: 236 | return 1 / (self * self) 237 | if exponent == -1: 238 | return 1 / self 239 | if exponent == 0: 240 | return ComplexTensor(torch.ones_like(self.real)) 241 | if exponent == 1: 242 | return self.clone() 243 | if exponent == 2: 244 | return self * self 245 | 246 | _abs = self.abs().pow(exponent) 247 | _angle = exponent * self.angle() 248 | return ComplexTensor(_abs * torch.cos(_angle), _abs * torch.sin(_angle)) 249 | 250 | def __ipow__(self, exponent) -> "ComplexTensor": 251 | c = self ** exponent 252 | self.real = c.real 253 | self.imag = c.imag 254 | return self 255 | 256 | def abs(self) -> torch.Tensor: 257 | return (self.real * self.real + self.imag * self.imag).sqrt() 258 | 259 | def angle(self) -> torch.Tensor: 260 | return torch.atan2(self.imag, self.real) 261 | 262 | def backward(self) -> None: 263 | self.real.backward() 264 | self.imag.backward() 265 | 266 | def byte(self) -> "ComplexTensor": 267 | return ComplexTensor(self.real.byte(), self.imag.byte()) 268 | 269 | def clone(self) -> "ComplexTensor": 270 | return ComplexTensor(self.real.clone(), self.imag.clone()) 271 | 272 | def conj(self) -> "ComplexTensor": 273 | return ComplexTensor(self.real, -self.imag) 274 | 275 | def conj_(self) -> "ComplexTensor": 276 | self.imag.neg_() 277 | return self 278 | 279 | def contiguous(self) -> "ComplexTensor": 280 | return ComplexTensor(self.real.contiguous(), self.imag.contiguous()) 281 | 282 | def copy_(self) -> "ComplexTensor": 283 | self.real = self.real.copy_() 284 | self.imag = self.imag.copy_() 285 | return self 286 | 287 | def cpu(self) -> "ComplexTensor": 288 | return ComplexTensor(self.real.cpu(), self.imag.cpu()) 289 | 290 | def cuda(self) -> "ComplexTensor": 291 | return ComplexTensor(self.real.cuda(), self.imag.cuda()) 292 | 293 | def expand(self, *sizes): 294 | return ComplexTensor(self.real.expand(*sizes), self.imag.expand(*sizes)) 295 | 296 | def expand_as(self, *args, **kwargs): 297 | return ComplexTensor( 298 | self.real.expand_as(*args, **kwargs), self.imag.expand_as(*args, **kwargs) 299 | ) 300 | 301 | def detach(self) -> "ComplexTensor": 302 | return ComplexTensor(self.real.detach(), self.imag.detach()) 303 | 304 | def detach_(self) -> "ComplexTensor": 305 | self.real.detach_() 306 | self.imag.detach_() 307 | return self 308 | 309 | @property 310 | def device(self): 311 | assert self.real.device == self.imag.device 312 | return self.real.device 313 | 314 | def diag(self) -> "ComplexTensor": 315 | return ComplexTensor(self.real.diag(), self.imag.diag()) 316 | 317 | def diagonal(self) -> "ComplexTensor": 318 | return ComplexTensor(self.real.diag(), self.imag.diag()) 319 | 320 | def dim(self) -> int: 321 | return self.real.dim() 322 | 323 | def double(self) -> "ComplexTensor": 324 | return ComplexTensor(self.real.double(), self.imag.double()) 325 | 326 | @property 327 | def dtype(self) -> torch.dtype: 328 | # Warning: Try to never use this dtype property. 329 | # It will break your code, when you change to the native 330 | # complex type. 331 | # Use instead directly `complex_tensor.real.dtype`. 332 | return self.real.dtype 333 | 334 | def is_floating_point(self): 335 | return False 336 | 337 | def is_complex(self): 338 | return True 339 | 340 | def eq(self, other) -> torch.Tensor: 341 | if isinstance(other, (ComplexTensor, complex)): 342 | return (self.real == other.real) * (self.imag == other.imag) 343 | else: 344 | return (self.real == other) * (self.imag == 0) 345 | 346 | def equal(self, other) -> bool: 347 | if isinstance(other, (ComplexTensor, complex)): 348 | return self.real.equal(other.real) and self.imag.equal(other.imag) 349 | else: 350 | return self.real.equal(other) and self.imag.equal(0) 351 | 352 | def float(self) -> "ComplexTensor": 353 | return ComplexTensor(self.real.float(), self.imag.float()) 354 | 355 | def fill(self, value) -> "ComplexTensor": 356 | if isinstance(value, complex): 357 | return ComplexTensor(self.real.fill(value.real), self.imag.fill(value.imag)) 358 | else: 359 | return ComplexTensor(self.real.fill(value), self.imag.fill(0)) 360 | 361 | def fill_(self, value) -> "ComplexTensor": 362 | if isinstance(value, complex): 363 | self.real.fill_(value.real) 364 | self.imag.fill_(value.imag) 365 | else: 366 | self.real.fill_(value) 367 | self.imag.fill_(0) 368 | return self 369 | 370 | def gather(self, dim, index) -> "ComplexTensor": 371 | return ComplexTensor(self.real.gather(dim, index), self.real.gather(dim, index)) 372 | 373 | def get_device(self, *args, **kwargs): 374 | return self.real.get_device(*args, **kwargs) 375 | 376 | def half(self) -> "ComplexTensor": 377 | return ComplexTensor(self.real.half(), self.imag.half()) 378 | 379 | def index_add(self, dim, index, tensor) -> "ComplexTensor": 380 | return ComplexTensor( 381 | self.real.index_add(dim, index, tensor), 382 | self.imag.index_add(dim, index, tensor), 383 | ) 384 | 385 | def index_copy(self, dim, index, tensor) -> "ComplexTensor": 386 | return ComplexTensor( 387 | self.real.index_copy(dim, index, tensor), 388 | self.imag.index_copy(dim, index, tensor), 389 | ) 390 | 391 | def index_fill(self, dim, index, value) -> "ComplexTensor": 392 | return ComplexTensor( 393 | self.real.index_fill(dim, index, value), 394 | self.imag.index_fill(dim, index, value), 395 | ) 396 | 397 | def index_select(self, dim, index) -> "ComplexTensor": 398 | return ComplexTensor( 399 | self.real.index_select(dim, index), self.imag.index_select(dim, index) 400 | ) 401 | 402 | def inverse(self, ntry=5) -> "ComplexTensor": 403 | # m x n x n 404 | in_size = self.size() 405 | a = self.view(-1, self.size(-1), self.size(-1)) 406 | # see "The Matrix Cookbook" (http://www2.imm.dtu.dk/pubdb/p.php?3274) 407 | # "Section 4.3" 408 | for i in range(ntry): 409 | t = i * 0.1 410 | 411 | e = a.real + t * a.imag 412 | f = a.imag - t * a.real 413 | 414 | try: 415 | x = torch.matmul(f, e.inverse()) 416 | z = (e + torch.matmul(x, f)).inverse() 417 | except Exception: 418 | if i == ntry - 1: 419 | raise 420 | continue 421 | 422 | if t != 0.0: 423 | eye = torch.eye( 424 | a.real.size(-1), dtype=a.real.dtype, device=a.real.device 425 | )[None] 426 | o_real = torch.matmul(z, (eye - t * x)) 427 | o_imag = -torch.matmul(z, (t * eye + x)) 428 | else: 429 | o_real = z 430 | o_imag = -torch.matmul(z, x) 431 | 432 | o = ComplexTensor(o_real, o_imag) 433 | return o.view(*in_size) 434 | 435 | def inverse2(self) -> "ComplexTensor": 436 | # To avoid cyclic import 437 | from torch_complex.utils import complex_matrix2real_matrix 438 | from torch_complex.utils import real_matrix2complex_matrix 439 | 440 | return real_matrix2complex_matrix(complex_matrix2real_matrix(self).inverse()) 441 | 442 | def item(self) -> numbers.Number: 443 | return self.real.item() + 1j * self.imag.item() 444 | 445 | def masked_fill(self, mask, value) -> "ComplexTensor": 446 | if isinstance(value, complex): 447 | return ComplexTensor( 448 | self.real.masked_fill(mask, value.real), 449 | self.imag.masked_fill(mask, value.imag), 450 | ) 451 | 452 | else: 453 | return ComplexTensor( 454 | self.real.masked_fill(mask, value), self.imag.masked_fill(mask, 0) 455 | ) 456 | 457 | def masked_fill_(self, mask, value) -> "ComplexTensor": 458 | if isinstance(value, complex): 459 | self.real.masked_fill_(mask, value.real) 460 | self.imag.masked_fill_(mask, value.imag) 461 | else: 462 | self.real.masked_fill_(mask, value) 463 | self.imag.masked_fill_(mask, 0) 464 | return self 465 | 466 | def mean(self, *args, **kwargs) -> "ComplexTensor": 467 | return ComplexTensor( 468 | self.real.mean(*args, **kwargs), self.imag.mean(*args, **kwargs) 469 | ) 470 | 471 | def neg(self) -> "ComplexTensor": 472 | return ComplexTensor(-self.real, -self.imag) 473 | 474 | def neg_(self) -> "ComplexTensor": 475 | self.real.neg_() 476 | self.imag.neg_() 477 | return self 478 | 479 | def nelement(self) -> int: 480 | return self.real.nelement() 481 | 482 | def numel(self) -> int: 483 | return self.real.numel() 484 | 485 | def new(self, *args, **kwargs) -> "ComplexTensor": 486 | return ComplexTensor( 487 | self.real.new(*args, **kwargs), self.imag.new(*args, **kwargs) 488 | ) 489 | 490 | def new_empty( 491 | self, size, dtype=None, device=None, requires_grad=False 492 | ) -> "ComplexTensor": 493 | real = self.real.new_empty( 494 | size, dtype=dtype, device=device, requires_grad=requires_grad 495 | ) 496 | imag = self.imag.new_empty( 497 | size, dtype=dtype, device=device, requires_grad=requires_grad 498 | ) 499 | return ComplexTensor(real, imag) 500 | 501 | def new_full( 502 | self, size, fill_value, dtype=None, device=None, requires_grad=False 503 | ) -> "ComplexTensor": 504 | if isinstance(fill_value, complex): 505 | real_value = fill_value.real 506 | imag_value = fill_value.imag 507 | else: 508 | real_value = fill_value 509 | imag_value = 0.0 510 | 511 | real = self.real.new_full( 512 | size, 513 | fill_value=real_value, 514 | dtype=dtype, 515 | device=device, 516 | requires_grad=requires_grad, 517 | ) 518 | imag = self.imag.new_full( 519 | size, 520 | fill_value=imag_value, 521 | dtype=dtype, 522 | device=device, 523 | requires_grad=requires_grad, 524 | ) 525 | return ComplexTensor(real, imag) 526 | 527 | def new_tensor( 528 | self, data, dtype=None, device=None, requires_grad=False 529 | ) -> "ComplexTensor": 530 | if isinstance(data, ComplexTensor): 531 | real = data.real 532 | imag = data.imag 533 | elif isinstance(data, numpy.ndarray): 534 | if data.dtype.kind == "c": 535 | real = data.real 536 | imag = data.imag 537 | else: 538 | real = data 539 | imag = None 540 | else: 541 | real = data 542 | imag = None 543 | 544 | real = self.real.new_tensor( 545 | real, dtype=dtype, device=device, requires_grad=requires_grad 546 | ) 547 | if imag is None: 548 | imag = torch.zeros_like( 549 | real, dtype=dtype, device=device, requires_grad=requires_grad 550 | ) 551 | else: 552 | imag = self.imag.new_tensor( 553 | imag, dtype=dtype, device=device, requires_grad=requires_grad 554 | ) 555 | return ComplexTensor(real, imag) 556 | 557 | def numpy(self) -> numpy.ndarray: 558 | return self.real.numpy() + 1j * self.imag.numpy() 559 | 560 | def __array__(self): 561 | # https://numpy.org/devdocs/user/basics.dispatch.html 562 | return self.real.__array__() + 1j * self.imag.__array__() 563 | 564 | def permute(self, *dims) -> "ComplexTensor": 565 | return ComplexTensor(self.real.permute(*dims), self.imag.permute(*dims)) 566 | 567 | @property 568 | def T(self): 569 | return ComplexTensor(self.real.T, self.imag.T) 570 | 571 | def pow(self, exponent) -> "ComplexTensor": 572 | return self ** exponent 573 | 574 | def requires_grad_(self) -> "ComplexTensor": 575 | self.real.requires_grad_() 576 | self.imag.requires_grad_() 577 | return self 578 | 579 | @property 580 | def requires_grad(self): 581 | assert self.real.requires_grad == self.imag.requires_grad 582 | return self.real.requires_grad 583 | 584 | @requires_grad.setter 585 | def requires_grad(self, value): 586 | self.real.requires_grad = value 587 | self.imag.requires_grad = value 588 | 589 | def repeat(self, *sizes): 590 | return ComplexTensor(self.real.repeat(*sizes), self.imag.repeat(*sizes)) 591 | 592 | def reshape(self, *shape): 593 | return ComplexTensor(self.real.reshape(*shape), self.imag.reshape(*shape)) 594 | 595 | def retain_grad(self) -> "ComplexTensor": 596 | self.real.retain_grad() 597 | self.imag.retain_grad() 598 | return self 599 | 600 | def share_memory_(self) -> "ComplexTensor": 601 | self.real.share_memory_() 602 | self.imag.share_memory_() 603 | return self 604 | 605 | @property 606 | def shape(self) -> torch.Size: 607 | return self.real.shape 608 | 609 | def size(self, *args, **kwargs) -> torch.Size: 610 | return self.real.size(*args, **kwargs) 611 | 612 | def ndimension(self): 613 | return self.real.ndimension() 614 | 615 | @property 616 | def ndim(self): 617 | return self.real.ndim 618 | 619 | def sqrt(self) -> "ComplexTensor": 620 | return self ** 0.5 621 | 622 | def squeeze(self, dim=None) -> "ComplexTensor": 623 | if dim is None: 624 | return ComplexTensor(self.real.squeeze(), self.imag.squeeze()) 625 | else: 626 | return ComplexTensor(self.real.squeeze(dim), self.imag.squeeze(dim)) 627 | 628 | def sum(self, *args, **kwargs) -> "ComplexTensor": 629 | """ 630 | sum(self, dim, keepdim, *, dtype=None) 631 | sum(self, axis, keepdims, *, dtype=None) # numpy style 632 | 633 | Args: 634 | dim or axis: 635 | keepdim or keepdims: 636 | **kwargs: 637 | 638 | Returns: 639 | 640 | """ 641 | return ComplexTensor( 642 | self.real.sum(*args, **kwargs), self.imag.sum(*args, **kwargs) 643 | ) 644 | 645 | def take(self, indices) -> "ComplexTensor": 646 | return ComplexTensor(self.real.take(indices), self.imag.take(indices)) 647 | 648 | def to(self, *args, **kwargs) -> "ComplexTensor": 649 | return ComplexTensor( 650 | self.real.to(*args, **kwargs), self.imag.to(*args, **kwargs) 651 | ) 652 | 653 | def tolist(self) -> List[numbers.Number]: 654 | return [r + 1j * i for r, i in zip(self.real.tolist(), self.imag.tolist())] 655 | 656 | def transpose(self, dim0, dim1) -> "ComplexTensor": 657 | return ComplexTensor( 658 | self.real.transpose(dim0, dim1), self.imag.transpose(dim0, dim1) 659 | ) 660 | 661 | def transpose_(self, dim0, dim1) -> "ComplexTensor": 662 | self.real.transpose_(dim0, dim1) 663 | self.imag.transpose_(dim0, dim1) 664 | return self 665 | 666 | def type(self, *args, **kwargs) -> str: 667 | if len(args) == 0 and len(kwargs) == 0: 668 | return self.real.type() 669 | else: 670 | return ComplexTensor( 671 | self.real.type(*args, **kwargs), self.imag.type(*args, **kwargs) 672 | ) 673 | 674 | def unbind(self, dim=0) -> "ComplexTensor": 675 | return tuple( 676 | map( 677 | lambda x: ComplexTensor(*x), 678 | zip(self.real.unbind(dim=dim), self.imag.unbind(dim=dim)) 679 | ) 680 | ) 681 | 682 | def unfold(self, dim, size, step): 683 | return ComplexTensor( 684 | self.real.unfold(dim, size, step), self.imag.unfold(dim, size, step) 685 | ) 686 | 687 | def unsqueeze(self, dim) -> "ComplexTensor": 688 | return ComplexTensor(self.real.unsqueeze(dim), self.imag.unsqueeze(dim)) 689 | 690 | def unsqueeze_(self, dim) -> "ComplexTensor": 691 | self.real.unsqueeze_(dim) 692 | self.imag.unsqueeze_(dim) 693 | return self 694 | 695 | def view(self, *args, **kwargs) -> "ComplexTensor": 696 | return ComplexTensor( 697 | self.real.view(*args, **kwargs), self.imag.view(*args, **kwargs) 698 | ) 699 | 700 | def view_as(self, tensor): 701 | return self.view(tensor.size()) 702 | -------------------------------------------------------------------------------- /torch_complex/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_complex.tensor import ComplexTensor 4 | 5 | 6 | def complex_matrix2real_matrix(c: ComplexTensor) -> torch.Tensor: 7 | # NOTE(kamo): 8 | # Complex value can be expressed as follows 9 | # a + bi => a * x + b y 10 | # where 11 | # x = |1 0| y = |0 -1| 12 | # |0 1|, |1 0| 13 | # A complex matrix can be also expressed as 14 | # |A -B| 15 | # |B A| 16 | # and complex vector can be expressed as 17 | # |A| 18 | # |B| 19 | assert c.size(-2) == c.size(-1), c.size() 20 | # (∗, m, m) -> (*, 2m, 2m) 21 | return torch.cat( 22 | [torch.cat([c.real, -c.imag], dim=-1), torch.cat([c.imag, c.real], dim=-1)], 23 | dim=-2, 24 | ) 25 | 26 | 27 | def complex_vector2real_vector(c: ComplexTensor) -> torch.Tensor: 28 | # (∗, m, k) -> (*, 2m, k) 29 | return torch.cat([c.real, c.imag], dim=-2) 30 | 31 | 32 | def real_matrix2complex_matrix(c: torch.Tensor) -> ComplexTensor: 33 | assert c.size(-2) == c.size(-1), c.size() 34 | # (∗, 2m, 2m) -> (*, m, m) 35 | n = c.size(-1) 36 | assert n % 2 == 0, n 37 | real = c[..., : n // 2, : n // 2] 38 | imag = c[..., n // 2 :, : n // 2] 39 | return ComplexTensor(real, imag) 40 | 41 | 42 | def real_vector2complex_vector(c: torch.Tensor) -> ComplexTensor: 43 | # (∗, 2m, k) -> (*, m, k) 44 | n = c.size(-2) 45 | assert n % 2 == 0, n 46 | real = c[..., : n // 2, :] 47 | imag = c[..., n // 2 :, :] 48 | return ComplexTensor(real, imag) 49 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SpatialFilterLoss(object): 4 | def __init__(self, alpha, l_type): 5 | self.alpha = alpha 6 | self.l_type = l_type 7 | 8 | def __call__(self, resi, frame_list): 9 | """ 10 | resi: (B,T,F,2), frame_list: list 11 | """ 12 | b_size, seq_len, freq_num, _ = resi.shape 13 | mask_for_loss = [] 14 | with torch.no_grad(): 15 | for i in range(b_size): 16 | tmp_mask = torch.ones((frame_list[i], freq_num, 2), dtype=resi.dtype) 17 | mask_for_loss.append(tmp_mask) 18 | mask_for_loss = torch.nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(resi.device) 19 | mag_mask_for_loss = mask_for_loss[...,0] 20 | 21 | resi_mag = torch.norm(resi, dim=-1) 22 | if self.l_type == "L1" or self.l_type == "l1": 23 | loss_com = (torch.abs(resi) * mask_for_loss).sum() / mask_for_loss.sum() 24 | loss_mag = (torch.abs(resi_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 25 | elif self.l_type == "L2" or self.l_type == "l2": 26 | loss_com = (torch.square(resi) * mask_for_loss).sum() / mask_for_loss.sum() 27 | loss_mag = (torch.square(resi_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 28 | else: 29 | raise RuntimeError("only L1 and L2 are supported") 30 | return self.alpha * loss_com + (1 - self.alpha) * loss_mag 31 | 32 | 33 | class ComMagEuclideanLoss(object): 34 | def __init__(self, alpha, l_type): 35 | self.alpha = alpha 36 | self.l_type = l_type 37 | 38 | def __call__(self, est, label, frame_list): 39 | """ 40 | est: (B,T,F,2) 41 | label: (B,T,F,2) 42 | frame_list: list 43 | alpha: scalar 44 | l_type: str, L1 or L2 45 | """ 46 | b_size, seq_len, freq_num, _ = est.shape 47 | mask_for_loss = [] 48 | with torch.no_grad(): 49 | for i in range(b_size): 50 | tmp_mask = torch.ones((frame_list[i], freq_num, 2), dtype=est.dtype) 51 | mask_for_loss.append(tmp_mask) 52 | mask_for_loss = torch.nn.utils.rnn.pad_sequence(mask_for_loss, batch_first=True).to(est.device) 53 | mag_mask_for_loss = mask_for_loss[...,0] 54 | est_mag, label_mag = torch.norm(est, dim=-1), torch.norm(label, dim=-1) 55 | 56 | if self.l_type == "L1" or self.l_type == "l1": 57 | loss_com = (torch.abs(est - label) * mask_for_loss).sum() / mask_for_loss.sum() 58 | loss_mag = (torch.abs(est_mag - label_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 59 | elif self.l_type == "L2" or self.l_type == "l2": 60 | loss_com = (torch.square(est - label) * mask_for_loss).sum() / mask_for_loss.sum() 61 | loss_mag = (torch.square(est_mag - label_mag) * mag_mask_for_loss).sum() / mag_mask_for_loss.sum() 62 | else: 63 | raise RuntimeError("only L1 and L2 are supported!") 64 | return self.alpha * loss_com + (1 - self.alpha) * loss_mag 65 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | EPSILON = np.finfo(np.float32).eps 10 | 11 | def logger_print(log): 12 | logging.info(log) 13 | print(log) 14 | 15 | def numParams(net): 16 | num = 0 17 | for param in net.parameters(): 18 | if param.requires_grad: 19 | num += int(np.prod(param.size())) 20 | return num 21 | 22 | class ToTensor(object): 23 | def __call__(self, 24 | x, 25 | type="float"): 26 | if type == "float": 27 | return torch.FloatTensor(x) 28 | elif type == "int": 29 | return torch.IntTensor(x) 30 | 31 | def pad_to_longest(batch_data): 32 | """ 33 | pad the waves with the longest length among one batch chunk 34 | :param batch_data: 35 | :return: 36 | """ 37 | mix_wav_batch_list, bf_wav_batch_list, target_wav_batch_list, wav_len_list = batch_data[0] 38 | to_tensor = ToTensor() 39 | mix_wav_batch_list, bf_wav_batch_list, target_wav_batch_list = to_tensor(mix_wav_batch_list), \ 40 | to_tensor(bf_wav_batch_list), \ 41 | to_tensor(target_wav_batch_list) 42 | mix_tensor, bf_tensor, target_tensor = nn.utils.rnn.pad_sequence(mix_wav_batch_list, batch_first=True), \ 43 | nn.utils.rnn.pad_sequence(bf_wav_batch_list, batch_first=True), \ 44 | nn.utils.rnn.pad_sequence(target_wav_batch_list, batch_first=True) # (B,L,M) 45 | return mix_tensor, bf_tensor, target_tensor, wav_len_list 46 | 47 | 48 | class BatchInfo(object): 49 | def __init__(self, feats, bfs, labels, frame_mask_list): 50 | self.feats = feats 51 | self.bfs = bfs 52 | self.labels = labels 53 | self.frame_mask_list = frame_mask_list 54 | 55 | 56 | def json_extraction(file_path, json_path, data_type): 57 | if not os.path.exists(json_path): 58 | os.makedirs(json_path) 59 | file_list = os.listdir(file_path) 60 | file_num = len(file_list) 61 | json_list = [] 62 | 63 | for i in range(file_num): 64 | file_name = file_list[i] 65 | file_name = os.path.splitext(file_name)[0] 66 | json_list.append(file_name) 67 | 68 | with open(os.path.join(json_path, "{}_files.json".format(data_type)), "w") as f: 69 | json.dump(json_list, f, indent=4) 70 | return os.path.join(json_path, "{}_files.json".format(data_type)) 71 | 72 | 73 | def complex_mul(inpt1, inpt2): 74 | """ 75 | inpt1: (B,2,...) or (...,2) 76 | inpt2: (B,2,...) or (...,2) 77 | """ 78 | if inpt1.shape[1] == 2: 79 | out_r = inpt1[:,0,...]*inpt2[:,0,...] - inpt1[:,-1,...]*inpt2[:,-1,...] 80 | out_i = inpt1[:,0,...]*inpt2[:,-1,...] + inpt1[:,-1,...]*inpt2[:,0,...] 81 | return torch.stack((out_r, out_i), dim=1) 82 | elif inpt1.shape[-1] == 2: 83 | out_r = inpt1[...,0]*inpt2[...,0] - inpt1[...,-1]*inpt2[...,-1] 84 | out_i = inpt1[...,0]*inpt2[...,-1] + inpt1[...,-1]*inpt2[...,0] 85 | return torch.stack((out_r, out_i), dim=-1) 86 | else: 87 | raise RuntimeError("Only supports two tensor formats") 88 | 89 | def complex_conj(inpt): 90 | """ 91 | inpt: (B,2,...) or (...,2) 92 | """ 93 | if inpt.shape[1] == 2: 94 | inpt_r, inpt_i = inpt[:,0,...], inpt[:,-1,...] 95 | return torch.stack((inpt_r, -inpt_i), dim=1) 96 | elif inpt.shape[-1] == 2: 97 | inpt_r, inpt_i = inpt[...,0], inpt[...,-1] 98 | return torch.stack((inpt_r, -inpt_i), dim=-1) 99 | 100 | def complex_div(inpt1, inpt2): 101 | """ 102 | inpt1: (B,2,...) or (...,2) 103 | inpt2: (B,2,...) or (...,2) 104 | """ 105 | if inpt1.shape[1] == 2: 106 | inpt1_r, inpt1_i = inpt1[:,0,...], inpt1[:,-1,...] 107 | inpt2_r, inpt2_i = inpt2[:,0,...], inpt2[:,-1,...] 108 | denom = torch.norm(inpt2, dim=1)**2.0 + EPSILON 109 | out_r = inpt1_r * inpt2_r + inpt1_i * inpt2_i 110 | out_i = inpt1_i * inpt2_r - inpt1_r * inpt2_i 111 | return torch.stack((out_r/denom, out_i/denom), dim=1) 112 | elif inpt1.shape[-1] == 2: 113 | inpt1_r, inpt1_i = inpt1[...,0], inpt1[...,-1] 114 | inpt2_r, inpt2_i = inpt2[...,0], inpt2[...,-1] 115 | denom = torch.norm(inpt2, dim=-1)**2.0 + EPSILON 116 | out_r = inpt1_r * inpt2_r + inpt1_i * inpt2_i 117 | out_i = inpt1_i * inpt2_r - inpt1_r * inpt2_i 118 | return torch.stack((out_r/denom, out_i/denom), dim=-1) 119 | 120 | 121 | class NormSwitch(nn.Module): 122 | def __init__(self, 123 | norm_type: str, 124 | format: str, 125 | num_features: int, 126 | affine: bool = True, 127 | ): 128 | super(NormSwitch, self).__init__() 129 | self.norm_type = norm_type 130 | self.format = format 131 | self.num_features = num_features 132 | self.affine = affine 133 | 134 | if norm_type == "BN": 135 | if format == "1D": 136 | self.norm = nn.BatchNorm1d(num_features, affine=True) 137 | else: 138 | self.norm = nn.BatchNorm2d(num_features, affine=True) 139 | elif norm_type == "cLN": 140 | if format == "1D": 141 | self.norm = CumulativeLayerNorm1d(num_features, affine) 142 | else: 143 | self.norm = CumulativeLayerNorm2d(num_features, affine) 144 | elif norm_type == "cIN": 145 | if format == "2D": 146 | self.norm = CumulativeLayerNorm2d(num_features, affine) 147 | 148 | def forward(self, inpt): 149 | return self.norm(inpt) 150 | 151 | class CumulativeLayerNorm2d(nn.Module): 152 | def __init__(self, 153 | num_features, 154 | affine=True, 155 | eps=1e-5, 156 | ): 157 | super(CumulativeLayerNorm2d, self).__init__() 158 | self.num_features = num_features 159 | self.eps = eps 160 | self.affine = affine 161 | 162 | if affine: 163 | self.gain = nn.Parameter(torch.ones(1,num_features,1,1)) 164 | self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 165 | else: 166 | self.gain = Variable(torch.ones(1,num_features,1,1), requires_grad=False) 167 | self.bias = Variable(torch.zeros(1,num_features,1,1), requires_grad=False) 168 | 169 | def forward(self, inpt): 170 | """ 171 | :param inpt: (B,C,T,F) 172 | :return: 173 | """ 174 | b_size, channel, seq_len, freq_num = inpt.shape 175 | step_sum = inpt.sum([1,3], keepdim=True) # (B,1,T,1) 176 | step_pow_sum = inpt.pow(2).sum([1,3], keepdim=True) # (B,1,T,1) 177 | cum_sum = torch.cumsum(step_sum, dim=-2) # (B,1,T,1) 178 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=-2) # (B,1,T,1) 179 | 180 | entry_cnt = np.arange(channel*freq_num, channel*freq_num*(seq_len+1), channel*freq_num) 181 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 182 | entry_cnt = entry_cnt.view(1,1,seq_len,1).expand_as(cum_sum) 183 | 184 | cum_mean = cum_sum / entry_cnt 185 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 186 | cum_std = (cum_var + self.eps).sqrt() 187 | 188 | x = (inpt - cum_mean) / cum_std 189 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 190 | 191 | 192 | class CumulativeInstanceNorm2d(nn.Module): 193 | def __init__(self, 194 | num_features, 195 | affine=True, 196 | eps=1e-5, 197 | ): 198 | super(CumulativeInstanceNorm2d, self).__init__() 199 | self.num_features = num_features 200 | self.eps = eps 201 | self.affine = affine 202 | 203 | if affine: 204 | self.gain = nn.Parameter(torch.ones(1,num_features,1,1)) 205 | self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 206 | else: 207 | self.gain = Variable(torch.ones(1,num_features,1,1), requires_grad=False) 208 | self.bias = Variable(torch.zeros(1,num_features,1,1), requires_grad=False) 209 | 210 | def forward(self, inpt): 211 | """ 212 | :param inpt: (B,C,T,F) 213 | :return: 214 | """ 215 | b_size, channel, seq_len, freq_num = inpt.shape 216 | step_sum = inpt.sum([3], keepdim=True) # (B,C,T,1) 217 | step_pow_sum = inpt.pow(2).sum([3], keepdim=True) # (B,C,T,1) 218 | cum_sum = torch.cumsum(step_sum, dim=-2) # (B,C,T,1) 219 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=-2) # (B,C,T,1) 220 | 221 | entry_cnt = np.arange(freq_num, freq_num*(seq_len+1), freq_num) 222 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 223 | entry_cnt = entry_cnt.view(1,1,seq_len,1).expand_as(cum_sum) 224 | 225 | cum_mean = cum_sum / entry_cnt 226 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 227 | cum_std = (cum_var + self.eps).sqrt() 228 | 229 | x = (inpt - cum_mean) / cum_std 230 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 231 | 232 | 233 | class CumulativeLayerNorm1d(nn.Module): 234 | def __init__(self, 235 | num_features, 236 | affine=True, 237 | eps=1e-5, 238 | ): 239 | super(CumulativeLayerNorm1d, self).__init__() 240 | self.num_features = num_features 241 | self.affine = affine 242 | self.eps = eps 243 | 244 | if affine: 245 | self.gain = nn.Parameter(torch.ones(1,num_features,1), requires_grad=True) 246 | self.bias = nn.Parameter(torch.zeros(1,num_features,1), requires_grad=True) 247 | else: 248 | self.gain = Variable(torch.ones(1, num_features, 1), requires_grad=False) 249 | self.bias = Variable(torch.zeros(1, num_features, 1), requires_gra=False) 250 | 251 | def forward(self, inpt): 252 | # inpt: (B,C,T) 253 | b_size, channel, seq_len = inpt.shape 254 | cum_sum = torch.cumsum(inpt.sum(1), dim=1) # (B,T) 255 | cum_power_sum = torch.cumsum(inpt.pow(2).sum(1), dim=1) # (B,T) 256 | 257 | entry_cnt = np.arange(channel, channel*(seq_len+1), channel) 258 | entry_cnt = torch.from_numpy(entry_cnt).type(inpt.type()) 259 | entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) # (B,T) 260 | 261 | cum_mean = cum_sum / entry_cnt # (B,T) 262 | cum_var = (cum_power_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) 263 | cum_std = (cum_var + self.eps).sqrt() 264 | 265 | x = (inpt - cum_mean.unsqueeze(dim=1).expand_as(inpt)) / cum_std.unsqueeze(dim=1).expand_as(inpt) 266 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 267 | 268 | def sisnr(est, label): 269 | label_power = np.sum(label**2.0) + 1e-8 270 | scale = np.sum(est*label) / label_power 271 | est_true = scale * label 272 | est_res = est - est_true 273 | true_power = np.sum(est_true**2.0, axis=0) + 1e-8 274 | res_power = np.sum(est_res**2.0, axis=0) + 1e-8 275 | sdr = 10*np.log10(true_power) - 10*np.log10(res_power) 276 | return sdr 277 | 278 | def cal_pesq(id, esti_utts, clean_utts, fs): 279 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 280 | from pesq import pesq 281 | pesq_score = pesq(fs, clean_utt, esti_utt, "nb") 282 | return pesq_score 283 | 284 | def cal_stoi(id, esti_utts, clean_utts, fs): 285 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 286 | from pystoi import stoi 287 | stoi_score = stoi(clean_utt, esti_utt, fs, extended=True) 288 | return 100*stoi_score 289 | 290 | def cal_sisnr(id, esti_utts, clean_utts, fs): 291 | clean_utt, esti_utt = clean_utts[id,:], esti_utts[id,:] 292 | sisnr_score = sisnr(esti_utt, clean_utt) 293 | return sisnr_score --------------------------------------------------------------------------------