├── README.md ├── bss_test.py ├── config_WSJ0_SDNet.yaml ├── data ├── Readme ├── __init__.py ├── dataloader.py ├── dict.py └── utils.py ├── jpg ├── Readme └── sdnet.jpeg ├── lr_scheduler.py ├── models ├── Readme ├── Schmidt_orth.py ├── WaveLoss.py ├── __init__.py ├── attention.py ├── beam.py ├── focal_loss.py ├── istft_irfft.py ├── loss.py ├── metrics.py ├── rnn.py ├── separation_dis.py ├── separation_tasnet.py └── seq2seq.py ├── predata_WSJ_lcx.py ├── run.sh ├── separation.py ├── test_WSJ0_SDNet.py └── train_WSJ0_SDNet.py /README.md: -------------------------------------------------------------------------------- 1 | # ICASSP 2021: SDNet:Speaker and Direction Inferred Dual-channel Speech Separation 2 | 3 | If you have the interest in our work, or use this code or part of it, please cite us! 4 | Consider citing: 5 | ```bash 6 | @inproceedings{li2021speaker, 7 | title={Speaker and Direction Inferred Dual-Channel Speech Separation}, 8 | author={Li, Chenxing and Xu, Jiaming and Mesgarani, Nima and Xu, Bo}, 9 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 10 | pages={5779--5783}, 11 | year={2021}, 12 | organization={IEEE} 13 | } 14 | ``` 15 | For more detailed descirption, you can further explore the whole paper with [this link](https://doi.org/10.1109/ICASSP39728.2021.9413818). 16 | 17 | # Requirements: 18 | Pytorch>=1.1.0
19 | resampy
20 | soundfile
21 | 22 | # Model Descriptions: 23 | ![](https://github.com/aispeech-lab/SDNet/blob/main/jpg/sdnet.jpeg) 24 | 25 | 26 | 27 | # Data Preparation 28 | 29 | Please refer to predata_WSJ_lcx.py 30 | A more detailed dataset preparation procedure will be updated soon. 31 | 32 | # Train and Test 33 | 34 | For train:
35 | python train_WSJ0_SDNet.py
36 | 37 | For test:
38 | python test_WSJ0_SDNet.py
39 | 40 | Please Modify the model path in test_WSJ0_SDNet.py. 41 | 42 | # Contact 43 | If you have any questions please contact:
44 | Email:lichenxing007@gmail.com 45 | 46 | # TODO 47 | 1. A brief implemention of SDNet 48 | 2. pretrained models. 49 | 3. separated samples. 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /bss_test.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import numpy as np 3 | import os 4 | import soundfile as sf 5 | from separation import bss_eval_sources 6 | 7 | path='batch_output/' 8 | # path='/home/sw/Shin/Codes/DL4SS_Keras/TDAA_beta/batch_output2/' 9 | def cal_SDRi(src_ref, src_est, mix): 10 | """Calculate Source-to-Distortion Ratio improvement (SDRi). 11 | NOTE: bss_eval_sources is very very slow. 12 | Args: 13 | src_ref: numpy.ndarray, [C, T] 14 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation 15 | mix: numpy.ndarray, [T] 16 | Returns: 17 | average_SDRi 18 | """ 19 | src_anchor = np.stack([mix, mix], axis=0) 20 | if src_ref.shape[0]==1: 21 | src_anchor=src_anchor[0] 22 | sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est) 23 | sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor) 24 | avg_SDR = ((sdr[0]) + (sdr[1])) / 2 25 | avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2 26 | return avg_SDR, avg_SDRi 27 | 28 | 29 | def cal_SISNRi(src_ref, src_est, mix): 30 | """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) 31 | Args: 32 | src_ref: numpy.ndarray, [C, T] 33 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation 34 | mix: numpy.ndarray, [T] 35 | Returns: 36 | average_SISNRi 37 | """ 38 | 39 | sisnr1 = cal_SISNR(src_ref[0], src_est[0]) 40 | sisnr2 = cal_SISNR(src_ref[1], src_est[1]) 41 | sisnr1b = cal_SISNR(src_ref[0], mix) 42 | sisnr2b = cal_SISNR(src_ref[1], mix) 43 | avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2 44 | return avg_SISNRi 45 | 46 | def cal_SISNRi_PIT(src_ref, src_est, mix): 47 | """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) 2-mix 48 | Args: 49 | src_ref: numpy.ndarray, [C, T] 50 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation 51 | mix: numpy.ndarray, [T] 52 | Returns: 53 | average_SISNRi 54 | """ 55 | 56 | sisnr1_a = cal_SISNR(src_ref[0,:], src_est[0,:]) 57 | sisnr2_a = cal_SISNR(src_ref[1,:], src_est[1,:]) 58 | 59 | sisnr1_b = cal_SISNR(src_ref[0,:], src_est[1,:]) 60 | sisnr2_b = cal_SISNR(src_ref[1,:], src_est[0,:]) 61 | 62 | sisnr1_o = cal_SISNR(src_ref[0,:], mix) 63 | sisnr2_o = cal_SISNR(src_ref[1,:], mix) 64 | 65 | avg_SISNR = max((sisnr1_a+sisnr2_a)/2, (sisnr1_b+sisnr2_b)/2) 66 | avg_SISNRi = max( ((sisnr1_a - sisnr1_o) + (sisnr2_a - sisnr2_o)) / 2 ,((sisnr1_b - sisnr1_o) + (sisnr2_b - sisnr2_o)) / 2 ) 67 | return avg_SISNR, avg_SISNRi 68 | 69 | def cal_SISNR(ref_sig, out_sig, eps=1e-8): 70 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) 71 | Args: 72 | ref_sig: numpy.ndarray, [T] 73 | out_sig: numpy.ndarray, [T] 74 | Returns: 75 | SISNR 76 | """ 77 | assert len(ref_sig) == len(out_sig) 78 | ref_sig = ref_sig - np.mean(ref_sig) 79 | out_sig = out_sig - np.mean(out_sig) 80 | ref_energy = np.sum(ref_sig ** 2) + eps 81 | proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy 82 | noise = out_sig - proj 83 | ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps) 84 | sisnr = 10 * np.log(ratio + eps) / np.log(10.0) 85 | return sisnr 86 | 87 | def cal(path,tmp=None): 88 | mix_number=len(set([l.split('_')[0] for l in os.listdir(path) if l[-3:]=='wav'])) 89 | print('num of mixed :',mix_number) 90 | SDR_sum=np.array([]) 91 | SDRi_sum=np.array([]) 92 | for idx in range(mix_number): 93 | pre_speech_channel=[] 94 | aim_speech_channel=[] 95 | mix_speech=[] 96 | for l in sorted(os.listdir(path)): 97 | if l[-3:]!='wav': 98 | continue 99 | if l.split('_')[0]==str(idx): 100 | if 'True_mix' in l: 101 | mix_speech.append(sf.read(path+l)[0]) 102 | if 'real' in l and 'noise' not in l: 103 | aim_speech_channel.append(sf.read(path+l)[0]) 104 | if 'pre' in l: 105 | pre_speech_channel.append(sf.read(path+l)[0]) 106 | 107 | assert len(aim_speech_channel)==len(pre_speech_channel) 108 | aim_speech_channel=np.array(aim_speech_channel) 109 | pre_speech_channel=np.array(pre_speech_channel) 110 | mix_speech=np.array(mix_speech) 111 | assert mix_speech.shape[0]==1 112 | mix_speech=mix_speech[0] 113 | 114 | result=bss_eval_sources(aim_speech_channel,pre_speech_channel) 115 | SDR_sum=np.append(SDR_sum,result[0]) 116 | 117 | # result=bss_eval_sources(aim_speech_channel,aim_speech_channel) 118 | # result_sdri=cal_SDRi(aim_speech_channel,pre_speech_channel,mix_speech) 119 | # print 'SDRi:',result_sdri 120 | result=cal_SISNRi(aim_speech_channel,pre_speech_channel,mix_speech) 121 | print('SI-SNR',result) 122 | # for ii in range(aim_speech_channel.shape[0]): 123 | # result=cal_SISNRi(aim_speech_channel[ii],pre_speech_channel[ii],mix_speech[ii]) 124 | # print('SI-SNR',result) 125 | # SDRi_sum=np.append(SDRi_sum,result_sdri) 126 | 127 | print('SDR_Aver for this batch:',SDR_sum.mean()) 128 | # print 'SDRi_Aver for this batch:',SDRi_sum.mean() 129 | return SDR_sum.mean(),SDRi_sum.mean() 130 | 131 | # cal(path) 132 | 133 | -------------------------------------------------------------------------------- /config_WSJ0_SDNet.yaml: -------------------------------------------------------------------------------- 1 | log: './log/' 2 | epoch: 300 3 | batch_size: 1 4 | param_init: 0.1 5 | optim: 'adam' 6 | loss: 'focal_loss' 7 | use_center_loss: 0 8 | learning_rate: 0.001 9 | max_grad_norm: 5 10 | learning_rate_decay: 0.5 11 | 12 | mask: 1 13 | schedule: 1 14 | bidirec: True 15 | start_decay_at: 5 16 | emb_size: 256 17 | encoder_hidden_size: 256 18 | decoder_hidden_size: 512 19 | num_layers: 4 20 | dropout: 0.5 21 | max_tgt_len: 5 22 | eval_interval: 1000 23 | save_interval: 1000 24 | max_generator_batches: 32 25 | metric: ['hamming_loss', 'micro_f1'] 26 | shared_vocab: 0 27 | WFM: 1 28 | MLMSE: 0 29 | beam_size: 5 30 | tmp_score: 0 31 | top1: 0 32 | ct_recu: 0 33 | 34 | use_tas: 1 35 | all_soft: 0 36 | 37 | global_emb: 1 38 | global_hidden: 0 39 | SPK_EMB_SIZE: 256 40 | schmidt: 0 41 | unit_norm: 1 42 | reID: 0 43 | is_SelfTune : 0 44 | is_dis: 0 45 | speech_cnn_net: 0 46 | relitu: 0 47 | ALPHA: 0.5 48 | quchong_alpha: 1 49 | 50 | #Minimum number of mixed speakers for training 51 | MIN_MIX: 2 52 | #Maximum number of mixed speakers for training 53 | MAX_MIX: 2 54 | MODE: 1 55 | DATASET : 'WSJ0' 56 | is_ComlexMask: 1 57 | num_samples_one_epoch: 20000 58 | 59 | Ground_truth: 1 60 | Comm_with_Memory: 0 61 | HIDDEN_UNITS: 300 62 | NUM_LAYERS: 3 63 | EMBEDDING_SIZE: 50 64 | 65 | ATT_SIZE: 100 66 | AUGMENT_DATA: 0 67 | MAX_EPOCH: 600 68 | EPOCH_SIZE: 600 69 | FRAME_RATE: 8000 70 | FRAME_LENGTH: 256 71 | FRAME_SHIFT: 64 72 | SHUFFLE_BATCH: 1 73 | voice_dB: 2.5 74 | noise_dB: -10 75 | normalize: 1 76 | MIN_LEN: 24000 77 | MAX_LEN: 24000 78 | WINDOWS: FRAME_LENGTH 79 | START_EALY_STOP: 0 80 | IS_LOG_SPECTRAL : 0 81 | channel_first: 1 82 | -------------------------------------------------------------------------------- /data/Readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from models.attention import * 2 | from models.rnn import * 3 | from models.seq2seq import * 4 | from models.loss import * 5 | from models.beam import * -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as torch_data 3 | import os 4 | import data.utils 5 | 6 | class dataset(torch_data.Dataset): 7 | 8 | def __init__(self, src, tgt, raw_src, raw_tgt): 9 | 10 | self.src = src 11 | self.tgt = tgt 12 | self.raw_src = raw_src 13 | self.raw_tgt = raw_tgt 14 | 15 | def __getitem__(self, index): 16 | return self.src[index], self.tgt[index], \ 17 | self.raw_src[index], self.raw_tgt[index] 18 | 19 | def __len__(self): 20 | return len(self.src) 21 | 22 | 23 | def load_dataset(path): 24 | pass 25 | 26 | def save_dataset(dataset, path): 27 | if not os.path.exists(path): 28 | os.mkdir(path) 29 | 30 | 31 | def padding(data): 32 | #data.sort(key=lambda x: len(x[0]), reverse=True) 33 | src, tgt, raw_src, raw_tgt = zip(*data) 34 | 35 | src_len = [len(s) for s in src] 36 | src_pad = torch.zeros(len(src), max(src_len)).long() 37 | for i, s in enumerate(src): 38 | end = src_len[i] 39 | src_pad[i, :end] = s[:end] 40 | 41 | tgt_len = [len(s) for s in tgt] 42 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long() 43 | for i, s in enumerate(tgt): 44 | end = tgt_len[i] 45 | tgt_pad[i, :end] = s[:end] 46 | #tgt_len = [length-1 for length in tgt_len] 47 | 48 | #return src_pad.t(), src_len, tgt_pad.t(), tgt_len 49 | return raw_src, src_pad.t(), torch.LongTensor(src_len), \ 50 | raw_tgt, tgt_pad.t(), torch.LongTensor(tgt_len) 51 | 52 | 53 | def get_loader(dataset, batch_size, shuffle, num_workers): 54 | 55 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 56 | batch_size=batch_size, 57 | shuffle=shuffle, 58 | num_workers=num_workers, 59 | collate_fn=padding) 60 | return data_loader -------------------------------------------------------------------------------- /data/dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | PAD = 0 4 | UNK = 1 5 | BOS = 2 6 | EOS = 3 7 | 8 | PAD_WORD = '' 9 | UNK_WORD = ' ' 10 | BOS_WORD = '' 11 | EOS_WORD = '' 12 | SPA_WORD = ' ' 13 | 14 | def flatten(l): 15 | for el in l: 16 | if hasattr(el, "__iter__"): 17 | for sub in flatten(el): 18 | yield sub 19 | else: 20 | yield el 21 | 22 | class Dict(object): 23 | def __init__(self, data=None, lower=False): 24 | self.idxToLabel = {} 25 | self.labelToIdx = {} 26 | self.frequencies = {} 27 | self.lower = lower 28 | self.special = [] 29 | 30 | if data is not None: 31 | if type(data) == str: 32 | self.loadFile(data) 33 | else: 34 | self.addSpecials(data) 35 | 36 | def size(self): 37 | return len(self.idxToLabel) 38 | 39 | # Load entries from a file. 40 | def loadFile(self, filename): 41 | for line in open(filename): 42 | fields = line.split() 43 | label = fields[0] 44 | idx = int(fields[1]) 45 | self.add(label, idx) 46 | 47 | # Write entries to a file. 48 | def writeFile(self, filename): 49 | with open(filename, 'w') as file: 50 | for i in range(self.size()): 51 | label = self.idxToLabel[i] 52 | file.write('%s %d\n' % (label, i)) 53 | 54 | file.close() 55 | 56 | def loadDict(self, idxToLabel): 57 | for i in range(len(idxToLabel)): 58 | label = idxToLabel[i] 59 | self.add(label, i) 60 | 61 | def lookup(self, key, default=None): 62 | key = key.lower() if self.lower else key 63 | try: 64 | return self.labelToIdx[key] 65 | except KeyError: 66 | return default 67 | 68 | def getLabel(self, idx, default=None): 69 | try: 70 | return self.idxToLabel[idx] 71 | except KeyError: 72 | return default 73 | 74 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 75 | def addSpecial(self, label, idx=None): 76 | idx = self.add(label, idx) 77 | self.special += [idx] 78 | 79 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 80 | def addSpecials(self, labels): 81 | for label in labels: 82 | self.addSpecial(label) 83 | 84 | # Add `label` in the dictionary. Use `idx` as its index if given. 85 | def add(self, label, idx=None): 86 | label = label.lower() if self.lower else label 87 | if idx is not None: 88 | self.idxToLabel[idx] = label 89 | self.labelToIdx[label] = idx 90 | else: 91 | if label in self.labelToIdx: 92 | idx = self.labelToIdx[label] 93 | else: 94 | idx = len(self.idxToLabel) 95 | self.idxToLabel[idx] = label 96 | self.labelToIdx[label] = idx 97 | 98 | if idx not in self.frequencies: 99 | self.frequencies[idx] = 1 100 | else: 101 | self.frequencies[idx] += 1 102 | 103 | return idx 104 | 105 | # Return a new dictionary with the `size` most frequent entries. 106 | def prune(self, size): 107 | if size >= self.size(): 108 | return self 109 | 110 | # Only keep the `size` most frequent entries. 111 | freq = torch.Tensor( 112 | [self.frequencies[i] for i in range(len(self.frequencies))]) 113 | _, idx = torch.sort(freq, 0, True) 114 | 115 | newDict = Dict() 116 | newDict.lower = self.lower 117 | 118 | # Add special entries in all cases. 119 | for i in self.special: 120 | newDict.addSpecial(self.idxToLabel[i]) 121 | 122 | for i in idx[:size]: 123 | newDict.add(self.idxToLabel[i]) 124 | 125 | return newDict 126 | 127 | # Convert `labels` to indices. Use `unkWord` if not found. 128 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 129 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None): 130 | vec = [] 131 | 132 | if bosWord is not None: 133 | vec += [self.lookup(bosWord)] 134 | 135 | unk = self.lookup(unkWord) 136 | vec += [self.lookup(label, default=unk) for label in labels] 137 | 138 | if eosWord is not None: 139 | vec += [self.lookup(eosWord)] 140 | 141 | vec = [x for x in flatten(vec)] 142 | 143 | return torch.LongTensor(vec) 144 | 145 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 146 | def convertToLabels(self, idx, stop): 147 | labels = [] 148 | 149 | for i in idx: 150 | if i == stop: 151 | break 152 | labels += [self.getLabel(i)] 153 | 154 | return labels -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import os 3 | import csv 4 | import codecs 5 | import yaml 6 | import time 7 | import numpy as np 8 | import shutil 9 | import soundfile as sf 10 | import librosa 11 | 12 | from sklearn import metrics 13 | 14 | 15 | class AttrDict(dict): 16 | def __init__(self, *args, **kwargs): 17 | super(AttrDict, self).__init__(*args, **kwargs) 18 | self.__dict__ = self 19 | 20 | 21 | def read_config(path): 22 | return AttrDict(yaml.load(open(path, 'r'))) 23 | 24 | 25 | def read_datas(filename, trans_to_num=False): 26 | lines = open(filename, 'r').readlines() 27 | lines = list(map(lambda x: x.split(), lines)) 28 | if trans_to_num: 29 | lines = [list(map(int, line)) for line in lines] 30 | return lines 31 | 32 | 33 | def save_datas(data, filename, trans_to_str=False): 34 | if trans_to_str: 35 | data = [list(map(str, line)) for line in data] 36 | lines = list(map(lambda x: " ".join(x), data)) 37 | with open(filename, 'w') as f: 38 | f.write("\n".join(lines)) 39 | 40 | 41 | def logging(file): 42 | def write_log(s): 43 | print(s, '') 44 | with open(file, 'a') as f: 45 | f.write(s) 46 | 47 | return write_log 48 | 49 | 50 | def logging_csv(file): 51 | def write_csv(s): 52 | # with open(file, 'a', newline='') as f: 53 | with open(file, 'a') as f: 54 | writer = csv.writer(f) 55 | writer.writerow(s) 56 | 57 | return write_csv 58 | 59 | 60 | def format_time(t): 61 | return time.strftime("%Y-%m-%d-%H:%M:%S", t) 62 | 63 | 64 | def eval_metrics(reference, candidate, label_dict, log_path): 65 | ref_dir = log_path + 'reference/' 66 | cand_dir = log_path + 'candidate/' 67 | if not os.path.exists(ref_dir): 68 | os.mkdir(ref_dir) 69 | if not os.path.exists(cand_dir): 70 | os.mkdir(cand_dir) 71 | ref_file = ref_dir + 'reference' 72 | cand_file = cand_dir + 'candidate' 73 | 74 | for i in range(len(reference)): 75 | with codecs.open(ref_file + str(i), 'w', 'utf-8') as f: 76 | f.write("".join(reference[i]) + '\n') 77 | with codecs.open(cand_file + str(i), 'w', 'utf-8') as f: 78 | f.write("".join(candidate[i]) + '\n') 79 | 80 | def make_label(l, label_dict): 81 | length = len(label_dict) 82 | result = np.zeros(length) 83 | indices = [label_dict.get(label.strip().lower(), 0) for label in l] 84 | result[indices] = 1 85 | return result 86 | 87 | def prepare_label(y_list, y_pre_list, label_dict): 88 | reference = np.array([make_label(y, label_dict) for y in y_list]) 89 | candidate = np.array([make_label(y_pre, label_dict) for y_pre in y_pre_list]) 90 | return reference, candidate 91 | 92 | def get_metrics(y, y_pre): 93 | hamming_loss = metrics.hamming_loss(y, y_pre) 94 | macro_f1 = metrics.f1_score(y, y_pre, average='macro') 95 | macro_precision = metrics.precision_score(y, y_pre, average='macro') 96 | macro_recall = metrics.recall_score(y, y_pre, average='macro') 97 | micro_f1 = metrics.f1_score(y, y_pre, average='micro') 98 | micro_precision = metrics.precision_score(y, y_pre, average='micro') 99 | micro_recall = metrics.recall_score(y, y_pre, average='micro') 100 | return hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall 101 | 102 | y, y_pre = prepare_label(reference, candidate, label_dict) 103 | hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall = get_metrics(y, 104 | y_pre) 105 | return {'hamming_loss': hamming_loss, 106 | 'macro_f1': macro_f1, 107 | 'macro_precision': macro_precision, 108 | 'macro_recall': macro_recall, 109 | 'micro_f1': micro_f1, 110 | 'micro_precision': micro_precision, 111 | 'micro_recall': micro_recall} 112 | 113 | 114 | def bss_eval(config, predict_multi_map, y_multi_map, y_map_gtruth, train_data, dst='batch_output'): 115 | # dst='batch_output' 116 | if os.path.exists(dst): 117 | print(" \ncleanup: " + dst + "/") 118 | shutil.rmtree(dst) 119 | os.makedirs(dst) 120 | 121 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']): 122 | for each_spk in each_sample.keys(): 123 | this_spk = each_spk 124 | wav_genTrue = each_sample[this_spk] 125 | # min_len = 39936 126 | min_len = len(wav_genTrue) 127 | if config.FRAME_SHIFT == 64: 128 | min_len = len(wav_genTrue) 129 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len], 130 | config.FRAME_RATE, ) 131 | 132 | predict_multi_map_list = [] 133 | pointer = 0 134 | for each_line in y_map_gtruth: 135 | predict_multi_map_list.append(predict_multi_map[pointer:(pointer + len(each_line))]) 136 | pointer += len(each_line) 137 | assert len(predict_multi_map_list) == len(y_map_gtruth) 138 | 139 | # 对于每个sample 140 | sample_idx = 0 # 代表一个batch里的依次第几个 141 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth, 142 | train_data['aim_spkname']): 143 | _mix_spec = train_data['mix_phase'][sample_idx] 144 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx] 145 | phase_mix = np.angle(_mix_spec) 146 | for idx, one_cha in enumerate(each_trueVector): 147 | this_spk = one_cha 148 | y_pre_map = each_pre[idx].data.cpu().numpy() 149 | _pred_spec = y_pre_map * np.exp(1j * phase_mix) 150 | wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT) 151 | min_len = len(wav_pre) 152 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, ) 153 | 154 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix) 155 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT) 156 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len], 157 | config.FRAME_RATE, ) 158 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len], 159 | config.FRAME_RATE, ) 160 | sample_idx += 1 161 | 162 | 163 | def bss_eval2(config, predict_multi_map, y_multi_map, y_map_gtruth, train_data, dst='batch_output'): 164 | # dst='batch_output' 165 | if os.path.exists(dst): 166 | print(" \ncleanup: " + dst + "/") 167 | shutil.rmtree(dst) 168 | os.makedirs(dst) 169 | 170 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']): 171 | for each_spk in each_sample.keys(): 172 | this_spk = each_spk 173 | wav_genTrue = each_sample[this_spk] 174 | # min_len = 39936 175 | min_len = len(wav_genTrue) 176 | if config.FRAME_SHIFT == 64: 177 | min_len = len(wav_genTrue) 178 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len], 179 | config.FRAME_RATE, ) 180 | 181 | predict_multi_map_list = [] 182 | pointer = 0 183 | for each_line in y_map_gtruth: 184 | predict_multi_map_list.append(predict_multi_map[pointer:(pointer + len(each_line))]) 185 | pointer += len(each_line) 186 | assert len(predict_multi_map_list) == len(y_map_gtruth) 187 | 188 | # 对于每个sample 189 | sample_idx = 0 # 代表一个batch里的依次第几个 190 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth, 191 | train_data['aim_spkname']): 192 | _mix_spec = train_data['mix_phase'][sample_idx] 193 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx] 194 | phase_mix = np.angle(_mix_spec) 195 | each_pre = each_pre[0] 196 | for idx, one_cha in enumerate(each_trueVector): 197 | this_spk = one_cha 198 | y_pre_map = each_pre[idx].data.cpu().numpy() 199 | _pred_spec = y_pre_map * np.exp(1j * phase_mix) 200 | wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT) 201 | min_len = len(wav_pre) 202 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, ) 203 | 204 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix) 205 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT) 206 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len], 207 | config.FRAME_RATE, ) 208 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len], 209 | config.FRAME_RATE, ) 210 | sample_idx += 1 211 | 212 | def bss_eval_tas(config, predict_wav, y_multi_map, y_map_gtruth, train_data, dst='batch_output'): 213 | # dst='batch_output' 214 | if os.path.exists(dst): 215 | print(" \ncleanup: " + dst + "/") 216 | shutil.rmtree(dst) 217 | os.makedirs(dst) 218 | 219 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']): 220 | for each_spk in each_sample.keys(): 221 | this_spk = each_spk 222 | wav_genTrue = each_sample[this_spk] 223 | # min_len = 39936 224 | min_len = len(wav_genTrue) 225 | if config.FRAME_SHIFT == 64: 226 | min_len = len(wav_genTrue) 227 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len], 228 | config.FRAME_RATE, ) 229 | 230 | predict_multi_map_list = [] 231 | pointer = 0 232 | # if len(predict_wav.shape)==3: 233 | # predict_wav=predict_wav.view(-1,predict_wav.shape[-1]) 234 | for each_line in y_map_gtruth: 235 | predict_multi_map_list.append(predict_wav[pointer:(pointer + len(each_line))]) 236 | pointer += len(each_line) 237 | assert len(predict_multi_map_list) == len(y_map_gtruth) 238 | predict_multi_map_list=[i for i in predict_wav.unsqueeze(1)] 239 | 240 | # 对于每个sample 241 | sample_idx = 0 # 代表一个batch里的依次第几个 242 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth, 243 | train_data['aim_spkname']): 244 | _mix_spec = train_data['mix_phase'][sample_idx] 245 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx] 246 | phase_mix = np.angle(_mix_spec) 247 | each_pre = each_pre[0] 248 | for idx, one_cha in enumerate(each_trueVector): 249 | this_spk = one_cha 250 | y_pre_map = each_pre[idx].data.cpu().numpy() 251 | # _pred_spec = y_pre_map * np.exp(1j * phase_mix) 252 | # wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT) 253 | wav_pre = y_pre_map 254 | min_len = len(wav_pre) 255 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, ) 256 | 257 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix) 258 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT) 259 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len], 260 | config.FRAME_RATE, ) 261 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len], 262 | config.FRAME_RATE, ) 263 | sample_idx += 1 264 | -------------------------------------------------------------------------------- /jpg/Readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /jpg/sdnet.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aispeech-lab/SDNet/d057e6d2524b1487d65d4473499d50ef935a7beb/jpg/sdnet.jpeg -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class _LRScheduler(object): 7 | def __init__(self, optimizer, last_epoch=-1): 8 | if not isinstance(optimizer, Optimizer): 9 | raise TypeError('{} is not an Optimizer'.format( 10 | type(optimizer).__name__)) 11 | self.optimizer = optimizer 12 | if last_epoch == -1: 13 | for group in optimizer.param_groups: 14 | group.setdefault('initial_lr', group['lr']) 15 | else: 16 | for i, group in enumerate(optimizer.param_groups): 17 | if 'initial_lr' not in group: 18 | raise KeyError("param 'initial_lr' is not specified " 19 | "in param_groups[{}] when resuming an optimizer".format(i)) 20 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 21 | self.step(last_epoch + 1) 22 | self.last_epoch = last_epoch 23 | 24 | def get_lr(self): 25 | raise NotImplementedError 26 | 27 | def step(self, epoch=None): 28 | if epoch is None: 29 | epoch = self.last_epoch + 1 30 | self.last_epoch = epoch 31 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 32 | param_group['lr'] = lr 33 | 34 | 35 | class LambdaLR(_LRScheduler): 36 | """Sets the learning rate of each parameter group to the initial lr 37 | times a given function. When last_epoch=-1, sets initial lr as lr. 38 | Args: 39 | optimizer (Optimizer): Wrapped optimizer. 40 | lr_lambda (function or list): A function which computes a multiplicative 41 | factor given an integer parameter epoch, or a list of such 42 | functions, one for each group in optimizer.param_groups. 43 | last_epoch (int): The index of last epoch. Default: -1. 44 | Example: 45 | >>> # Assuming optimizer has two groups. 46 | >>> lambda1 = lambda epoch: epoch // 30 47 | >>> lambda2 = lambda epoch: 0.95 ** epoch 48 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 49 | >>> for epoch in range(100): 50 | >>> scheduler.step() 51 | >>> train(...) 52 | >>> validate(...) 53 | """ 54 | 55 | def __init__(self, optimizer, lr_lambda, last_epoch=-1): 56 | self.optimizer = optimizer 57 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 58 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 59 | else: 60 | if len(lr_lambda) != len(optimizer.param_groups): 61 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 62 | len(optimizer.param_groups), len(lr_lambda))) 63 | self.lr_lambdas = list(lr_lambda) 64 | self.last_epoch = last_epoch 65 | super(LambdaLR, self).__init__(optimizer, last_epoch) 66 | 67 | def get_lr(self): 68 | return [base_lr * lmbda(self.last_epoch) 69 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 70 | 71 | 72 | class StepLR(_LRScheduler): 73 | """Sets the learning rate of each parameter group to the initial lr 74 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 75 | initial lr as lr. 76 | Args: 77 | optimizer (Optimizer): Wrapped optimizer. 78 | step_size (int): Period of learning rate decay. 79 | gamma (float): Multiplicative factor of learning rate decay. 80 | Default: 0.1. 81 | last_epoch (int): The index of last epoch. Default: -1. 82 | Example: 83 | >>> # Assuming optimizer uses lr = 0.5 for all groups 84 | >>> # lr = 0.05 if epoch < 30 85 | >>> # lr = 0.005 if 30 <= epoch < 60 86 | >>> # lr = 0.0005 if 60 <= epoch < 90 87 | >>> # ... 88 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 89 | >>> for epoch in range(100): 90 | >>> scheduler.step() 91 | >>> train(...) 92 | >>> validate(...) 93 | """ 94 | 95 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): 96 | self.step_size = step_size 97 | self.gamma = gamma 98 | super(StepLR, self).__init__(optimizer, last_epoch) 99 | 100 | def get_lr(self): 101 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size) 102 | for base_lr in self.base_lrs] 103 | 104 | 105 | class MultiStepLR(_LRScheduler): 106 | """Set the learning rate of each parameter group to the initial lr decayed 107 | by gamma once the number of epoch reaches one of the milestones. When 108 | last_epoch=-1, sets initial lr as lr. 109 | Args: 110 | optimizer (Optimizer): Wrapped optimizer. 111 | milestones (list): List of epoch indices. Must be increasing. 112 | gamma (float): Multiplicative factor of learning rate decay. 113 | Default: 0.1. 114 | last_epoch (int): The index of last epoch. Default: -1. 115 | Example: 116 | >>> # Assuming optimizer uses lr = 0.5 for all groups 117 | >>> # lr = 0.05 if epoch < 30 118 | >>> # lr = 0.005 if 30 <= epoch < 80 119 | >>> # lr = 0.0005 if epoch >= 80 120 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 121 | >>> for epoch in range(100): 122 | >>> scheduler.step() 123 | >>> train(...) 124 | >>> validate(...) 125 | """ 126 | 127 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 128 | if not list(milestones) == sorted(milestones): 129 | raise ValueError('Milestones should be a list of' 130 | ' increasing integers. Got {}', milestones) 131 | self.milestones = milestones 132 | self.gamma = gamma 133 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 134 | 135 | def get_lr(self): 136 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 137 | for base_lr in self.base_lrs] 138 | 139 | 140 | class ExponentialLR(_LRScheduler): 141 | """Set the learning rate of each parameter group to the initial lr decayed 142 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr. 143 | Args: 144 | optimizer (Optimizer): Wrapped optimizer. 145 | gamma (float): Multiplicative factor of learning rate decay. 146 | last_epoch (int): The index of last epoch. Default: -1. 147 | """ 148 | 149 | def __init__(self, optimizer, gamma, last_epoch=-1): 150 | self.gamma = gamma 151 | super(ExponentialLR, self).__init__(optimizer, last_epoch) 152 | 153 | def get_lr(self): 154 | return [base_lr * self.gamma ** self.last_epoch 155 | for base_lr in self.base_lrs] 156 | 157 | 158 | class CosineAnnealingLR(_LRScheduler): 159 | """Set the learning rate of each parameter group using a cosine annealing 160 | schedule, where :math:`\eta_{max}` is set to the initial lr and 161 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 162 | .. math:: 163 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 164 | \cos(\frac{T_{cur}}{T_{max}}\pi)) 165 | When last_epoch=-1, sets initial lr as lr. 166 | It has been proposed in 167 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 168 | implements the cosine annealing part of SGDR, and not the restarts. 169 | Args: 170 | optimizer (Optimizer): Wrapped optimizer. 171 | T_max (int): Maximum number of iterations. 172 | eta_min (float): Minimum learning rate. Default: 0. 173 | last_epoch (int): The index of last epoch. Default: -1. 174 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 175 | https://arxiv.org/abs/1608.03983 176 | """ 177 | 178 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): 179 | self.T_max = T_max 180 | self.eta_min = eta_min 181 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) 182 | 183 | def get_lr(self): 184 | return [self.eta_min + (base_lr - self.eta_min) * 185 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 186 | for base_lr in self.base_lrs] 187 | 188 | 189 | class ReduceLROnPlateau(object): 190 | """Reduce learning rate when a metric has stopped improving. 191 | Models often benefit from reducing the learning rate by a factor 192 | of 2-10 once learning stagnates. This scheduler reads a metrics 193 | quantity and if no improvement is seen for a 'patience' number 194 | of epochs, the learning rate is reduced. 195 | Args: 196 | optimizer (Optimizer): Wrapped optimizer. 197 | mode (str): One of `min`, `max`. In `min` mode, lr will 198 | be reduced when the quantity monitored has stopped 199 | decreasing; in `max` mode it will be reduced when the 200 | quantity monitored has stopped increasing. Default: 'min'. 201 | factor (float): Factor by which the learning rate will be 202 | reduced. new_lr = lr * factor. Default: 0.1. 203 | patience (int): Number of epochs with no improvement after 204 | which learning rate will be reduced. Default: 10. 205 | verbose (bool): If True, prints a message to stdout for 206 | each update. Default: False. 207 | threshold (float): Threshold for measuring the new optimum, 208 | to only focus on significant changes. Default: 1e-4. 209 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 210 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 211 | mode or best * ( 1 - threshold ) in `min` mode. 212 | In `abs` mode, dynamic_threshold = best + threshold in 213 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 214 | cooldown (int): Number of epochs to wait before resuming 215 | normal operation after lr has been reduced. Default: 0. 216 | min_lr (float or list): A scalar or a list of scalars. A 217 | lower bound on the learning rate of all param groups 218 | or each group respectively. Default: 0. 219 | eps (float): Minimal decay applied to lr. If the difference 220 | between new and old lr is smaller than eps, the update is 221 | ignored. Default: 1e-8. 222 | Example: 223 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 224 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 225 | >>> for epoch in range(10): 226 | >>> train(...) 227 | >>> val_loss = validate(...) 228 | >>> # Note that step should be called after validate() 229 | >>> scheduler.step(val_loss) 230 | """ 231 | 232 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 233 | verbose=False, threshold=1e-4, threshold_mode='rel', 234 | cooldown=0, min_lr=0, eps=1e-8): 235 | 236 | if factor >= 1.0: 237 | raise ValueError('Factor should be < 1.0.') 238 | self.factor = factor 239 | 240 | if not isinstance(optimizer, Optimizer): 241 | raise TypeError('{} is not an Optimizer'.format( 242 | type(optimizer).__name__)) 243 | self.optimizer = optimizer 244 | 245 | if isinstance(min_lr, list) or isinstance(min_lr, tuple): 246 | if len(min_lr) != len(optimizer.param_groups): 247 | raise ValueError("expected {} min_lrs, got {}".format( 248 | len(optimizer.param_groups), len(min_lr))) 249 | self.min_lrs = list(min_lr) 250 | else: 251 | self.min_lrs = [min_lr] * len(optimizer.param_groups) 252 | 253 | self.patience = patience 254 | self.verbose = verbose 255 | self.cooldown = cooldown 256 | self.cooldown_counter = 0 257 | self.mode = mode 258 | self.threshold = threshold 259 | self.threshold_mode = threshold_mode 260 | self.best = None 261 | self.num_bad_epochs = None 262 | self.mode_worse = None # the worse value for the chosen mode 263 | self.is_better = None 264 | self.eps = eps 265 | self.last_epoch = -1 266 | self._init_is_better(mode=mode, threshold=threshold, 267 | threshold_mode=threshold_mode) 268 | self._reset() 269 | 270 | def _reset(self): 271 | """Resets num_bad_epochs counter and cooldown counter.""" 272 | self.best = self.mode_worse 273 | self.cooldown_counter = 0 274 | self.num_bad_epochs = 0 275 | 276 | def step(self, metrics, epoch=None): 277 | current = metrics 278 | if epoch is None: 279 | epoch = self.last_epoch = self.last_epoch + 1 280 | self.last_epoch = epoch 281 | 282 | if self.is_better(current, self.best): 283 | self.best = current 284 | self.num_bad_epochs = 0 285 | else: 286 | self.num_bad_epochs += 1 287 | 288 | if self.in_cooldown: 289 | self.cooldown_counter -= 1 290 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 291 | 292 | if self.num_bad_epochs > self.patience: 293 | self._reduce_lr(epoch) 294 | self.cooldown_counter = self.cooldown 295 | self.num_bad_epochs = 0 296 | 297 | def _reduce_lr(self, epoch): 298 | for i, param_group in enumerate(self.optimizer.param_groups): 299 | old_lr = float(param_group['lr']) 300 | new_lr = max(old_lr * self.factor, self.min_lrs[i]) 301 | if old_lr - new_lr > self.eps: 302 | param_group['lr'] = new_lr 303 | if self.verbose: 304 | print('Epoch {:5d}: reducing learning rate' 305 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 306 | 307 | @property 308 | def in_cooldown(self): 309 | return self.cooldown_counter > 0 310 | 311 | def _init_is_better(self, mode, threshold, threshold_mode): 312 | if mode not in {'min', 'max'}: 313 | raise ValueError('mode ' + mode + ' is unknown!') 314 | if threshold_mode not in {'rel', 'abs'}: 315 | raise ValueError('threshold mode ' + mode + ' is unknown!') 316 | if mode == 'min' and threshold_mode == 'rel': 317 | rel_epsilon = 1. - threshold 318 | self.is_better = lambda a, best: a < best * rel_epsilon 319 | self.mode_worse = float('Inf') 320 | elif mode == 'min' and threshold_mode == 'abs': 321 | self.is_better = lambda a, best: a < best - threshold 322 | self.mode_worse = float('Inf') 323 | elif mode == 'max' and threshold_mode == 'rel': 324 | rel_epsilon = threshold + 1. 325 | self.is_better = lambda a, best: a > best * rel_epsilon 326 | self.mode_worse = -float('Inf') 327 | else: # mode == 'max' and epsilon_mode == 'abs': 328 | self.is_better = lambda a, best: a > best + threshold 329 | self.mode_worse = -float('Inf') 330 | -------------------------------------------------------------------------------- /models/Readme: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/Schmidt_orth.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | 4 | 5 | def schmidt(this_vec, vectors): 6 | # this_vector是[bs,hidden_emb] 7 | # vectors是个列表,每一个里面应该是this_vector这么大的东西 8 | if len(vectors) == 0: 9 | return this_vec 10 | else: 11 | for vec in vectors: 12 | assert len(vec.size()) == len(this_vec.size()) == 2 13 | dot = torch.bmm(this_vec.unsqueeze(1), vec.unsqueeze(-1)).squeeze(-1) 14 | norm = torch.bmm(vec.unsqueeze(1), vec.unsqueeze(-1)).squeeze(-1) 15 | frac = dot / norm # bs,1 16 | this_vec = this_vec - (frac * vec) 17 | # print 'final_vec:',this_vec 18 | return this_vec 19 | -------------------------------------------------------------------------------- /models/WaveLoss.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | import soundfile as sf 5 | import resampy 6 | # from prepare_data_wsj2 import linearspectrogram 7 | import librosa 8 | import numpy as np 9 | from models.istft_irfft import istft_irfft 10 | 11 | # 参考https://github.com/jonlu0602/DeepDenoisingAutoencoder/blob/master/python/utils.py 12 | def linearspectrogram(y, dBscale = 1, normalize=1): 13 | fft_size = 256 14 | hop_size = 128 15 | ref_db = 20 16 | max_db = 100 17 | D = librosa.core.spectrum.stft(y, fft_size, hop_size) # F, T 18 | F, T = D.shape 19 | S = np.abs(D) 20 | if dBscale: 21 | S = librosa.amplitude_to_db(S) 22 | if normalize: 23 | # normalization 24 | S = np.clip((S - ref_db + max_db) / max_db, 1e-8, 1) 25 | return S, np.angle(D) 26 | 27 | def concatenateFeature(inputList, dim): 28 | out = inputList[0] 29 | for i in range(1, len(inputList)): 30 | out = torch.cat((out, inputList[i]), dim=dim) 31 | return out 32 | 33 | class WaveLoss(nn.Module): 34 | def __init__(self, dBscale = 1, denormalize=1, max_db=100, ref_db=20, nfft=256, hop_size=128): 35 | super(WaveLoss, self).__init__() 36 | self.dBscale = dBscale 37 | self.denormalize = denormalize 38 | self.max_db = max_db 39 | self.ref_db = ref_db 40 | self.nfft = nfft 41 | self.hop_size = hop_size 42 | self.mse_loss = nn.MSELoss() 43 | 44 | def genWav(self, S, phase): 45 | ''' 46 | :param S: (B, F-1, T) to be padded with 0 in this function 47 | :param phase: (B, F, T) 48 | :return: (B, num_samples) 49 | ''' 50 | if self.dBscale: 51 | if self.denormalize: 52 | # denormalization 53 | S = S * self.max_db - self.max_db + self.ref_db 54 | # to amplitude 55 | # https://github.com/pytorch/pytorch/issues/12426 56 | # RuntimeError: the derivative for pow is not implemented 57 | # S = torch.pow(10, S * 0.05) 58 | S = 10 ** (S * 0.05) 59 | 60 | # pad with 0 61 | B, F, T = S.shape 62 | pad = torch.zeros(B, 1, T).to(S.device) 63 | # 注意tensor要同一类型 64 | Sfull = concatenateFeature([S, pad], dim=-2) # 由于前面预测少了一个维度的频率,所以这里补0 65 | 66 | # deal with the complex 67 | Sfull_ = Sfull.data.cpu().numpy() 68 | phase_ = phase.data.cpu().numpy() 69 | Sfull_spec = Sfull_ * np.exp(1.0j * phase_) 70 | S_sign = np.sign(np.real(Sfull_spec)) 71 | S_sign = torch.from_numpy(S_sign).to(S.device) 72 | Sfull_spec_imag = np.imag(Sfull_spec) 73 | Sfull_spec_imag = torch.from_numpy(Sfull_spec_imag).unsqueeze(-1).to(S.device) 74 | Sfull = torch.mul(Sfull, S_sign).unsqueeze(-1) 75 | # print(Sfull.shape) 76 | # print(Sfull_spec_imag.shape) 77 | stft_matrix = concatenateFeature([Sfull, Sfull_spec_imag], dim=-1) # (B, F, T, 2) 78 | # print(stft_matrix.shape) 79 | 80 | wav = istft_irfft(stft_matrix, hop_length=self.hop_size, win_length=self.nfft) 81 | return wav 82 | 83 | def forward(self, target_mag, target_phase, pred_mag, pred_phase): 84 | ''' 85 | :param target_mag: (B, F-1, T) 86 | :param target_phase: (B, F, T) 87 | :param pred_mag: (B, F-1, T) 88 | :param pred_phase: (B, F, T) 89 | :return: 90 | ''' 91 | target_wav = self.genWav(target_mag, target_phase) 92 | pred_wav = self.genWav(pred_mag, pred_phase) 93 | 94 | # target_wav_arr = target_wav.squeeze(0).cpu().data.numpy() 95 | # pred_wav_arr = pred_wav.squeeze(0).cpu().data.numpy() 96 | # print('target wav arr', target_wav_arr.shape) 97 | # sf.write('target.wav', target_wav_arr, 8000) 98 | # sf.write('pred.wav', pred_wav_arr, 8000) 99 | 100 | loss = self.mse_loss(target_wav, pred_wav) 101 | return loss 102 | 103 | if __name__ == '__main__': 104 | wav_f = 'test.wav' 105 | wav_fo = 'test_o.wav' 106 | def read_wav(f): 107 | wav, sr = sf.read(f) 108 | if len(wav.shape) > 1: 109 | wav = wav[:, 0] 110 | if sr != 8000: 111 | wav = resampy.resample(wav, sr, 8000) 112 | spec, phase = linearspectrogram(wav) 113 | return spec, phase, wav 114 | target_spec , target_phase, wav1 = read_wav(wav_f) 115 | pred_spec, pred_phase, wav2 = read_wav(wav_fo) 116 | # print('librosa,', pred_spec.shape) 117 | 118 | librosa_stft = librosa.stft(wav1, n_fft=256, hop_length=128, window='hann') 119 | # print('librosa stft', librosa_stft.shape) 120 | # print('librosa.stft', librosa_stft) 121 | _magnitude = np.abs(librosa_stft) 122 | # print('mag,', _magnitude.shape) 123 | wav_re_librosa = librosa.core.spectrum.istft(librosa_stft, hop_length=128) 124 | sf.write('wav_re_librosa.wav', wav_re_librosa, 8000) 125 | 126 | def clip_spec(spec, phase): 127 | spec_clip = spec[0:-1, :410] # (F-1, T) 128 | phase_clip = phase[:, :410] 129 | spec_tensor = torch.from_numpy(spec_clip) 130 | spec_tensor = spec_tensor.unsqueeze(0) # (B, T, F) 131 | # print(spec_tensor.shape) 132 | phase_tensor = torch.from_numpy(phase_clip).unsqueeze(0) 133 | # print(phase_tensor.shape) 134 | return spec_tensor, phase_tensor 135 | target_spec_tensor, target_phase_tensor = clip_spec(target_spec, target_phase) 136 | pred_spec_tensor, pred_phase_tensor = clip_spec(pred_spec, pred_phase) 137 | 138 | wav_loss = WaveLoss(dBscale=1, nfft=256, hop_size=128) 139 | loss = wav_loss(target_spec_tensor, target_phase_tensor, pred_spec_tensor, pred_phase_tensor) 140 | print('loss', loss.item()) 141 | 142 | wav1 = torch.FloatTensor(wav1) 143 | torch_stft_matrix = torch.stft(wav1, n_fft=256, hop_length=128, window=torch.hann_window(256)) 144 | torch_stft_matrix = torch_stft_matrix.unsqueeze(0) 145 | # print('torch stft', torch_stft_matrix.shape) 146 | # print(torch_stft_matrix[:,:,:,0]) 147 | # print(torch_stft_matrix[:,:,:,1]) 148 | wav_re = istft_irfft(torch_stft_matrix, hop_length=128, win_length=256) 149 | wav_re = wav_re.squeeze(0).cpu().data.numpy() 150 | # print('wav_re', wav_re.shape) 151 | sf.write('wav_re.wav', wav_re, 8000) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.attention import * 2 | from models.rnn import * 3 | from models.seq2seq import * 4 | from models.separation_dis import * 5 | from models.separation_tasnet import * 6 | from models.loss import * 7 | from models.beam import * 8 | from models.Schmidt_orth import * 9 | from models.metrics import * 10 | from models.focal_loss import * 11 | from models.WaveLoss import * 12 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | import data.dict as dict 8 | 9 | 10 | class global_attention(nn.Module): 11 | 12 | def __init__(self, hidden_size, activation=None): 13 | super(global_attention, self).__init__() 14 | self.linear_in = nn.Linear(hidden_size, hidden_size) 15 | self.linear_out = nn.Linear(2 * hidden_size, hidden_size,bias=1) 16 | self.softmax = nn.Softmax() 17 | # self.batchnorm=nn.BatchNorm1d(hidden_size) 18 | self.tanh = nn.Tanh() 19 | self.activation = activation 20 | 21 | def forward(self, x, context): 22 | gamma_h = self.linear_in(x).unsqueeze(2) # unsequeee这个函数相当于直接reshape多出来一维度,值得学习。 # batch * size * 1 23 | if self.activation == 'tanh': 24 | gamma_h = self.tanh(gamma_h) 25 | weights = torch.bmm(context, gamma_h).squeeze(2) # batch * time 26 | weights = self.softmax(weights) # batch * time 27 | c_t = torch.bmm(weights.unsqueeze(1), context).squeeze(1) # batch * size 28 | output = self.linear_out(torch.cat([c_t, x], 1)) #添加额外的batchnorm 29 | # output = self.batchnorm(output) 30 | output = self.tanh(output) 31 | return output, weights 32 | -------------------------------------------------------------------------------- /models/beam.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | 4 | 5 | # import data.dict_spk2idx[as dict 6 | 7 | class Beam(object): 8 | def __init__(self, size, dict_spk2idx, n_best=1, cuda=True): 9 | self.dict_spk2idx = dict_spk2idx 10 | self.size = size 11 | self.tt = torch.cuda if cuda else torch 12 | 13 | # The score for each translation on the beam. 14 | self.scores = self.tt.FloatTensor(size).zero_() 15 | self.allScores = [] 16 | 17 | # The backpointers at each time-step. 18 | self.prevKs = [] 19 | 20 | # The outputs at each time-step. 21 | self.nextYs = [self.tt.LongTensor(size) 22 | .fill_(dict_spk2idx[''])] 23 | self.nextYs[0][0] = dict_spk2idx[''] 24 | # Has EOS topped the beam yet. 25 | self._eos = dict_spk2idx[''] 26 | self.eosTop = False 27 | 28 | # The attentions (matrix) for each time. 29 | self.attn = [] 30 | 31 | # The last hiddens(matrix) for each time. 32 | self.hiddens = [] 33 | 34 | # The last hiddens(matrix) for each time. 35 | self.sch_hiddens = [] 36 | 37 | # The last embs(matrix) for each time. 38 | self.embs = [] 39 | 40 | # Time and k pair for finished. 41 | self.finished = [] 42 | self.n_best = n_best 43 | 44 | def updates_sch_embeddings(self, hiddens_this_step): 45 | if len(self.sch_hiddens) == 0: 46 | self.sch_hiddens.append([hiddens_this_step]) 47 | else: 48 | self.sch_hiddens.append(self.sch_hiddens[-1] + [hiddens_this_step]) 49 | 50 | def getCurrentState(self): 51 | "Get the outputs for the current timestep." 52 | return self.nextYs[-1] 53 | 54 | def getCurrentOrigin(self): 55 | "Get the backpointers for the current timestep." 56 | return self.prevKs[-1] 57 | 58 | def advance(self, wordLk, attnOut, hidden, emb): 59 | """ 60 | Given prob over words for every last beam `wordLk` and attention 61 | `attnOut`: Compute and update the beam search. 62 | Parameters: 63 | * `wordLk`- probs of advancing from the last step (K x words) 64 | * `attnOut`- attention at the last step 65 | Returns: True if beam search is complete. 66 | """ 67 | numWords = wordLk.size(1) 68 | 69 | # Sum the previous scores. 70 | if len(self.prevKs) > 0: 71 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) 72 | # Don't let EOS have children. 73 | for i in range(self.nextYs[-1].size(0)): 74 | if self.nextYs[-1][i] == self._eos: 75 | beamLk[i] = -1e20 76 | else: 77 | beamLk = wordLk[0] 78 | flatBeamLk = beamLk.view(-1) 79 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) 80 | 81 | self.allScores.append(self.scores) 82 | self.scores = bestScores 83 | 84 | # bestScoresId is flattened beam x word array, so calculate which 85 | # word and beam each score came from 86 | prevK = bestScoresId / numWords 87 | self.prevKs.append(prevK) 88 | self.nextYs.append((bestScoresId - prevK * numWords)) 89 | self.attn.append(attnOut.index_select(0, prevK)) 90 | self.hiddens.append(hidden.index_select(0, prevK)) 91 | self.updates_sch_embeddings(hidden.index_select(0, prevK)) 92 | self.embs.append(emb.index_select(0, prevK)) 93 | 94 | for i in range(self.nextYs[-1].size(0)): 95 | if self.nextYs[-1][i] == self._eos: 96 | s = self.scores[i] 97 | self.finished.append((s, len(self.nextYs) - 1, i)) 98 | 99 | # End condition is when top-of-beam is '' and no global score. 100 | if self.nextYs[-1][0] == self.dict_spk2idx['']: 101 | # self.allScores.append(self.scores) 102 | self.eosTop = True 103 | 104 | def done(self): 105 | return self.eosTop and len(self.finished) >= self.n_best 106 | 107 | def beam_update(self, state, idx): 108 | positions = self.getCurrentOrigin() 109 | for e in state: 110 | a, br, d = e.size() 111 | e = e.view(a, self.size, br // self.size, d) 112 | sentStates = e[:, :, idx] 113 | sentStates.data.copy_(sentStates.data.index_select(1, positions)) 114 | 115 | def beam_update_context(self, state, idx): 116 | positions = self.getCurrentOrigin() 117 | e = state.unsqueeze(0) 118 | a, br, len, d, = e.size() 119 | e = e.view(a, self.size, br // self.size, len, d) 120 | sentStates = e[:, :, idx] 121 | sentStates.data.copy_(sentStates.data.index_select(1, positions)) 122 | 123 | def beam_update_hidden(self, state, idx): 124 | positions = self.getCurrentOrigin() 125 | e = state 126 | a, br, d = e.size() 127 | e = e.view(a, self.size, br // self.size, d) 128 | sentStates = e[:, :, idx] 129 | sentStates.data.copy_(sentStates.data.index_select(1, positions)) 130 | 131 | def sortFinished(self, minimum=None): 132 | if minimum is not None: 133 | i = 0 134 | # Add from beam until we have minimum outputs. 135 | while len(self.finished) < minimum: 136 | s = self.scores[i] 137 | self.finished.append((s, len(self.nextYs) - 1, i)) 138 | 139 | self.finished.sort(key=lambda a: -a[0]) 140 | scores = [sc for sc, _, _ in self.finished] 141 | ks = [(t, k) for _, t, k in self.finished] 142 | return scores, ks 143 | 144 | def getHyp(self, timestep, k): 145 | """ 146 | Walk back to construct the full hypothesis. 147 | """ 148 | hyp, attn, hidden, emb = [], [], [], [] 149 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): 150 | hyp.append(self.nextYs[j + 1][k]) 151 | attn.append(self.attn[j][k]) 152 | hidden.append(self.hiddens[j][k]) 153 | emb.append(self.embs[j][k]) 154 | k = self.prevKs[j][k] 155 | return hyp[::-1], torch.stack(attn[::-1]), torch.stack(hidden[::-1]), torch.stack(emb[::-1]) 156 | -------------------------------------------------------------------------------- /models/focal_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | 8 | def __init__(self, gamma=0, eps=1e-7): 9 | super(FocalLoss, self).__init__() 10 | self.gamma = gamma 11 | self.eps = eps 12 | self.ce = torch.nn.CrossEntropyLoss() 13 | 14 | def forward(self, input, target): 15 | logp = self.ce(input, target) 16 | p = torch.exp(-logp) 17 | loss = (1 - p) ** self.gamma * logp 18 | return loss.mean() -------------------------------------------------------------------------------- /models/istft_irfft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | 4 | # this is Keunwoo Choi's implementation of istft. 5 | # https://gist.github.com/keunwoochoi/2f349e72cc941f6f10d4adf9b0d3f37e#file-istft-torch-py 6 | def istft_irfft(stft_matrix, length=None, hop_length=None, win_length=None, window='hann', 7 | center=True, normalized=False, onesided=True): 8 | """stft_matrix = (batch, freq, time, complex) 9 | 10 | All based on librosa 11 | - http://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#istft 12 | What's missing? 13 | - normalize by sum of squared window --> do we need it here? 14 | Actually the result is ok by simply dividing y by 2. 15 | """ 16 | assert normalized == False 17 | assert onesided == True 18 | assert window == "hann" 19 | assert center == True 20 | 21 | device = stft_matrix.device 22 | n_fft = 2 * (stft_matrix.shape[-3] - 1) 23 | 24 | batch = stft_matrix.shape[0] 25 | 26 | # By default, use the entire frame 27 | if win_length is None: 28 | win_length = n_fft 29 | 30 | if hop_length is None: 31 | hop_length = int(win_length // 4) 32 | 33 | istft_window = torch.hann_window(n_fft).to(device).view(1, -1) # (batch, freq) 34 | 35 | n_frames = stft_matrix.shape[-2] 36 | expected_signal_len = n_fft + hop_length * (n_frames - 1) 37 | 38 | y = torch.zeros(batch, expected_signal_len, device=device) 39 | for i in range(n_frames): 40 | sample = i * hop_length 41 | spec = stft_matrix[:, :, i] 42 | iffted = torch.irfft(spec, signal_ndim=1, signal_sizes=(win_length,)) 43 | 44 | ytmp = istft_window * iffted 45 | y[:, sample:(sample+n_fft)] += ytmp 46 | 47 | y = y[:, n_fft//2:] 48 | 49 | if length is not None: 50 | if y.shape[1] > length: 51 | y = y[:, :length] 52 | elif y.shape[1] < length: 53 | y = torch.cat(y[:, :length], torch.zeros(y.shape[0], length - y.shape[1], device=y.device)) 54 | coeff = n_fft/float(hop_length) / 2.0 # -> this might go wrong if curretnly asserted values (especially, `normalized`) changes. 55 | return y / coeff 56 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import data.dict as dict 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | import models.focal_loss as focal_loss 9 | from itertools import permutations 10 | 11 | 12 | EPS = 1e-8 13 | def rank_feas(raw_tgt, feas_list, out_type='torch'): 14 | final_num = [] 15 | for each_feas, each_line in zip(feas_list, raw_tgt): 16 | for spk in each_line: 17 | final_num.append(each_feas[spk]) 18 | # 目标就是这个batch里一共有多少条比如 1spk 3spk 2spk,最后就是6个spk的特征 19 | if out_type=='numpy': 20 | return np.array(final_num) 21 | else: 22 | return torch.from_numpy(np.array(final_num)) 23 | 24 | 25 | def criterion(tgt_vocab_size, use_cuda, loss): 26 | weight = torch.ones(tgt_vocab_size) 27 | weight[dict.PAD] = 0 28 | if loss=='focal_loss': 29 | crit = focal_loss.FocalLoss(gamma=2) 30 | else: 31 | crit = nn.CrossEntropyLoss(weight, size_average=False) 32 | if use_cuda: 33 | crit.cuda() 34 | return crit 35 | 36 | def criterion_dir(tgt_vocab_size, use_cuda, loss): 37 | weight = torch.ones(tgt_vocab_size) 38 | weight[dict.PAD] = 0 39 | if loss=='focal_loss': 40 | crit = focal_loss.FocalLoss(gamma=2) 41 | else: 42 | crit = nn.CrossEntropyLoss(weight, size_average=False) 43 | if use_cuda: 44 | crit.cuda() 45 | return crit 46 | 47 | 48 | def memory_efficiency_cross_entropy_loss(hidden_outputs, decoder, targets, criterion, config): 49 | outputs = Variable(hidden_outputs.data, requires_grad=True, volatile=False) 50 | num_total, num_correct, loss = 0, 0, 0 51 | 52 | outputs_split = torch.split(outputs, config.max_generator_batches) 53 | targets_split = torch.split(targets, config.max_generator_batches) 54 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)): 55 | out_t = out_t.view(-1, out_t.size(2)) 56 | scores_t = decoder.compute_score(out_t) 57 | loss_t = criterion(scores_t, targ_t.view(-1)) 58 | pred_t = scores_t.max(1)[1] 59 | num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(dict.PAD).data).sum() 60 | num_total_t = targ_t.ne(dict.PAD).data.sum() 61 | num_correct += num_correct_t 62 | num_total += num_total_t 63 | loss += loss_t.data[0] 64 | loss_t.div(num_total_t).backward() 65 | 66 | grad_output = outputs.grad.data 67 | hidden_outputs.backward(grad_output) 68 | 69 | return loss, num_total, num_correct, config.tgt_vocab, config.tgt_vocab 70 | 71 | 72 | def cross_entropy_loss(hidden_outputs, decoder, targets, criterion, config, sim_score=0): 73 | # hidden_outputs:[max_len,bs,512] 74 | batch_size= targets.size()[1] 75 | targets=targets.view(-1) 76 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2)) 77 | scores = decoder.compute_score(outputs) 78 | loss = criterion(scores, targets.view(-1)) + sim_score 79 | pred = scores.max(1)[1] 80 | num_correct = pred.data.eq(targets.data).masked_select(targets.ne(dict.PAD).data).sum() 81 | # num_correct = pred.data.eq(targets.data).masked_select(targets.ne(targets[-1]).data).sum() 82 | num_total = float(targets.ne(dict.PAD).data.sum()) 83 | loss *= batch_size 84 | loss = loss.div(num_total) 85 | # loss = loss.data[0] 86 | 87 | return loss, num_total, num_correct 88 | 89 | def cross_entropy_loss_dir(hidden_outputs, decoder, targets, criterion, config, sim_score=0): 90 | batch_size= targets.size()[1] 91 | targets=targets.view(-1) 92 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2)) 93 | scores = decoder.compute_score_dir(outputs) 94 | #print("scores:",scores.size()) 95 | #print("targets:",targets.view(-1)) 96 | loss = criterion(scores, targets.view(-1)) + sim_score 97 | pred = scores.max(1)[1] 98 | num_correct = pred.data.eq(targets.data).masked_select(targets.ne(dict.PAD).data).sum() 99 | num_total = float(targets.ne(dict.PAD).data.sum()) 100 | loss *= batch_size 101 | loss = loss.div(num_total) 102 | # loss = loss.data[0] 103 | 104 | return loss, num_total, num_correct 105 | 106 | def mmse_loss(hidden_outputs, decoder, targets, mse_loss, softmax): 107 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2)) 108 | scores = softmax(decoder.compute_score_dir(outputs)) 109 | targets = targets.view(-1) 110 | target_one_hot = torch.zeros(scores.size(0), dir_vocab_size).scatter_(1, targets, 1) 111 | #scores = linear(scores) 112 | print("score.size",scores) 113 | print("targets.size",targets) 114 | loss = mse_loss(scores.float(), targets.float()) 115 | 116 | return loss 117 | 118 | def mmse_loss2(hidden_outputs, decoder, targets, mse_loss): 119 | print("hidden_outputs", hidden_outputs.size()) 120 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2)) 121 | scores_1, scores = decoder.compute_score_dir(outputs) 122 | scores = F.sigmoid(scores) 123 | print("score.size",scores) 124 | print("targets.size",targets) 125 | loss = mse_loss(scores.view(-1).float(), targets.view(-1).float()/20) 126 | 127 | return loss 128 | 129 | def ss_loss(config, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func,wav_loss): 130 | predict_multi_map = multi_mask * x_input_map_multi 131 | # predict_multi_map=Variable(y_multi_map) 132 | y_multi_map = Variable(y_multi_map) 133 | 134 | loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map) 135 | 136 | # 各通道和为1的loss部分,应该可以更多的带来差异 137 | # y_sum_map=Variable(torch.ones(config.batch_size,config.mix_speech_len,config.speech_fre)).cuda() 138 | # predict_sum_map=torch.sum(multi_mask,1) 139 | # loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map) 140 | print('loss 1 eval: ', loss_multi_speech.data.cpu().numpy()) 141 | # print('losssum eval :',loss_multi_sum_speech.data.cpu().numpy() 142 | # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech 143 | print('evaling multi-abs norm this eval batch:', torch.abs(y_multi_map - predict_multi_map).norm().data.cpu().numpy()) 144 | # loss_multi_speech=loss_multi_speech+3*loss_multi_sum_speech 145 | print('loss for whole separation part:', loss_multi_speech.data.cpu().numpy()) 146 | return loss_multi_speech 147 | 148 | def ss_tas_loss(config,predict_wav, y_multi_wav, mix_length,loss_multi_func): 149 | loss = cal_loss_with_order(y_multi_wav, predict_wav, mix_length)[0] 150 | #loss_mse = loss_multi_func(predict_wav, y_multi_wav) 151 | return loss 152 | 153 | def cal_loss_with_order(source, estimate_source, source_lengths): 154 | """ 155 | Args: 156 | source: [B, C, T], B is batch size 157 | estimate_source: [B, C, T] 158 | source_lengths: [B] 159 | """ 160 | # print('real Tas SNI:',source[:,:,16000:16005]) 161 | # print('pre Tas SNI:',estimate_source[:,:,16000:16005]) 162 | max_snr = cal_si_snr_with_order(source, estimate_source, source_lengths) 163 | loss = 0 - torch.mean(max_snr) 164 | return loss, 165 | 166 | def cal_loss_with_PIT(source, estimate_source, source_lengths): 167 | """ 168 | Args: 169 | source: [B, C, T], B is batch size 170 | estimate_source: [B, C, T] 171 | source_lengths: [B] 172 | """ 173 | max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source, 174 | estimate_source, 175 | source_lengths) 176 | loss = 0 - torch.mean(max_snr) 177 | reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx) 178 | return loss, max_snr, estimate_source, reorder_estimate_source 179 | 180 | def cal_si_snr_with_order(source, estimate_source, source_lengths): 181 | """Calculate SI-SNR with given order. 182 | Args: 183 | source: [B, C, T], B is batch size 184 | estimate_source: [B, C, T] 185 | source_lengths: [B], each item is between [0, T] 186 | """ 187 | print("source.size()",source.size()) 188 | print("estimate_source.size()",estimate_source.size()) 189 | assert source.size() == estimate_source.size() 190 | B, C, T = source.size() 191 | # mask padding position along T 192 | mask = get_mask(source, source_lengths) 193 | estimate_source *= mask 194 | 195 | # Step 1. Zero-mean norm 196 | num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1] 197 | mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples 198 | mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples 199 | zero_mean_target = source - mean_target 200 | zero_mean_estimate = estimate_source - mean_estimate 201 | # mask padding position along T 202 | zero_mean_target *= mask 203 | zero_mean_estimate *= mask 204 | 205 | # Step 2. SI-SNR with order 206 | # reshape to use broadcast 207 | s_target = zero_mean_target # [B, C, T] 208 | s_estimate = zero_mean_estimate # [B, C, T] 209 | # s_target = s / ||s||^2 210 | pair_wise_dot = torch.sum(s_estimate * s_target, dim=2, keepdim=True) # [B, C, 1] 211 | s_target_energy = torch.sum(s_target ** 2, dim=2, keepdim=True) + EPS # [B, C, 1] 212 | pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, T] 213 | # e_noise = s' - s_target 214 | e_noise = s_estimate - pair_wise_proj # [B, C, T] 215 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) 216 | pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=2) / (torch.sum(e_noise ** 2, dim=2) + EPS) 217 | pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C] 218 | print(pair_wise_si_snr) 219 | 220 | return torch.sum(pair_wise_si_snr,dim=1)/C 221 | 222 | def cal_si_snr_with_pit(source, estimate_source, source_lengths): 223 | """Calculate SI-SNR with PIT training. 224 | Args: 225 | source: [B, C, T], B is batch size 226 | estimate_source: [B, C, T] 227 | source_lengths: [B], each item is between [0, T] 228 | """ 229 | assert source.size() == estimate_source.size() 230 | B, C, T = source.size() 231 | # mask padding position along T 232 | mask = get_mask(source, source_lengths) 233 | estimate_source *= mask 234 | 235 | # Step 1. Zero-mean norm 236 | num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1] 237 | mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples 238 | mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples 239 | zero_mean_target = source - mean_target 240 | zero_mean_estimate = estimate_source - mean_estimate 241 | # mask padding position along T 242 | zero_mean_target *= mask 243 | zero_mean_estimate *= mask 244 | 245 | # Step 2. SI-SNR with PIT 246 | # reshape to use broadcast 247 | s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T] 248 | s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T] 249 | # s_target = s / ||s||^2 250 | pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [B, C, C, 1] 251 | s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1] 252 | pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T] 253 | # e_noise = s' - s_target 254 | e_noise = s_estimate - pair_wise_proj # [B, C, C, T] 255 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) 256 | pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS) 257 | pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C] 258 | 259 | # Get max_snr of each utterance 260 | # permutations, [C!, C] 261 | perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long) 262 | # one-hot, [C!, C, C] 263 | index = torch.unsqueeze(perms, 2) 264 | # perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1) 265 | perms_one_hot = source.new_zeros((perms.size()[0],perms.size()[1], C)).scatter_(2, index, 1) 266 | # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation 267 | snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot]) 268 | max_snr_idx = torch.argmax(snr_set, dim=1) # [B] 269 | # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1] 270 | max_snr, _ = torch.max(snr_set, dim=1, keepdim=True) 271 | max_snr /= C 272 | return max_snr, perms, max_snr_idx 273 | 274 | def reorder_source(source, perms, max_snr_idx): 275 | """ 276 | Args: 277 | source: [B, C, T] 278 | perms: [C!, C], permutations 279 | max_snr_idx: [B], each item is between [0, C!) 280 | Returns: 281 | reorder_source: [B, C, T] 282 | """ 283 | # B, C, *_ = source.size() 284 | B, C, __ = source.size() 285 | # [B, C], permutation whose SI-SNR is max of each utterance 286 | # for each utterance, reorder estimate source according this permutation 287 | max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx) 288 | # print('max_snr_perm', max_snr_perm) 289 | # maybe use torch.gather()/index_select()/scatter() to impl this? 290 | reorder_source = torch.zeros_like(source) 291 | for b in range(B): 292 | for c in range(C): 293 | reorder_source[b, c] = source[b, max_snr_perm[b][c]] 294 | return reorder_source 295 | 296 | 297 | def get_mask(source, source_lengths): 298 | """ 299 | Args: 300 | source: [B, C, T] 301 | source_lengths: [B] 302 | Returns: 303 | mask: [B, 1, T] 304 | """ 305 | B, _, T = source.size() 306 | mask = source.new_ones((B, 1, T)) 307 | #mask = Variable(torch.ones((B, 1, T))).cuda() 308 | for i in range(B): 309 | #print("source_lengths[i]",source_lengths[i]) 310 | mask[i, :, source_lengths[i]:] = 0 311 | return mask 312 | def ss_loss_MLMSE(config, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func, Var): 313 | try: 314 | if Var == None: 315 | Var = Variable(torch.eye(config.speech_fre, config.speech_fre).cuda(), requires_grad=0) # 初始化的是单位矩阵 316 | print('Set Var to:', Var) 317 | except: 318 | pass 319 | assert Var.size() == (config.speech_fre, config.speech_fre) 320 | 321 | predict_multi_map = torch.mean(multi_mask * x_input_map_multi, -2) # 在时间维度上平均 322 | # predict_multi_map=Variable(y_multi_map) 323 | y_multi_map = torch.mean(Variable(y_multi_map), -2) # 在时间维度上平均 324 | 325 | loss_vector = (y_multi_map - predict_multi_map).view(-1, config.speech_fre).unsqueeze(1) # 应该是bs*1*fre 326 | 327 | Var_inverse = torch.inverse(Var) 328 | Var_inverse = Var_inverse.unsqueeze(0).expand(loss_vector.size()[0], config.speech_fre, 329 | config.speech_fre) # 扩展成batch的形式 330 | loss_multi_speech = torch.bmm(torch.bmm(loss_vector, Var_inverse), loss_vector.transpose(1, 2)) 331 | loss_multi_speech = torch.mean(loss_multi_speech, 0) 332 | 333 | # 各通道和为1的loss部分,应该可以更多的带来差异 334 | y_sum_map = Variable(torch.ones(config.batch_size, config.mix_speech_len, config.speech_fre)).cuda() 335 | predict_sum_map = torch.sum(multi_mask, 1) 336 | loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map) 337 | print('loss 1 eval, losssum eval : ', loss_multi_speech.data.cpu().numpy(), loss_multi_sum_speech.data.cpu().numpy()) 338 | # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech 339 | print('evaling multi-abs norm this eval batch:', torch.abs(y_multi_map - predict_multi_map).norm().data.cpu().numpy()) 340 | # loss_multi_speech=loss_multi_speech+3*loss_multi_sum_speech 341 | print('loss for whole separation part:', loss_multi_speech.data.cpu().numpy()) 342 | # return F.relu(loss_multi_speech) 343 | return loss_multi_speech 344 | 345 | 346 | def dis_loss(config, top_k_num, dis_model, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func): 347 | predict_multi_map = multi_mask * x_input_map_multi 348 | y_multi_map = Variable(y_multi_map).cuda() 349 | score_true = dis_model(y_multi_map) 350 | score_false = dis_model(predict_multi_map) 351 | acc_true = torch.sum(score_true > 0.5).data.cpu().numpy() / float(score_true.size()[0]) 352 | acc_false = torch.sum(score_false < 0.5).data.cpu().numpy() / float(score_true.size()[0]) 353 | acc_dis = (acc_false + acc_true) / 2 354 | print('acc for dis:(ture,false,aver)', acc_true, acc_false, acc_dis) 355 | 356 | loss_dis_true = loss_multi_func(score_true, Variable(torch.ones(config.batch_size * top_k_num, 1)).cuda()) 357 | loss_dis_false = loss_multi_func(score_false, Variable(torch.zeros(config.batch_size * top_k_num, 1)).cuda()) 358 | loss_dis = loss_dis_true + loss_dis_false 359 | print('loss for dis:(ture,false)', loss_dis_true.data.cpu().numpy(), loss_dis_false.data.cpu().numpy()) 360 | return loss_dis 361 | -------------------------------------------------------------------------------- /models/metrics.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | from __future__ import print_function 3 | from __future__ import division 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Parameter 8 | import math 9 | 10 | 11 | class ArcMarginProduct(nn.Module): 12 | r"""Implement of large margin arc distance: : 13 | Args: 14 | in_features: size of each input sample 15 | out_features: size of each output sample 16 | s: norm of input feature 17 | m: margin 18 | 19 | cos(theta + m) 20 | """ 21 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False): 22 | super(ArcMarginProduct, self).__init__() 23 | self.in_features = in_features 24 | self.out_features = out_features 25 | self.s = s 26 | self.m = m 27 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 28 | nn.init.xavier_uniform(self.weight) 29 | 30 | self.easy_margin = easy_margin 31 | self.cos_m = math.cos(m) 32 | self.sin_m = math.sin(m) 33 | self.th = math.cos(math.pi - m) 34 | self.mm = math.sin(math.pi - m) * m 35 | 36 | def forward(self, input, label): 37 | # --------------------------- cos(theta) & phi(theta) --------------------------- 38 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 39 | if label is None: #如果没给label,则是要测试阶段,直接输出cosine的就可以了 40 | return cosine 41 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 42 | phi = cosine * self.cos_m - sine * self.sin_m 43 | if self.easy_margin: 44 | phi = torch.where(cosine > 0, phi, cosine) 45 | else: 46 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 47 | # --------------------------- convert label to one-hot --------------------------- 48 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 49 | one_hot = torch.zeros(cosine.size(), device='cuda') 50 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 51 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 52 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 53 | output *= self.s 54 | # print(output) 55 | 56 | return output 57 | 58 | 59 | class AddMarginProduct(nn.Module): 60 | r"""Implement of large margin cosine distance: : 61 | Args: 62 | in_features: size of each input sample 63 | out_features: size of each output sample 64 | s: norm of input feature 65 | m: margin 66 | cos(theta) - m 67 | """ 68 | 69 | def __init__(self, in_features, out_features, s=30.0, m=0.40): 70 | super(AddMarginProduct, self).__init__() 71 | self.in_features = in_features 72 | self.out_features = out_features 73 | self.s = s 74 | self.m = m 75 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 76 | nn.init.xavier_uniform_(self.weight) 77 | 78 | def forward(self, input, label): 79 | # --------------------------- cos(theta) & phi(theta) --------------------------- 80 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 81 | phi = cosine - self.m 82 | # --------------------------- convert label to one-hot --------------------------- 83 | one_hot = torch.zeros(cosine.size(), device='cuda') 84 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 85 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 86 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 87 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 88 | output *= self.s 89 | # print(output) 90 | 91 | return output 92 | 93 | def __repr__(self): 94 | return self.__class__.__name__ + '(' \ 95 | + 'in_features=' + str(self.in_features) \ 96 | + ', out_features=' + str(self.out_features) \ 97 | + ', s=' + str(self.s) \ 98 | + ', m=' + str(self.m) + ')' 99 | 100 | 101 | class SphereProduct(nn.Module): 102 | r"""Implement of large margin cosine distance: : 103 | Args: 104 | in_features: size of each input sample 105 | out_features: size of each output sample 106 | m: margin 107 | cos(m*theta) 108 | """ 109 | def __init__(self, in_features, out_features, m=4): 110 | super(SphereProduct, self).__init__() 111 | self.in_features = in_features 112 | self.out_features = out_features 113 | self.m = m 114 | self.base = 1000.0 115 | self.gamma = 0.12 116 | self.power = 1 117 | self.LambdaMin = 5.0 118 | self.iter = 0 119 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 120 | nn.init.xavier_uniform(self.weight) 121 | 122 | # duplication formula 123 | self.mlambda = [ 124 | lambda x: x ** 0, 125 | lambda x: x ** 1, 126 | lambda x: 2 * x ** 2 - 1, 127 | lambda x: 4 * x ** 3 - 3 * x, 128 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 129 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 130 | ] 131 | 132 | def forward(self, input, label): 133 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) 134 | self.iter += 1 135 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) 136 | 137 | # --------------------------- cos(theta) & phi(theta) --------------------------- 138 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 139 | cos_theta = cos_theta.clamp(-1, 1) 140 | cos_m_theta = self.mlambda[self.m](cos_theta) 141 | theta = cos_theta.data.acos() 142 | k = (self.m * theta / 3.14159265).floor() 143 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k 144 | NormOfFeature = torch.norm(input, 2, 1) 145 | 146 | # --------------------------- convert label to one-hot --------------------------- 147 | one_hot = torch.zeros(cos_theta.size()) 148 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot 149 | one_hot.scatter_(1, label.view(-1, 1), 1) 150 | 151 | # --------------------------- Calculate output --------------------------- 152 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta 153 | output *= NormOfFeature.view(-1, 1) 154 | 155 | return output 156 | 157 | def __repr__(self): 158 | return self.__class__.__name__ + '(' \ 159 | + 'in_features=' + str(self.in_features) \ 160 | + ', out_features=' + str(self.out_features) \ 161 | + ', m=' + str(self.m) + ')' -------------------------------------------------------------------------------- /models/rnn.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | import data.dict as dict 9 | import models 10 | 11 | import numpy as np 12 | 13 | 14 | class StackedLSTM(nn.Module): 15 | def __init__(self, num_layers, input_size, hidden_size, dropout): 16 | super(StackedLSTM, self).__init__() 17 | self.dropout = nn.Dropout(dropout) 18 | self.num_layers = num_layers 19 | self.layers = nn.ModuleList() 20 | 21 | for i in range(num_layers): 22 | self.layers.append(nn.LSTMCell(input_size, hidden_size)) 23 | input_size = hidden_size 24 | 25 | def forward(self, input, hidden): 26 | h_0, c_0 = hidden 27 | h_1, c_1 = [], [] 28 | for i, layer in enumerate(self.layers): 29 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i])) 30 | input = h_1_i 31 | if i + 1 != self.num_layers: 32 | input = self.dropout(input) 33 | h_1 += [h_1_i] 34 | c_1 += [c_1_i] 35 | 36 | h_1 = torch.stack(h_1) # 把多层的LSTMCell模型的输出给组织起来了,的到了[num_layers,batch_size,hidden_size]的东西 37 | c_1 = torch.stack(c_1) 38 | 39 | return input, (h_1, c_1) 40 | 41 | 42 | class rnn_encoder(nn.Module): 43 | 44 | def __init__(self, config, input_emb_size): 45 | super(rnn_encoder, self).__init__() 46 | self.rnn = nn.LSTM(input_size=input_emb_size, hidden_size=config.encoder_hidden_size, 47 | num_layers=config.num_layers, dropout=config.dropout, bidirectional=config.bidirec) 48 | self.config = config 49 | 50 | def forward(self, input, lengths): 51 | input = input.transpose(0, 1) 52 | embs = pack(input, list(map(int, lengths))) # 这里batch是第二个维度 53 | outputs, (h, c) = self.rnn(embs) 54 | outputs = unpack(outputs)[0] 55 | if not self.config.bidirec: 56 | return outputs, (h, c) # h,c是最后一个step的,大小是(num_layers * num_directions, batch, hidden_size) 57 | else: 58 | batch_size = h.size(1) 59 | h = h.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size) 60 | c = c.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size) 61 | state = (h.transpose(0, 1), c.transpose(0, 1)) # 每一个元素是 (num_layers,batch,2*hidden_size)这么大。 62 | return outputs, state 63 | 64 | class gated_rnn_encoder(nn.Module): 65 | 66 | def __init__(self, config, vocab_size, embedding=None): 67 | super(gated_rnn_encoder, self).__init__() 68 | if embedding is not None: 69 | self.embedding = embedding 70 | else: 71 | self.embedding = nn.Embedding(vocab_size, config.emb_size) 72 | self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.encoder_hidden_size, 73 | num_layers=config.num_layers, dropout=config.dropout) 74 | self.gated = nn.Sequential(nn.Linear(config.encoder_hidden_size, 1), nn.Sigmoid()) 75 | 76 | def forward(self, input, lengths): 77 | embs = pack(self.embedding(input), lengths) 78 | outputs, state = self.rnn(embs) 79 | outputs = unpack(outputs)[0] 80 | p = self.gated(outputs) 81 | outputs = outputs * p 82 | return outputs, state 83 | 84 | 85 | class rnn_decoder(nn.Module): 86 | 87 | def __init__(self, config, vocab_size, dir_vocab_size, embedding=None, score_fn=None): 88 | super(rnn_decoder, self).__init__() 89 | if embedding is not None: 90 | self.embedding = embedding 91 | self.embedding_dir = embedding 92 | else: 93 | self.embedding = nn.Embedding(vocab_size, config.emb_size) 94 | self.embedding_dir = nn.Embedding(dir_vocab_size, config.emb_size) 95 | self.rnn = StackedLSTM(input_size=config.emb_size, hidden_size=config.decoder_hidden_size, 96 | num_layers=config.num_layers, dropout=config.dropout) 97 | self.rnn_dir = StackedLSTM(input_size=config.emb_size, hidden_size=config.decoder_hidden_size, 98 | num_layers=config.num_layers, dropout=config.dropout) 99 | 100 | self.score_fn = score_fn 101 | if self.score_fn.startswith('general'): 102 | self.linear = nn.Linear(config.decoder_hidden_size, config.emb_size) 103 | self.linear_dir = nn.Linear(config.decoder_hidden_size, config.emb_size) 104 | elif score_fn.startswith('concat'): 105 | self.linear_query = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size) 106 | self.linear_weight = nn.Linear(config.emb_size, config.decoder_hidden_size) 107 | self.linear_v = nn.Linear(config.decoder_hidden_size, 1) 108 | self.linear_query_dir = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size) 109 | self.linear_weight_dir = nn.Linear(config.emb_size, config.decoder_hidden_size) 110 | self.linear_v_dir = nn.Linear(config.decoder_hidden_size, 1) 111 | elif not self.score_fn.startswith('dot'): 112 | self.linear = nn.Linear(config.decoder_hidden_size, vocab_size) 113 | self.linear_dir = nn.Linear(config.decoder_hidden_size, dir_vocab_size) 114 | self.linear_output = nn.Linear(dir_vocab_size, 1) 115 | 116 | if hasattr(config, 'att_act'): 117 | activation = config.att_act 118 | print('use attention activation %s' % activation) 119 | else: 120 | activation = None 121 | 122 | self.attention = models.global_attention(config.decoder_hidden_size, activation) 123 | self.attention_dir = models.global_attention(config.decoder_hidden_size, activation) 124 | self.hidden_size = config.decoder_hidden_size 125 | self.dropout = nn.Dropout(config.dropout) 126 | self.config = config 127 | 128 | if self.config.global_emb: 129 | self.gated1 = nn.Linear(config.emb_size, config.emb_size) 130 | self.gated2 = nn.Linear(config.emb_size, config.emb_size) 131 | self.gated1_dir = nn.Linear(config.emb_size, config.emb_size) 132 | self.gated2_dir = nn.Linear(config.emb_size, config.emb_size) 133 | 134 | def forward(self, inputs, inputs_dir, init_state, contexts): 135 | 136 | outputs, outputs_dir, state, attns, global_embs = [], [], init_state, [], [] 137 | 138 | ## speaker 139 | embs = self.embedding(inputs).split(1) # time_step [1,bs,embsize] 140 | max_time_step = len(embs) 141 | emb = embs[0] # 第一步BOS的embedding. 142 | output, state_speaker = self.rnn(emb.squeeze(0), state) 143 | output, attn_weights = self.attention(output, contexts) 144 | output = self.dropout(output) 145 | soft_score = F.softmax(self.linear(output)) # 第一步的概率分布也就是 bs,vocal这么大 146 | 147 | ## direction 148 | embs_dir = self.embedding_dir(inputs_dir).split(1) # time_step [1,bs,embsize] 149 | emb_dir = embs_dir[0] # 第一步BOS的embedding. 150 | output_dir, state_dir = self.rnn_dir(emb_dir.squeeze(0), state) 151 | output_dir, attn_weights_dir = self.attention_dir(output_dir, contexts) 152 | output_dir = self.dropout(output_dir) 153 | #soft_score_dir_1 = F.sigmoid(self.linear_dir(output_dir)) 154 | soft_score_dir = F.softmax(self.linear_dir(output_dir)) 155 | 156 | attn_weights = attn_weights + attn_weights_dir 157 | outputs += [output] 158 | outputs_dir += [output_dir] 159 | attns += [attn_weights] 160 | 161 | batch_size = soft_score.size(0) 162 | a, b = self.embedding.weight.size() 163 | c, d = self.embedding_dir.weight.size() 164 | 165 | for i in range(max_time_step - 1): 166 | ## speaker 167 | emb1 = torch.bmm(soft_score.unsqueeze(1), 168 | self.embedding.weight.expand((batch_size, a, b))) 169 | emb2 = embs[i + 1] 170 | gamma = F.sigmoid(self.gated1(emb1.squeeze(1)) + self.gated2(emb2.squeeze(0))) 171 | emb = gamma * emb1.squeeze(1) + (1 - gamma) * emb2.squeeze(0) 172 | output, state_speaker = self.rnn(emb, state_speaker) 173 | output, attn_weights = self.attention(output, contexts) 174 | output = self.dropout(output) 175 | soft_score = F.softmax(self.linear(output)) 176 | 177 | ## direction 178 | emb1_dir = torch.bmm(soft_score_dir.unsqueeze(1), 179 | self.embedding_dir.weight.expand((batch_size, c, d))) 180 | emb2_dir = embs_dir[i + 1] 181 | gamma_dir = F.sigmoid(self.gated1_dir(emb1_dir.squeeze(1)) + self.gated2_dir(emb2_dir.squeeze(0))) 182 | emb_dir = gamma_dir * emb1_dir.squeeze(1) + (1 - gamma_dir) * emb2_dir.squeeze(0) 183 | output_dir, state_dir = self.rnn_dir(emb_dir, state_dir) 184 | output_dir, attn_weights_dir = self.attention_dir(output_dir, contexts) 185 | output_dir = self.dropout(output_dir) 186 | #soft_score_dir_1 = F.sigmoid(self.linear_dir(output_dir)) 187 | soft_score_dir = F.softmax(self.linear_dir(output_dir)) 188 | 189 | attn_weights = attn_weights + attn_weights_dir 190 | emb = emb + emb_dir 191 | outputs += [output] 192 | outputs_dir += [output_dir] 193 | global_embs += [emb] 194 | attns += [attn_weights] 195 | 196 | outputs = torch.stack(outputs) 197 | outputs_dir = torch.stack(outputs_dir) 198 | global_embs = torch.stack(global_embs) 199 | attns = torch.stack(attns) 200 | return outputs, outputs_dir, state, global_embs 201 | 202 | def compute_score(self, hiddens): 203 | if self.score_fn.startswith('general'): 204 | if self.score_fn.endswith('not'): 205 | scores = torch.matmul(self.linear(hiddens), Variable(self.embedding.weight.t().data)) 206 | else: 207 | scores = torch.matmul(self.linear(hiddens), self.embedding.weight.t()) 208 | elif self.score_fn.startswith('concat'): 209 | if self.score_fn.endswith('not'): 210 | scores = self.linear_v(torch.tanh(self.linear_query(hiddens).unsqueeze(1) + \ 211 | self.linear_weight(Variable(self.embedding.weight.data)).unsqueeze( 212 | 0))).squeeze(2) 213 | else: 214 | scores = self.linear_v(torch.tanh(self.linear_query(hiddens).unsqueeze(1) + \ 215 | self.linear_weight(self.embedding.weight).unsqueeze(0))).squeeze(2) 216 | elif self.score_fn.startswith('dot'): 217 | if self.score_fn.endswith('not'): 218 | scores = torch.matmul(hiddens, Variable(self.embedding.weight.t().data)) 219 | else: 220 | scores = torch.matmul(hiddens, self.embedding.weight.t()) 221 | # elif self.score_fn.startswith('arc_margin'): 222 | # scores = self.linear(hiddens,targets) 223 | else: 224 | scores = self.linear(hiddens) 225 | return scores 226 | 227 | def compute_score_dir(self, hiddens): 228 | if self.score_fn.startswith('general'): 229 | if self.score_fn.endswith('not'): 230 | scores = torch.matmul(self.linear_dir(hiddens), Variable(self.embedding_dir.weight.t().data)) 231 | else: 232 | scores = torch.matmul(self.linear_dir(hiddens), self.embedding_dir.weight.t()) 233 | elif self.score_fn.startswith('concat'): 234 | if self.score_fn.endswith('not'): 235 | scores = self.linear_v_dir(torch.tanh(self.linear_query_dir(hiddens).unsqueeze(1) + \ 236 | self.linear_weight_dir(Variable(self.embedding_dir.weight.data)).unsqueeze( 237 | 0))).squeeze(2) 238 | else: 239 | scores = self.linear_v_dir(torch.tanh(self.linear_query_dir(hiddens).unsqueeze(1) + \ 240 | self.linear_weight_dir(self.embedding_dir.weight).unsqueeze(0))).squeeze(2) 241 | elif self.score_fn.startswith('dot'): 242 | if self.score_fn.endswith('not'): 243 | scores = torch.matmul(hiddens, Variable(self.embedding_dir.weight.t().data)) 244 | else: 245 | scores = torch.matmul(hiddens, self.embedding_dir.weight.t()) 246 | else: 247 | #scores_1 = F.sigmoid(self.linear_dir(hiddens)) 248 | scores = self.linear_dir(hiddens) 249 | return scores 250 | 251 | def sample(self, input, init_state, contexts): 252 | inputs, outputs, sample_ids, state = [], [], [], init_state 253 | attns = [] 254 | inputs += input 255 | max_time_step = self.config.max_tgt_len 256 | soft_score = None 257 | mask = None 258 | for i in range(max_time_step): 259 | output, state, attn_weights = self.sample_one(inputs[i], soft_score, state, contexts, mask) 260 | if self.config.global_emb: 261 | soft_score = F.softmax(output) 262 | predicted = output.max(1)[1] 263 | inputs += [predicted] 264 | sample_ids += [predicted] 265 | outputs += [output] 266 | attns += [attn_weights] 267 | if self.config.mask: 268 | if mask is None: 269 | mask = predicted.unsqueeze(1).long() 270 | else: 271 | mask = torch.cat((mask, predicted.unsqueeze(1)), 1) 272 | 273 | sample_ids = torch.stack(sample_ids) 274 | attns = torch.stack(attns) 275 | return sample_ids, (outputs, attns) 276 | 277 | def sample_one(self, input, input_dir, soft_score, soft_score_dir, state, state_dir, tmp_hiddens, tmp_hiddens_dir, contexts, mask,mask_dir): 278 | if self.config.global_emb: 279 | batch_size = contexts.size(0) 280 | a, b = self.embedding.weight.size() 281 | if soft_score is None: 282 | emb = self.embedding(input) 283 | else: 284 | emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b))) 285 | emb2 = self.embedding(input) 286 | gamma = F.sigmoid(self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze())) 287 | emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze() 288 | 289 | c, d = self.embedding_dir.weight.size() 290 | if soft_score_dir is None: 291 | emb_dir = self.embedding_dir(input_dir) 292 | else: 293 | emb1_dir = torch.bmm(soft_score_dir.unsqueeze(1), self.embedding_dir.weight.expand((batch_size, c, d))) 294 | emb2_dir = self.embedding_dir(input_dir) 295 | gamma_dir = F.sigmoid(self.gated1_dir(emb1_dir.squeeze()) + self.gated2_dir(emb2_dir.squeeze())) 296 | emb_dir = gamma_dir * emb1_dir.squeeze() + (1 - gamma_dir) * emb2_dir.squeeze() 297 | else: 298 | emb = self.embedding(input) 299 | emb_dir = self.embedding_dir(input_dir) 300 | 301 | output, state = self.rnn(emb, state) 302 | output_bk = output 303 | hidden, attn_weights = self.attention(output, contexts) 304 | if self.config.schmidt: 305 | hidden = models.schmidt(hidden, tmp_hiddens) 306 | output = self.compute_score(hidden) 307 | if self.config.mask: 308 | if mask is not None: 309 | output = output.scatter_(1, mask, -9999999999) 310 | 311 | output_dir, state_dir = self.rnn_dir(emb_dir, state_dir) 312 | output_dir_bk = output_dir 313 | hidden_dir, attn_weights_dir = self.attention_dir(output_dir, contexts) 314 | if self.config.schmidt: 315 | hidden_dir = models.schmidt(hidden_dir, tmp_hiddens_dir) 316 | output_dir = self.compute_score_dir(hidden_dir) 317 | if self.config.mask: 318 | if mask_dir is not None: 319 | output_dir = output_dir.scatter_(1, mask_dir, -9999999999) 320 | 321 | return output, output_dir, state, state_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_dir_bk 322 | -------------------------------------------------------------------------------- /models/separation_dis.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import sys 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import random 9 | 10 | np.random.seed(1) # 设定种子 11 | torch.manual_seed(1) 12 | random.seed(1) 13 | torch.cuda.set_device(0) 14 | test_all_outputchannel = 0 15 | 16 | 17 | class ATTENTION(nn.Module): 18 | def __init__(self, hidden_size, query_size, align_hidden_size, mode='dot'): 19 | super(ATTENTION, self).__init__() 20 | # self.mix_emb_size=config.EMBEDDING_SIZE 21 | self.hidden_size = hidden_size 22 | self.query_size = query_size 23 | # self.align_hidden_size=hidden_size #align模式下的隐层大小,暂时取跟原来一致的 24 | self.align_hidden_size = align_hidden_size # align模式下的隐层大小,暂时取跟原来一致的 25 | self.mode = mode 26 | self.Linear_1 = nn.Linear(self.hidden_size, self.align_hidden_size, bias=False) 27 | # self.Linear_2=nn.Linear(hidden_sizedw,self.align_hidden_size,bias=False) 28 | self.Linear_2 = nn.Linear(self.query_size, self.align_hidden_size, bias=False) 29 | self.Linear_3 = nn.Linear(self.align_hidden_size, 1, bias=False) 30 | 31 | def forward(self, mix_hidden, query): 32 | # todo:这个要弄好,其实也可以直接抛弃memory来进行attention | DONE 33 | BATCH_SIZE = mix_hidden.size()[0] 34 | assert query.size() == (BATCH_SIZE, self.query_size) 35 | assert mix_hidden.size()[-1] == self.hidden_size 36 | # mix_hidden:bs,max_len,fre,hidden_size query:bs,hidden_size 37 | if self.mode == 'dot': 38 | # mix_hidden=mix_hidden.view(-1,1,self.hidden_size) 39 | mix_shape = mix_hidden.size() 40 | mix_hidden = mix_hidden.view(BATCH_SIZE, -1, self.hidden_size) 41 | query = query.view(-1, self.hidden_size, 1) 42 | # print '\n\n',mix_hidden.requires_grad,query.requires_grad,'\n\n' 43 | dot = torch.baddbmm(Variable(torch.zeros(1, 1)), mix_hidden, query) 44 | energy = dot.view(BATCH_SIZE, mix_shape[1], mix_shape[2]) 45 | # TODO: 这里可以想想是不是能换成Relu之类的 46 | mask = F.sigmoid(energy) 47 | return mask 48 | 49 | elif self.mode == 'align': 50 | # mix_hidden=Variable(mix_hidden) 51 | # query=Variable(query) 52 | mix_shape = mix_hidden.size() 53 | mix_hidden = mix_hidden.view(-1, self.hidden_size) 54 | mix_hidden = self.Linear_1(mix_hidden).view(BATCH_SIZE, -1, self.align_hidden_size) 55 | query = self.Linear_2(query).view(-1, 1, self.align_hidden_size) # bs,1,hidden 56 | sum = F.tanh(mix_hidden + query) 57 | # TODO:从这里开始做起 58 | energy = self.Linear_3(sum.view(-1, self.align_hidden_size)).view(BATCH_SIZE, mix_shape[1], mix_shape[2]) 59 | mask = F.sigmoid(energy) 60 | return mask 61 | else: 62 | print 63 | 'NO this attention methods.' 64 | raise IndexError 65 | 66 | 67 | class MIX_SPEECH_CNN(nn.Module): 68 | def __init__(self, config, input_fre, mix_speech_len): 69 | super(MIX_SPEECH_CNN, self).__init__() 70 | self.input_fre = input_fre 71 | self.mix_speech_len = mix_speech_len 72 | self.config = config 73 | 74 | self.cnn1 = nn.Conv2d(1, 96, (1, 7), stride=1, padding=(0, 3), dilation=(1, 1)) 75 | self.cnn2 = nn.Conv2d(96, 96, (7, 1), stride=1, padding=(3, 0), dilation=(1, 1)) 76 | self.cnn3 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(2, 2), dilation=(1, 1)) 77 | self.cnn4 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(4, 2), dilation=(2, 1)) 78 | self.cnn5 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(8, 2), dilation=(4, 1)) 79 | 80 | self.cnn6 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(16, 2), dilation=(8, 1)) 81 | self.cnn7 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(32, 2), dilation=(16, 1)) 82 | self.cnn8 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(64, 2), dilation=(32, 1)) 83 | self.cnn9 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(2, 2), dilation=(1, 1)) 84 | self.cnn10 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(4, 4), dilation=(2, 2)) 85 | 86 | self.cnn11 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(8, 8), dilation=(4, 4)) 87 | self.cnn12 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(16, 16), dilation=(8, 8)) 88 | self.cnn13 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(32, 32), dilation=(16, 16)) 89 | self.cnn14 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(64, 64), dilation=(32, 32)) 90 | self.cnn15 = nn.Conv2d(96, 8, (1, 1), stride=1, padding=(0, 0), dilation=(1, 1)) 91 | self.num_cnns = 15 92 | self.bn1 = nn.BatchNorm2d(96) 93 | self.bn2 = nn.BatchNorm2d(96) 94 | self.bn3 = nn.BatchNorm2d(96) 95 | self.bn4 = nn.BatchNorm2d(96) 96 | self.bn5 = nn.BatchNorm2d(96) 97 | self.bn6 = nn.BatchNorm2d(96) 98 | self.bn7 = nn.BatchNorm2d(96) 99 | self.bn8 = nn.BatchNorm2d(96) 100 | self.bn9 = nn.BatchNorm2d(96) 101 | self.bn10 = nn.BatchNorm2d(96) 102 | self.bn11 = nn.BatchNorm2d(96) 103 | self.bn12 = nn.BatchNorm2d(96) 104 | self.bn13 = nn.BatchNorm2d(96) 105 | self.bn14 = nn.BatchNorm2d(96) 106 | self.bn15 = nn.BatchNorm2d(8) 107 | 108 | def forward(self, x): 109 | print 110 | 'speech input size:', x.size() 111 | assert len(x.size()) == 3 112 | x = x.unsqueeze(1) 113 | print 114 | '\nSpeech layer log:' 115 | x = x.contiguous() 116 | for idx in range(self.num_cnns): 117 | cnn_layer = eval('self.cnn{}'.format(idx + 1)) 118 | bn_layer = eval('self.bn{}'.format(idx + 1)) 119 | x = F.relu(cnn_layer(x)) 120 | x = bn_layer(x) 121 | print 122 | 'speech shape after CNNs:', idx, '', x.size() 123 | 124 | out = x.transpose(1, 3).transpose(1, 2).contiguous() 125 | print 126 | 'speech output size:', out.size() 127 | return out, out 128 | 129 | 130 | class MIX_SPEECH(nn.Module): 131 | def __init__(self, config, input_fre, mix_speech_len): 132 | super(MIX_SPEECH, self).__init__() 133 | self.input_fre = input_fre 134 | self.mix_speech_len = mix_speech_len 135 | self.layer = nn.LSTM( 136 | input_size=input_fre, 137 | hidden_size=config.HIDDEN_UNITS, 138 | num_layers=4, 139 | batch_first=True, 140 | bidirectional=True 141 | ) 142 | # self.batchnorm = nn.BatchNorm1d(self.input_fre*config.EMBEDDING_SIZE) 143 | self.Linear = nn.Linear(2 * config.HIDDEN_UNITS, self.input_fre * config.EMBEDDING_SIZE,bias=1) 144 | self.config = config 145 | 146 | def forward(self, x): 147 | x, hidden = self.layer(x) 148 | batch_size = x.size()[0] 149 | x = x.contiguous() 150 | xx = x 151 | x = x.view(batch_size * self.mix_speech_len, -1) 152 | # out=F.tanh(self.Linear(x)) 153 | out = self.Linear(x) 154 | # out = self.batchnorm(out) 155 | out = F.tanh(out) 156 | # out=F.relu(out) 157 | out = out.view(batch_size, self.mix_speech_len, self.input_fre, -1) 158 | # print 'Mix speech output shape:',out.size() 159 | return out, xx 160 | 161 | 162 | class Discriminator(nn.Module): 163 | def __init__(self): 164 | super(Discriminator, self).__init__() 165 | self.cnn = nn.Conv2d(1, 64, (3, 3), stride=(2, 2), ) 166 | self.cnn1 = nn.Conv2d(64, 64, (3, 3), stride=(2, 2), ) 167 | self.cnn2 = nn.Conv2d(64, 64, (3, 3), stride=(2, 2), ) 168 | # self.final=nn.Linear(36480,1) 169 | self.final = nn.Linear(73920, 1) 170 | 171 | def forward(self, spec): 172 | bs, topk, len, fre = spec.size() 173 | spec = spec.view(bs * topk, 1, len, fre) 174 | spec = F.relu(self.cnn(spec)) 175 | spec = F.relu(self.cnn1(spec)) 176 | spec = F.relu(self.cnn2(spec)) 177 | spec = spec.view(bs * topk, -1) 178 | print 179 | 'size spec:', spec.size() 180 | score = F.sigmoid(self.final(spec)) 181 | print 182 | 'size spec:', score.size() 183 | return score 184 | 185 | 186 | class SPEECH_EMBEDDING(nn.Module): 187 | def __init__(self, num_labels, embedding_size, max_num_channel): 188 | super(SPEECH_EMBEDDING, self).__init__() 189 | self.num_all = num_labels 190 | self.emb_size = embedding_size 191 | self.max_num_out = max_num_channel 192 | # self.layer=nn.Embedding(num_labels,embedding_size,padding_idx=-1) 193 | self.layer = nn.Embedding(num_labels, embedding_size) 194 | 195 | def forward(self, input, mask_idx): 196 | aim_matrix = torch.from_numpy(np.array(mask_idx)) 197 | all = self.layer(Variable(aim_matrix)) # bs*num_labels(最多混合人个数)×Embedding的大小 198 | out = all 199 | return out 200 | 201 | 202 | class ADDJUST(nn.Module): 203 | # 这个模块是负责处理目标人的对应扰动的,进行一些偏移的调整 204 | def __init__(self, config, hidden_units, embedding_size): 205 | super(ADDJUST, self).__init__() 206 | self.config = config 207 | self.hidden_units = hidden_units 208 | self.emb_size = embedding_size 209 | self.layer = nn.Linear(hidden_units + embedding_size, embedding_size, bias=False) 210 | 211 | def forward(self, input_hidden, prob_emb): 212 | top_k_num = prob_emb.size()[1] 213 | x = torch.mean(input_hidden, 1).view(self.config.batch_size, 1, self.hidden_units).expand( 214 | self.config.batch_size, top_k_num, self.hidden_units) 215 | can = torch.cat([x, prob_emb], dim=2) 216 | all = self.layer(can) # bs*num_labels(最多混合人个数)×Embedding的大小 217 | out = all 218 | return out 219 | 220 | 221 | class SS(nn.Module): 222 | def __init__(self, config, speech_fre, mix_speech_len, num_labels): 223 | super(SS, self).__init__() 224 | self.config = config 225 | self.speech_fre = speech_fre 226 | self.mix_speech_len = mix_speech_len 227 | self.num_labels = num_labels 228 | print 229 | 'Begin to build the maim model for speech speration part.' 230 | if config.speech_cnn_net: 231 | self.mix_hidden_layer_3d = MIX_SPEECH_CNN(config, speech_fre, mix_speech_len) 232 | else: 233 | self.mix_hidden_layer_3d = MIX_SPEECH(config, speech_fre, mix_speech_len) 234 | # att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot') 235 | self.att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, config.SPK_EMB_SIZE, config.ATT_SIZE, 'align') 236 | if self.config.is_SelfTune: 237 | self.adjust_layer = ADDJUST(config, 2 * config.HIDDEN_UNITS, config.SPK_EMB_SIZE) 238 | print 239 | 'Adopt adjust layer.' 240 | 241 | def forward(self, mix_feas, hidden_outputs, targets, dict_spk2idx=None): 242 | ''' 243 | :param targets:这个targets的大小是:topk,bs 注意后面要transpose 244 | 123 324 345 245 | 323 E E 246 | 这种样子的,所以要去找aim_list 应该找到的结果是,先transpose之后,然后flatten,然后取不是E的:[0 1 2 4 ] 247 | 248 | ''' 249 | 250 | config = self.config 251 | top_k_max, batch_size = targets.size() # 这个top_k_max其实就是最多有几个说话人,应该是跟Max_MIX是保持一样的 252 | # assert top_k_max==config.MAX_MIX 253 | aim_list = (targets.transpose(0, 1).contiguous().view(-1) != dict_spk2idx['']).nonzero().squeeze() 254 | aim_list = aim_list.data.cpu().numpy() 255 | 256 | mix_speech_hidden, mix_tmp_hidden = self.mix_hidden_layer_3d(mix_feas) 257 | mix_speech_multiEmbs = torch.transpose(hidden_outputs, 0, 1).contiguous() # bs*num_labels(最多混合人个数)×Embedding的大小 258 | mix_speech_multiEmbs = mix_speech_multiEmbs.view(-1, config.SPK_EMB_SIZE) # bs*num_labels(最多混合人个数)×Embedding的大小 259 | # assert mix_speech_multiEmbs.size()[0]==targets.shape 260 | mix_speech_multiEmbs = mix_speech_multiEmbs[aim_list] # aim_num,embs 261 | # mix_speech_multiEmbs=mix_speech_multiEmbs[0] # aim_num,embs 262 | # print mix_speech_multiEmbs.shape 263 | if self.config.is_SelfTune: 264 | # TODO: 这里应该也是有问题的,暂时不用selfTune 265 | mix_adjust = self.adjust_layer(mix_tmp_hidden, mix_speech_multiEmbs) 266 | mix_speech_multiEmbs = mix_adjust + mix_speech_multiEmbs 267 | mix_speech_hidden_5d = mix_speech_hidden.view(batch_size, 1, self.mix_speech_len, self.speech_fre, 268 | config.EMBEDDING_SIZE) 269 | mix_speech_hidden_5d = mix_speech_hidden_5d.expand(batch_size, top_k_max, self.mix_speech_len, self.speech_fre, 270 | config.EMBEDDING_SIZE).contiguous() 271 | mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(-1, self.mix_speech_len, self.speech_fre, 272 | config.EMBEDDING_SIZE) 273 | mix_speech_hidden_5d_last = mix_speech_hidden_5d_last[aim_list] 274 | att_multi_speech = self.att_speech_layer(mix_speech_hidden_5d_last, 275 | mix_speech_multiEmbs.view(-1, config.SPK_EMB_SIZE)) 276 | att_multi_speech = att_multi_speech.view(-1, self.mix_speech_len, self.speech_fre) # bs,num_labels,len,fre这个东西 277 | multi_mask = att_multi_speech 278 | assert multi_mask.shape[0] == len(aim_list) 279 | return multi_mask 280 | 281 | 282 | def top_k_mask(batch_pro, alpha, top_k): 283 | 'batch_pro是 bs*n的概率分布,例如2×3的,每一行是一个概率分布\ 284 | alpha是阈值,大于它的才可以取,可以跟Multi-label语音分离的ACC的alpha对应;\ 285 | top_k是最多输出几个候选目标\ 286 | 输出是与bs*n的一个mask,float型的' 287 | size = batch_pro.size() 288 | final = torch.zeros(size) 289 | sort_result, sort_index = torch.sort(batch_pro, 1, True) # 先排个序 290 | sort_index = sort_index[:, :top_k] # 选出每行的top_k的id 291 | sort_result = torch.sum(sort_result > alpha, 1) 292 | for line_idx in range(size[0]): 293 | line_top_k = sort_index[line_idx][:int(sort_result[line_idx].data.cpu().numpy())] 294 | line_top_k = line_top_k.data.cpu().numpy() 295 | for i in line_top_k: 296 | final[line_idx, i] = 1 297 | return final 298 | 299 | 300 | -------------------------------------------------------------------------------- /models/separation_tasnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | EPS = 1e-8 9 | 10 | def gcd(a, b): 11 | a, b = (a, b) if a >= b else (b, a) 12 | while b: 13 | a, b = b, a % b 14 | return a 15 | 16 | def overlap_and_add(signal, frame_step): 17 | """Reconstructs a signal from a framed representation. 18 | 19 | Adds potentially overlapping frames of a signal with shape 20 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 21 | The resulting tensor has shape `[..., output_size]` where 22 | 23 | output_size = (frames - 1) * frame_step + frame_length 24 | 25 | Args: 26 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. 27 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. 28 | 29 | Returns: 30 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. 31 | output_size = (frames - 1) * frame_step + frame_length 32 | 33 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py 34 | """ 35 | outer_dimensions = signal.size()[:-2] 36 | frames, frame_length = signal.size()[-2:] 37 | 38 | subframe_length = gcd(frame_length, frame_step) # gcd=Greatest Common Divisor 39 | subframe_step = frame_step // subframe_length 40 | subframes_per_frame = frame_length // subframe_length 41 | output_size = frame_step * (frames - 1) + frame_length 42 | output_subframes = output_size // subframe_length 43 | 44 | # subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) 45 | subframe_signal = signal.view(outer_dimensions[0],outer_dimensions[1], -1, subframe_length) 46 | 47 | frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) 48 | frame = signal.new_tensor(frame).long() # signal may in GPU or CPU 49 | frame = frame.contiguous().view(-1) 50 | 51 | # result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) 52 | result = signal.new_zeros(outer_dimensions[0],outer_dimensions[1], output_subframes, subframe_length) 53 | result.index_add_(-2, frame, subframe_signal) 54 | # result = result.view(*outer_dimensions, -1) 55 | result = result.view(outer_dimensions[0],outer_dimensions[1], -1) 56 | return result 57 | 58 | 59 | def remove_pad(inputs, inputs_lengths): 60 | """ 61 | Args: 62 | inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size 63 | inputs_lengths: torch.Tensor, [B] 64 | Returns: 65 | results: a list containing B items, each item is [C, T], T varies 66 | """ 67 | results = [] 68 | dim = inputs.dim() 69 | if dim == 3: 70 | C = inputs.size(1) 71 | for input, length in zip(inputs, inputs_lengths): 72 | if dim == 3: # [B, C, T] 73 | results.append(input[:,:length].view(C, -1).cpu().numpy()) 74 | elif dim == 2: # [B, T] 75 | results.append(input[:length].view(-1).cpu().numpy()) 76 | return results 77 | 78 | class ConvTasNet(nn.Module): 79 | def __init__(self, N=256, L=40, B=256, H=512, P=3, X=8, R=4, C=2, norm_type="gLN", causal=False, 80 | mask_nonlinear='sigmoid'): 81 | """ 82 | Args: 83 | N: Number of filters in autoencoder 84 | L: Length of the filters (in samples) 85 | B: Number of channels in bottleneck 1 * 1-conv block 86 | H: Number of channels in convolutional blocks 87 | P: Kernel size in convolutional blocks 88 | X: Number of convolutional blocks in each repeat 89 | R: Number of repeats 90 | C: Number of speakers 91 | norm_type: BN, gLN, cLN 92 | causal: causal or non-causal 93 | mask_nonlinear: use which non-linear function to generate mask 94 | """ 95 | super(ConvTasNet, self).__init__() 96 | # Hyper-parameter 97 | self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C 98 | self.norm_type = norm_type 99 | self.causal = causal 100 | self.mask_nonlinear = mask_nonlinear 101 | # Components 102 | self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear) 103 | # init 104 | #for p in self.parameters(): 105 | # if p.dim() > 1: 106 | # nn.init.xavier_normal_(p) 107 | 108 | def forward(self, mixture, hidden_outputs): 109 | """ 110 | Args: 111 | mixture: [M, T], M is batch size, T is #samples 112 | Returns: 113 | est_source: [M, C, T] 114 | """ 115 | #mixture_w = self.encoder(mixture) 116 | est_mask = self.separator(mixture, hidden_outputs) 117 | #est_source = self.decoder(mixture_w, est_mask) 118 | 119 | # T changed after conv1d in encoder, fix it here 120 | #T_origin = mixture.size(-1) 121 | #T_conv = est_source.size(-1) 122 | #est_source = F.pad(est_source, (0, T_origin - T_conv)) 123 | return est_mask 124 | 125 | @classmethod 126 | def load_model(cls, path): 127 | # Load to CPU 128 | package = torch.load(path, map_location=lambda storage, loc: storage) 129 | model = cls.load_model_from_package(package) 130 | return model 131 | 132 | @classmethod 133 | def load_model_from_package(cls, package): 134 | model = cls(package['N'], package['L'], package['B'], package['H'], 135 | package['P'], package['X'], package['R'], package['C'], 136 | norm_type=package['norm_type'], causal=package['causal'], 137 | mask_nonlinear=package['mask_nonlinear']) 138 | model.load_state_dict(package['state_dict']) 139 | return model 140 | 141 | @staticmethod 142 | def serialize(model, optimizer, epoch, tr_loss=None, cv_loss=None): 143 | package = { 144 | # hyper-parameter 145 | 'N': model.N, 'L': model.L, 'B': model.B, 'H': model.H, 146 | 'P': model.P, 'X': model.X, 'R': model.R, 'C': model.C, 147 | 'norm_type': model.norm_type, 'causal': model.causal, 148 | 'mask_nonlinear': model.mask_nonlinear, 149 | # state 150 | 'state_dict': model.state_dict(), 151 | 'optim_dict': optimizer.state_dict(), 152 | 'epoch': epoch 153 | } 154 | if tr_loss is not None: 155 | package['tr_loss'] = tr_loss 156 | package['cv_loss'] = cv_loss 157 | return package 158 | 159 | 160 | class TasNetEncoder(nn.Module): 161 | """Estimation of the nonnegative mixture weight by a 1-D conv layer. 162 | """ 163 | def __init__(self, L=40, N=256, Ch = 32): 164 | super(TasNetEncoder, self).__init__() 165 | # Hyper-parameter 166 | self.L, self.N = L, N 167 | self.Ch = Ch 168 | # Components 169 | # 50% overlap 170 | 171 | self.conv1d_c2 = nn.Conv1d(2, N, kernel_size=L, stride=L // 2, bias=False) 172 | 173 | def forward(self, mixture): 174 | """ 175 | Args: 176 | mixture: [M, T, C], M is batch size, T is #samples, C is channel number 177 | Returns: 178 | mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1 179 | """ 180 | mixture =mixture.transpose(1,2) 181 | mixture_w = F.relu(self.conv1d_c2(mixture)) 182 | return mixture_w 183 | 184 | 185 | class TasNetDecoder(nn.Module): 186 | def __init__(self, N=256, L=40): 187 | super(TasNetDecoder, self).__init__() 188 | # Hyper-parameter 189 | self.N, self.L = N, L 190 | # Components 191 | self.basis_signals = nn.Linear(N, L, bias=False) 192 | 193 | self.conv1dTranspose = nn.ConvTranspose1d(N, 1, L, stride=L//2) 194 | 195 | def forward(self, mixture_w, est_mask): 196 | """ 197 | Args: 198 | mixture_w: [M, N, K] 199 | est_mask: [M, C, N, K] 200 | Returns: 201 | est_source: [M, C, T] 202 | """ 203 | # D = W * M 204 | source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K] 205 | #source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N] 206 | M, C, N, K = source_w.size() 207 | source_w = source_w.view(-1, N, K).contiguous() 208 | est_source = self.conv1dTranspose(source_w).view(M, C, -1).contiguous() 209 | # S = DV 210 | #est_source = self.basis_signals(source_w) # [M, C, K, L] 211 | #est_source = overlap_and_add(est_source, self.L//2) # M x C x T 212 | return est_source 213 | 214 | class TemporalConvNet(nn.Module): 215 | def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, 216 | mask_nonlinear='relu'): 217 | """ 218 | Args: 219 | N: Number of filters in autoencoder 220 | B: Number of channels in bottleneck 1 * 1-conv block 221 | H: Number of channels in convolutional blocks 222 | P: Kernel size in convolutional blocks 223 | X: Number of convolutional blocks in each repeat 224 | R: Number of repeats 225 | C: Number of speakers 226 | norm_type: BN, gLN, cLN 227 | causal: causal or non-causal 228 | mask_nonlinear: use which non-linear function to generate mask 229 | """ 230 | super(TemporalConvNet, self).__init__() 231 | # Hyper-parameter 232 | self.C = C 233 | self.B = B 234 | self.mask_nonlinear = mask_nonlinear 235 | # Components 236 | # [M, N, K] -> [M, N, K] 237 | layer_norm = ChannelwiseLayerNorm(N) 238 | # [M, N, K] -> [M, B, K] 239 | bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) 240 | # [M, B, K] -> [M, B, K] 241 | repeats = [] 242 | for r in range(R): 243 | blocks = [] 244 | for x in range(X): 245 | dilation = 2**x 246 | padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 247 | blocks += [TemporalBlock(B, H, P, stride=1, 248 | padding=padding, 249 | dilation=dilation, 250 | norm_type=norm_type, 251 | causal=causal)] 252 | repeats += [nn.Sequential(*blocks)] 253 | temporal_conv_net = nn.Sequential(*repeats) 254 | # [M, B, K] -> [M, C*N, K] 255 | # self.mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False) 256 | self.mask_conv1x1 = nn.Conv1d(B, N, 1, bias=False) 257 | # 256 should keep consisten with SPK_EMB_SIZE in config 258 | # Put together 259 | self.network = nn.Sequential(layer_norm, 260 | bottleneck_conv1x1, 261 | temporal_conv_net,) 262 | # mask_conv1x1) 263 | 264 | def forward(self, mixture_w, hidden_outputs): 265 | """ 266 | Keep this API same with TasNet 267 | Args: 268 | mixture_w: [M, N, K], M is batch size 269 | hidden_outputs: [M, C, D] 270 | returns: 271 | est_mask: [M, C, N, K] 272 | """ 273 | B = self.B 274 | M, N, K = mixture_w.size() 275 | _, C, D = hidden_outputs.size() 276 | assert M==_ 277 | original_sep= self.network(mixture_w).unsqueeze(1).expand(M,self.C,B,K) # [M, N, K] -> [M, C, B, K] 278 | hidden_outputs=hidden_outputs.unsqueeze(-1).expand(M,self.C, D, K)# [M,C,D,K] 279 | #original_sep=torch.cat((original_sep,hidden_outputs),dim=2).view(-1,B+D,K) #[M*C,(B+D),K] 280 | original_sep = (original_sep * hidden_outputs).contiguous().view(-1,B,K) 281 | score = self.mask_conv1x1(original_sep) # -> [M*C,N, K] 282 | 283 | score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] 284 | if self.mask_nonlinear == 'softmax': 285 | est_mask = F.softmax(score, dim=1) 286 | elif self.mask_nonlinear == 'relu': 287 | est_mask = F.relu(score) 288 | elif self.mask_nonlinear == 'sigmoid': 289 | est_mask = F.sigmoid(score) 290 | else: 291 | raise ValueError("Unsupported mask non-linear function") 292 | return est_mask 293 | 294 | 295 | class TemporalBlock(nn.Module): 296 | def __init__(self, in_channels, out_channels, kernel_size, 297 | stride, padding, dilation, norm_type="gLN", causal=False): 298 | super(TemporalBlock, self).__init__() 299 | # [M, B, K] -> [M, H, K] 300 | conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) 301 | prelu = nn.PReLU() 302 | norm = chose_norm(norm_type, out_channels) 303 | # [M, H, K] -> [M, B, K] 304 | dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, 305 | stride, padding, dilation, norm_type, 306 | causal) 307 | # Put together 308 | self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) 309 | 310 | def forward(self, x): 311 | """ 312 | Args: 313 | x: [M, B, K] 314 | Returns: 315 | [M, B, K] 316 | """ 317 | residual = x 318 | out = self.net(x) 319 | # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? 320 | return out + residual # look like w/o F.relu is better than w/ F.relu 321 | # return F.relu(out + residual) 322 | 323 | 324 | class DepthwiseSeparableConv(nn.Module): 325 | def __init__(self, in_channels, out_channels, kernel_size, 326 | stride, padding, dilation, norm_type="gLN", causal=False): 327 | super(DepthwiseSeparableConv, self).__init__() 328 | # Use `groups` option to implement depthwise convolution 329 | # [M, H, K] -> [M, H, K] 330 | depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, 331 | stride=stride, padding=padding, 332 | dilation=dilation, groups=in_channels, 333 | bias=False) 334 | if causal: 335 | chomp = Chomp1d(padding) 336 | prelu = nn.PReLU() 337 | norm = chose_norm(norm_type, in_channels) 338 | # [M, H, K] -> [M, B, K] 339 | pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) 340 | # Put together 341 | if causal: 342 | self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, 343 | pointwise_conv) 344 | else: 345 | self.net = nn.Sequential(depthwise_conv, prelu, norm, 346 | pointwise_conv) 347 | 348 | def forward(self, x): 349 | """ 350 | Args: 351 | x: [M, H, K] 352 | Returns: 353 | result: [M, B, K] 354 | """ 355 | return self.net(x) 356 | 357 | 358 | class Chomp1d(nn.Module): 359 | """To ensure the output length is the same as the input. 360 | """ 361 | def __init__(self, chomp_size): 362 | super(Chomp1d, self).__init__() 363 | self.chomp_size = chomp_size 364 | 365 | def forward(self, x): 366 | """ 367 | Args: 368 | x: [M, H, Kpad] 369 | Returns: 370 | [M, H, K] 371 | """ 372 | return x[:, :, :-self.chomp_size].contiguous() 373 | 374 | 375 | def chose_norm(norm_type, channel_size): 376 | """The input of normlization will be (M, C, K), where M is batch size, 377 | C is channel size and K is sequence length. 378 | """ 379 | if norm_type == "gLN": 380 | return GlobalLayerNorm(channel_size) 381 | elif norm_type == "cLN": 382 | return ChannelwiseLayerNorm(channel_size) 383 | else: # norm_type == "BN": 384 | # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics 385 | # along M and K, so this BN usage is right. 386 | return nn.BatchNorm1d(channel_size) 387 | 388 | 389 | # TODO: Use nn.LayerNorm to impl cLN to speed up 390 | class ChannelwiseLayerNorm(nn.Module): 391 | """Channel-wise Layer Normalization (cLN)""" 392 | def __init__(self, channel_size): 393 | super(ChannelwiseLayerNorm, self).__init__() 394 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] 395 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] 396 | self.reset_parameters() 397 | 398 | def reset_parameters(self): 399 | self.gamma.data.fill_(1) 400 | self.beta.data.zero_() 401 | 402 | def forward(self, y): 403 | """ 404 | Args: 405 | y: [M, N, K], M is batch size, N is channel size, K is length 406 | Returns: 407 | cLN_y: [M, N, K] 408 | """ 409 | mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] 410 | var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] 411 | cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta 412 | return cLN_y 413 | 414 | 415 | class GlobalLayerNorm(nn.Module): 416 | """Global Layer Normalization (gLN)""" 417 | def __init__(self, channel_size): 418 | super(GlobalLayerNorm, self).__init__() 419 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] 420 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] 421 | self.reset_parameters() 422 | 423 | def reset_parameters(self): 424 | self.gamma.data.fill_(1) 425 | self.beta.data.zero_() 426 | 427 | def forward(self, y): 428 | """ 429 | Args: 430 | y: [M, N, K], M is batch size, N is channel size, K is length 431 | Returns: 432 | gLN_y: [M, N, K] 433 | """ 434 | # TODO: in torch 1.0, torch.mean() support dim list 435 | mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1] 436 | var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) 437 | gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta 438 | return gLN_y 439 | 440 | 441 | if __name__ == "__main__": 442 | torch.manual_seed(123) 443 | M, N, L, T = 2, 3, 4, 12 444 | K = 2*T//L-1 445 | B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False 446 | mixture = torch.randint(3, (M, T)) 447 | # test Encoder 448 | encoder = Encoder(L, N) 449 | encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size()) 450 | mixture_w = encoder(mixture) 451 | print('mixture', mixture) 452 | print('U', encoder.conv1d_U.weight) 453 | print('mixture_w', mixture_w) 454 | print('mixture_w size', mixture_w.size()) 455 | 456 | # test TemporalConvNet 457 | separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal) 458 | est_mask = separator(mixture_w) 459 | print('est_mask', est_mask) 460 | print('model', separator) 461 | 462 | # test Decoder 463 | decoder = Decoder(N, L) 464 | est_mask = torch.randint(2, (B, K, C, N)) 465 | est_source = decoder(mixture_w, est_mask) 466 | print('est_source', est_source) 467 | 468 | # test Conv-TasNet 469 | conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type) 470 | est_source = conv_tasnet(mixture) 471 | print('est_source', est_source) 472 | print('est_source size', est_source.size()) 473 | 474 | -------------------------------------------------------------------------------- /models/seq2seq.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | # import data.dict as dict 7 | import models 8 | # from figure_hot import relitu_line 9 | 10 | import numpy as np 11 | 12 | 13 | class seq2seq(nn.Module): 14 | 15 | def __init__(self, config, input_emb_size, mix_speech_len, tgt_vocab_size, tgt_dir_vocab_size, use_cuda, pretrain=None, score_fn=''): 16 | super(seq2seq, self).__init__() 17 | if pretrain is not None: 18 | src_embedding = pretrain['src_emb'] 19 | tgt_embedding = pretrain['tgt_emb'] 20 | else: 21 | src_embedding = None 22 | tgt_embedding = None 23 | 24 | self.encoder = models.rnn_encoder(config, input_emb_size) 25 | ## TasNet Encoder 26 | self.TasNetEncoder = models.TasNetEncoder() 27 | self.TasNetDecoder = models.TasNetDecoder() 28 | 29 | if config.shared_vocab == False: 30 | self.decoder = models.rnn_decoder(config, tgt_vocab_size, tgt_dir_vocab_size, embedding=tgt_embedding, score_fn=score_fn) 31 | else: 32 | self.decoder = models.rnn_decoder(config, tgt_vocab_size, tgt_dir_vocab_size, embedding=self.encoder.embedding, score_fn=score_fn) 33 | self.use_cuda = use_cuda 34 | self.tgt_vocab_size = tgt_vocab_size 35 | self.tgt_dir_vocab_size = tgt_dir_vocab_size 36 | self.config = config 37 | self.criterion = models.criterion(tgt_vocab_size, use_cuda, config.loss) 38 | self.criterion_dir = models.criterion_dir(tgt_dir_vocab_size, use_cuda, config.loss) 39 | self.loss_for_ss = nn.MSELoss() 40 | self.loss_for_dir = nn.MSELoss() 41 | self.log_softmax = nn.LogSoftmax() 42 | self.softmax = nn.Softmax() 43 | self.linear_output = nn.Linear(tgt_dir_vocab_size, 1) 44 | self.wav_loss = models.WaveLoss(dBscale=1, nfft=config.FRAME_LENGTH, hop_size=config.FRAME_SHIFT) 45 | 46 | speech_fre = input_emb_size 47 | num_labels = tgt_vocab_size 48 | if config.use_tas: 49 | self.ss_model = models.ConvTasNet() 50 | else: 51 | self.ss_model = models.SS(config, speech_fre, mix_speech_len, num_labels) 52 | 53 | def compute_loss(self, hidden_outputs, hidden_outputs_dir, targets, targets_dir, memory_efficiency): 54 | if memory_efficiency: 55 | return models.memory_efficiency_cross_entropy_loss(hidden_outputs, self.decoder, targets, self.criterion, self.config) 56 | else: 57 | sgm_loss_speaker, num_total, num_correct = models.cross_entropy_loss(hidden_outputs, self.decoder, targets, self.criterion, self.config) 58 | #sgm_loss_direction = models.mmse_loss2(hidden_outputs_dir, self.decoder, targets_dir, self.loss_for_dir) 59 | sgm_loss_direction, num_total_dir, num_correct_dir = models.cross_entropy_loss_dir(hidden_outputs_dir, self.decoder, targets_dir, self.criterion_dir, self.config) 60 | print("sgm_loss_speaker:", sgm_loss_speaker) 61 | print("sgm_loss_direction:", sgm_loss_direction) 62 | sgm_loss = sgm_loss_speaker + sgm_loss_direction 63 | return sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir 64 | 65 | def separation_loss(self, x_input_map_multi, masks, y_multi_map, Var='NoItem'): 66 | if not self.config.MLMSE: 67 | return models.ss_loss(self.config, x_input_map_multi, masks, y_multi_map, self.loss_for_ss,self.wav_loss) 68 | else: 69 | return models.ss_loss_MLMSE(self.config, x_input_map_multi, masks, y_multi_map, self.loss_for_ss, Var) 70 | 71 | def separation_tas_loss(self,predict_wav, y_multi_wav,mix_lengths): 72 | return models.ss_tas_loss(self.config, predict_wav, y_multi_wav, mix_lengths,self.loss_for_ss) 73 | 74 | def update_var(self, x_input_map_multi, multi_masks, y_multi_map): 75 | predict_multi_map = torch.mean(multi_masks * x_input_map_multi, -2) # 在时间维度上平均 76 | y_multi_map = torch.mean(Variable(y_multi_map), -2) # 在时间维度上平均 77 | loss_vector = (y_multi_map - predict_multi_map).view(-1, self.config.speech_fre).unsqueeze(-1) # 应该是bs*1*fre 78 | Var = torch.bmm(loss_vector, loss_vector.transpose(1, 2)) 79 | Var = torch.mean(Var, 0) # 在batch的维度上平均 80 | return Var.detach() 81 | 82 | def forward(self, src, tgt, tgt_dir): 83 | #lengths, indices = torch.sort(src_len.squeeze(0), dim=0, descending=True) 84 | # src = torch.index_select(src, dim=0, index=indices) 85 | # tgt = torch.index_select(tgt, dim=0, index=indices) 86 | 87 | mix_wav = src.transpose(0,1) # [batch, sample, channel] 88 | mix = self.TasNetEncoder(mix_wav) # [batch, featuremap, timeStep] 89 | mix_infer = mix.transpose(1,2) # [batch, timeStep, featuremap] 90 | _, lengths, _ = mix_infer.size() 91 | # 4 equals to the number of GPU 92 | lengths = Variable(torch.LongTensor(self.config.batch_size/4).zero_() + lengths).unsqueeze(0).cuda() 93 | lengths, indices = torch.sort(lengths.squeeze(0), dim=0, descending=True) 94 | 95 | contexts, state = self.encoder(mix_infer, lengths.data.tolist()) # context [max_len,batch_size,hidden_size×2] 96 | outputs, outputs_dir, final_state, global_embs = self.decoder(tgt[:-1], tgt_dir[:-1], state, contexts.transpose(0, 1)) 97 | 98 | if self.config.use_tas: 99 | predicted_maps = self.ss_model(mix, global_embs.transpose(0,1)) 100 | 101 | predicted_signal = self.TasNetDecoder(mix, predicted_maps) # [batch, spkN, timeStep] 102 | 103 | return outputs, outputs_dir, tgt[1:], tgt_dir[1:], predicted_signal.transpose(0,1) 104 | 105 | def sample(self, src, src_len): 106 | # src=src.squeeze() 107 | if self.use_cuda: 108 | src = src.cuda() 109 | src_len = src_len.cuda() 110 | 111 | lengths, indices = torch.sort(src_len, dim=0, descending=True) 112 | _, ind = torch.sort(indices) 113 | src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True) 114 | bos = Variable(torch.ones(src.size(1)).long().fill_(dict.BOS), volatile=True) 115 | 116 | if self.use_cuda: 117 | bos = bos.cuda() 118 | 119 | contexts, state = self.encoder(src, lengths.tolist()) 120 | sample_ids, final_outputs = self.decoder.sample([bos], state, contexts.transpose(0, 1)) 121 | _, attns_weight = final_outputs 122 | alignments = attns_weight.max(2)[1] 123 | sample_ids = torch.index_select(sample_ids.data, dim=1, index=ind) 124 | alignments = torch.index_select(alignments.data, dim=1, index=ind) 125 | # targets = tgt[1:] 126 | 127 | return sample_ids.t(), alignments.t() 128 | 129 | def beam_sample(self, src, dict_spk2idx, dict_dir2idx, beam_size=1): 130 | 131 | mix_wav = src.transpose(0,1) # [batch, sample] 132 | mix = self.TasNetEncoder(mix_wav) # [batch, featuremap, timeStep] 133 | 134 | mix_infer = mix.transpose(1,2) # [batch, timeStep, featuremap] 135 | batch_size, lengths, _ = mix_infer.size() 136 | lengths = Variable(torch.LongTensor(self.config.batch_size).zero_() + lengths).unsqueeze(0).cuda() 137 | lengths, indices = torch.sort(lengths.squeeze(0), dim=0, descending=True) 138 | 139 | contexts, encState = self.encoder(mix_infer, lengths.data.tolist()) # context [max_len,batch_size,hidden_size×2] 140 | 141 | # (1b) Initialize for the decoder. 142 | def var(a): 143 | return Variable(a, volatile=True) 144 | 145 | def rvar(a): 146 | return var(a.repeat(1, beam_size, 1)) 147 | 148 | def bottle(m): 149 | return m.view(batch_size * beam_size, -1) 150 | 151 | def unbottle(m): 152 | return m.view(beam_size, batch_size, -1) 153 | 154 | # Repeat everything beam_size times. 155 | contexts = rvar(contexts.data).transpose(0, 1) 156 | decState = (rvar(encState[0].data), rvar(encState[1].data)) 157 | decState_dir = (rvar(encState[0].data), rvar(encState[1].data)) 158 | # decState.repeat_beam_size_times(beam_size) 159 | beam = [models.Beam(beam_size, dict_spk2idx, n_best=1, 160 | cuda=self.use_cuda) for __ in range(batch_size)] 161 | 162 | beam_dir = [models.Beam(beam_size, dict_dir2idx, n_best=1, 163 | cuda=self.use_cuda) for __ in range(batch_size)] 164 | # (2) run the decoder to generate sentences, using beam search. 165 | 166 | mask = None 167 | mask_dir = None 168 | soft_score = None 169 | tmp_hiddens = [] 170 | tmp_soft_score = [] 171 | 172 | soft_score_dir = None 173 | tmp_hiddens_dir = [] 174 | tmp_soft_score_dir = [] 175 | output_list = [] 176 | output_dir_list = [] 177 | predicted_list = [] 178 | predicted_dir_list = [] 179 | output_bk_list = [] 180 | output_bk_dir_list = [] 181 | hidden_list = [] 182 | hidden_dir_list = [] 183 | emb_list = [] 184 | emb_dir_list = [] 185 | for i in range(self.config.max_tgt_len): 186 | 187 | if all((b.done() for b in beam)): 188 | break 189 | if all((b_dir.done() for b_dir in beam_dir)): 190 | break 191 | 192 | # Construct batch x beam_size nxt words. 193 | # Get all the pending current beam words and arrange for forward. 194 | inp = var(torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(-1)) 195 | inp_dir = var(torch.stack([b_dir.getCurrentState() for b_dir in beam_dir]).t().contiguous().view(-1)) 196 | 197 | # Run one step. 198 | output, output_dir, decState, decState_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_bk_dir = self.decoder.sample_one(inp, inp_dir, soft_score, soft_score_dir, decState, decState_dir, tmp_hiddens, tmp_hiddens_dir, 199 | contexts, mask, mask_dir) 200 | soft_score = F.softmax(output) 201 | soft_score_dir = F.softmax(output_dir) 202 | 203 | predicted = output.max(1)[1] 204 | predicted_dir = output_dir.max(1)[1] 205 | if self.config.mask: 206 | if mask is None: 207 | mask = predicted.unsqueeze(1).long() 208 | mask_dir = predicted_dir.unsqueeze(1).long() 209 | else: 210 | mask = torch.cat((mask, predicted.unsqueeze(1)), 1) 211 | mask_dir = torch.cat((mask_dir, predicted_dir.unsqueeze(1)), 1) 212 | # decOut: beam x rnn_size 213 | 214 | # (b) Compute a vector of batch*beam word scores. 215 | 216 | output_list.append(output[0]) 217 | output_dir_list.append(output_dir[0]) 218 | 219 | output = unbottle(self.log_softmax(output)) 220 | output_dir = unbottle(F.sigmoid(output_dir)) 221 | 222 | attn = unbottle(attn_weights) 223 | hidden = unbottle(hidden) 224 | emb = unbottle(emb) 225 | attn_dir = unbottle(attn_weights_dir) 226 | hidden_dir = unbottle(hidden_dir) 227 | emb_dir = unbottle(emb_dir) 228 | # beam x tgt_vocab 229 | 230 | output_bk_list.append(output_bk[0]) 231 | output_bk_dir_list.append(output_bk_dir[0]) 232 | hidden_list.append(hidden[0]) 233 | hidden_dir_list.append(hidden_dir[0]) 234 | emb_list.append(emb[0]) 235 | emb_dir_list.append(emb_dir[0]) 236 | 237 | predicted_list.append(predicted) 238 | predicted_dir_list.append(predicted_dir) 239 | 240 | # (c) Advance each beam. 241 | # update state 242 | 243 | for j, b in enumerate(beam): 244 | b.advance(output.data[:, j], attn.data[:, j], hidden.data[:, j], emb.data[:, j]) 245 | b.beam_update(decState, j) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值! 246 | if self.config.ct_recu: 247 | b.beam_update_context(contexts, j) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值! 248 | for i, a in enumerate(beam_dir): 249 | a.advance(output_dir.data[:, i], attn_dir.data[:, i], hidden_dir.data[:, i], emb_dir.data[:, i]) 250 | a.beam_update(decState_dir, i) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值! 251 | if self.config.ct_recu: 252 | a.beam_update_context(contexts, i) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值! 253 | # print "beam after decState:",decState[0].data.cpu().numpy().mean() 254 | 255 | # (3) Package everything up. 256 | allHyps,allHyps_dir, allScores, allAttn, allHiddens, allEmbs = [],[], [], [], [], [] 257 | 258 | ind = range(batch_size) 259 | for j in ind: 260 | b = beam[j] 261 | c = beam_dir[j] 262 | n_best = 1 263 | scores, ks = b.sortFinished(minimum=n_best) 264 | hyps, hyps_dir, attn, hiddens, embs = [], [], [], [], [] 265 | for i, (times, k) in enumerate(ks[:n_best]): 266 | hyp, att, hidden, emb = b.getHyp(times, k) 267 | hyp_dir, att_dir, hidden_dir, emb_dir = c.getHyp(times, k) 268 | if self.config.relitu: 269 | relitu_line(626, 1, att[0].cpu().numpy()) 270 | relitu_line(626, 1, att[1].cpu().numpy()) 271 | hyps.append(hyp) 272 | attn.append(att.max(1)[1]) 273 | hiddens.append(hidden+hidden_dir) 274 | embs.append(emb+emb_dir) 275 | hyps_dir.append(hyp_dir) 276 | allHyps.append(hyps[0]) 277 | allHyps_dir.append(hyps_dir[0]) 278 | allScores.append(scores[0]) 279 | allAttn.append(attn[0]) 280 | allHiddens.append(hiddens[0]) 281 | allEmbs.append(embs[0]) 282 | 283 | ss_embs = Variable(torch.stack(allEmbs, 0).transpose(0, 1)) # to [decLen, bs, dim] 284 | if self.config.use_tas: 285 | predicted_maps = self.ss_model(mix, ss_embs[1:].transpose(0,1)) 286 | 287 | predicted_signal = self.TasNetDecoder(mix, predicted_maps) # [batch, spkN, timeStep] 288 | return allHyps, allHyps_dir, allAttn, allHiddens, predicted_signal.transpose(0,1), output_list, output_dir_list, output_bk_list, output_bk_dir_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list 289 | -------------------------------------------------------------------------------- /predata_WSJ_lcx.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import os 3 | import numpy as np 4 | import random 5 | import re 6 | import soundfile as sf 7 | import resampy 8 | import librosa 9 | import argparse 10 | import data.utils as utils 11 | import models 12 | from scipy.io import wavfile 13 | import scipy.signal 14 | # Add the config. 15 | parser = argparse.ArgumentParser(description='predata scripts.') 16 | parser.add_argument('-config', default='config_WSJ0_Tasnet.yaml', type=str, help="config file") 17 | opt = parser.parse_args() 18 | config = utils.read_config(opt.config) 19 | 20 | channel_first = config.channel_first 21 | np.random.seed(1) 22 | random.seed(1) 23 | 24 | data_path = '/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel' 25 | 26 | def pad_list(xs, pad_value): 27 | n_batch = len(xs) 28 | max_len = max(x.size(0) for x in xs) 29 | pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value) 30 | for i in range(n_batch): 31 | pad[i, :xs[i].size(0)] = xs[i] 32 | return pad 33 | 34 | def get_energy_order(multi_spk_fea_list): 35 | order=[] 36 | for one_line in multi_spk_fea_list: 37 | dd=sorted(one_line.items(),key= lambda d:d[1].sum(),reverse=True) 38 | dd=[d[0] for d in dd] 39 | order.append(dd) 40 | return order 41 | 42 | def get_spk_order(dir_tgt,raw_tgt): 43 | raw_tgt_dir=[] 44 | i=0 45 | for sample in raw_tgt: 46 | dd = [dir_tgt[i][spk] for spk in sample] 47 | raw_tgt_dir.append(dd) 48 | i=i+1 49 | return raw_tgt_dir 50 | 51 | 52 | def _collate_fn(mix_data,source_data,raw_tgt=None): 53 | """ 54 | Args: 55 | batch: list, len(batch) = 1. See AudioDataset.__getitem__() 56 | Returns: 57 | mixtures_pad: B x T, torch.Tensor 58 | ilens : B, torch.Tentor 59 | sources_pad: B x C x T, torch.Tensor 60 | """ 61 | mixtures, sources = mix_data,source_data 62 | if raw_tgt is None: #如果没有给定顺序 63 | raw_tgt = [sorted(spk.keys()) for spk in source_data] 64 | # sources= models.rank_feas(raw_tgt, source_data,out_type='numpy') 65 | sources=[] 66 | for each_feas, each_line in zip(source_data, raw_tgt): 67 | sources.append(np.stack([each_feas[spk] for spk in each_line])) 68 | sources=np.array(sources) 69 | mixtures=np.array(mixtures) 70 | # get batch of lengths of input sequences 71 | ilens = np.array([mix.shape[0] for mix in mixtures]) 72 | 73 | # perform padding and convert to tensor 74 | pad_value = 0 75 | # mixtures_pad = pad_list([mix.float() for mix in mixtures], pad_value) 76 | ilens = ilens 77 | # sources_pad = pad_list([torch.from_numpy(s).float() for s in sources], pad_value) 78 | # N x T x C -> N x C x T 79 | # sources_pad = sources_pad.permute((0, 2, 1)).contiguous() 80 | return mixtures, ilens, sources 81 | # return mixtures_pad, ilens, sources_pad 82 | 83 | def prepare_data(mode, train_or_test, min=None, max=None): 84 | 85 | if min: 86 | config.MIN_MIX = min 87 | if max: 88 | config.MAX_MIX = max 89 | 90 | mix_speechs = [] 91 | aim_fea = [] 92 | aim_spkid = [] 93 | aim_spkname = [] 94 | query = [] 95 | multi_spk_fea_list = [] 96 | multi_spk_wav_list = [] 97 | direction = [] 98 | 99 | if config.MODE == 1: 100 | if config.DATASET == 'WSJ0': 101 | spk_file_tr = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/wav_spk.txt','r') 102 | all_spk_train = [i.replace("\n","") for i in spk_file_tr] 103 | all_spk_train = sorted(all_spk_train) 104 | print(all_spk_train) 105 | 106 | spk_file_tt = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/test/wav_spk.txt','r') 107 | all_spk_test = [i.replace("\n","") for i in spk_file_tt] 108 | all_spk_test = sorted(all_spk_test) 109 | print(all_spk_test) 110 | all_spk = all_spk_train + all_spk_test 111 | print(all_spk) 112 | 113 | all_dir = [i for i in range(1,20)] 114 | dicDirFile = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/wav_dirLabel2.txt', 'r')#打开数据 115 | dirDict = {} 116 | while True: 117 | line = dicDirFile.readline() 118 | if line == '': 119 | break 120 | index = line.find(' ') 121 | key = line[:index] 122 | #print(key) 123 | value = line[index:] 124 | dirDict[key] = value.replace("\n","").replace(" ","") 125 | dicDirFile.close() 126 | 127 | spk_samples_list = {} 128 | batch_idx = 0 129 | list_path = '/mnt/lustre/xushuang2/lcx/data/create-speaker-mixtures/' 130 | all_samples_list = {} 131 | sample_idx = {} 132 | number_samples = {} 133 | batch_mix = {} 134 | mix_number_list = range(config.MIN_MIX, config.MAX_MIX + 1) 135 | number_samples_all = 0 136 | for mix_k in mix_number_list: 137 | if train_or_test == 'train': 138 | aim_list_path = list_path + 'mix_{}_spk_tr.txt'.format(mix_k) 139 | if train_or_test == 'valid': 140 | aim_list_path = list_path + 'mix_{}_spk_cv.txt'.format(mix_k) 141 | if train_or_test == 'test': 142 | aim_list_path = list_path + 'mix_{}_spk_tt.txt'.format(mix_k) 143 | config.batch_size = 1 144 | 145 | all_samples_list[mix_k] = open(aim_list_path).readlines() # [:31] 146 | number_samples[mix_k] = len(all_samples_list[mix_k]) 147 | batch_mix[mix_k] = len(all_samples_list[mix_k]) / config.batch_size 148 | number_samples_all += len(all_samples_list[mix_k]) 149 | 150 | sample_idx[mix_k] = 0 151 | 152 | if train_or_test == 'train' and config.SHUFFLE_BATCH: 153 | random.shuffle(all_samples_list[mix_k]) 154 | print('shuffle success!', all_samples_list[mix_k][0]) 155 | 156 | batch_total = number_samples_all / config.batch_size 157 | 158 | mix_k = random.sample(mix_number_list, 1)[0] 159 | # while True: 160 | for ___ in range(number_samples_all): 161 | if ___ == number_samples_all - 1: 162 | print('ends here.___') 163 | yield False 164 | mix_len = 0 165 | if sample_idx[mix_k] >= batch_mix[mix_k] * config.batch_size: 166 | mix_number_list.remove(mix_k) 167 | try: 168 | mix_k = random.sample(mix_number_list, 1)[0] 169 | except ValueError: 170 | print('seems there gets all over.') 171 | if len(mix_number_list) == 0: 172 | print('all mix number is over~!') 173 | yield False 174 | 175 | batch_idx = 0 176 | mix_speechs = np.zeros((config.batch_size, config.MAX_LEN)) 177 | mix_feas = [] 178 | mix_phase = [] 179 | aim_fea = [] 180 | aim_spkid = [] 181 | aim_spkname = [] 182 | query = [] 183 | multi_spk_fea_list = [] 184 | multi_spk_order_list=[] 185 | multi_spk_wav_list = [] 186 | continue 187 | 188 | all_over = 1 189 | for kkkkk in mix_number_list: 190 | if not sample_idx[kkkkk] >= batch_mix[mix_k] * config.batch_size: 191 | all_over = 0 192 | break 193 | if all_over: 194 | print('all mix number is over~!') 195 | yield False 196 | 197 | # mix_k=random.sample(mix_number_list,1)[0] 198 | if train_or_test == 'train': 199 | aim_spk_k = random.sample(all_spk_train, mix_k) 200 | elif train_or_test == 'test': 201 | aim_spk_k = random.sample(all_spk_test, mix_k) 202 | 203 | aim_spk_k = re.findall('/([0-9][0-9].)/', all_samples_list[mix_k][sample_idx[mix_k]]) 204 | aim_spk_db_k = [float(dd) for dd in re.findall(' (.*?) ', all_samples_list[mix_k][sample_idx[mix_k]])] 205 | aim_spk_samplename_k = re.findall('/(.{8})\.wav ', all_samples_list[mix_k][sample_idx[mix_k]]) 206 | assert len(aim_spk_k) == mix_k == len(aim_spk_db_k) == len(aim_spk_samplename_k) 207 | 208 | multi_fea_dict_this_sample = {} 209 | multi_wav_dict_this_sample = {} 210 | multi_name_list_this_sample = [] 211 | multi_db_dict_this_sample = {} 212 | direction_sample = {} 213 | for k, spk in enumerate(aim_spk_k): 214 | 215 | sample_name = aim_spk_samplename_k[k] 216 | if aim_spk_db_k[k] ==0: 217 | aim_spk_db_k[k] = int(aim_spk_db_k[k]) 218 | if train_or_test != 'test': 219 | spk_speech_path = data_path + '/' + 'train' + '/' + sample_name + '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav' 220 | else: 221 | spk_speech_path = data_path + '/' + 'test' + '/' + sample_name + '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav' 222 | 223 | signal, rate = sf.read(spk_speech_path) 224 | 225 | wav_name = sample_name+ '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav' 226 | direction_sample[spk] = dirDict[wav_name] 227 | if rate != config.FRAME_RATE: 228 | print("config.FRAME_RATE",config.FRAME_RATE) 229 | signal = signal.transpose() 230 | signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best') 231 | signal = signal.transpose() 232 | 233 | if signal.shape[0] > config.MAX_LEN: 234 | signal = signal[:config.MAX_LEN,:] 235 | 236 | if signal.shape[0] > mix_len: 237 | mix_len = signal.shape[0] 238 | 239 | signal -= np.mean(signal) 240 | signal /= np.max(np.abs(signal)) 241 | 242 | if signal.shape[0] < config.MAX_LEN: 243 | signal = np.r_[signal, np.zeros((config.MAX_LEN - signal.shape[0],signal.shape[1]))] 244 | 245 | if k == 0: 246 | ratio = 10 ** (aim_spk_db_k[k] / 20.0) 247 | signal = ratio * signal 248 | aim_spkname.append(aim_spk_k[0]) 249 | aim_spk_speech = signal 250 | aim_spkid.append(aim_spkname) 251 | wav_mix = signal 252 | signal_c0 = signal[:,0] 253 | a,b,frq = scipy.signal.stft(signal_c0, fs=8000, nfft=config.FRAME_LENGTH, noverlap=config.FRAME_SHIFT) 254 | aim_fea_clean = np.transpose(np.abs(frq)) 255 | aim_fea.append(aim_fea_clean) 256 | multi_fea_dict_this_sample[spk] = aim_fea_clean 257 | multi_wav_dict_this_sample[spk] = signal[:,0] 258 | 259 | else: 260 | ratio = 10 ** (aim_spk_db_k[k] / 20.0) 261 | signal = ratio * signal 262 | wav_mix = wav_mix + signal 263 | a,b,frq = scipy.signal.stft(signal[:,0], fs=8000, nfft=config.FRAME_LENGTH, noverlap=config.FRAME_SHIFT) 264 | some_fea_clean = np.transpose(np.abs(frq)) 265 | multi_fea_dict_this_sample[spk] = some_fea_clean 266 | multi_wav_dict_this_sample[spk] = signal[:,0] 267 | 268 | multi_spk_fea_list.append(multi_fea_dict_this_sample) 269 | multi_spk_wav_list.append(multi_wav_dict_this_sample) 270 | 271 | mix_speechs.append(wav_mix) 272 | direction.append(direction_sample) 273 | batch_idx += 1 274 | 275 | if batch_idx == config.batch_size: 276 | mix_k = random.sample(mix_number_list, 1)[0] 277 | aim_fea = np.array(aim_fea) 278 | query = np.array(query) 279 | print('spk_list_from_this_gen:{}'.format(aim_spkname)) 280 | print('aim spk list:', [one.keys() for one in multi_spk_fea_list]) 281 | batch_ordre=get_energy_order(multi_spk_wav_list) 282 | direction = get_spk_order(direction, batch_ordre) 283 | if mode == 'global': 284 | all_spk = sorted(all_spk) 285 | all_spk = sorted(all_spk_train) 286 | all_spk.insert(0, '') # 添加两个结构符号,来标识开始或结束。 287 | all_spk.append('') 288 | all_dir = sorted(all_dir) 289 | all_dir.insert(0, '') 290 | all_dir.append('') 291 | all_spk_test = sorted(all_spk_test) 292 | dict_spk_to_idx = {spk: idx for idx, spk in enumerate(all_spk)} 293 | dict_idx_to_spk = {idx: spk for idx, spk in enumerate(all_spk)} 294 | dict_dir_to_idx = {dire: idx for idx, dire in enumerate(all_dir)} 295 | dict_idx_to_dir = {idx: dire for idx, dire in enumerate(all_dir)} 296 | yield {'all_spk': all_spk, 297 | 'dict_spk_to_idx': dict_spk_to_idx, 298 | 'dict_idx_to_spk': dict_idx_to_spk, 299 | 'all_dir': all_dir, 300 | 'dict_dir_to_idx': dict_dir_to_idx, 301 | 'dict_idx_to_dir': dict_idx_to_dir, 302 | 'num_fre': aim_fea.shape[2], 303 | 'num_frames': aim_fea.shape[1], 304 | 'total_spk_num': len(all_spk), 305 | 'total_batch_num': batch_total 306 | } 307 | elif mode == 'once': 308 | yield {'mix_wav': mix_speechs, 309 | 'aim_fea': aim_fea, 310 | 'aim_spkname': aim_spkname, 311 | 'direction': direction, 312 | 'query': query, 313 | 'num_all_spk': len(all_spk), 314 | 'multi_spk_fea_list': multi_spk_fea_list, 315 | 'multi_spk_wav_list': multi_spk_wav_list, 316 | 'batch_order': batch_ordre, 317 | 'batch_total': batch_total, 318 | 'tas_zip': _collate_fn(mix_speechs,multi_spk_wav_list,batch_ordre) 319 | } 320 | elif mode == 'tasnet': 321 | yield _collate_fn(mix_speechs,multi_spk_wav_list) 322 | 323 | batch_idx = 0 324 | mix_speechs = [] 325 | aim_fea = [] 326 | aim_spkid = [] 327 | aim_spkname = [] 328 | query = [] 329 | multi_spk_fea_list = [] 330 | multi_spk_wav_list = [] 331 | direction = [] 332 | sample_idx[mix_k] += 1 333 | 334 | else: 335 | raise ValueError('No such dataset:{} for Speech.'.format(config.DATASET)) 336 | pass 337 | 338 | else: 339 | raise ValueError('No such Model:{}'.format(config.MODE)) 340 | 341 | if __name__ == '__main__': 342 | train_len=[] 343 | train_data_gen = prepare_data('once', 'train') 344 | while True: 345 | train_data_gen.next() 346 | pass 347 | print(np.array(train_len).mean()) 348 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | #TRAIN 4 | python train_WSJ0_SDNet.py 5 | 6 | #TEST 7 | python -u test_WSJ0_SDNet.py 8 | 9 | -------------------------------------------------------------------------------- /separation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Source separation algorithms attempt to extract recordings of individual 4 | sources from a recording of a mixture of sources. Evaluation methods for 5 | source separation compare the extracted sources from reference sources and 6 | attempt to measure the perceptual quality of the separation. 7 | 8 | See also the bss_eval MATLAB toolbox: 9 | http://bass-db.gforge.inria.fr/bss_eval/ 10 | 11 | Conventions 12 | ----------- 13 | 14 | An audio signal is expected to be in the format of a 1-dimensional array where 15 | the entries are the samples of the audio signal. When providing a group of 16 | estimated or reference sources, they should be provided in a 2-dimensional 17 | array, where the first dimension corresponds to the source number and the 18 | second corresponds to the samples. 19 | 20 | Metrics 21 | ------- 22 | 23 | * :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources 24 | metrics from bss_eval, which optionally optimally match the estimated sources 25 | to the reference sources and measure the distortion and artifacts present in 26 | the estimated sources as well as the interference between them. 27 | 28 | * :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the 29 | bss_eval_sources metrics on a frame-by-frame basis. 30 | 31 | * :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images 32 | metrics from bss_eval, which includes the metrics in 33 | :func:`mir_eval.separation.bss_eval_sources` plus the image to spatial 34 | distortion ratio. 35 | 36 | * :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the 37 | bss_eval_images metrics on a frame-by-frame basis. 38 | 39 | References 40 | ---------- 41 | .. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric 42 | Févotte, "Performance measurement in blind audio source separation," IEEE 43 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006. 44 | 45 | 46 | ''' 47 | 48 | import numpy as np 49 | import scipy.fftpack 50 | from scipy.linalg import toeplitz 51 | from scipy.signal import fftconvolve 52 | import collections 53 | import itertools 54 | import warnings 55 | 56 | # The maximum allowable number of sources (prevents insane computational load) 57 | MAX_SOURCES = 100 58 | 59 | 60 | def validate(reference_sources, estimated_sources): 61 | """Checks that the input data to a metric are valid, and throws helpful 62 | errors if not. 63 | 64 | Parameters 65 | ---------- 66 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 67 | matrix containing true sources 68 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 69 | matrix containing estimated sources 70 | 71 | """ 72 | 73 | if reference_sources.shape != estimated_sources.shape: 74 | raise ValueError('The shape of estimated sources and the true ' 75 | 'sources should match. reference_sources.shape ' 76 | '= {}, estimated_sources.shape ' 77 | '= {}'.format(reference_sources.shape, 78 | estimated_sources.shape)) 79 | 80 | if reference_sources.ndim > 3 or estimated_sources.ndim > 3: 81 | raise ValueError('The number of dimensions is too high (must be less ' 82 | 'than 3). reference_sources.ndim = {}, ' 83 | 'estimated_sources.ndim ' 84 | '= {}'.format(reference_sources.ndim, 85 | estimated_sources.ndim)) 86 | 87 | if reference_sources.size == 0: 88 | warnings.warn("reference_sources is empty, should be of size " 89 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 90 | "be empty np.ndarrays") 91 | elif _any_source_silent(reference_sources): 92 | raise ValueError('All the reference sources should be non-silent (not ' 93 | 'all-zeros), but at least one of the reference ' 94 | 'sources is all 0s, which introduces ambiguity to the' 95 | ' evaluation. (Otherwise we can add infinitely many ' 96 | 'all-zero sources.)') 97 | 98 | if estimated_sources.size == 0: 99 | warnings.warn("estimated_sources is empty, should be of size " 100 | "(nsrc, nsample). sdr, sir, sar, and perm will all " 101 | "be empty np.ndarrays") 102 | elif _any_source_silent(estimated_sources): 103 | raise ValueError('All the estimated sources should be non-silent (not ' 104 | 'all-zeros), but at least one of the estimated ' 105 | 'sources is all 0s. Since we require each reference ' 106 | 'source to be non-silent, having a silent estimated ' 107 | 'source will result in an underdetermined system.') 108 | 109 | if (estimated_sources.shape[0] > MAX_SOURCES or 110 | reference_sources.shape[0] > MAX_SOURCES): 111 | raise ValueError('The supplied matrices should be of shape (nsrc,' 112 | ' nsampl) but reference_sources.shape[0] = {} and ' 113 | 'estimated_sources.shape[0] = {} which is greater ' 114 | 'than mir_eval.separation.MAX_SOURCES = {}. To ' 115 | 'override this check, set ' 116 | 'mir_eval.separation.MAX_SOURCES to a ' 117 | 'larger value.'.format(reference_sources.shape[0], 118 | estimated_sources.shape[0], 119 | MAX_SOURCES)) 120 | 121 | 122 | def _any_source_silent(sources): 123 | """Returns true if the parameter sources has any silent first dimensions""" 124 | return np.any(np.all(np.sum( 125 | sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1)) 126 | 127 | 128 | def bss_eval_sources(reference_sources, estimated_sources, 129 | compute_permutation=True): 130 | """ 131 | Ordering and measurement of the separation quality for estimated source 132 | signals in terms of filtered true source, interference and artifacts. 133 | 134 | The decomposition allows a time-invariant filter distortion of length 135 | 512, as described in Section III.B of [#vincent2006performance]_. 136 | 137 | Passing ``False`` for ``compute_permutation`` will improve the computation 138 | performance of the evaluation; however, it is not always appropriate and 139 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources. 140 | 141 | Examples 142 | -------- 143 | >>> # reference_sources[n] should be an ndarray of samples of the 144 | >>> # n'th reference source 145 | >>> # estimated_sources[n] should be the same for the n'th estimated 146 | >>> # source 147 | >>> (sdr, sir, sar, 148 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources, 149 | ... estimated_sources) 150 | 151 | Parameters 152 | ---------- 153 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 154 | matrix containing true sources (must have same shape as 155 | estimated_sources) 156 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 157 | matrix containing estimated sources (must have same shape as 158 | reference_sources) 159 | compute_permutation : bool, optional 160 | compute permutation of estimate/source combinations (True by default) 161 | 162 | Returns 163 | ------- 164 | sdr : np.ndarray, shape=(nsrc,) 165 | vector of Signal to Distortion Ratios (SDR) 166 | sir : np.ndarray, shape=(nsrc,) 167 | vector of Source to Interference Ratios (SIR) 168 | sar : np.ndarray, shape=(nsrc,) 169 | vector of Sources to Artifacts Ratios (SAR) 170 | perm : np.ndarray, shape=(nsrc,) 171 | vector containing the best ordering of estimated sources in 172 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 173 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ..., 174 | nsrc-1]`` if ``compute_permutation`` is ``False``. 175 | 176 | References 177 | ---------- 178 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 179 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 180 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 181 | (2007-2010): Achievements and remaining challenges", Signal Processing, 182 | 92, pp. 1928-1936, 2012. 183 | 184 | """ 185 | 186 | # make sure the input is of shape (nsrc, nsampl) 187 | if estimated_sources.ndim == 1: 188 | estimated_sources = estimated_sources[np.newaxis, :] 189 | if reference_sources.ndim == 1: 190 | reference_sources = reference_sources[np.newaxis, :] 191 | 192 | validate(reference_sources, estimated_sources) 193 | # If empty matrices were supplied, return empty lists (special case) 194 | if reference_sources.size == 0 or estimated_sources.size == 0: 195 | return np.array([]), np.array([]), np.array([]), np.array([]) 196 | 197 | nsrc = estimated_sources.shape[0] 198 | 199 | # does user desire permutations? 200 | if compute_permutation: 201 | # compute criteria for all possible pair matches 202 | sdr = np.empty((nsrc, nsrc)) 203 | sir = np.empty((nsrc, nsrc)) 204 | sar = np.empty((nsrc, nsrc)) 205 | for jest in range(nsrc): 206 | for jtrue in range(nsrc): 207 | s_true, e_spat, e_interf, e_artif = \ 208 | _bss_decomp_mtifilt(reference_sources, 209 | estimated_sources[jest], 210 | jtrue, 512) 211 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \ 212 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 213 | 214 | # select the best ordering 215 | perms = list(itertools.permutations(list(range(nsrc)))) 216 | mean_sir = np.empty(len(perms)) 217 | dum = np.arange(nsrc) 218 | for (i, perm) in enumerate(perms): 219 | mean_sir[i] = np.mean(sir[perm, dum]) 220 | popt = perms[np.argmax(mean_sir)] 221 | idx = (popt, dum) 222 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt)) 223 | else: 224 | # compute criteria for only the simple correspondence 225 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 226 | sdr = np.empty(nsrc) 227 | sir = np.empty(nsrc) 228 | sar = np.empty(nsrc) 229 | for j in range(nsrc): 230 | s_true, e_spat, e_interf, e_artif = \ 231 | _bss_decomp_mtifilt(reference_sources, 232 | estimated_sources[j], 233 | j, 512) 234 | sdr[j], sir[j], sar[j] = \ 235 | _bss_source_crit(s_true, e_spat, e_interf, e_artif) 236 | 237 | # return the default permutation for compatibility 238 | popt = np.arange(nsrc) 239 | return (sdr, sir, sar, popt) 240 | 241 | 242 | def bss_eval_sources_framewise(reference_sources, estimated_sources, 243 | window=30 * 44100, hop=15 * 44100, 244 | compute_permutation=False): 245 | """Framewise computation of bss_eval_sources 246 | 247 | Please be aware that this function does not compute permutations (by 248 | default) on the possible relations between reference_sources and 249 | estimated_sources due to the dangers of a changing permutation. Therefore 250 | (by default), it assumes that ``reference_sources[i]`` corresponds to 251 | ``estimated_sources[i]``. To enable computing permutations please set 252 | ``compute_permutation`` to be ``True`` and check that the returned ``perm`` 253 | is identical for all windows. 254 | 255 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 256 | using only a single window or are shorter than the window length, the 257 | result of :func:`mir_eval.separation.bss_eval_sources` called on 258 | ``reference_sources`` and ``estimated_sources`` (with the 259 | ``compute_permutation`` parameter passed to 260 | :func:`mir_eval.separation.bss_eval_sources`) is returned. 261 | 262 | Examples 263 | -------- 264 | >>> # reference_sources[n] should be an ndarray of samples of the 265 | >>> # n'th reference source 266 | >>> # estimated_sources[n] should be the same for the n'th estimated 267 | >>> # source 268 | >>> (sdr, sir, sar, 269 | ... perm) = mir_eval.separation.bss_eval_sources_framewise( 270 | reference_sources, 271 | ... estimated_sources) 272 | 273 | Parameters 274 | ---------- 275 | reference_sources : np.ndarray, shape=(nsrc, nsampl) 276 | matrix containing true sources (must have the same shape as 277 | ``estimated_sources``) 278 | estimated_sources : np.ndarray, shape=(nsrc, nsampl) 279 | matrix containing estimated sources (must have the same shape as 280 | ``reference_sources``) 281 | window : int, optional 282 | Window length for framewise evaluation (default value is 30s at a 283 | sample rate of 44.1kHz) 284 | hop : int, optional 285 | Hop size for framewise evaluation (default value is 15s at a 286 | sample rate of 44.1kHz) 287 | compute_permutation : bool, optional 288 | compute permutation of estimate/source combinations for all windows 289 | (False by default) 290 | 291 | Returns 292 | ------- 293 | sdr : np.ndarray, shape=(nsrc, nframes) 294 | vector of Signal to Distortion Ratios (SDR) 295 | sir : np.ndarray, shape=(nsrc, nframes) 296 | vector of Source to Interference Ratios (SIR) 297 | sar : np.ndarray, shape=(nsrc, nframes) 298 | vector of Sources to Artifacts Ratios (SAR) 299 | perm : np.ndarray, shape=(nsrc, nframes) 300 | vector containing the best ordering of estimated sources in 301 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 302 | true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for 303 | all windows if ``compute_permutation`` is ``False`` 304 | 305 | """ 306 | 307 | # make sure the input is of shape (nsrc, nsampl) 308 | if estimated_sources.ndim == 1: 309 | estimated_sources = estimated_sources[np.newaxis, :] 310 | if reference_sources.ndim == 1: 311 | reference_sources = reference_sources[np.newaxis, :] 312 | 313 | validate(reference_sources, estimated_sources) 314 | # If empty matrices were supplied, return empty lists (special case) 315 | if reference_sources.size == 0 or estimated_sources.size == 0: 316 | return np.array([]), np.array([]), np.array([]), np.array([]) 317 | 318 | nsrc = reference_sources.shape[0] 319 | 320 | nwin = int( 321 | np.floor((reference_sources.shape[1] - window + hop) / hop) 322 | ) 323 | # if fewer than 2 windows would be evaluated, return the sources result 324 | if nwin < 2: 325 | result = bss_eval_sources(reference_sources, 326 | estimated_sources, 327 | compute_permutation) 328 | return [np.expand_dims(score, -1) for score in result] 329 | 330 | # compute the criteria across all windows 331 | sdr = np.empty((nsrc, nwin)) 332 | sir = np.empty((nsrc, nwin)) 333 | sar = np.empty((nsrc, nwin)) 334 | perm = np.empty((nsrc, nwin)) 335 | 336 | # k iterates across all the windows 337 | for k in range(nwin): 338 | win_slice = slice(k * hop, k * hop + window) 339 | ref_slice = reference_sources[:, win_slice] 340 | est_slice = estimated_sources[:, win_slice] 341 | # check for a silent frame 342 | if (not _any_source_silent(ref_slice) and 343 | not _any_source_silent(est_slice)): 344 | sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources( 345 | ref_slice, est_slice, compute_permutation 346 | ) 347 | else: 348 | # if we have a silent frame set results as np.nan 349 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 350 | 351 | return sdr, sir, sar, perm 352 | 353 | 354 | def bss_eval_images(reference_sources, estimated_sources, 355 | compute_permutation=True): 356 | """Implementation of the bss_eval_images function from the 357 | BSS_EVAL Matlab toolbox. 358 | 359 | Ordering and measurement of the separation quality for estimated source 360 | signals in terms of filtered true source, interference and artifacts. 361 | This method also provides the ISR measure. 362 | 363 | The decomposition allows a time-invariant filter distortion of length 364 | 512, as described in Section III.B of [#vincent2006performance]_. 365 | 366 | Passing ``False`` for ``compute_permutation`` will improve the computation 367 | performance of the evaluation; however, it is not always appropriate and 368 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images. 369 | 370 | Examples 371 | -------- 372 | >>> # reference_sources[n] should be an ndarray of samples of the 373 | >>> # n'th reference source 374 | >>> # estimated_sources[n] should be the same for the n'th estimated 375 | >>> # source 376 | >>> (sdr, isr, sir, sar, 377 | ... perm) = mir_eval.separation.bss_eval_images(reference_sources, 378 | ... estimated_sources) 379 | 380 | Parameters 381 | ---------- 382 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 383 | matrix containing true sources 384 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 385 | matrix containing estimated sources 386 | compute_permutation : bool, optional 387 | compute permutation of estimate/source combinations (True by default) 388 | 389 | Returns 390 | ------- 391 | sdr : np.ndarray, shape=(nsrc,) 392 | vector of Signal to Distortion Ratios (SDR) 393 | isr : np.ndarray, shape=(nsrc,) 394 | vector of source Image to Spatial distortion Ratios (ISR) 395 | sir : np.ndarray, shape=(nsrc,) 396 | vector of Source to Interference Ratios (SIR) 397 | sar : np.ndarray, shape=(nsrc,) 398 | vector of Sources to Artifacts Ratios (SAR) 399 | perm : np.ndarray, shape=(nsrc,) 400 | vector containing the best ordering of estimated sources in 401 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to 402 | true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)`` 403 | if ``compute_permutation`` is ``False``. 404 | 405 | References 406 | ---------- 407 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau 408 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik 409 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign 410 | (2007-2010): Achievements and remaining challenges", Signal Processing, 411 | 92, pp. 1928-1936, 2012. 412 | 413 | """ 414 | 415 | # make sure the input has 3 dimensions 416 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 417 | estimated_sources = np.atleast_3d(estimated_sources) 418 | reference_sources = np.atleast_3d(reference_sources) 419 | # we will ensure input doesn't have more than 3 dimensions in validate 420 | 421 | validate(reference_sources, estimated_sources) 422 | # If empty matrices were supplied, return empty lists (special case) 423 | if reference_sources.size == 0 or estimated_sources.size == 0: 424 | return np.array([]), np.array([]), np.array([]), \ 425 | np.array([]), np.array([]) 426 | 427 | # determine size parameters 428 | nsrc = estimated_sources.shape[0] 429 | nsampl = estimated_sources.shape[1] 430 | nchan = estimated_sources.shape[2] 431 | 432 | # does the user desire permutation? 433 | if compute_permutation: 434 | # compute criteria for all possible pair matches 435 | sdr = np.empty((nsrc, nsrc)) 436 | isr = np.empty((nsrc, nsrc)) 437 | sir = np.empty((nsrc, nsrc)) 438 | sar = np.empty((nsrc, nsrc)) 439 | for jest in range(nsrc): 440 | for jtrue in range(nsrc): 441 | s_true, e_spat, e_interf, e_artif = \ 442 | _bss_decomp_mtifilt_images( 443 | reference_sources, 444 | np.reshape( 445 | estimated_sources[jest], 446 | (nsampl, nchan), 447 | order='F' 448 | ), 449 | jtrue, 450 | 512 451 | ) 452 | sdr[jest, jtrue], isr[jest, jtrue], \ 453 | sir[jest, jtrue], sar[jest, jtrue] = \ 454 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 455 | 456 | # select the best ordering 457 | perms = list(itertools.permutations(range(nsrc))) 458 | mean_sir = np.empty(len(perms)) 459 | dum = np.arange(nsrc) 460 | for (i, perm) in enumerate(perms): 461 | mean_sir[i] = np.mean(sir[perm, dum]) 462 | popt = perms[np.argmax(mean_sir)] 463 | idx = (popt, dum) 464 | return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt)) 465 | else: 466 | # compute criteria for only the simple correspondence 467 | # (estimate 1 is estimate corresponding to reference source 1, etc.) 468 | sdr = np.empty(nsrc) 469 | isr = np.empty(nsrc) 470 | sir = np.empty(nsrc) 471 | sar = np.empty(nsrc) 472 | Gj = [0] * nsrc # prepare G matrics with zeroes 473 | G = np.zeros(1) 474 | for j in range(nsrc): 475 | # save G matrix to avoid recomputing it every call 476 | s_true, e_spat, e_interf, e_artif, Gj_temp, G = \ 477 | _bss_decomp_mtifilt_images(reference_sources, 478 | np.reshape(estimated_sources[j], 479 | (nsampl, nchan), 480 | order='F'), 481 | j, 512, Gj[j], G) 482 | Gj[j] = Gj_temp 483 | sdr[j], isr[j], sir[j], sar[j] = \ 484 | _bss_image_crit(s_true, e_spat, e_interf, e_artif) 485 | 486 | # return the default permutation for compatibility 487 | popt = np.arange(nsrc) 488 | return (sdr, isr, sir, sar, popt) 489 | 490 | 491 | def bss_eval_images_framewise(reference_sources, estimated_sources, 492 | window=30 * 44100, hop=15 * 44100, 493 | compute_permutation=False): 494 | """Framewise computation of bss_eval_images 495 | 496 | Please be aware that this function does not compute permutations (by 497 | default) on the possible relations between ``reference_sources`` and 498 | ``estimated_sources`` due to the dangers of a changing permutation. 499 | Therefore (by default), it assumes that ``reference_sources[i]`` 500 | corresponds to ``estimated_sources[i]``. To enable computing permutations 501 | please set ``compute_permutation`` to be ``True`` and check that the 502 | returned ``perm`` is identical for all windows. 503 | 504 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated 505 | using only a single window or are shorter than the window length, the 506 | result of ``bss_eval_sources`` called on ``reference_sources`` and 507 | ``estimated_sources`` (with the ``compute_permutation`` parameter passed to 508 | ``bss_eval_images``) is returned 509 | 510 | Examples 511 | -------- 512 | >>> # reference_sources[n] should be an ndarray of samples of the 513 | >>> # n'th reference source 514 | >>> # estimated_sources[n] should be the same for the n'th estimated 515 | >>> # source 516 | >>> (sdr, isr, sir, sar, 517 | ... perm) = mir_eval.separation.bss_eval_images_framewise( 518 | reference_sources, 519 | ... estimated_sources, 520 | window, 521 | .... hop) 522 | 523 | Parameters 524 | ---------- 525 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 526 | matrix containing true sources (must have the same shape as 527 | ``estimated_sources``) 528 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan) 529 | matrix containing estimated sources (must have the same shape as 530 | ``reference_sources``) 531 | window : int 532 | Window length for framewise evaluation 533 | hop : int 534 | Hop size for framewise evaluation 535 | compute_permutation : bool, optional 536 | compute permutation of estimate/source combinations for all windows 537 | (False by default) 538 | 539 | Returns 540 | ------- 541 | sdr : np.ndarray, shape=(nsrc, nframes) 542 | vector of Signal to Distortion Ratios (SDR) 543 | isr : np.ndarray, shape=(nsrc, nframes) 544 | vector of source Image to Spatial distortion Ratios (ISR) 545 | sir : np.ndarray, shape=(nsrc, nframes) 546 | vector of Source to Interference Ratios (SIR) 547 | sar : np.ndarray, shape=(nsrc, nframes) 548 | vector of Sources to Artifacts Ratios (SAR) 549 | perm : np.ndarray, shape=(nsrc, nframes) 550 | vector containing the best ordering of estimated sources in 551 | the mean SIR sense (estimated source number perm[j] corresponds to 552 | true source number j) 553 | Note: perm will be range(nsrc) for all windows if compute_permutation 554 | is False 555 | 556 | """ 557 | 558 | # make sure the input has 3 dimensions 559 | # assuming input is in shape (nsampl) or (nsrc, nsampl) 560 | estimated_sources = np.atleast_3d(estimated_sources) 561 | reference_sources = np.atleast_3d(reference_sources) 562 | # we will ensure input doesn't have more than 3 dimensions in validate 563 | 564 | validate(reference_sources, estimated_sources) 565 | # If empty matrices were supplied, return empty lists (special case) 566 | if reference_sources.size == 0 or estimated_sources.size == 0: 567 | return np.array([]), np.array([]), np.array([]), np.array([]) 568 | 569 | nsrc = reference_sources.shape[0] 570 | 571 | nwin = int( 572 | np.floor((reference_sources.shape[1] - window + hop) / hop) 573 | ) 574 | # if fewer than 2 windows would be evaluated, return the images result 575 | if nwin < 2: 576 | result = bss_eval_images(reference_sources, 577 | estimated_sources, 578 | compute_permutation) 579 | return [np.expand_dims(score, -1) for score in result] 580 | 581 | # compute the criteria across all windows 582 | sdr = np.empty((nsrc, nwin)) 583 | isr = np.empty((nsrc, nwin)) 584 | sir = np.empty((nsrc, nwin)) 585 | sar = np.empty((nsrc, nwin)) 586 | perm = np.empty((nsrc, nwin)) 587 | 588 | # k iterates across all the windows 589 | for k in range(nwin): 590 | win_slice = slice(k * hop, k * hop + window) 591 | ref_slice = reference_sources[:, win_slice, :] 592 | est_slice = estimated_sources[:, win_slice, :] 593 | # check for a silent frame 594 | if (not _any_source_silent(ref_slice) and 595 | not _any_source_silent(est_slice)): 596 | sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \ 597 | bss_eval_images( 598 | ref_slice, est_slice, compute_permutation 599 | ) 600 | else: 601 | # if we have a silent frame set results as np.nan 602 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan 603 | 604 | return sdr, isr, sir, sar, perm 605 | 606 | 607 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen): 608 | """Decomposition of an estimated source image into four components 609 | representing respectively the true source image, spatial (or filtering) 610 | distortion, interference and artifacts, derived from the true source 611 | images using multichannel time-invariant filters. 612 | """ 613 | nsampl = estimated_source.size 614 | # decomposition 615 | # true source image 616 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1))) 617 | # spatial (or filtering) distortion 618 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source, 619 | flen) - s_true 620 | # interference 621 | e_interf = _project(reference_sources, 622 | estimated_source, flen) - s_true - e_spat 623 | # artifacts 624 | e_artif = -s_true - e_spat - e_interf 625 | e_artif[:nsampl] += estimated_source 626 | return (s_true, e_spat, e_interf, e_artif) 627 | 628 | 629 | def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen, 630 | Gj=None, G=None): 631 | """Decomposition of an estimated source image into four components 632 | representing respectively the true source image, spatial (or filtering) 633 | distortion, interference and artifacts, derived from the true source 634 | images using multichannel time-invariant filters. 635 | Adapted version to work with multichannel sources. 636 | Improved performance can be gained by passing Gj and G parameters initially 637 | as all zeros. These parameters store the results from the computation of 638 | the G matrix in _project_images and then return them for subsequent calls 639 | to this function. This only works when not computing permuations. 640 | """ 641 | nsampl = np.shape(estimated_source)[0] 642 | nchan = np.shape(estimated_source)[1] 643 | # are we saving the Gj and G parameters? 644 | saveg = Gj is not None and G is not None 645 | # decomposition 646 | # true source image 647 | s_true = np.hstack((np.reshape(reference_sources[j], 648 | (nsampl, nchan), 649 | order="F").transpose(), 650 | np.zeros((nchan, flen - 1)))) 651 | # spatial (or filtering) distortion 652 | if saveg: 653 | e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :], 654 | estimated_source, flen, Gj) 655 | else: 656 | e_spat = _project_images(reference_sources[j, np.newaxis, :], 657 | estimated_source, flen) 658 | e_spat = e_spat - s_true 659 | # interference 660 | if saveg: 661 | e_interf, G = _project_images(reference_sources, 662 | estimated_source, flen, G) 663 | else: 664 | e_interf = _project_images(reference_sources, 665 | estimated_source, flen) 666 | e_interf = e_interf - s_true - e_spat 667 | # artifacts 668 | e_artif = -s_true - e_spat - e_interf 669 | e_artif[:, :nsampl] += estimated_source.transpose() 670 | # return Gj and G only if they were passed in 671 | if saveg: 672 | return (s_true, e_spat, e_interf, e_artif, Gj, G) 673 | else: 674 | return (s_true, e_spat, e_interf, e_artif) 675 | 676 | 677 | def _project(reference_sources, estimated_source, flen): 678 | """Least-squares projection of estimated source on the subspace spanned by 679 | delayed versions of reference sources, with delays between 0 and flen-1 680 | """ 681 | nsrc = reference_sources.shape[0] 682 | nsampl = reference_sources.shape[1] 683 | 684 | # computing coefficients of least squares problem via FFT ## 685 | # zero padding and FFT of input data 686 | reference_sources = np.hstack((reference_sources, 687 | np.zeros((nsrc, flen - 1)))) 688 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1))) 689 | n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.))) 690 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 691 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 692 | # inner products between delayed versions of reference_sources 693 | G = np.zeros((nsrc * flen, nsrc * flen)) 694 | for i in range(nsrc): 695 | for j in range(nsrc): 696 | ssf = sf[i] * np.conj(sf[j]) 697 | ssf = np.real(scipy.fftpack.ifft(ssf)) 698 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 699 | r=ssf[:flen]) 700 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss 701 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T 702 | # inner products between estimated_source and delayed versions of 703 | # reference_sources 704 | D = np.zeros(nsrc * flen) 705 | for i in range(nsrc): 706 | ssef = sf[i] * np.conj(sef) 707 | ssef = np.real(scipy.fftpack.ifft(ssef)) 708 | D[i * flen: (i + 1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1])) 709 | 710 | # Computing projection 711 | # Distortion filters 712 | try: 713 | C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F') 714 | except np.linalg.linalg.LinAlgError: 715 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F') 716 | # Filtering 717 | sproj = np.zeros(nsampl + flen - 1) 718 | for i in range(nsrc): 719 | sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1] 720 | return sproj 721 | 722 | 723 | def _project_images(reference_sources, estimated_source, flen, G=None): 724 | """Least-squares projection of estimated source on the subspace spanned by 725 | delayed versions of reference sources, with delays between 0 and flen-1. 726 | Passing G as all zeros will populate the G matrix and return it so it can 727 | be passed into the next call to avoid recomputing G (this will only works 728 | if not computing permutations). 729 | """ 730 | nsrc = reference_sources.shape[0] 731 | nsampl = reference_sources.shape[1] 732 | nchan = reference_sources.shape[2] 733 | reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)), 734 | (nchan * nsrc, nsampl), order='F') 735 | 736 | # computing coefficients of least squares problem via FFT ## 737 | # zero padding and FFT of input data 738 | reference_sources = np.hstack((reference_sources, 739 | np.zeros((nchan * nsrc, flen - 1)))) 740 | estimated_source = \ 741 | np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1)))) 742 | n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.))) 743 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1) 744 | sef = scipy.fftpack.fft(estimated_source, n=n_fft) 745 | 746 | # inner products between delayed versions of reference_sources 747 | if G is None: 748 | saveg = False 749 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 750 | for i in range(nchan * nsrc): 751 | for j in range(i + 1): 752 | ssf = sf[i] * np.conj(sf[j]) 753 | ssf = np.real(scipy.fftpack.ifft(ssf)) 754 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 755 | r=ssf[:flen]) 756 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss 757 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T 758 | else: # avoid recomputing G (only works if no permutation is desired) 759 | saveg = True # return G 760 | if np.all(G == 0): # only compute G if passed as 0 761 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen)) 762 | for i in range(nchan * nsrc): 763 | for j in range(i + 1): 764 | ssf = sf[i] * np.conj(sf[j]) 765 | ssf = np.real(scipy.fftpack.ifft(ssf)) 766 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])), 767 | r=ssf[:flen]) 768 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss 769 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T 770 | 771 | # inner products between estimated_source and delayed versions of 772 | # reference_sources 773 | D = np.zeros((nchan * nsrc * flen, nchan)) 774 | for k in range(nchan * nsrc): 775 | for i in range(nchan): 776 | ssef = sf[k] * np.conj(sef[i]) 777 | ssef = np.real(scipy.fftpack.ifft(ssef)) 778 | D[k * flen: (k + 1) * flen, i] = \ 779 | np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose() 780 | 781 | # Computing projection 782 | # Distortion filters 783 | try: 784 | C = np.linalg.solve(G, D).reshape(flen, nchan * nsrc, nchan, order='F') 785 | except np.linalg.linalg.LinAlgError: 786 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan * nsrc, nchan, 787 | order='F') 788 | # Filtering 789 | sproj = np.zeros((nchan, nsampl + flen - 1)) 790 | for k in range(nchan * nsrc): 791 | for i in range(nchan): 792 | sproj[i] += fftconvolve(C[:, k, i].transpose(), 793 | reference_sources[k])[:nsampl + flen - 1] 794 | # return G only if it was passed in 795 | if saveg: 796 | return sproj, G 797 | else: 798 | return sproj 799 | 800 | 801 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif): 802 | """Measurement of the separation quality for a given source in terms of 803 | filtered true source, interference and artifacts. 804 | """ 805 | # energy ratios 806 | s_filt = s_true + e_spat 807 | sdr = _safe_db(np.sum(s_filt ** 2), np.sum((e_interf + e_artif) ** 2)) 808 | sir = _safe_db(np.sum(s_filt ** 2), np.sum(e_interf ** 2)) 809 | sar = _safe_db(np.sum((s_filt + e_interf) ** 2), np.sum(e_artif ** 2)) 810 | return (sdr, sir, sar) 811 | 812 | 813 | def _bss_image_crit(s_true, e_spat, e_interf, e_artif): 814 | """Measurement of the separation quality for a given image in terms of 815 | filtered true source, spatial error, interference and artifacts. 816 | """ 817 | # energy ratios 818 | sdr = _safe_db(np.sum(s_true ** 2), np.sum((e_spat + e_interf + e_artif) ** 2)) 819 | isr = _safe_db(np.sum(s_true ** 2), np.sum(e_spat ** 2)) 820 | sir = _safe_db(np.sum((s_true + e_spat) ** 2), np.sum(e_interf ** 2)) 821 | sar = _safe_db(np.sum((s_true + e_spat + e_interf) ** 2), np.sum(e_artif ** 2)) 822 | return (sdr, isr, sir, sar) 823 | 824 | 825 | def _safe_db(num, den): 826 | """Properly handle the potential +Inf db SIR, instead of raising a 827 | RuntimeWarning. Only denominator is checked because the numerator can never 828 | be 0. 829 | """ 830 | if den == 0: 831 | return np.Inf 832 | return 10 * np.log10(num / den) 833 | 834 | 835 | def evaluate(reference_sources, estimated_sources, **kwargs): 836 | """Compute all metrics for the given reference and estimated signals. 837 | 838 | NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images` 839 | for any valid input and will additionally compute 840 | :func:`mir_eval.separation.bss_eval_sources` for valid input with fewer 841 | than 3 dimensions. 842 | 843 | Examples 844 | -------- 845 | >>> # reference_sources[n] should be an ndarray of samples of the 846 | >>> # n'th reference source 847 | >>> # estimated_sources[n] should be the same for the n'th estimated source 848 | >>> scores = mir_eval.separation.evaluate(reference_sources, 849 | ... estimated_sources) 850 | 851 | Parameters 852 | ---------- 853 | reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 854 | matrix containing true sources 855 | estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan]) 856 | matrix containing estimated sources 857 | kwargs 858 | Additional keyword arguments which will be passed to the 859 | appropriate metric or preprocessing functions. 860 | 861 | Returns 862 | ------- 863 | scores : dict 864 | Dictionary of scores, where the key is the metric name (str) and 865 | the value is the (float) score achieved. 866 | 867 | """ 868 | # Compute all the metrics 869 | scores = collections.OrderedDict() 870 | 871 | sdr, isr, sir, sar, perm = util.filter_kwargs( 872 | bss_eval_images, 873 | reference_sources, 874 | estimated_sources, 875 | **kwargs 876 | ) 877 | scores['Images - Source to Distortion'] = sdr.tolist() 878 | scores['Images - Image to Spatial'] = isr.tolist() 879 | scores['Images - Source to Interference'] = sir.tolist() 880 | scores['Images - Source to Artifact'] = sar.tolist() 881 | scores['Images - Source permutation'] = perm.tolist() 882 | 883 | sdr, isr, sir, sar, perm = util.filter_kwargs( 884 | bss_eval_images_framewise, 885 | reference_sources, 886 | estimated_sources, 887 | **kwargs 888 | ) 889 | scores['Images Frames - Source to Distortion'] = sdr.tolist() 890 | scores['Images Frames - Image to Spatial'] = isr.tolist() 891 | scores['Images Frames - Source to Interference'] = sir.tolist() 892 | scores['Images Frames - Source to Artifact'] = sar.tolist() 893 | scores['Images Frames - Source permutation'] = perm.tolist() 894 | 895 | # Verify we can compute sources on this input 896 | if reference_sources.ndim < 3 and estimated_sources.ndim < 3: 897 | sdr, sir, sar, perm = util.filter_kwargs( 898 | bss_eval_sources_framewise, 899 | reference_sources, 900 | estimated_sources, 901 | **kwargs 902 | ) 903 | scores['Sources Frames - Source to Distortion'] = sdr.tolist() 904 | scores['Sources Frames - Source to Interference'] = sir.tolist() 905 | scores['Sources Frames - Source to Artifact'] = sar.tolist() 906 | scores['Sources Frames - Source permutation'] = perm.tolist() 907 | 908 | sdr, sir, sar, perm = util.filter_kwargs( 909 | bss_eval_sources, 910 | reference_sources, 911 | estimated_sources, 912 | **kwargs 913 | ) 914 | scores['Sources - Source to Distortion'] = sdr.tolist() 915 | scores['Sources - Source to Interference'] = sir.tolist() 916 | scores['Sources - Source to Artifact'] = sar.tolist() 917 | scores['Sources - Source permutation'] = perm.tolist() 918 | 919 | return scores 920 | -------------------------------------------------------------------------------- /test_WSJ0_SDNet.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import os 3 | import argparse 4 | import time 5 | import json 6 | import collections 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | import numpy as np 13 | 14 | import models 15 | import data.utils as utils 16 | from optims import Optim 17 | import lr_scheduler as L 18 | from predata_WSJ_lcx import prepare_data 19 | import bss_test 20 | from models.loss import ss_tas_loss 21 | from scipy.io import wavfile 22 | # config 23 | parser = argparse.ArgumentParser(description='train_WSJ_tasnet.py') 24 | 25 | parser.add_argument('-config', default='config_WSJ0_SDNet.yaml', type=str, help="config file") 26 | parser.add_argument('-gpus', default=[3], nargs='+', type=int, help="Use CUDA on the listed devices.") 27 | parser.add_argument('-restore', default='', type=str, help="restore checkpoint") 28 | parser.add_argument('-seed', type=int, default=1234, help="Random seed") 29 | parser.add_argument('-model', default='seq2seq', type=str, help="Model selection") 30 | parser.add_argument('-score', default='', type=str, help="score_fn") 31 | parser.add_argument('-notrain', default=True, type=bool, help="train or not") 32 | parser.add_argument('-log', default='', type=str, help="log directory") 33 | parser.add_argument('-memory', default=False, type=bool, help="memory efficiency") 34 | parser.add_argument('-score_fc', default='', type=str, help="memory efficiency") 35 | 36 | opt = parser.parse_args() 37 | config = utils.read_config(opt.config) 38 | torch.manual_seed(opt.seed) 39 | 40 | # checkpoint 41 | if opt.restore: 42 | print('loading checkpoint...\n', opt.restore) 43 | checkpoints = torch.load(opt.restore,map_location={'cuda:2':'cuda:0'}) 44 | 45 | # cuda 46 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 47 | use_cuda = True 48 | if use_cuda: 49 | torch.cuda.set_device(opt.gpus[0]) 50 | torch.cuda.manual_seed(opt.seed) 51 | print(use_cuda) 52 | 53 | # load the global statistic of the data 54 | print('loading data...\n') 55 | start_time = time.time() 56 | 57 | spk_global_gen = prepare_data(mode='global', train_or_test='train') # 数据中的一些统计参数的读取 58 | global_para = next(spk_global_gen) 59 | print(global_para) 60 | 61 | spk_all_list = global_para['all_spk'] # 所有说话人的列表 62 | dict_spk2idx = global_para['dict_spk_to_idx'] 63 | dict_idx2spk = global_para['dict_idx_to_spk'] 64 | direction_all_list = global_para['all_dir'] 65 | dict_dir2idx = global_para['dict_dir_to_idx'] 66 | dict_idx2dir = global_para['dict_idx_to_dir'] 67 | speech_fre = global_para['num_fre'] # 语音频率总数 68 | total_frames = global_para['num_frames'] # 语音长度 69 | spk_num_total = global_para['total_spk_num'] # 总计说话人数目 70 | batch_total = global_para['total_batch_num'] # 一个epoch里多少个batch 71 | 72 | print(dict_idx2spk) 73 | print(dict_idx2dir) 74 | 75 | config.speech_fre = speech_fre 76 | mix_speech_len = total_frames 77 | config.mix_speech_len = total_frames 78 | num_labels = len(spk_all_list) 79 | num_dir_labels = len(direction_all_list) 80 | 81 | del spk_global_gen 82 | print('loading the global setting cost: %.3f' % (time.time() - start_time)) 83 | print("num_dir_labels", num_dir_labels) 84 | print("num_labels", num_labels) 85 | # model 86 | print('building model...\n') 87 | model = getattr(models, opt.model)(config, 256, mix_speech_len, num_labels, num_dir_labels, use_cuda, None, opt.score_fc) 88 | 89 | if opt.restore: 90 | model.load_state_dict(checkpoints['model']) 91 | if use_cuda: 92 | model.cuda() 93 | if len(opt.gpus) > 1: 94 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) 95 | 96 | # optimizer 97 | if 0 and opt.restore: 98 | optim = checkpoints['optim'] 99 | else: 100 | optim = Optim(config.optim, config.learning_rate, config.max_grad_norm, 101 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at) 102 | 103 | optim.set_parameters(model.parameters()) 104 | 105 | if config.schedule: 106 | # scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch) 107 | scheduler = L.StepLR(optim.optimizer, step_size=20, gamma=0.2) 108 | 109 | # total number of parameters 110 | param_count = 0 111 | for param in model.parameters(): 112 | param_count += param.view(-1).size()[0] 113 | 114 | # logging modeule 115 | if not os.path.exists(config.log): 116 | os.mkdir(config.log) 117 | if opt.log == '': 118 | log_path = config.log + utils.format_time(time.localtime()) + '/' 119 | else: 120 | log_path = config.log + opt.log + '/' 121 | if not os.path.exists(log_path): 122 | os.mkdir(log_path) 123 | print('log_path:',log_path) 124 | 125 | logging = utils.logging(log_path + 'log.txt') # 单独写一个logging的函数,直接调用,既print,又记录到Log文件里。 126 | logging_csv = utils.logging_csv(log_path + 'record.csv') 127 | for k, v in config.items(): 128 | logging("%s:\t%s\n" % (str(k), str(v))) 129 | logging("\n") 130 | logging(repr(model) + "\n\n") 131 | 132 | logging('total number of parameters: %d\n\n' % param_count) 133 | logging('score function is %s\n\n' % opt.score) 134 | 135 | if opt.restore: 136 | updates = checkpoints['updates'] 137 | else: 138 | updates = 0 139 | 140 | total_loss, start_time = 0, time.time() 141 | total_loss_sgm, total_loss_ss = 0, 0 142 | report_total, report_correct = 0, 0 143 | report_vocab, report_tot_vocab = 0, 0 144 | scores = [[] for metric in config.metric] 145 | scores = collections.OrderedDict(zip(config.metric, scores)) 146 | best_SDR = 0.0 147 | e=0 148 | loss_last_epoch = 1000000.0 149 | 150 | def eval(epoch): 151 | # config.batch_size=1 152 | model.eval() 153 | # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1' 154 | reference, candidate, source, alignments = [], [], [], [] 155 | e = epoch 156 | test_or_valid = 'test' 157 | #test_or_valid = 'valid' 158 | print('Test or valid:', test_or_valid) 159 | eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX) 160 | SDR_SUM = np.array([]) 161 | SDRi_SUM = np.array([]) 162 | SISNR_SUM = np.array([]) 163 | SISNRI_SUM = np.array([]) 164 | SS_SUM = np.array([]) 165 | batch_idx = 0 166 | global best_SDR, Var 167 | f = open('./results/spk2.txt', 'a') 168 | f_dir = open('./results/dir2.txt', 'a') 169 | f_bk = open('./results/spk_bk.txt', 'a') 170 | f_bk_dir = open('./results/dir_bk.txt', 'a') 171 | f_emb = open('./results/spk_emb.txt', 'a') 172 | f_emb_dir = open('./results/dir_emb.txt', 'a') 173 | f_hidden = open('./results/spk_hidden.txt', 'a') 174 | f_hidden_dir = open('./results/dir_hidden.txt', 'a') 175 | while True: 176 | print('-' * 30) 177 | eval_data =next(eval_data_gen) 178 | if eval_data == False: 179 | print('SDR_aver_eval_epoch:', SDR_SUM.mean()) 180 | print('SDRi_aver_eval_epoch:', SDRi_SUM.mean()) 181 | print('SISNR_aver_eval_epoch:', SISNR_SUM.mean()) 182 | print('SISNRI_aver_eval_epoch:', SISNRI_SUM.mean()) 183 | print('SS_aver_eval_epoch:', SS_SUM.mean()) 184 | break # 如果这个epoch的生成器没有数据了,直接进入下一个epoch 185 | 186 | raw_tgt= eval_data['batch_order'] 187 | 188 | padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip'] 189 | padded_mixture=torch.from_numpy(padded_mixture).float() 190 | mixture_lengths=torch.from_numpy(mixture_lengths) 191 | padded_source=torch.from_numpy(padded_source).float() 192 | 193 | padded_mixture = padded_mixture.cuda().transpose(0,1) 194 | mixture_lengths = mixture_lengths.cuda() 195 | padded_source = padded_source.cuda() 196 | 197 | top_k = len(raw_tgt[0]) 198 | tgt = Variable(torch.ones(top_k + 2, config.batch_size)) 199 | src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0) 200 | tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0) 201 | 202 | if use_cuda: 203 | tgt = tgt.cuda() 204 | src_len = src_len.cuda() 205 | tgt_len = tgt_len.cuda() 206 | 207 | if 1 and len(opt.gpus) > 1: 208 | samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.module.beam_sample(padded_mixture, dict_spk2idx, dict_dir2idx, config.beam_size) 209 | else: 210 | samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.beam_sample(padded_mixture, dict_spk2idx,dict_dir2idx, config.beam_size) 211 | 212 | predicted_masks = predicted_masks.transpose(0,1) 213 | predicted_masks = predicted_masks[:,0:top_k,:] 214 | mixture = torch.chunk(padded_mixture, 2, dim=-1) 215 | padded_mixture_c0 = mixture[0].squeeze() 216 | 217 | padded_source1= padded_source.data.cpu() 218 | predicted_masks1 = predicted_masks.data.cpu() 219 | 220 | padded_source= padded_source.squeeze().data.cpu().numpy() 221 | padded_mixture = padded_mixture.squeeze().data.cpu().numpy() 222 | predicted_masks = predicted_masks.squeeze().data.cpu().numpy() 223 | padded_mixture_c0 = padded_mixture_c0.squeeze().data.cpu().numpy() 224 | mixture_lengths = mixture_lengths.cpu() 225 | 226 | predicted_masks = predicted_masks - np.mean(predicted_masks) 227 | predicted_masks /= np.max(np.abs(predicted_masks)) 228 | 229 | # ''''' 230 | if batch_idx <= (3000 / config.batch_size): # only the former batches counts the SDR 231 | 232 | sisnr, sisnri = bss_test.cal_SISNRi_PIT(padded_source, predicted_masks,padded_mixture_c0) 233 | sdr, sdri = bss_test.cal_SDRi(padded_source,predicted_masks, padded_mixture_c0) 234 | loss = ss_tas_loss(config,predicted_masks1, padded_source1, mixture_lengths, True) 235 | loss = loss.numpy() 236 | try: 237 | #SDR_SUM,SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output1/')) 238 | SDR_SUM = np.append(SDR_SUM, sdr) 239 | SDRi_SUM = np.append(SDRi_SUM, sdri) 240 | 241 | SISNR_SUM = np.append(SISNR_SUM, sisnr) 242 | SISNRI_SUM = np.append(SISNRI_SUM, sisnri) 243 | SS_SUM = np.append(SS_SUM, loss) 244 | except:# AssertionError,wrong_info: 245 | print('Errors in calculating the SDR',wrong_info) 246 | print('SDR_aver_now:', SDR_SUM.mean()) 247 | print('SDRi_aver_now:', SDRi_SUM.mean()) 248 | print('SISNR_aver_now:', SISNR_SUM.mean()) 249 | print('SISNRI_aver_now:', SISNRI_SUM.mean()) 250 | print('SS_aver_now:', SS_SUM.mean()) 251 | 252 | elif batch_idx == (3000 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR: # only record the best SDR once. 253 | print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean())) 254 | best_SDR = SDR_SUM.mean() 255 | 256 | # ''' 257 | candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['']) for s in samples] 258 | # source += raw_src 259 | reference += raw_tgt 260 | print('samples:', samples) 261 | print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:])) 262 | alignments += [align for align in alignment] 263 | batch_idx += 1 264 | f.close() 265 | f_dir.close() 266 | score = {} 267 | result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path) 268 | logging_csv([e, updates, result['hamming_loss'], \ 269 | result['micro_f1'], result['micro_precision'], result['micro_recall']]) 270 | print('hamming_loss: %.8f | micro_f1: %.4f' 271 | % (result['hamming_loss'], result['micro_f1'])) 272 | score['hamming_loss'] = result['hamming_loss'] 273 | score['micro_f1'] = result['micro_f1'] 274 | return score 275 | 276 | 277 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 278 | def convertToLabels(dict, idx, stop): 279 | labels = [] 280 | 281 | for i in idx: 282 | i = int(i) 283 | if i == stop: 284 | break 285 | labels += [dict[i]] 286 | 287 | return labels 288 | 289 | 290 | def save_model(path): 291 | global updates 292 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict() 293 | checkpoints = { 294 | 'model': model_state_dict, 295 | 'config': config, 296 | 'optim': optim, 297 | 'updates': updates} 298 | 299 | torch.save(checkpoints, path) 300 | 301 | 302 | def main(): 303 | 304 | eval(1) 305 | for metric in config.metric: 306 | logging("Best %s score: %.2f\n" % (metric, max(scores[metric]))) 307 | 308 | 309 | if __name__ == '__main__': 310 | main() 311 | -------------------------------------------------------------------------------- /train_WSJ0_SDNet.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import os 3 | import argparse 4 | import time 5 | import json 6 | import collections 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | import numpy as np 13 | 14 | import models 15 | import data.utils as utils 16 | from optims import Optim 17 | import lr_scheduler as L 18 | from predata_WSJ_lcx import prepare_data 19 | import bss_test 20 | # config 21 | parser = argparse.ArgumentParser(description='train_WSJ_tasnet.py') 22 | 23 | parser.add_argument('-config', default='config_WSJ0_SDNet.yaml', type=str, help="config file") 24 | parser.add_argument('-gpus', default=[0], nargs='+', type=int, help="Use CUDA on the listed devices.") 25 | parser.add_argument('-restore', default='', type=str, help="restore checkpoint") 26 | parser.add_argument('-seed', type=int, default=1234, help="Random seed") 27 | parser.add_argument('-model', default='seq2seq', type=str, help="Model selection") 28 | parser.add_argument('-score', default='', type=str, help="score_fn") 29 | parser.add_argument('-notrain', default=False, type=bool, help="train or not") 30 | parser.add_argument('-log', default='', type=str, help="log directory") 31 | parser.add_argument('-memory', default=False, type=bool, help="memory efficiency") 32 | parser.add_argument('-score_fc', default='', type=str, help="memory efficiency") 33 | 34 | opt = parser.parse_args() 35 | config = utils.read_config(opt.config) 36 | torch.manual_seed(opt.seed) 37 | 38 | # checkpoint 39 | if opt.restore: 40 | print('loading checkpoint...\n', opt.restore) 41 | checkpoints = torch.load(opt.restore,map_location={'cuda:2':'cuda:0'}) 42 | 43 | # cuda 44 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0 45 | use_cuda = True 46 | if use_cuda: 47 | torch.cuda.set_device(opt.gpus[0]) 48 | torch.cuda.manual_seed(opt.seed) 49 | print(use_cuda) 50 | 51 | # load the global statistic of the data 52 | print('loading data...\n') 53 | start_time = time.time() 54 | 55 | spk_global_gen = prepare_data(mode='global', train_or_test='train') # 数据中的一些统计参数的读取 56 | global_para = next(spk_global_gen) 57 | print(global_para) 58 | 59 | spk_all_list = global_para['all_spk'] # 所有说话人的列表 60 | dict_spk2idx = global_para['dict_spk_to_idx'] 61 | dict_idx2spk = global_para['dict_idx_to_spk'] 62 | direction_all_list = global_para['all_dir'] 63 | dict_dir2idx = global_para['dict_dir_to_idx'] 64 | dict_idx2dir = global_para['dict_idx_to_dir'] 65 | speech_fre = global_para['num_fre'] # 语音频率总数 66 | total_frames = global_para['num_frames'] # 语音长度 67 | spk_num_total = global_para['total_spk_num'] # 总计说话人数目 68 | batch_total = global_para['total_batch_num'] # 一个epoch里多少个batch 69 | 70 | print(dict_idx2spk) 71 | print(dict_idx2dir) 72 | 73 | config.speech_fre = speech_fre 74 | mix_speech_len = total_frames 75 | config.mix_speech_len = total_frames 76 | num_labels = len(spk_all_list) 77 | num_dir_labels = len(direction_all_list) 78 | 79 | del spk_global_gen 80 | print('loading the global setting cost: %.3f' % (time.time() - start_time)) 81 | print("num_dir_labels", num_dir_labels) 82 | print("num_labels", num_labels) 83 | # model 84 | print('building model...\n') 85 | model = getattr(models, opt.model)(config, 256, mix_speech_len, num_labels, num_dir_labels, use_cuda, None, opt.score_fc) 86 | 87 | if opt.restore: 88 | model.load_state_dict(checkpoints['model']) 89 | if use_cuda: 90 | model.cuda() 91 | if len(opt.gpus) > 1: 92 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1) 93 | 94 | # optimizer 95 | if 0 and opt.restore: 96 | optim = checkpoints['optim'] 97 | else: 98 | optim = Optim(config.optim, config.learning_rate, config.max_grad_norm, 99 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at) 100 | 101 | optim.set_parameters(model.parameters()) 102 | 103 | if config.schedule: 104 | # scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch) 105 | scheduler = L.StepLR(optim.optimizer, step_size=20, gamma=0.2) 106 | 107 | # total number of parameters 108 | param_count = 0 109 | for param in model.parameters(): 110 | param_count += param.view(-1).size()[0] 111 | 112 | # logging modeule 113 | if not os.path.exists(config.log): 114 | os.mkdir(config.log) 115 | if opt.log == '': 116 | log_path = config.log + utils.format_time(time.localtime()) + '/' 117 | else: 118 | log_path = config.log + opt.log + '/' 119 | if not os.path.exists(log_path): 120 | os.mkdir(log_path) 121 | print('log_path:',log_path) 122 | 123 | logging = utils.logging(log_path + 'log.txt') # 单独写一个logging的函数,直接调用,既print,又记录到Log文件里。 124 | logging_csv = utils.logging_csv(log_path + 'record.csv') 125 | for k, v in config.items(): 126 | logging("%s:\t%s\n" % (str(k), str(v))) 127 | logging("\n") 128 | logging(repr(model) + "\n\n") 129 | 130 | logging('total number of parameters: %d\n\n' % param_count) 131 | logging('score function is %s\n\n' % opt.score) 132 | 133 | if opt.restore: 134 | updates = checkpoints['updates'] 135 | else: 136 | updates = 0 137 | 138 | total_loss, start_time = 0, time.time() 139 | total_loss_sgm, total_loss_ss = 0, 0 140 | report_total, report_correct = 0, 0 141 | report_vocab, report_tot_vocab = 0, 0 142 | scores = [[] for metric in config.metric] 143 | scores = collections.OrderedDict(zip(config.metric, scores)) 144 | best_SDR = 0.0 145 | e=0 146 | loss_last_epoch = 1000000.0 147 | 148 | def train(epoch): 149 | global e, updates, total_loss, start_time, report_total,report_correct, total_loss_sgm, total_loss_ss,loss_last_epoch 150 | e = epoch 151 | model.train() 152 | SDR_SUM = np.array([]) 153 | SDRi_SUM = np.array([]) 154 | total_loss_final = 0 155 | 156 | if config.schedule and scheduler.get_lr()[0]>5e-5: 157 | scheduler.step() 158 | print("Decaying learning rate to %g" % scheduler.get_lr()[0]) 159 | 160 | if opt.model == 'gated': 161 | model.current_epoch = epoch 162 | 163 | train_data_gen = prepare_data('once', 'train') 164 | while True: 165 | train_data = next(train_data_gen) 166 | if train_data == False: 167 | print("ss loss (SISNR) in trainset:", total_loss_final) 168 | break 169 | 170 | raw_tgt = train_data['batch_order'] 171 | raw_tgt_dir = train_data['direction'] 172 | 173 | padded_mixture, mixture_lengths, padded_source = train_data['tas_zip'] 174 | padded_mixture = Variable(torch.from_numpy(padded_mixture).float()) 175 | mixture_lengths = torch.from_numpy(mixture_lengths) 176 | padded_source = torch.from_numpy(padded_source).float() 177 | 178 | 179 | padded_mixture = padded_mixture.cuda().transpose(0,1) 180 | mixture_lengths = mixture_lengths.cuda() 181 | padded_source = padded_source.cuda() 182 | 183 | # 要保证底下这几个都是longTensor(长整数) 184 | tgt_max_len = config.MAX_MIX + 2 # with bos and eos. 185 | tgt = Variable(torch.from_numpy(np.array([[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['']] for 186 | spks in raw_tgt], dtype=np.int))).transpose(0, 1) # 转换成数字,然后前后加开始和结束符号。 187 | tgt_dir = Variable(torch.from_numpy(np.array([[0] + [dict_dir2idx[int(dire)] for dire in directs] + (tgt_max_len - len(directs) - 1) * [dict_dir2idx['']] for 188 | directs in raw_tgt_dir], dtype=np.int))).transpose(0, 1) # 转换成数字,然后前后加开始和结束符号。 189 | src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0) 190 | tgt_len = Variable( torch.LongTensor([len(one_spk) for one_spk in train_data['multi_spk_fea_list']])).unsqueeze(0) 191 | if use_cuda: 192 | tgt = tgt.cuda() 193 | tgt_dir = tgt_dir.cuda() 194 | src_len = src_len.cuda() 195 | tgt_len = tgt_len.cuda() 196 | 197 | model.zero_grad() 198 | 199 | # aim_list 就是找到有正经说话人的地方的标号 200 | aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) != dict_spk2idx['']).nonzero().squeeze() 201 | aim_list = aim_list.data.cpu().numpy() 202 | 203 | outputs, outputs_dir, targets, targets_dir, multi_mask = model(padded_mixture, tgt, tgt_dir) # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用 204 | multi_mask = multi_mask.transpose(0,1) 205 | 206 | if 1 and len(opt.gpus) > 1: 207 | sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir = model.module.compute_loss(outputs, outputs_dir, targets, targets_dir, opt.memory) 208 | else: 209 | sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir = model.compute_loss(outputs, outputs_dir, targets, targets_dir, opt.memory) 210 | print('loss for SGM,this batch:', sgm_loss.item()) 211 | print("num_total",num_total) 212 | print("num_correct",num_correct) 213 | 214 | if config.use_tas: 215 | if 1 and len(opt.gpus) > 1: 216 | ss_loss = model.module.separation_tas_loss(multi_mask, padded_source, mixture_lengths) 217 | else: 218 | ss_loss = model.separation_tas_loss(multi_mask, padded_source, mixture_lengths) 219 | 220 | print('loss for SS,this batch:', ss_loss.item()) 221 | 222 | loss = 4*sgm_loss + ss_loss 223 | 224 | loss.backward() 225 | total_loss_sgm += sgm_loss.item() 226 | total_loss_ss += ss_loss.item() 227 | total_loss_final += ss_loss.item() 228 | 229 | report_correct += num_correct.item() 230 | report_total += num_total 231 | optim.step() 232 | updates += 1 233 | 234 | if updates % 30 == 0: 235 | logging( 236 | "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n" 237 | % (time.time() - start_time, epoch, updates, loss / num_total, total_loss_sgm / 30.0, 238 | total_loss_ss / 30.0, report_correct/report_total)) 239 | total_loss_sgm, total_loss_ss = 0, 0 240 | 241 | if total_loss_final < loss_last_epoch: 242 | loss_last_epoch = total_loss_final 243 | save_model(log_path + 'DSNet_{}_{}.pt'.format(epoch,total_loss_final)) 244 | 245 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 246 | def convertToLabels(dict, idx, stop): 247 | labels = [] 248 | 249 | for i in idx: 250 | i = int(i) 251 | if i == stop: 252 | break 253 | labels += [dict[i]] 254 | 255 | return labels 256 | 257 | 258 | def save_model(path): 259 | global updates 260 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict() 261 | checkpoints = { 262 | 'model': model_state_dict, 263 | 'config': config, 264 | 'optim': optim, 265 | 'updates': updates} 266 | 267 | torch.save(checkpoints, path) 268 | 269 | 270 | def main(): 271 | for i in range(1, config.epoch + 1): 272 | train(i) 273 | for metric in config.metric: 274 | logging("Best %s score: %.2f\n" % (metric, max(scores[metric]))) 275 | 276 | 277 | if __name__ == '__main__': 278 | main() 279 | --------------------------------------------------------------------------------