├── README.md
├── bss_test.py
├── config_WSJ0_SDNet.yaml
├── data
├── Readme
├── __init__.py
├── dataloader.py
├── dict.py
└── utils.py
├── jpg
├── Readme
└── sdnet.jpeg
├── lr_scheduler.py
├── models
├── Readme
├── Schmidt_orth.py
├── WaveLoss.py
├── __init__.py
├── attention.py
├── beam.py
├── focal_loss.py
├── istft_irfft.py
├── loss.py
├── metrics.py
├── rnn.py
├── separation_dis.py
├── separation_tasnet.py
└── seq2seq.py
├── predata_WSJ_lcx.py
├── run.sh
├── separation.py
├── test_WSJ0_SDNet.py
└── train_WSJ0_SDNet.py
/README.md:
--------------------------------------------------------------------------------
1 | # ICASSP 2021: SDNet:Speaker and Direction Inferred Dual-channel Speech Separation
2 |
3 | If you have the interest in our work, or use this code or part of it, please cite us!
4 | Consider citing:
5 | ```bash
6 | @inproceedings{li2021speaker,
7 | title={Speaker and Direction Inferred Dual-Channel Speech Separation},
8 | author={Li, Chenxing and Xu, Jiaming and Mesgarani, Nima and Xu, Bo},
9 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
10 | pages={5779--5783},
11 | year={2021},
12 | organization={IEEE}
13 | }
14 | ```
15 | For more detailed descirption, you can further explore the whole paper with [this link](https://doi.org/10.1109/ICASSP39728.2021.9413818).
16 |
17 | # Requirements:
18 | Pytorch>=1.1.0
19 | resampy
20 | soundfile
21 |
22 | # Model Descriptions:
23 | 
24 |
25 |
26 |
27 | # Data Preparation
28 |
29 | Please refer to predata_WSJ_lcx.py
30 | A more detailed dataset preparation procedure will be updated soon.
31 |
32 | # Train and Test
33 |
34 | For train:
35 | python train_WSJ0_SDNet.py
36 |
37 | For test:
38 | python test_WSJ0_SDNet.py
39 |
40 | Please Modify the model path in test_WSJ0_SDNet.py.
41 |
42 | # Contact
43 | If you have any questions please contact:
44 | Email:lichenxing007@gmail.com
45 |
46 | # TODO
47 | 1. A brief implemention of SDNet
48 | 2. pretrained models.
49 | 3. separated samples.
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/bss_test.py:
--------------------------------------------------------------------------------
1 | #coding=utf8
2 | import numpy as np
3 | import os
4 | import soundfile as sf
5 | from separation import bss_eval_sources
6 |
7 | path='batch_output/'
8 | # path='/home/sw/Shin/Codes/DL4SS_Keras/TDAA_beta/batch_output2/'
9 | def cal_SDRi(src_ref, src_est, mix):
10 | """Calculate Source-to-Distortion Ratio improvement (SDRi).
11 | NOTE: bss_eval_sources is very very slow.
12 | Args:
13 | src_ref: numpy.ndarray, [C, T]
14 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
15 | mix: numpy.ndarray, [T]
16 | Returns:
17 | average_SDRi
18 | """
19 | src_anchor = np.stack([mix, mix], axis=0)
20 | if src_ref.shape[0]==1:
21 | src_anchor=src_anchor[0]
22 | sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
23 | sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
24 | avg_SDR = ((sdr[0]) + (sdr[1])) / 2
25 | avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
26 | return avg_SDR, avg_SDRi
27 |
28 |
29 | def cal_SISNRi(src_ref, src_est, mix):
30 | """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
31 | Args:
32 | src_ref: numpy.ndarray, [C, T]
33 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
34 | mix: numpy.ndarray, [T]
35 | Returns:
36 | average_SISNRi
37 | """
38 |
39 | sisnr1 = cal_SISNR(src_ref[0], src_est[0])
40 | sisnr2 = cal_SISNR(src_ref[1], src_est[1])
41 | sisnr1b = cal_SISNR(src_ref[0], mix)
42 | sisnr2b = cal_SISNR(src_ref[1], mix)
43 | avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
44 | return avg_SISNRi
45 |
46 | def cal_SISNRi_PIT(src_ref, src_est, mix):
47 | """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) 2-mix
48 | Args:
49 | src_ref: numpy.ndarray, [C, T]
50 | src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
51 | mix: numpy.ndarray, [T]
52 | Returns:
53 | average_SISNRi
54 | """
55 |
56 | sisnr1_a = cal_SISNR(src_ref[0,:], src_est[0,:])
57 | sisnr2_a = cal_SISNR(src_ref[1,:], src_est[1,:])
58 |
59 | sisnr1_b = cal_SISNR(src_ref[0,:], src_est[1,:])
60 | sisnr2_b = cal_SISNR(src_ref[1,:], src_est[0,:])
61 |
62 | sisnr1_o = cal_SISNR(src_ref[0,:], mix)
63 | sisnr2_o = cal_SISNR(src_ref[1,:], mix)
64 |
65 | avg_SISNR = max((sisnr1_a+sisnr2_a)/2, (sisnr1_b+sisnr2_b)/2)
66 | avg_SISNRi = max( ((sisnr1_a - sisnr1_o) + (sisnr2_a - sisnr2_o)) / 2 ,((sisnr1_b - sisnr1_o) + (sisnr2_b - sisnr2_o)) / 2 )
67 | return avg_SISNR, avg_SISNRi
68 |
69 | def cal_SISNR(ref_sig, out_sig, eps=1e-8):
70 | """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
71 | Args:
72 | ref_sig: numpy.ndarray, [T]
73 | out_sig: numpy.ndarray, [T]
74 | Returns:
75 | SISNR
76 | """
77 | assert len(ref_sig) == len(out_sig)
78 | ref_sig = ref_sig - np.mean(ref_sig)
79 | out_sig = out_sig - np.mean(out_sig)
80 | ref_energy = np.sum(ref_sig ** 2) + eps
81 | proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
82 | noise = out_sig - proj
83 | ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
84 | sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
85 | return sisnr
86 |
87 | def cal(path,tmp=None):
88 | mix_number=len(set([l.split('_')[0] for l in os.listdir(path) if l[-3:]=='wav']))
89 | print('num of mixed :',mix_number)
90 | SDR_sum=np.array([])
91 | SDRi_sum=np.array([])
92 | for idx in range(mix_number):
93 | pre_speech_channel=[]
94 | aim_speech_channel=[]
95 | mix_speech=[]
96 | for l in sorted(os.listdir(path)):
97 | if l[-3:]!='wav':
98 | continue
99 | if l.split('_')[0]==str(idx):
100 | if 'True_mix' in l:
101 | mix_speech.append(sf.read(path+l)[0])
102 | if 'real' in l and 'noise' not in l:
103 | aim_speech_channel.append(sf.read(path+l)[0])
104 | if 'pre' in l:
105 | pre_speech_channel.append(sf.read(path+l)[0])
106 |
107 | assert len(aim_speech_channel)==len(pre_speech_channel)
108 | aim_speech_channel=np.array(aim_speech_channel)
109 | pre_speech_channel=np.array(pre_speech_channel)
110 | mix_speech=np.array(mix_speech)
111 | assert mix_speech.shape[0]==1
112 | mix_speech=mix_speech[0]
113 |
114 | result=bss_eval_sources(aim_speech_channel,pre_speech_channel)
115 | SDR_sum=np.append(SDR_sum,result[0])
116 |
117 | # result=bss_eval_sources(aim_speech_channel,aim_speech_channel)
118 | # result_sdri=cal_SDRi(aim_speech_channel,pre_speech_channel,mix_speech)
119 | # print 'SDRi:',result_sdri
120 | result=cal_SISNRi(aim_speech_channel,pre_speech_channel,mix_speech)
121 | print('SI-SNR',result)
122 | # for ii in range(aim_speech_channel.shape[0]):
123 | # result=cal_SISNRi(aim_speech_channel[ii],pre_speech_channel[ii],mix_speech[ii])
124 | # print('SI-SNR',result)
125 | # SDRi_sum=np.append(SDRi_sum,result_sdri)
126 |
127 | print('SDR_Aver for this batch:',SDR_sum.mean())
128 | # print 'SDRi_Aver for this batch:',SDRi_sum.mean()
129 | return SDR_sum.mean(),SDRi_sum.mean()
130 |
131 | # cal(path)
132 |
133 |
--------------------------------------------------------------------------------
/config_WSJ0_SDNet.yaml:
--------------------------------------------------------------------------------
1 | log: './log/'
2 | epoch: 300
3 | batch_size: 1
4 | param_init: 0.1
5 | optim: 'adam'
6 | loss: 'focal_loss'
7 | use_center_loss: 0
8 | learning_rate: 0.001
9 | max_grad_norm: 5
10 | learning_rate_decay: 0.5
11 |
12 | mask: 1
13 | schedule: 1
14 | bidirec: True
15 | start_decay_at: 5
16 | emb_size: 256
17 | encoder_hidden_size: 256
18 | decoder_hidden_size: 512
19 | num_layers: 4
20 | dropout: 0.5
21 | max_tgt_len: 5
22 | eval_interval: 1000
23 | save_interval: 1000
24 | max_generator_batches: 32
25 | metric: ['hamming_loss', 'micro_f1']
26 | shared_vocab: 0
27 | WFM: 1
28 | MLMSE: 0
29 | beam_size: 5
30 | tmp_score: 0
31 | top1: 0
32 | ct_recu: 0
33 |
34 | use_tas: 1
35 | all_soft: 0
36 |
37 | global_emb: 1
38 | global_hidden: 0
39 | SPK_EMB_SIZE: 256
40 | schmidt: 0
41 | unit_norm: 1
42 | reID: 0
43 | is_SelfTune : 0
44 | is_dis: 0
45 | speech_cnn_net: 0
46 | relitu: 0
47 | ALPHA: 0.5
48 | quchong_alpha: 1
49 |
50 | #Minimum number of mixed speakers for training
51 | MIN_MIX: 2
52 | #Maximum number of mixed speakers for training
53 | MAX_MIX: 2
54 | MODE: 1
55 | DATASET : 'WSJ0'
56 | is_ComlexMask: 1
57 | num_samples_one_epoch: 20000
58 |
59 | Ground_truth: 1
60 | Comm_with_Memory: 0
61 | HIDDEN_UNITS: 300
62 | NUM_LAYERS: 3
63 | EMBEDDING_SIZE: 50
64 |
65 | ATT_SIZE: 100
66 | AUGMENT_DATA: 0
67 | MAX_EPOCH: 600
68 | EPOCH_SIZE: 600
69 | FRAME_RATE: 8000
70 | FRAME_LENGTH: 256
71 | FRAME_SHIFT: 64
72 | SHUFFLE_BATCH: 1
73 | voice_dB: 2.5
74 | noise_dB: -10
75 | normalize: 1
76 | MIN_LEN: 24000
77 | MAX_LEN: 24000
78 | WINDOWS: FRAME_LENGTH
79 | START_EALY_STOP: 0
80 | IS_LOG_SPECTRAL : 0
81 | channel_first: 1
82 |
--------------------------------------------------------------------------------
/data/Readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from models.attention import *
2 | from models.rnn import *
3 | from models.seq2seq import *
4 | from models.loss import *
5 | from models.beam import *
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as torch_data
3 | import os
4 | import data.utils
5 |
6 | class dataset(torch_data.Dataset):
7 |
8 | def __init__(self, src, tgt, raw_src, raw_tgt):
9 |
10 | self.src = src
11 | self.tgt = tgt
12 | self.raw_src = raw_src
13 | self.raw_tgt = raw_tgt
14 |
15 | def __getitem__(self, index):
16 | return self.src[index], self.tgt[index], \
17 | self.raw_src[index], self.raw_tgt[index]
18 |
19 | def __len__(self):
20 | return len(self.src)
21 |
22 |
23 | def load_dataset(path):
24 | pass
25 |
26 | def save_dataset(dataset, path):
27 | if not os.path.exists(path):
28 | os.mkdir(path)
29 |
30 |
31 | def padding(data):
32 | #data.sort(key=lambda x: len(x[0]), reverse=True)
33 | src, tgt, raw_src, raw_tgt = zip(*data)
34 |
35 | src_len = [len(s) for s in src]
36 | src_pad = torch.zeros(len(src), max(src_len)).long()
37 | for i, s in enumerate(src):
38 | end = src_len[i]
39 | src_pad[i, :end] = s[:end]
40 |
41 | tgt_len = [len(s) for s in tgt]
42 | tgt_pad = torch.zeros(len(tgt), max(tgt_len)).long()
43 | for i, s in enumerate(tgt):
44 | end = tgt_len[i]
45 | tgt_pad[i, :end] = s[:end]
46 | #tgt_len = [length-1 for length in tgt_len]
47 |
48 | #return src_pad.t(), src_len, tgt_pad.t(), tgt_len
49 | return raw_src, src_pad.t(), torch.LongTensor(src_len), \
50 | raw_tgt, tgt_pad.t(), torch.LongTensor(tgt_len)
51 |
52 |
53 | def get_loader(dataset, batch_size, shuffle, num_workers):
54 |
55 | data_loader = torch.utils.data.DataLoader(dataset=dataset,
56 | batch_size=batch_size,
57 | shuffle=shuffle,
58 | num_workers=num_workers,
59 | collate_fn=padding)
60 | return data_loader
--------------------------------------------------------------------------------
/data/dict.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | PAD = 0
4 | UNK = 1
5 | BOS = 2
6 | EOS = 3
7 |
8 | PAD_WORD = ''
9 | UNK_WORD = ' '
10 | BOS_WORD = ''
11 | EOS_WORD = ''
12 | SPA_WORD = ' '
13 |
14 | def flatten(l):
15 | for el in l:
16 | if hasattr(el, "__iter__"):
17 | for sub in flatten(el):
18 | yield sub
19 | else:
20 | yield el
21 |
22 | class Dict(object):
23 | def __init__(self, data=None, lower=False):
24 | self.idxToLabel = {}
25 | self.labelToIdx = {}
26 | self.frequencies = {}
27 | self.lower = lower
28 | self.special = []
29 |
30 | if data is not None:
31 | if type(data) == str:
32 | self.loadFile(data)
33 | else:
34 | self.addSpecials(data)
35 |
36 | def size(self):
37 | return len(self.idxToLabel)
38 |
39 | # Load entries from a file.
40 | def loadFile(self, filename):
41 | for line in open(filename):
42 | fields = line.split()
43 | label = fields[0]
44 | idx = int(fields[1])
45 | self.add(label, idx)
46 |
47 | # Write entries to a file.
48 | def writeFile(self, filename):
49 | with open(filename, 'w') as file:
50 | for i in range(self.size()):
51 | label = self.idxToLabel[i]
52 | file.write('%s %d\n' % (label, i))
53 |
54 | file.close()
55 |
56 | def loadDict(self, idxToLabel):
57 | for i in range(len(idxToLabel)):
58 | label = idxToLabel[i]
59 | self.add(label, i)
60 |
61 | def lookup(self, key, default=None):
62 | key = key.lower() if self.lower else key
63 | try:
64 | return self.labelToIdx[key]
65 | except KeyError:
66 | return default
67 |
68 | def getLabel(self, idx, default=None):
69 | try:
70 | return self.idxToLabel[idx]
71 | except KeyError:
72 | return default
73 |
74 | # Mark this `label` and `idx` as special (i.e. will not be pruned).
75 | def addSpecial(self, label, idx=None):
76 | idx = self.add(label, idx)
77 | self.special += [idx]
78 |
79 | # Mark all labels in `labels` as specials (i.e. will not be pruned).
80 | def addSpecials(self, labels):
81 | for label in labels:
82 | self.addSpecial(label)
83 |
84 | # Add `label` in the dictionary. Use `idx` as its index if given.
85 | def add(self, label, idx=None):
86 | label = label.lower() if self.lower else label
87 | if idx is not None:
88 | self.idxToLabel[idx] = label
89 | self.labelToIdx[label] = idx
90 | else:
91 | if label in self.labelToIdx:
92 | idx = self.labelToIdx[label]
93 | else:
94 | idx = len(self.idxToLabel)
95 | self.idxToLabel[idx] = label
96 | self.labelToIdx[label] = idx
97 |
98 | if idx not in self.frequencies:
99 | self.frequencies[idx] = 1
100 | else:
101 | self.frequencies[idx] += 1
102 |
103 | return idx
104 |
105 | # Return a new dictionary with the `size` most frequent entries.
106 | def prune(self, size):
107 | if size >= self.size():
108 | return self
109 |
110 | # Only keep the `size` most frequent entries.
111 | freq = torch.Tensor(
112 | [self.frequencies[i] for i in range(len(self.frequencies))])
113 | _, idx = torch.sort(freq, 0, True)
114 |
115 | newDict = Dict()
116 | newDict.lower = self.lower
117 |
118 | # Add special entries in all cases.
119 | for i in self.special:
120 | newDict.addSpecial(self.idxToLabel[i])
121 |
122 | for i in idx[:size]:
123 | newDict.add(self.idxToLabel[i])
124 |
125 | return newDict
126 |
127 | # Convert `labels` to indices. Use `unkWord` if not found.
128 | # Optionally insert `bosWord` at the beginning and `eosWord` at the .
129 | def convertToIdx(self, labels, unkWord, bosWord=None, eosWord=None):
130 | vec = []
131 |
132 | if bosWord is not None:
133 | vec += [self.lookup(bosWord)]
134 |
135 | unk = self.lookup(unkWord)
136 | vec += [self.lookup(label, default=unk) for label in labels]
137 |
138 | if eosWord is not None:
139 | vec += [self.lookup(eosWord)]
140 |
141 | vec = [x for x in flatten(vec)]
142 |
143 | return torch.LongTensor(vec)
144 |
145 | # Convert `idx` to labels. If index `stop` is reached, convert it and return.
146 | def convertToLabels(self, idx, stop):
147 | labels = []
148 |
149 | for i in idx:
150 | if i == stop:
151 | break
152 | labels += [self.getLabel(i)]
153 |
154 | return labels
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import os
3 | import csv
4 | import codecs
5 | import yaml
6 | import time
7 | import numpy as np
8 | import shutil
9 | import soundfile as sf
10 | import librosa
11 |
12 | from sklearn import metrics
13 |
14 |
15 | class AttrDict(dict):
16 | def __init__(self, *args, **kwargs):
17 | super(AttrDict, self).__init__(*args, **kwargs)
18 | self.__dict__ = self
19 |
20 |
21 | def read_config(path):
22 | return AttrDict(yaml.load(open(path, 'r')))
23 |
24 |
25 | def read_datas(filename, trans_to_num=False):
26 | lines = open(filename, 'r').readlines()
27 | lines = list(map(lambda x: x.split(), lines))
28 | if trans_to_num:
29 | lines = [list(map(int, line)) for line in lines]
30 | return lines
31 |
32 |
33 | def save_datas(data, filename, trans_to_str=False):
34 | if trans_to_str:
35 | data = [list(map(str, line)) for line in data]
36 | lines = list(map(lambda x: " ".join(x), data))
37 | with open(filename, 'w') as f:
38 | f.write("\n".join(lines))
39 |
40 |
41 | def logging(file):
42 | def write_log(s):
43 | print(s, '')
44 | with open(file, 'a') as f:
45 | f.write(s)
46 |
47 | return write_log
48 |
49 |
50 | def logging_csv(file):
51 | def write_csv(s):
52 | # with open(file, 'a', newline='') as f:
53 | with open(file, 'a') as f:
54 | writer = csv.writer(f)
55 | writer.writerow(s)
56 |
57 | return write_csv
58 |
59 |
60 | def format_time(t):
61 | return time.strftime("%Y-%m-%d-%H:%M:%S", t)
62 |
63 |
64 | def eval_metrics(reference, candidate, label_dict, log_path):
65 | ref_dir = log_path + 'reference/'
66 | cand_dir = log_path + 'candidate/'
67 | if not os.path.exists(ref_dir):
68 | os.mkdir(ref_dir)
69 | if not os.path.exists(cand_dir):
70 | os.mkdir(cand_dir)
71 | ref_file = ref_dir + 'reference'
72 | cand_file = cand_dir + 'candidate'
73 |
74 | for i in range(len(reference)):
75 | with codecs.open(ref_file + str(i), 'w', 'utf-8') as f:
76 | f.write("".join(reference[i]) + '\n')
77 | with codecs.open(cand_file + str(i), 'w', 'utf-8') as f:
78 | f.write("".join(candidate[i]) + '\n')
79 |
80 | def make_label(l, label_dict):
81 | length = len(label_dict)
82 | result = np.zeros(length)
83 | indices = [label_dict.get(label.strip().lower(), 0) for label in l]
84 | result[indices] = 1
85 | return result
86 |
87 | def prepare_label(y_list, y_pre_list, label_dict):
88 | reference = np.array([make_label(y, label_dict) for y in y_list])
89 | candidate = np.array([make_label(y_pre, label_dict) for y_pre in y_pre_list])
90 | return reference, candidate
91 |
92 | def get_metrics(y, y_pre):
93 | hamming_loss = metrics.hamming_loss(y, y_pre)
94 | macro_f1 = metrics.f1_score(y, y_pre, average='macro')
95 | macro_precision = metrics.precision_score(y, y_pre, average='macro')
96 | macro_recall = metrics.recall_score(y, y_pre, average='macro')
97 | micro_f1 = metrics.f1_score(y, y_pre, average='micro')
98 | micro_precision = metrics.precision_score(y, y_pre, average='micro')
99 | micro_recall = metrics.recall_score(y, y_pre, average='micro')
100 | return hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall
101 |
102 | y, y_pre = prepare_label(reference, candidate, label_dict)
103 | hamming_loss, macro_f1, macro_precision, macro_recall, micro_f1, micro_precision, micro_recall = get_metrics(y,
104 | y_pre)
105 | return {'hamming_loss': hamming_loss,
106 | 'macro_f1': macro_f1,
107 | 'macro_precision': macro_precision,
108 | 'macro_recall': macro_recall,
109 | 'micro_f1': micro_f1,
110 | 'micro_precision': micro_precision,
111 | 'micro_recall': micro_recall}
112 |
113 |
114 | def bss_eval(config, predict_multi_map, y_multi_map, y_map_gtruth, train_data, dst='batch_output'):
115 | # dst='batch_output'
116 | if os.path.exists(dst):
117 | print(" \ncleanup: " + dst + "/")
118 | shutil.rmtree(dst)
119 | os.makedirs(dst)
120 |
121 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']):
122 | for each_spk in each_sample.keys():
123 | this_spk = each_spk
124 | wav_genTrue = each_sample[this_spk]
125 | # min_len = 39936
126 | min_len = len(wav_genTrue)
127 | if config.FRAME_SHIFT == 64:
128 | min_len = len(wav_genTrue)
129 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len],
130 | config.FRAME_RATE, )
131 |
132 | predict_multi_map_list = []
133 | pointer = 0
134 | for each_line in y_map_gtruth:
135 | predict_multi_map_list.append(predict_multi_map[pointer:(pointer + len(each_line))])
136 | pointer += len(each_line)
137 | assert len(predict_multi_map_list) == len(y_map_gtruth)
138 |
139 | # 对于每个sample
140 | sample_idx = 0 # 代表一个batch里的依次第几个
141 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth,
142 | train_data['aim_spkname']):
143 | _mix_spec = train_data['mix_phase'][sample_idx]
144 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx]
145 | phase_mix = np.angle(_mix_spec)
146 | for idx, one_cha in enumerate(each_trueVector):
147 | this_spk = one_cha
148 | y_pre_map = each_pre[idx].data.cpu().numpy()
149 | _pred_spec = y_pre_map * np.exp(1j * phase_mix)
150 | wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT)
151 | min_len = len(wav_pre)
152 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, )
153 |
154 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix)
155 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT)
156 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len],
157 | config.FRAME_RATE, )
158 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len],
159 | config.FRAME_RATE, )
160 | sample_idx += 1
161 |
162 |
163 | def bss_eval2(config, predict_multi_map, y_multi_map, y_map_gtruth, train_data, dst='batch_output'):
164 | # dst='batch_output'
165 | if os.path.exists(dst):
166 | print(" \ncleanup: " + dst + "/")
167 | shutil.rmtree(dst)
168 | os.makedirs(dst)
169 |
170 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']):
171 | for each_spk in each_sample.keys():
172 | this_spk = each_spk
173 | wav_genTrue = each_sample[this_spk]
174 | # min_len = 39936
175 | min_len = len(wav_genTrue)
176 | if config.FRAME_SHIFT == 64:
177 | min_len = len(wav_genTrue)
178 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len],
179 | config.FRAME_RATE, )
180 |
181 | predict_multi_map_list = []
182 | pointer = 0
183 | for each_line in y_map_gtruth:
184 | predict_multi_map_list.append(predict_multi_map[pointer:(pointer + len(each_line))])
185 | pointer += len(each_line)
186 | assert len(predict_multi_map_list) == len(y_map_gtruth)
187 |
188 | # 对于每个sample
189 | sample_idx = 0 # 代表一个batch里的依次第几个
190 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth,
191 | train_data['aim_spkname']):
192 | _mix_spec = train_data['mix_phase'][sample_idx]
193 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx]
194 | phase_mix = np.angle(_mix_spec)
195 | each_pre = each_pre[0]
196 | for idx, one_cha in enumerate(each_trueVector):
197 | this_spk = one_cha
198 | y_pre_map = each_pre[idx].data.cpu().numpy()
199 | _pred_spec = y_pre_map * np.exp(1j * phase_mix)
200 | wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT)
201 | min_len = len(wav_pre)
202 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, )
203 |
204 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix)
205 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT)
206 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len],
207 | config.FRAME_RATE, )
208 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len],
209 | config.FRAME_RATE, )
210 | sample_idx += 1
211 |
212 | def bss_eval_tas(config, predict_wav, y_multi_map, y_map_gtruth, train_data, dst='batch_output'):
213 | # dst='batch_output'
214 | if os.path.exists(dst):
215 | print(" \ncleanup: " + dst + "/")
216 | shutil.rmtree(dst)
217 | os.makedirs(dst)
218 |
219 | for sample_idx, each_sample in enumerate(train_data['multi_spk_wav_list']):
220 | for each_spk in each_sample.keys():
221 | this_spk = each_spk
222 | wav_genTrue = each_sample[this_spk]
223 | # min_len = 39936
224 | min_len = len(wav_genTrue)
225 | if config.FRAME_SHIFT == 64:
226 | min_len = len(wav_genTrue)
227 | sf.write(dst + '/{}_{}_realTrue.wav'.format(sample_idx, this_spk), wav_genTrue[:min_len],
228 | config.FRAME_RATE, )
229 |
230 | predict_multi_map_list = []
231 | pointer = 0
232 | # if len(predict_wav.shape)==3:
233 | # predict_wav=predict_wav.view(-1,predict_wav.shape[-1])
234 | for each_line in y_map_gtruth:
235 | predict_multi_map_list.append(predict_wav[pointer:(pointer + len(each_line))])
236 | pointer += len(each_line)
237 | assert len(predict_multi_map_list) == len(y_map_gtruth)
238 | predict_multi_map_list=[i for i in predict_wav.unsqueeze(1)]
239 |
240 | # 对于每个sample
241 | sample_idx = 0 # 代表一个batch里的依次第几个
242 | for each_y, each_pre, each_trueVector, spk_name in zip(y_multi_map, predict_multi_map_list, y_map_gtruth,
243 | train_data['aim_spkname']):
244 | _mix_spec = train_data['mix_phase'][sample_idx]
245 | feas_tgt = train_data['multi_spk_fea_list'][sample_idx]
246 | phase_mix = np.angle(_mix_spec)
247 | each_pre = each_pre[0]
248 | for idx, one_cha in enumerate(each_trueVector):
249 | this_spk = one_cha
250 | y_pre_map = each_pre[idx].data.cpu().numpy()
251 | # _pred_spec = y_pre_map * np.exp(1j * phase_mix)
252 | # wav_pre = librosa.core.spectrum.istft(np.transpose(_pred_spec), config.FRAME_SHIFT)
253 | wav_pre = y_pre_map
254 | min_len = len(wav_pre)
255 | sf.write(dst + '/{}_{}_pre.wav'.format(sample_idx, this_spk), wav_pre[:min_len], config.FRAME_RATE, )
256 |
257 | gen_true_spec = feas_tgt[this_spk] * np.exp(1j * phase_mix)
258 | wav_gen_True = librosa.core.spectrum.istft(np.transpose(gen_true_spec), config.FRAME_SHIFT)
259 | sf.write(dst + '/{}_{}_genTrue.wav'.format(sample_idx, this_spk), wav_gen_True[:min_len],
260 | config.FRAME_RATE, )
261 | sf.write(dst + '/{}_True_mix.wav'.format(sample_idx), train_data['mix_wav'][sample_idx][:min_len],
262 | config.FRAME_RATE, )
263 | sample_idx += 1
264 |
--------------------------------------------------------------------------------
/jpg/Readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/jpg/sdnet.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aispeech-lab/SDNet/d057e6d2524b1487d65d4473499d50ef935a7beb/jpg/sdnet.jpeg
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from bisect import bisect_right
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class _LRScheduler(object):
7 | def __init__(self, optimizer, last_epoch=-1):
8 | if not isinstance(optimizer, Optimizer):
9 | raise TypeError('{} is not an Optimizer'.format(
10 | type(optimizer).__name__))
11 | self.optimizer = optimizer
12 | if last_epoch == -1:
13 | for group in optimizer.param_groups:
14 | group.setdefault('initial_lr', group['lr'])
15 | else:
16 | for i, group in enumerate(optimizer.param_groups):
17 | if 'initial_lr' not in group:
18 | raise KeyError("param 'initial_lr' is not specified "
19 | "in param_groups[{}] when resuming an optimizer".format(i))
20 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
21 | self.step(last_epoch + 1)
22 | self.last_epoch = last_epoch
23 |
24 | def get_lr(self):
25 | raise NotImplementedError
26 |
27 | def step(self, epoch=None):
28 | if epoch is None:
29 | epoch = self.last_epoch + 1
30 | self.last_epoch = epoch
31 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
32 | param_group['lr'] = lr
33 |
34 |
35 | class LambdaLR(_LRScheduler):
36 | """Sets the learning rate of each parameter group to the initial lr
37 | times a given function. When last_epoch=-1, sets initial lr as lr.
38 | Args:
39 | optimizer (Optimizer): Wrapped optimizer.
40 | lr_lambda (function or list): A function which computes a multiplicative
41 | factor given an integer parameter epoch, or a list of such
42 | functions, one for each group in optimizer.param_groups.
43 | last_epoch (int): The index of last epoch. Default: -1.
44 | Example:
45 | >>> # Assuming optimizer has two groups.
46 | >>> lambda1 = lambda epoch: epoch // 30
47 | >>> lambda2 = lambda epoch: 0.95 ** epoch
48 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
49 | >>> for epoch in range(100):
50 | >>> scheduler.step()
51 | >>> train(...)
52 | >>> validate(...)
53 | """
54 |
55 | def __init__(self, optimizer, lr_lambda, last_epoch=-1):
56 | self.optimizer = optimizer
57 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
58 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
59 | else:
60 | if len(lr_lambda) != len(optimizer.param_groups):
61 | raise ValueError("Expected {} lr_lambdas, but got {}".format(
62 | len(optimizer.param_groups), len(lr_lambda)))
63 | self.lr_lambdas = list(lr_lambda)
64 | self.last_epoch = last_epoch
65 | super(LambdaLR, self).__init__(optimizer, last_epoch)
66 |
67 | def get_lr(self):
68 | return [base_lr * lmbda(self.last_epoch)
69 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
70 |
71 |
72 | class StepLR(_LRScheduler):
73 | """Sets the learning rate of each parameter group to the initial lr
74 | decayed by gamma every step_size epochs. When last_epoch=-1, sets
75 | initial lr as lr.
76 | Args:
77 | optimizer (Optimizer): Wrapped optimizer.
78 | step_size (int): Period of learning rate decay.
79 | gamma (float): Multiplicative factor of learning rate decay.
80 | Default: 0.1.
81 | last_epoch (int): The index of last epoch. Default: -1.
82 | Example:
83 | >>> # Assuming optimizer uses lr = 0.5 for all groups
84 | >>> # lr = 0.05 if epoch < 30
85 | >>> # lr = 0.005 if 30 <= epoch < 60
86 | >>> # lr = 0.0005 if 60 <= epoch < 90
87 | >>> # ...
88 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
89 | >>> for epoch in range(100):
90 | >>> scheduler.step()
91 | >>> train(...)
92 | >>> validate(...)
93 | """
94 |
95 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
96 | self.step_size = step_size
97 | self.gamma = gamma
98 | super(StepLR, self).__init__(optimizer, last_epoch)
99 |
100 | def get_lr(self):
101 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
102 | for base_lr in self.base_lrs]
103 |
104 |
105 | class MultiStepLR(_LRScheduler):
106 | """Set the learning rate of each parameter group to the initial lr decayed
107 | by gamma once the number of epoch reaches one of the milestones. When
108 | last_epoch=-1, sets initial lr as lr.
109 | Args:
110 | optimizer (Optimizer): Wrapped optimizer.
111 | milestones (list): List of epoch indices. Must be increasing.
112 | gamma (float): Multiplicative factor of learning rate decay.
113 | Default: 0.1.
114 | last_epoch (int): The index of last epoch. Default: -1.
115 | Example:
116 | >>> # Assuming optimizer uses lr = 0.5 for all groups
117 | >>> # lr = 0.05 if epoch < 30
118 | >>> # lr = 0.005 if 30 <= epoch < 80
119 | >>> # lr = 0.0005 if epoch >= 80
120 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
121 | >>> for epoch in range(100):
122 | >>> scheduler.step()
123 | >>> train(...)
124 | >>> validate(...)
125 | """
126 |
127 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
128 | if not list(milestones) == sorted(milestones):
129 | raise ValueError('Milestones should be a list of'
130 | ' increasing integers. Got {}', milestones)
131 | self.milestones = milestones
132 | self.gamma = gamma
133 | super(MultiStepLR, self).__init__(optimizer, last_epoch)
134 |
135 | def get_lr(self):
136 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
137 | for base_lr in self.base_lrs]
138 |
139 |
140 | class ExponentialLR(_LRScheduler):
141 | """Set the learning rate of each parameter group to the initial lr decayed
142 | by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
143 | Args:
144 | optimizer (Optimizer): Wrapped optimizer.
145 | gamma (float): Multiplicative factor of learning rate decay.
146 | last_epoch (int): The index of last epoch. Default: -1.
147 | """
148 |
149 | def __init__(self, optimizer, gamma, last_epoch=-1):
150 | self.gamma = gamma
151 | super(ExponentialLR, self).__init__(optimizer, last_epoch)
152 |
153 | def get_lr(self):
154 | return [base_lr * self.gamma ** self.last_epoch
155 | for base_lr in self.base_lrs]
156 |
157 |
158 | class CosineAnnealingLR(_LRScheduler):
159 | """Set the learning rate of each parameter group using a cosine annealing
160 | schedule, where :math:`\eta_{max}` is set to the initial lr and
161 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
162 | .. math::
163 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
164 | \cos(\frac{T_{cur}}{T_{max}}\pi))
165 | When last_epoch=-1, sets initial lr as lr.
166 | It has been proposed in
167 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
168 | implements the cosine annealing part of SGDR, and not the restarts.
169 | Args:
170 | optimizer (Optimizer): Wrapped optimizer.
171 | T_max (int): Maximum number of iterations.
172 | eta_min (float): Minimum learning rate. Default: 0.
173 | last_epoch (int): The index of last epoch. Default: -1.
174 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
175 | https://arxiv.org/abs/1608.03983
176 | """
177 |
178 | def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
179 | self.T_max = T_max
180 | self.eta_min = eta_min
181 | super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
182 |
183 | def get_lr(self):
184 | return [self.eta_min + (base_lr - self.eta_min) *
185 | (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
186 | for base_lr in self.base_lrs]
187 |
188 |
189 | class ReduceLROnPlateau(object):
190 | """Reduce learning rate when a metric has stopped improving.
191 | Models often benefit from reducing the learning rate by a factor
192 | of 2-10 once learning stagnates. This scheduler reads a metrics
193 | quantity and if no improvement is seen for a 'patience' number
194 | of epochs, the learning rate is reduced.
195 | Args:
196 | optimizer (Optimizer): Wrapped optimizer.
197 | mode (str): One of `min`, `max`. In `min` mode, lr will
198 | be reduced when the quantity monitored has stopped
199 | decreasing; in `max` mode it will be reduced when the
200 | quantity monitored has stopped increasing. Default: 'min'.
201 | factor (float): Factor by which the learning rate will be
202 | reduced. new_lr = lr * factor. Default: 0.1.
203 | patience (int): Number of epochs with no improvement after
204 | which learning rate will be reduced. Default: 10.
205 | verbose (bool): If True, prints a message to stdout for
206 | each update. Default: False.
207 | threshold (float): Threshold for measuring the new optimum,
208 | to only focus on significant changes. Default: 1e-4.
209 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
210 | dynamic_threshold = best * ( 1 + threshold ) in 'max'
211 | mode or best * ( 1 - threshold ) in `min` mode.
212 | In `abs` mode, dynamic_threshold = best + threshold in
213 | `max` mode or best - threshold in `min` mode. Default: 'rel'.
214 | cooldown (int): Number of epochs to wait before resuming
215 | normal operation after lr has been reduced. Default: 0.
216 | min_lr (float or list): A scalar or a list of scalars. A
217 | lower bound on the learning rate of all param groups
218 | or each group respectively. Default: 0.
219 | eps (float): Minimal decay applied to lr. If the difference
220 | between new and old lr is smaller than eps, the update is
221 | ignored. Default: 1e-8.
222 | Example:
223 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
224 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
225 | >>> for epoch in range(10):
226 | >>> train(...)
227 | >>> val_loss = validate(...)
228 | >>> # Note that step should be called after validate()
229 | >>> scheduler.step(val_loss)
230 | """
231 |
232 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
233 | verbose=False, threshold=1e-4, threshold_mode='rel',
234 | cooldown=0, min_lr=0, eps=1e-8):
235 |
236 | if factor >= 1.0:
237 | raise ValueError('Factor should be < 1.0.')
238 | self.factor = factor
239 |
240 | if not isinstance(optimizer, Optimizer):
241 | raise TypeError('{} is not an Optimizer'.format(
242 | type(optimizer).__name__))
243 | self.optimizer = optimizer
244 |
245 | if isinstance(min_lr, list) or isinstance(min_lr, tuple):
246 | if len(min_lr) != len(optimizer.param_groups):
247 | raise ValueError("expected {} min_lrs, got {}".format(
248 | len(optimizer.param_groups), len(min_lr)))
249 | self.min_lrs = list(min_lr)
250 | else:
251 | self.min_lrs = [min_lr] * len(optimizer.param_groups)
252 |
253 | self.patience = patience
254 | self.verbose = verbose
255 | self.cooldown = cooldown
256 | self.cooldown_counter = 0
257 | self.mode = mode
258 | self.threshold = threshold
259 | self.threshold_mode = threshold_mode
260 | self.best = None
261 | self.num_bad_epochs = None
262 | self.mode_worse = None # the worse value for the chosen mode
263 | self.is_better = None
264 | self.eps = eps
265 | self.last_epoch = -1
266 | self._init_is_better(mode=mode, threshold=threshold,
267 | threshold_mode=threshold_mode)
268 | self._reset()
269 |
270 | def _reset(self):
271 | """Resets num_bad_epochs counter and cooldown counter."""
272 | self.best = self.mode_worse
273 | self.cooldown_counter = 0
274 | self.num_bad_epochs = 0
275 |
276 | def step(self, metrics, epoch=None):
277 | current = metrics
278 | if epoch is None:
279 | epoch = self.last_epoch = self.last_epoch + 1
280 | self.last_epoch = epoch
281 |
282 | if self.is_better(current, self.best):
283 | self.best = current
284 | self.num_bad_epochs = 0
285 | else:
286 | self.num_bad_epochs += 1
287 |
288 | if self.in_cooldown:
289 | self.cooldown_counter -= 1
290 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
291 |
292 | if self.num_bad_epochs > self.patience:
293 | self._reduce_lr(epoch)
294 | self.cooldown_counter = self.cooldown
295 | self.num_bad_epochs = 0
296 |
297 | def _reduce_lr(self, epoch):
298 | for i, param_group in enumerate(self.optimizer.param_groups):
299 | old_lr = float(param_group['lr'])
300 | new_lr = max(old_lr * self.factor, self.min_lrs[i])
301 | if old_lr - new_lr > self.eps:
302 | param_group['lr'] = new_lr
303 | if self.verbose:
304 | print('Epoch {:5d}: reducing learning rate'
305 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
306 |
307 | @property
308 | def in_cooldown(self):
309 | return self.cooldown_counter > 0
310 |
311 | def _init_is_better(self, mode, threshold, threshold_mode):
312 | if mode not in {'min', 'max'}:
313 | raise ValueError('mode ' + mode + ' is unknown!')
314 | if threshold_mode not in {'rel', 'abs'}:
315 | raise ValueError('threshold mode ' + mode + ' is unknown!')
316 | if mode == 'min' and threshold_mode == 'rel':
317 | rel_epsilon = 1. - threshold
318 | self.is_better = lambda a, best: a < best * rel_epsilon
319 | self.mode_worse = float('Inf')
320 | elif mode == 'min' and threshold_mode == 'abs':
321 | self.is_better = lambda a, best: a < best - threshold
322 | self.mode_worse = float('Inf')
323 | elif mode == 'max' and threshold_mode == 'rel':
324 | rel_epsilon = threshold + 1.
325 | self.is_better = lambda a, best: a > best * rel_epsilon
326 | self.mode_worse = -float('Inf')
327 | else: # mode == 'max' and epsilon_mode == 'abs':
328 | self.is_better = lambda a, best: a > best + threshold
329 | self.mode_worse = -float('Inf')
330 |
--------------------------------------------------------------------------------
/models/Readme:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/Schmidt_orth.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 |
4 |
5 | def schmidt(this_vec, vectors):
6 | # this_vector是[bs,hidden_emb]
7 | # vectors是个列表,每一个里面应该是this_vector这么大的东西
8 | if len(vectors) == 0:
9 | return this_vec
10 | else:
11 | for vec in vectors:
12 | assert len(vec.size()) == len(this_vec.size()) == 2
13 | dot = torch.bmm(this_vec.unsqueeze(1), vec.unsqueeze(-1)).squeeze(-1)
14 | norm = torch.bmm(vec.unsqueeze(1), vec.unsqueeze(-1)).squeeze(-1)
15 | frac = dot / norm # bs,1
16 | this_vec = this_vec - (frac * vec)
17 | # print 'final_vec:',this_vec
18 | return this_vec
19 |
--------------------------------------------------------------------------------
/models/WaveLoss.py:
--------------------------------------------------------------------------------
1 | #coding=utf8
2 | import torch
3 | import torch.nn as nn
4 | import soundfile as sf
5 | import resampy
6 | # from prepare_data_wsj2 import linearspectrogram
7 | import librosa
8 | import numpy as np
9 | from models.istft_irfft import istft_irfft
10 |
11 | # 参考https://github.com/jonlu0602/DeepDenoisingAutoencoder/blob/master/python/utils.py
12 | def linearspectrogram(y, dBscale = 1, normalize=1):
13 | fft_size = 256
14 | hop_size = 128
15 | ref_db = 20
16 | max_db = 100
17 | D = librosa.core.spectrum.stft(y, fft_size, hop_size) # F, T
18 | F, T = D.shape
19 | S = np.abs(D)
20 | if dBscale:
21 | S = librosa.amplitude_to_db(S)
22 | if normalize:
23 | # normalization
24 | S = np.clip((S - ref_db + max_db) / max_db, 1e-8, 1)
25 | return S, np.angle(D)
26 |
27 | def concatenateFeature(inputList, dim):
28 | out = inputList[0]
29 | for i in range(1, len(inputList)):
30 | out = torch.cat((out, inputList[i]), dim=dim)
31 | return out
32 |
33 | class WaveLoss(nn.Module):
34 | def __init__(self, dBscale = 1, denormalize=1, max_db=100, ref_db=20, nfft=256, hop_size=128):
35 | super(WaveLoss, self).__init__()
36 | self.dBscale = dBscale
37 | self.denormalize = denormalize
38 | self.max_db = max_db
39 | self.ref_db = ref_db
40 | self.nfft = nfft
41 | self.hop_size = hop_size
42 | self.mse_loss = nn.MSELoss()
43 |
44 | def genWav(self, S, phase):
45 | '''
46 | :param S: (B, F-1, T) to be padded with 0 in this function
47 | :param phase: (B, F, T)
48 | :return: (B, num_samples)
49 | '''
50 | if self.dBscale:
51 | if self.denormalize:
52 | # denormalization
53 | S = S * self.max_db - self.max_db + self.ref_db
54 | # to amplitude
55 | # https://github.com/pytorch/pytorch/issues/12426
56 | # RuntimeError: the derivative for pow is not implemented
57 | # S = torch.pow(10, S * 0.05)
58 | S = 10 ** (S * 0.05)
59 |
60 | # pad with 0
61 | B, F, T = S.shape
62 | pad = torch.zeros(B, 1, T).to(S.device)
63 | # 注意tensor要同一类型
64 | Sfull = concatenateFeature([S, pad], dim=-2) # 由于前面预测少了一个维度的频率,所以这里补0
65 |
66 | # deal with the complex
67 | Sfull_ = Sfull.data.cpu().numpy()
68 | phase_ = phase.data.cpu().numpy()
69 | Sfull_spec = Sfull_ * np.exp(1.0j * phase_)
70 | S_sign = np.sign(np.real(Sfull_spec))
71 | S_sign = torch.from_numpy(S_sign).to(S.device)
72 | Sfull_spec_imag = np.imag(Sfull_spec)
73 | Sfull_spec_imag = torch.from_numpy(Sfull_spec_imag).unsqueeze(-1).to(S.device)
74 | Sfull = torch.mul(Sfull, S_sign).unsqueeze(-1)
75 | # print(Sfull.shape)
76 | # print(Sfull_spec_imag.shape)
77 | stft_matrix = concatenateFeature([Sfull, Sfull_spec_imag], dim=-1) # (B, F, T, 2)
78 | # print(stft_matrix.shape)
79 |
80 | wav = istft_irfft(stft_matrix, hop_length=self.hop_size, win_length=self.nfft)
81 | return wav
82 |
83 | def forward(self, target_mag, target_phase, pred_mag, pred_phase):
84 | '''
85 | :param target_mag: (B, F-1, T)
86 | :param target_phase: (B, F, T)
87 | :param pred_mag: (B, F-1, T)
88 | :param pred_phase: (B, F, T)
89 | :return:
90 | '''
91 | target_wav = self.genWav(target_mag, target_phase)
92 | pred_wav = self.genWav(pred_mag, pred_phase)
93 |
94 | # target_wav_arr = target_wav.squeeze(0).cpu().data.numpy()
95 | # pred_wav_arr = pred_wav.squeeze(0).cpu().data.numpy()
96 | # print('target wav arr', target_wav_arr.shape)
97 | # sf.write('target.wav', target_wav_arr, 8000)
98 | # sf.write('pred.wav', pred_wav_arr, 8000)
99 |
100 | loss = self.mse_loss(target_wav, pred_wav)
101 | return loss
102 |
103 | if __name__ == '__main__':
104 | wav_f = 'test.wav'
105 | wav_fo = 'test_o.wav'
106 | def read_wav(f):
107 | wav, sr = sf.read(f)
108 | if len(wav.shape) > 1:
109 | wav = wav[:, 0]
110 | if sr != 8000:
111 | wav = resampy.resample(wav, sr, 8000)
112 | spec, phase = linearspectrogram(wav)
113 | return spec, phase, wav
114 | target_spec , target_phase, wav1 = read_wav(wav_f)
115 | pred_spec, pred_phase, wav2 = read_wav(wav_fo)
116 | # print('librosa,', pred_spec.shape)
117 |
118 | librosa_stft = librosa.stft(wav1, n_fft=256, hop_length=128, window='hann')
119 | # print('librosa stft', librosa_stft.shape)
120 | # print('librosa.stft', librosa_stft)
121 | _magnitude = np.abs(librosa_stft)
122 | # print('mag,', _magnitude.shape)
123 | wav_re_librosa = librosa.core.spectrum.istft(librosa_stft, hop_length=128)
124 | sf.write('wav_re_librosa.wav', wav_re_librosa, 8000)
125 |
126 | def clip_spec(spec, phase):
127 | spec_clip = spec[0:-1, :410] # (F-1, T)
128 | phase_clip = phase[:, :410]
129 | spec_tensor = torch.from_numpy(spec_clip)
130 | spec_tensor = spec_tensor.unsqueeze(0) # (B, T, F)
131 | # print(spec_tensor.shape)
132 | phase_tensor = torch.from_numpy(phase_clip).unsqueeze(0)
133 | # print(phase_tensor.shape)
134 | return spec_tensor, phase_tensor
135 | target_spec_tensor, target_phase_tensor = clip_spec(target_spec, target_phase)
136 | pred_spec_tensor, pred_phase_tensor = clip_spec(pred_spec, pred_phase)
137 |
138 | wav_loss = WaveLoss(dBscale=1, nfft=256, hop_size=128)
139 | loss = wav_loss(target_spec_tensor, target_phase_tensor, pred_spec_tensor, pred_phase_tensor)
140 | print('loss', loss.item())
141 |
142 | wav1 = torch.FloatTensor(wav1)
143 | torch_stft_matrix = torch.stft(wav1, n_fft=256, hop_length=128, window=torch.hann_window(256))
144 | torch_stft_matrix = torch_stft_matrix.unsqueeze(0)
145 | # print('torch stft', torch_stft_matrix.shape)
146 | # print(torch_stft_matrix[:,:,:,0])
147 | # print(torch_stft_matrix[:,:,:,1])
148 | wav_re = istft_irfft(torch_stft_matrix, hop_length=128, win_length=256)
149 | wav_re = wav_re.squeeze(0).cpu().data.numpy()
150 | # print('wav_re', wav_re.shape)
151 | sf.write('wav_re.wav', wav_re, 8000)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.attention import *
2 | from models.rnn import *
3 | from models.seq2seq import *
4 | from models.separation_dis import *
5 | from models.separation_tasnet import *
6 | from models.loss import *
7 | from models.beam import *
8 | from models.Schmidt_orth import *
9 | from models.metrics import *
10 | from models.focal_loss import *
11 | from models.WaveLoss import *
12 |
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | from torch.nn.utils.rnn import pack_padded_sequence as pack
6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
7 | import data.dict as dict
8 |
9 |
10 | class global_attention(nn.Module):
11 |
12 | def __init__(self, hidden_size, activation=None):
13 | super(global_attention, self).__init__()
14 | self.linear_in = nn.Linear(hidden_size, hidden_size)
15 | self.linear_out = nn.Linear(2 * hidden_size, hidden_size,bias=1)
16 | self.softmax = nn.Softmax()
17 | # self.batchnorm=nn.BatchNorm1d(hidden_size)
18 | self.tanh = nn.Tanh()
19 | self.activation = activation
20 |
21 | def forward(self, x, context):
22 | gamma_h = self.linear_in(x).unsqueeze(2) # unsequeee这个函数相当于直接reshape多出来一维度,值得学习。 # batch * size * 1
23 | if self.activation == 'tanh':
24 | gamma_h = self.tanh(gamma_h)
25 | weights = torch.bmm(context, gamma_h).squeeze(2) # batch * time
26 | weights = self.softmax(weights) # batch * time
27 | c_t = torch.bmm(weights.unsqueeze(1), context).squeeze(1) # batch * size
28 | output = self.linear_out(torch.cat([c_t, x], 1)) #添加额外的batchnorm
29 | # output = self.batchnorm(output)
30 | output = self.tanh(output)
31 | return output, weights
32 |
--------------------------------------------------------------------------------
/models/beam.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 |
4 |
5 | # import data.dict_spk2idx[as dict
6 |
7 | class Beam(object):
8 | def __init__(self, size, dict_spk2idx, n_best=1, cuda=True):
9 | self.dict_spk2idx = dict_spk2idx
10 | self.size = size
11 | self.tt = torch.cuda if cuda else torch
12 |
13 | # The score for each translation on the beam.
14 | self.scores = self.tt.FloatTensor(size).zero_()
15 | self.allScores = []
16 |
17 | # The backpointers at each time-step.
18 | self.prevKs = []
19 |
20 | # The outputs at each time-step.
21 | self.nextYs = [self.tt.LongTensor(size)
22 | .fill_(dict_spk2idx[''])]
23 | self.nextYs[0][0] = dict_spk2idx['']
24 | # Has EOS topped the beam yet.
25 | self._eos = dict_spk2idx['']
26 | self.eosTop = False
27 |
28 | # The attentions (matrix) for each time.
29 | self.attn = []
30 |
31 | # The last hiddens(matrix) for each time.
32 | self.hiddens = []
33 |
34 | # The last hiddens(matrix) for each time.
35 | self.sch_hiddens = []
36 |
37 | # The last embs(matrix) for each time.
38 | self.embs = []
39 |
40 | # Time and k pair for finished.
41 | self.finished = []
42 | self.n_best = n_best
43 |
44 | def updates_sch_embeddings(self, hiddens_this_step):
45 | if len(self.sch_hiddens) == 0:
46 | self.sch_hiddens.append([hiddens_this_step])
47 | else:
48 | self.sch_hiddens.append(self.sch_hiddens[-1] + [hiddens_this_step])
49 |
50 | def getCurrentState(self):
51 | "Get the outputs for the current timestep."
52 | return self.nextYs[-1]
53 |
54 | def getCurrentOrigin(self):
55 | "Get the backpointers for the current timestep."
56 | return self.prevKs[-1]
57 |
58 | def advance(self, wordLk, attnOut, hidden, emb):
59 | """
60 | Given prob over words for every last beam `wordLk` and attention
61 | `attnOut`: Compute and update the beam search.
62 | Parameters:
63 | * `wordLk`- probs of advancing from the last step (K x words)
64 | * `attnOut`- attention at the last step
65 | Returns: True if beam search is complete.
66 | """
67 | numWords = wordLk.size(1)
68 |
69 | # Sum the previous scores.
70 | if len(self.prevKs) > 0:
71 | beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
72 | # Don't let EOS have children.
73 | for i in range(self.nextYs[-1].size(0)):
74 | if self.nextYs[-1][i] == self._eos:
75 | beamLk[i] = -1e20
76 | else:
77 | beamLk = wordLk[0]
78 | flatBeamLk = beamLk.view(-1)
79 | bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
80 |
81 | self.allScores.append(self.scores)
82 | self.scores = bestScores
83 |
84 | # bestScoresId is flattened beam x word array, so calculate which
85 | # word and beam each score came from
86 | prevK = bestScoresId / numWords
87 | self.prevKs.append(prevK)
88 | self.nextYs.append((bestScoresId - prevK * numWords))
89 | self.attn.append(attnOut.index_select(0, prevK))
90 | self.hiddens.append(hidden.index_select(0, prevK))
91 | self.updates_sch_embeddings(hidden.index_select(0, prevK))
92 | self.embs.append(emb.index_select(0, prevK))
93 |
94 | for i in range(self.nextYs[-1].size(0)):
95 | if self.nextYs[-1][i] == self._eos:
96 | s = self.scores[i]
97 | self.finished.append((s, len(self.nextYs) - 1, i))
98 |
99 | # End condition is when top-of-beam is '' and no global score.
100 | if self.nextYs[-1][0] == self.dict_spk2idx['']:
101 | # self.allScores.append(self.scores)
102 | self.eosTop = True
103 |
104 | def done(self):
105 | return self.eosTop and len(self.finished) >= self.n_best
106 |
107 | def beam_update(self, state, idx):
108 | positions = self.getCurrentOrigin()
109 | for e in state:
110 | a, br, d = e.size()
111 | e = e.view(a, self.size, br // self.size, d)
112 | sentStates = e[:, :, idx]
113 | sentStates.data.copy_(sentStates.data.index_select(1, positions))
114 |
115 | def beam_update_context(self, state, idx):
116 | positions = self.getCurrentOrigin()
117 | e = state.unsqueeze(0)
118 | a, br, len, d, = e.size()
119 | e = e.view(a, self.size, br // self.size, len, d)
120 | sentStates = e[:, :, idx]
121 | sentStates.data.copy_(sentStates.data.index_select(1, positions))
122 |
123 | def beam_update_hidden(self, state, idx):
124 | positions = self.getCurrentOrigin()
125 | e = state
126 | a, br, d = e.size()
127 | e = e.view(a, self.size, br // self.size, d)
128 | sentStates = e[:, :, idx]
129 | sentStates.data.copy_(sentStates.data.index_select(1, positions))
130 |
131 | def sortFinished(self, minimum=None):
132 | if minimum is not None:
133 | i = 0
134 | # Add from beam until we have minimum outputs.
135 | while len(self.finished) < minimum:
136 | s = self.scores[i]
137 | self.finished.append((s, len(self.nextYs) - 1, i))
138 |
139 | self.finished.sort(key=lambda a: -a[0])
140 | scores = [sc for sc, _, _ in self.finished]
141 | ks = [(t, k) for _, t, k in self.finished]
142 | return scores, ks
143 |
144 | def getHyp(self, timestep, k):
145 | """
146 | Walk back to construct the full hypothesis.
147 | """
148 | hyp, attn, hidden, emb = [], [], [], []
149 | for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
150 | hyp.append(self.nextYs[j + 1][k])
151 | attn.append(self.attn[j][k])
152 | hidden.append(self.hiddens[j][k])
153 | emb.append(self.embs[j][k])
154 | k = self.prevKs[j][k]
155 | return hyp[::-1], torch.stack(attn[::-1]), torch.stack(hidden[::-1]), torch.stack(emb[::-1])
156 |
--------------------------------------------------------------------------------
/models/focal_loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class FocalLoss(nn.Module):
7 |
8 | def __init__(self, gamma=0, eps=1e-7):
9 | super(FocalLoss, self).__init__()
10 | self.gamma = gamma
11 | self.eps = eps
12 | self.ce = torch.nn.CrossEntropyLoss()
13 |
14 | def forward(self, input, target):
15 | logp = self.ce(input, target)
16 | p = torch.exp(-logp)
17 | loss = (1 - p) ** self.gamma * logp
18 | return loss.mean()
--------------------------------------------------------------------------------
/models/istft_irfft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import librosa
3 |
4 | # this is Keunwoo Choi's implementation of istft.
5 | # https://gist.github.com/keunwoochoi/2f349e72cc941f6f10d4adf9b0d3f37e#file-istft-torch-py
6 | def istft_irfft(stft_matrix, length=None, hop_length=None, win_length=None, window='hann',
7 | center=True, normalized=False, onesided=True):
8 | """stft_matrix = (batch, freq, time, complex)
9 |
10 | All based on librosa
11 | - http://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#istft
12 | What's missing?
13 | - normalize by sum of squared window --> do we need it here?
14 | Actually the result is ok by simply dividing y by 2.
15 | """
16 | assert normalized == False
17 | assert onesided == True
18 | assert window == "hann"
19 | assert center == True
20 |
21 | device = stft_matrix.device
22 | n_fft = 2 * (stft_matrix.shape[-3] - 1)
23 |
24 | batch = stft_matrix.shape[0]
25 |
26 | # By default, use the entire frame
27 | if win_length is None:
28 | win_length = n_fft
29 |
30 | if hop_length is None:
31 | hop_length = int(win_length // 4)
32 |
33 | istft_window = torch.hann_window(n_fft).to(device).view(1, -1) # (batch, freq)
34 |
35 | n_frames = stft_matrix.shape[-2]
36 | expected_signal_len = n_fft + hop_length * (n_frames - 1)
37 |
38 | y = torch.zeros(batch, expected_signal_len, device=device)
39 | for i in range(n_frames):
40 | sample = i * hop_length
41 | spec = stft_matrix[:, :, i]
42 | iffted = torch.irfft(spec, signal_ndim=1, signal_sizes=(win_length,))
43 |
44 | ytmp = istft_window * iffted
45 | y[:, sample:(sample+n_fft)] += ytmp
46 |
47 | y = y[:, n_fft//2:]
48 |
49 | if length is not None:
50 | if y.shape[1] > length:
51 | y = y[:, :length]
52 | elif y.shape[1] < length:
53 | y = torch.cat(y[:, :length], torch.zeros(y.shape[0], length - y.shape[1], device=y.device))
54 | coeff = n_fft/float(hop_length) / 2.0 # -> this might go wrong if curretnly asserted values (especially, `normalized`) changes.
55 | return y / coeff
56 |
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import data.dict as dict
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | import models.focal_loss as focal_loss
9 | from itertools import permutations
10 |
11 |
12 | EPS = 1e-8
13 | def rank_feas(raw_tgt, feas_list, out_type='torch'):
14 | final_num = []
15 | for each_feas, each_line in zip(feas_list, raw_tgt):
16 | for spk in each_line:
17 | final_num.append(each_feas[spk])
18 | # 目标就是这个batch里一共有多少条比如 1spk 3spk 2spk,最后就是6个spk的特征
19 | if out_type=='numpy':
20 | return np.array(final_num)
21 | else:
22 | return torch.from_numpy(np.array(final_num))
23 |
24 |
25 | def criterion(tgt_vocab_size, use_cuda, loss):
26 | weight = torch.ones(tgt_vocab_size)
27 | weight[dict.PAD] = 0
28 | if loss=='focal_loss':
29 | crit = focal_loss.FocalLoss(gamma=2)
30 | else:
31 | crit = nn.CrossEntropyLoss(weight, size_average=False)
32 | if use_cuda:
33 | crit.cuda()
34 | return crit
35 |
36 | def criterion_dir(tgt_vocab_size, use_cuda, loss):
37 | weight = torch.ones(tgt_vocab_size)
38 | weight[dict.PAD] = 0
39 | if loss=='focal_loss':
40 | crit = focal_loss.FocalLoss(gamma=2)
41 | else:
42 | crit = nn.CrossEntropyLoss(weight, size_average=False)
43 | if use_cuda:
44 | crit.cuda()
45 | return crit
46 |
47 |
48 | def memory_efficiency_cross_entropy_loss(hidden_outputs, decoder, targets, criterion, config):
49 | outputs = Variable(hidden_outputs.data, requires_grad=True, volatile=False)
50 | num_total, num_correct, loss = 0, 0, 0
51 |
52 | outputs_split = torch.split(outputs, config.max_generator_batches)
53 | targets_split = torch.split(targets, config.max_generator_batches)
54 | for i, (out_t, targ_t) in enumerate(zip(outputs_split, targets_split)):
55 | out_t = out_t.view(-1, out_t.size(2))
56 | scores_t = decoder.compute_score(out_t)
57 | loss_t = criterion(scores_t, targ_t.view(-1))
58 | pred_t = scores_t.max(1)[1]
59 | num_correct_t = pred_t.data.eq(targ_t.data).masked_select(targ_t.ne(dict.PAD).data).sum()
60 | num_total_t = targ_t.ne(dict.PAD).data.sum()
61 | num_correct += num_correct_t
62 | num_total += num_total_t
63 | loss += loss_t.data[0]
64 | loss_t.div(num_total_t).backward()
65 |
66 | grad_output = outputs.grad.data
67 | hidden_outputs.backward(grad_output)
68 |
69 | return loss, num_total, num_correct, config.tgt_vocab, config.tgt_vocab
70 |
71 |
72 | def cross_entropy_loss(hidden_outputs, decoder, targets, criterion, config, sim_score=0):
73 | # hidden_outputs:[max_len,bs,512]
74 | batch_size= targets.size()[1]
75 | targets=targets.view(-1)
76 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2))
77 | scores = decoder.compute_score(outputs)
78 | loss = criterion(scores, targets.view(-1)) + sim_score
79 | pred = scores.max(1)[1]
80 | num_correct = pred.data.eq(targets.data).masked_select(targets.ne(dict.PAD).data).sum()
81 | # num_correct = pred.data.eq(targets.data).masked_select(targets.ne(targets[-1]).data).sum()
82 | num_total = float(targets.ne(dict.PAD).data.sum())
83 | loss *= batch_size
84 | loss = loss.div(num_total)
85 | # loss = loss.data[0]
86 |
87 | return loss, num_total, num_correct
88 |
89 | def cross_entropy_loss_dir(hidden_outputs, decoder, targets, criterion, config, sim_score=0):
90 | batch_size= targets.size()[1]
91 | targets=targets.view(-1)
92 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2))
93 | scores = decoder.compute_score_dir(outputs)
94 | #print("scores:",scores.size())
95 | #print("targets:",targets.view(-1))
96 | loss = criterion(scores, targets.view(-1)) + sim_score
97 | pred = scores.max(1)[1]
98 | num_correct = pred.data.eq(targets.data).masked_select(targets.ne(dict.PAD).data).sum()
99 | num_total = float(targets.ne(dict.PAD).data.sum())
100 | loss *= batch_size
101 | loss = loss.div(num_total)
102 | # loss = loss.data[0]
103 |
104 | return loss, num_total, num_correct
105 |
106 | def mmse_loss(hidden_outputs, decoder, targets, mse_loss, softmax):
107 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2))
108 | scores = softmax(decoder.compute_score_dir(outputs))
109 | targets = targets.view(-1)
110 | target_one_hot = torch.zeros(scores.size(0), dir_vocab_size).scatter_(1, targets, 1)
111 | #scores = linear(scores)
112 | print("score.size",scores)
113 | print("targets.size",targets)
114 | loss = mse_loss(scores.float(), targets.float())
115 |
116 | return loss
117 |
118 | def mmse_loss2(hidden_outputs, decoder, targets, mse_loss):
119 | print("hidden_outputs", hidden_outputs.size())
120 | outputs = hidden_outputs.view(-1, hidden_outputs.size(2))
121 | scores_1, scores = decoder.compute_score_dir(outputs)
122 | scores = F.sigmoid(scores)
123 | print("score.size",scores)
124 | print("targets.size",targets)
125 | loss = mse_loss(scores.view(-1).float(), targets.view(-1).float()/20)
126 |
127 | return loss
128 |
129 | def ss_loss(config, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func,wav_loss):
130 | predict_multi_map = multi_mask * x_input_map_multi
131 | # predict_multi_map=Variable(y_multi_map)
132 | y_multi_map = Variable(y_multi_map)
133 |
134 | loss_multi_speech = loss_multi_func(predict_multi_map, y_multi_map)
135 |
136 | # 各通道和为1的loss部分,应该可以更多的带来差异
137 | # y_sum_map=Variable(torch.ones(config.batch_size,config.mix_speech_len,config.speech_fre)).cuda()
138 | # predict_sum_map=torch.sum(multi_mask,1)
139 | # loss_multi_sum_speech=loss_multi_func(predict_sum_map,y_sum_map)
140 | print('loss 1 eval: ', loss_multi_speech.data.cpu().numpy())
141 | # print('losssum eval :',loss_multi_sum_speech.data.cpu().numpy()
142 | # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
143 | print('evaling multi-abs norm this eval batch:', torch.abs(y_multi_map - predict_multi_map).norm().data.cpu().numpy())
144 | # loss_multi_speech=loss_multi_speech+3*loss_multi_sum_speech
145 | print('loss for whole separation part:', loss_multi_speech.data.cpu().numpy())
146 | return loss_multi_speech
147 |
148 | def ss_tas_loss(config,predict_wav, y_multi_wav, mix_length,loss_multi_func):
149 | loss = cal_loss_with_order(y_multi_wav, predict_wav, mix_length)[0]
150 | #loss_mse = loss_multi_func(predict_wav, y_multi_wav)
151 | return loss
152 |
153 | def cal_loss_with_order(source, estimate_source, source_lengths):
154 | """
155 | Args:
156 | source: [B, C, T], B is batch size
157 | estimate_source: [B, C, T]
158 | source_lengths: [B]
159 | """
160 | # print('real Tas SNI:',source[:,:,16000:16005])
161 | # print('pre Tas SNI:',estimate_source[:,:,16000:16005])
162 | max_snr = cal_si_snr_with_order(source, estimate_source, source_lengths)
163 | loss = 0 - torch.mean(max_snr)
164 | return loss,
165 |
166 | def cal_loss_with_PIT(source, estimate_source, source_lengths):
167 | """
168 | Args:
169 | source: [B, C, T], B is batch size
170 | estimate_source: [B, C, T]
171 | source_lengths: [B]
172 | """
173 | max_snr, perms, max_snr_idx = cal_si_snr_with_pit(source,
174 | estimate_source,
175 | source_lengths)
176 | loss = 0 - torch.mean(max_snr)
177 | reorder_estimate_source = reorder_source(estimate_source, perms, max_snr_idx)
178 | return loss, max_snr, estimate_source, reorder_estimate_source
179 |
180 | def cal_si_snr_with_order(source, estimate_source, source_lengths):
181 | """Calculate SI-SNR with given order.
182 | Args:
183 | source: [B, C, T], B is batch size
184 | estimate_source: [B, C, T]
185 | source_lengths: [B], each item is between [0, T]
186 | """
187 | print("source.size()",source.size())
188 | print("estimate_source.size()",estimate_source.size())
189 | assert source.size() == estimate_source.size()
190 | B, C, T = source.size()
191 | # mask padding position along T
192 | mask = get_mask(source, source_lengths)
193 | estimate_source *= mask
194 |
195 | # Step 1. Zero-mean norm
196 | num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
197 | mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
198 | mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
199 | zero_mean_target = source - mean_target
200 | zero_mean_estimate = estimate_source - mean_estimate
201 | # mask padding position along T
202 | zero_mean_target *= mask
203 | zero_mean_estimate *= mask
204 |
205 | # Step 2. SI-SNR with order
206 | # reshape to use broadcast
207 | s_target = zero_mean_target # [B, C, T]
208 | s_estimate = zero_mean_estimate # [B, C, T]
209 | # s_target = s / ||s||^2
210 | pair_wise_dot = torch.sum(s_estimate * s_target, dim=2, keepdim=True) # [B, C, 1]
211 | s_target_energy = torch.sum(s_target ** 2, dim=2, keepdim=True) + EPS # [B, C, 1]
212 | pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, T]
213 | # e_noise = s' - s_target
214 | e_noise = s_estimate - pair_wise_proj # [B, C, T]
215 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
216 | pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=2) / (torch.sum(e_noise ** 2, dim=2) + EPS)
217 | pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C]
218 | print(pair_wise_si_snr)
219 |
220 | return torch.sum(pair_wise_si_snr,dim=1)/C
221 |
222 | def cal_si_snr_with_pit(source, estimate_source, source_lengths):
223 | """Calculate SI-SNR with PIT training.
224 | Args:
225 | source: [B, C, T], B is batch size
226 | estimate_source: [B, C, T]
227 | source_lengths: [B], each item is between [0, T]
228 | """
229 | assert source.size() == estimate_source.size()
230 | B, C, T = source.size()
231 | # mask padding position along T
232 | mask = get_mask(source, source_lengths)
233 | estimate_source *= mask
234 |
235 | # Step 1. Zero-mean norm
236 | num_samples = source_lengths.view(-1, 1, 1).float() # [B, 1, 1]
237 | mean_target = torch.sum(source, dim=2, keepdim=True) / num_samples
238 | mean_estimate = torch.sum(estimate_source, dim=2, keepdim=True) / num_samples
239 | zero_mean_target = source - mean_target
240 | zero_mean_estimate = estimate_source - mean_estimate
241 | # mask padding position along T
242 | zero_mean_target *= mask
243 | zero_mean_estimate *= mask
244 |
245 | # Step 2. SI-SNR with PIT
246 | # reshape to use broadcast
247 | s_target = torch.unsqueeze(zero_mean_target, dim=1) # [B, 1, C, T]
248 | s_estimate = torch.unsqueeze(zero_mean_estimate, dim=2) # [B, C, 1, T]
249 | # s_target = s / ||s||^2
250 | pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) # [B, C, C, 1]
251 | s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + EPS # [B, 1, C, 1]
252 | pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, C, C, T]
253 | # e_noise = s' - s_target
254 | e_noise = s_estimate - pair_wise_proj # [B, C, C, T]
255 | # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
256 | pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=3) / (torch.sum(e_noise ** 2, dim=3) + EPS)
257 | pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B, C, C]
258 |
259 | # Get max_snr of each utterance
260 | # permutations, [C!, C]
261 | perms = source.new_tensor(list(permutations(range(C))), dtype=torch.long)
262 | # one-hot, [C!, C, C]
263 | index = torch.unsqueeze(perms, 2)
264 | # perms_one_hot = source.new_zeros((*perms.size(), C)).scatter_(2, index, 1)
265 | perms_one_hot = source.new_zeros((perms.size()[0],perms.size()[1], C)).scatter_(2, index, 1)
266 | # [B, C!] <- [B, C, C] einsum [C!, C, C], SI-SNR sum of each permutation
267 | snr_set = torch.einsum('bij,pij->bp', [pair_wise_si_snr, perms_one_hot])
268 | max_snr_idx = torch.argmax(snr_set, dim=1) # [B]
269 | # max_snr = torch.gather(snr_set, 1, max_snr_idx.view(-1, 1)) # [B, 1]
270 | max_snr, _ = torch.max(snr_set, dim=1, keepdim=True)
271 | max_snr /= C
272 | return max_snr, perms, max_snr_idx
273 |
274 | def reorder_source(source, perms, max_snr_idx):
275 | """
276 | Args:
277 | source: [B, C, T]
278 | perms: [C!, C], permutations
279 | max_snr_idx: [B], each item is between [0, C!)
280 | Returns:
281 | reorder_source: [B, C, T]
282 | """
283 | # B, C, *_ = source.size()
284 | B, C, __ = source.size()
285 | # [B, C], permutation whose SI-SNR is max of each utterance
286 | # for each utterance, reorder estimate source according this permutation
287 | max_snr_perm = torch.index_select(perms, dim=0, index=max_snr_idx)
288 | # print('max_snr_perm', max_snr_perm)
289 | # maybe use torch.gather()/index_select()/scatter() to impl this?
290 | reorder_source = torch.zeros_like(source)
291 | for b in range(B):
292 | for c in range(C):
293 | reorder_source[b, c] = source[b, max_snr_perm[b][c]]
294 | return reorder_source
295 |
296 |
297 | def get_mask(source, source_lengths):
298 | """
299 | Args:
300 | source: [B, C, T]
301 | source_lengths: [B]
302 | Returns:
303 | mask: [B, 1, T]
304 | """
305 | B, _, T = source.size()
306 | mask = source.new_ones((B, 1, T))
307 | #mask = Variable(torch.ones((B, 1, T))).cuda()
308 | for i in range(B):
309 | #print("source_lengths[i]",source_lengths[i])
310 | mask[i, :, source_lengths[i]:] = 0
311 | return mask
312 | def ss_loss_MLMSE(config, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func, Var):
313 | try:
314 | if Var == None:
315 | Var = Variable(torch.eye(config.speech_fre, config.speech_fre).cuda(), requires_grad=0) # 初始化的是单位矩阵
316 | print('Set Var to:', Var)
317 | except:
318 | pass
319 | assert Var.size() == (config.speech_fre, config.speech_fre)
320 |
321 | predict_multi_map = torch.mean(multi_mask * x_input_map_multi, -2) # 在时间维度上平均
322 | # predict_multi_map=Variable(y_multi_map)
323 | y_multi_map = torch.mean(Variable(y_multi_map), -2) # 在时间维度上平均
324 |
325 | loss_vector = (y_multi_map - predict_multi_map).view(-1, config.speech_fre).unsqueeze(1) # 应该是bs*1*fre
326 |
327 | Var_inverse = torch.inverse(Var)
328 | Var_inverse = Var_inverse.unsqueeze(0).expand(loss_vector.size()[0], config.speech_fre,
329 | config.speech_fre) # 扩展成batch的形式
330 | loss_multi_speech = torch.bmm(torch.bmm(loss_vector, Var_inverse), loss_vector.transpose(1, 2))
331 | loss_multi_speech = torch.mean(loss_multi_speech, 0)
332 |
333 | # 各通道和为1的loss部分,应该可以更多的带来差异
334 | y_sum_map = Variable(torch.ones(config.batch_size, config.mix_speech_len, config.speech_fre)).cuda()
335 | predict_sum_map = torch.sum(multi_mask, 1)
336 | loss_multi_sum_speech = loss_multi_func(predict_sum_map, y_sum_map)
337 | print('loss 1 eval, losssum eval : ', loss_multi_speech.data.cpu().numpy(), loss_multi_sum_speech.data.cpu().numpy())
338 | # loss_multi_speech=loss_multi_speech+0.5*loss_multi_sum_speech
339 | print('evaling multi-abs norm this eval batch:', torch.abs(y_multi_map - predict_multi_map).norm().data.cpu().numpy())
340 | # loss_multi_speech=loss_multi_speech+3*loss_multi_sum_speech
341 | print('loss for whole separation part:', loss_multi_speech.data.cpu().numpy())
342 | # return F.relu(loss_multi_speech)
343 | return loss_multi_speech
344 |
345 |
346 | def dis_loss(config, top_k_num, dis_model, x_input_map_multi, multi_mask, y_multi_map, loss_multi_func):
347 | predict_multi_map = multi_mask * x_input_map_multi
348 | y_multi_map = Variable(y_multi_map).cuda()
349 | score_true = dis_model(y_multi_map)
350 | score_false = dis_model(predict_multi_map)
351 | acc_true = torch.sum(score_true > 0.5).data.cpu().numpy() / float(score_true.size()[0])
352 | acc_false = torch.sum(score_false < 0.5).data.cpu().numpy() / float(score_true.size()[0])
353 | acc_dis = (acc_false + acc_true) / 2
354 | print('acc for dis:(ture,false,aver)', acc_true, acc_false, acc_dis)
355 |
356 | loss_dis_true = loss_multi_func(score_true, Variable(torch.ones(config.batch_size * top_k_num, 1)).cuda())
357 | loss_dis_false = loss_multi_func(score_false, Variable(torch.zeros(config.batch_size * top_k_num, 1)).cuda())
358 | loss_dis = loss_dis_true + loss_dis_false
359 | print('loss for dis:(ture,false)', loss_dis_true.data.cpu().numpy(), loss_dis_false.data.cpu().numpy())
360 | return loss_dis
361 |
--------------------------------------------------------------------------------
/models/metrics.py:
--------------------------------------------------------------------------------
1 | #coding=utf8
2 | from __future__ import print_function
3 | from __future__ import division
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn import Parameter
8 | import math
9 |
10 |
11 | class ArcMarginProduct(nn.Module):
12 | r"""Implement of large margin arc distance: :
13 | Args:
14 | in_features: size of each input sample
15 | out_features: size of each output sample
16 | s: norm of input feature
17 | m: margin
18 |
19 | cos(theta + m)
20 | """
21 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
22 | super(ArcMarginProduct, self).__init__()
23 | self.in_features = in_features
24 | self.out_features = out_features
25 | self.s = s
26 | self.m = m
27 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
28 | nn.init.xavier_uniform(self.weight)
29 |
30 | self.easy_margin = easy_margin
31 | self.cos_m = math.cos(m)
32 | self.sin_m = math.sin(m)
33 | self.th = math.cos(math.pi - m)
34 | self.mm = math.sin(math.pi - m) * m
35 |
36 | def forward(self, input, label):
37 | # --------------------------- cos(theta) & phi(theta) ---------------------------
38 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
39 | if label is None: #如果没给label,则是要测试阶段,直接输出cosine的就可以了
40 | return cosine
41 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
42 | phi = cosine * self.cos_m - sine * self.sin_m
43 | if self.easy_margin:
44 | phi = torch.where(cosine > 0, phi, cosine)
45 | else:
46 | phi = torch.where(cosine > self.th, phi, cosine - self.mm)
47 | # --------------------------- convert label to one-hot ---------------------------
48 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
49 | one_hot = torch.zeros(cosine.size(), device='cuda')
50 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
51 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
52 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
53 | output *= self.s
54 | # print(output)
55 |
56 | return output
57 |
58 |
59 | class AddMarginProduct(nn.Module):
60 | r"""Implement of large margin cosine distance: :
61 | Args:
62 | in_features: size of each input sample
63 | out_features: size of each output sample
64 | s: norm of input feature
65 | m: margin
66 | cos(theta) - m
67 | """
68 |
69 | def __init__(self, in_features, out_features, s=30.0, m=0.40):
70 | super(AddMarginProduct, self).__init__()
71 | self.in_features = in_features
72 | self.out_features = out_features
73 | self.s = s
74 | self.m = m
75 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
76 | nn.init.xavier_uniform_(self.weight)
77 |
78 | def forward(self, input, label):
79 | # --------------------------- cos(theta) & phi(theta) ---------------------------
80 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
81 | phi = cosine - self.m
82 | # --------------------------- convert label to one-hot ---------------------------
83 | one_hot = torch.zeros(cosine.size(), device='cuda')
84 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
85 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
86 | # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
87 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
88 | output *= self.s
89 | # print(output)
90 |
91 | return output
92 |
93 | def __repr__(self):
94 | return self.__class__.__name__ + '(' \
95 | + 'in_features=' + str(self.in_features) \
96 | + ', out_features=' + str(self.out_features) \
97 | + ', s=' + str(self.s) \
98 | + ', m=' + str(self.m) + ')'
99 |
100 |
101 | class SphereProduct(nn.Module):
102 | r"""Implement of large margin cosine distance: :
103 | Args:
104 | in_features: size of each input sample
105 | out_features: size of each output sample
106 | m: margin
107 | cos(m*theta)
108 | """
109 | def __init__(self, in_features, out_features, m=4):
110 | super(SphereProduct, self).__init__()
111 | self.in_features = in_features
112 | self.out_features = out_features
113 | self.m = m
114 | self.base = 1000.0
115 | self.gamma = 0.12
116 | self.power = 1
117 | self.LambdaMin = 5.0
118 | self.iter = 0
119 | self.weight = Parameter(torch.FloatTensor(out_features, in_features))
120 | nn.init.xavier_uniform(self.weight)
121 |
122 | # duplication formula
123 | self.mlambda = [
124 | lambda x: x ** 0,
125 | lambda x: x ** 1,
126 | lambda x: 2 * x ** 2 - 1,
127 | lambda x: 4 * x ** 3 - 3 * x,
128 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
129 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
130 | ]
131 |
132 | def forward(self, input, label):
133 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
134 | self.iter += 1
135 | self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))
136 |
137 | # --------------------------- cos(theta) & phi(theta) ---------------------------
138 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
139 | cos_theta = cos_theta.clamp(-1, 1)
140 | cos_m_theta = self.mlambda[self.m](cos_theta)
141 | theta = cos_theta.data.acos()
142 | k = (self.m * theta / 3.14159265).floor()
143 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
144 | NormOfFeature = torch.norm(input, 2, 1)
145 |
146 | # --------------------------- convert label to one-hot ---------------------------
147 | one_hot = torch.zeros(cos_theta.size())
148 | one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot
149 | one_hot.scatter_(1, label.view(-1, 1), 1)
150 |
151 | # --------------------------- Calculate output ---------------------------
152 | output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
153 | output *= NormOfFeature.view(-1, 1)
154 |
155 | return output
156 |
157 | def __repr__(self):
158 | return self.__class__.__name__ + '(' \
159 | + 'in_features=' + str(self.in_features) \
160 | + ', out_features=' + str(self.out_features) \
161 | + ', m=' + str(self.m) + ')'
--------------------------------------------------------------------------------
/models/rnn.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from torch.nn.utils.rnn import pack_padded_sequence as pack
7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack
8 | import data.dict as dict
9 | import models
10 |
11 | import numpy as np
12 |
13 |
14 | class StackedLSTM(nn.Module):
15 | def __init__(self, num_layers, input_size, hidden_size, dropout):
16 | super(StackedLSTM, self).__init__()
17 | self.dropout = nn.Dropout(dropout)
18 | self.num_layers = num_layers
19 | self.layers = nn.ModuleList()
20 |
21 | for i in range(num_layers):
22 | self.layers.append(nn.LSTMCell(input_size, hidden_size))
23 | input_size = hidden_size
24 |
25 | def forward(self, input, hidden):
26 | h_0, c_0 = hidden
27 | h_1, c_1 = [], []
28 | for i, layer in enumerate(self.layers):
29 | h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
30 | input = h_1_i
31 | if i + 1 != self.num_layers:
32 | input = self.dropout(input)
33 | h_1 += [h_1_i]
34 | c_1 += [c_1_i]
35 |
36 | h_1 = torch.stack(h_1) # 把多层的LSTMCell模型的输出给组织起来了,的到了[num_layers,batch_size,hidden_size]的东西
37 | c_1 = torch.stack(c_1)
38 |
39 | return input, (h_1, c_1)
40 |
41 |
42 | class rnn_encoder(nn.Module):
43 |
44 | def __init__(self, config, input_emb_size):
45 | super(rnn_encoder, self).__init__()
46 | self.rnn = nn.LSTM(input_size=input_emb_size, hidden_size=config.encoder_hidden_size,
47 | num_layers=config.num_layers, dropout=config.dropout, bidirectional=config.bidirec)
48 | self.config = config
49 |
50 | def forward(self, input, lengths):
51 | input = input.transpose(0, 1)
52 | embs = pack(input, list(map(int, lengths))) # 这里batch是第二个维度
53 | outputs, (h, c) = self.rnn(embs)
54 | outputs = unpack(outputs)[0]
55 | if not self.config.bidirec:
56 | return outputs, (h, c) # h,c是最后一个step的,大小是(num_layers * num_directions, batch, hidden_size)
57 | else:
58 | batch_size = h.size(1)
59 | h = h.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size)
60 | c = c.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size)
61 | state = (h.transpose(0, 1), c.transpose(0, 1)) # 每一个元素是 (num_layers,batch,2*hidden_size)这么大。
62 | return outputs, state
63 |
64 | class gated_rnn_encoder(nn.Module):
65 |
66 | def __init__(self, config, vocab_size, embedding=None):
67 | super(gated_rnn_encoder, self).__init__()
68 | if embedding is not None:
69 | self.embedding = embedding
70 | else:
71 | self.embedding = nn.Embedding(vocab_size, config.emb_size)
72 | self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.encoder_hidden_size,
73 | num_layers=config.num_layers, dropout=config.dropout)
74 | self.gated = nn.Sequential(nn.Linear(config.encoder_hidden_size, 1), nn.Sigmoid())
75 |
76 | def forward(self, input, lengths):
77 | embs = pack(self.embedding(input), lengths)
78 | outputs, state = self.rnn(embs)
79 | outputs = unpack(outputs)[0]
80 | p = self.gated(outputs)
81 | outputs = outputs * p
82 | return outputs, state
83 |
84 |
85 | class rnn_decoder(nn.Module):
86 |
87 | def __init__(self, config, vocab_size, dir_vocab_size, embedding=None, score_fn=None):
88 | super(rnn_decoder, self).__init__()
89 | if embedding is not None:
90 | self.embedding = embedding
91 | self.embedding_dir = embedding
92 | else:
93 | self.embedding = nn.Embedding(vocab_size, config.emb_size)
94 | self.embedding_dir = nn.Embedding(dir_vocab_size, config.emb_size)
95 | self.rnn = StackedLSTM(input_size=config.emb_size, hidden_size=config.decoder_hidden_size,
96 | num_layers=config.num_layers, dropout=config.dropout)
97 | self.rnn_dir = StackedLSTM(input_size=config.emb_size, hidden_size=config.decoder_hidden_size,
98 | num_layers=config.num_layers, dropout=config.dropout)
99 |
100 | self.score_fn = score_fn
101 | if self.score_fn.startswith('general'):
102 | self.linear = nn.Linear(config.decoder_hidden_size, config.emb_size)
103 | self.linear_dir = nn.Linear(config.decoder_hidden_size, config.emb_size)
104 | elif score_fn.startswith('concat'):
105 | self.linear_query = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size)
106 | self.linear_weight = nn.Linear(config.emb_size, config.decoder_hidden_size)
107 | self.linear_v = nn.Linear(config.decoder_hidden_size, 1)
108 | self.linear_query_dir = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size)
109 | self.linear_weight_dir = nn.Linear(config.emb_size, config.decoder_hidden_size)
110 | self.linear_v_dir = nn.Linear(config.decoder_hidden_size, 1)
111 | elif not self.score_fn.startswith('dot'):
112 | self.linear = nn.Linear(config.decoder_hidden_size, vocab_size)
113 | self.linear_dir = nn.Linear(config.decoder_hidden_size, dir_vocab_size)
114 | self.linear_output = nn.Linear(dir_vocab_size, 1)
115 |
116 | if hasattr(config, 'att_act'):
117 | activation = config.att_act
118 | print('use attention activation %s' % activation)
119 | else:
120 | activation = None
121 |
122 | self.attention = models.global_attention(config.decoder_hidden_size, activation)
123 | self.attention_dir = models.global_attention(config.decoder_hidden_size, activation)
124 | self.hidden_size = config.decoder_hidden_size
125 | self.dropout = nn.Dropout(config.dropout)
126 | self.config = config
127 |
128 | if self.config.global_emb:
129 | self.gated1 = nn.Linear(config.emb_size, config.emb_size)
130 | self.gated2 = nn.Linear(config.emb_size, config.emb_size)
131 | self.gated1_dir = nn.Linear(config.emb_size, config.emb_size)
132 | self.gated2_dir = nn.Linear(config.emb_size, config.emb_size)
133 |
134 | def forward(self, inputs, inputs_dir, init_state, contexts):
135 |
136 | outputs, outputs_dir, state, attns, global_embs = [], [], init_state, [], []
137 |
138 | ## speaker
139 | embs = self.embedding(inputs).split(1) # time_step [1,bs,embsize]
140 | max_time_step = len(embs)
141 | emb = embs[0] # 第一步BOS的embedding.
142 | output, state_speaker = self.rnn(emb.squeeze(0), state)
143 | output, attn_weights = self.attention(output, contexts)
144 | output = self.dropout(output)
145 | soft_score = F.softmax(self.linear(output)) # 第一步的概率分布也就是 bs,vocal这么大
146 |
147 | ## direction
148 | embs_dir = self.embedding_dir(inputs_dir).split(1) # time_step [1,bs,embsize]
149 | emb_dir = embs_dir[0] # 第一步BOS的embedding.
150 | output_dir, state_dir = self.rnn_dir(emb_dir.squeeze(0), state)
151 | output_dir, attn_weights_dir = self.attention_dir(output_dir, contexts)
152 | output_dir = self.dropout(output_dir)
153 | #soft_score_dir_1 = F.sigmoid(self.linear_dir(output_dir))
154 | soft_score_dir = F.softmax(self.linear_dir(output_dir))
155 |
156 | attn_weights = attn_weights + attn_weights_dir
157 | outputs += [output]
158 | outputs_dir += [output_dir]
159 | attns += [attn_weights]
160 |
161 | batch_size = soft_score.size(0)
162 | a, b = self.embedding.weight.size()
163 | c, d = self.embedding_dir.weight.size()
164 |
165 | for i in range(max_time_step - 1):
166 | ## speaker
167 | emb1 = torch.bmm(soft_score.unsqueeze(1),
168 | self.embedding.weight.expand((batch_size, a, b)))
169 | emb2 = embs[i + 1]
170 | gamma = F.sigmoid(self.gated1(emb1.squeeze(1)) + self.gated2(emb2.squeeze(0)))
171 | emb = gamma * emb1.squeeze(1) + (1 - gamma) * emb2.squeeze(0)
172 | output, state_speaker = self.rnn(emb, state_speaker)
173 | output, attn_weights = self.attention(output, contexts)
174 | output = self.dropout(output)
175 | soft_score = F.softmax(self.linear(output))
176 |
177 | ## direction
178 | emb1_dir = torch.bmm(soft_score_dir.unsqueeze(1),
179 | self.embedding_dir.weight.expand((batch_size, c, d)))
180 | emb2_dir = embs_dir[i + 1]
181 | gamma_dir = F.sigmoid(self.gated1_dir(emb1_dir.squeeze(1)) + self.gated2_dir(emb2_dir.squeeze(0)))
182 | emb_dir = gamma_dir * emb1_dir.squeeze(1) + (1 - gamma_dir) * emb2_dir.squeeze(0)
183 | output_dir, state_dir = self.rnn_dir(emb_dir, state_dir)
184 | output_dir, attn_weights_dir = self.attention_dir(output_dir, contexts)
185 | output_dir = self.dropout(output_dir)
186 | #soft_score_dir_1 = F.sigmoid(self.linear_dir(output_dir))
187 | soft_score_dir = F.softmax(self.linear_dir(output_dir))
188 |
189 | attn_weights = attn_weights + attn_weights_dir
190 | emb = emb + emb_dir
191 | outputs += [output]
192 | outputs_dir += [output_dir]
193 | global_embs += [emb]
194 | attns += [attn_weights]
195 |
196 | outputs = torch.stack(outputs)
197 | outputs_dir = torch.stack(outputs_dir)
198 | global_embs = torch.stack(global_embs)
199 | attns = torch.stack(attns)
200 | return outputs, outputs_dir, state, global_embs
201 |
202 | def compute_score(self, hiddens):
203 | if self.score_fn.startswith('general'):
204 | if self.score_fn.endswith('not'):
205 | scores = torch.matmul(self.linear(hiddens), Variable(self.embedding.weight.t().data))
206 | else:
207 | scores = torch.matmul(self.linear(hiddens), self.embedding.weight.t())
208 | elif self.score_fn.startswith('concat'):
209 | if self.score_fn.endswith('not'):
210 | scores = self.linear_v(torch.tanh(self.linear_query(hiddens).unsqueeze(1) + \
211 | self.linear_weight(Variable(self.embedding.weight.data)).unsqueeze(
212 | 0))).squeeze(2)
213 | else:
214 | scores = self.linear_v(torch.tanh(self.linear_query(hiddens).unsqueeze(1) + \
215 | self.linear_weight(self.embedding.weight).unsqueeze(0))).squeeze(2)
216 | elif self.score_fn.startswith('dot'):
217 | if self.score_fn.endswith('not'):
218 | scores = torch.matmul(hiddens, Variable(self.embedding.weight.t().data))
219 | else:
220 | scores = torch.matmul(hiddens, self.embedding.weight.t())
221 | # elif self.score_fn.startswith('arc_margin'):
222 | # scores = self.linear(hiddens,targets)
223 | else:
224 | scores = self.linear(hiddens)
225 | return scores
226 |
227 | def compute_score_dir(self, hiddens):
228 | if self.score_fn.startswith('general'):
229 | if self.score_fn.endswith('not'):
230 | scores = torch.matmul(self.linear_dir(hiddens), Variable(self.embedding_dir.weight.t().data))
231 | else:
232 | scores = torch.matmul(self.linear_dir(hiddens), self.embedding_dir.weight.t())
233 | elif self.score_fn.startswith('concat'):
234 | if self.score_fn.endswith('not'):
235 | scores = self.linear_v_dir(torch.tanh(self.linear_query_dir(hiddens).unsqueeze(1) + \
236 | self.linear_weight_dir(Variable(self.embedding_dir.weight.data)).unsqueeze(
237 | 0))).squeeze(2)
238 | else:
239 | scores = self.linear_v_dir(torch.tanh(self.linear_query_dir(hiddens).unsqueeze(1) + \
240 | self.linear_weight_dir(self.embedding_dir.weight).unsqueeze(0))).squeeze(2)
241 | elif self.score_fn.startswith('dot'):
242 | if self.score_fn.endswith('not'):
243 | scores = torch.matmul(hiddens, Variable(self.embedding_dir.weight.t().data))
244 | else:
245 | scores = torch.matmul(hiddens, self.embedding_dir.weight.t())
246 | else:
247 | #scores_1 = F.sigmoid(self.linear_dir(hiddens))
248 | scores = self.linear_dir(hiddens)
249 | return scores
250 |
251 | def sample(self, input, init_state, contexts):
252 | inputs, outputs, sample_ids, state = [], [], [], init_state
253 | attns = []
254 | inputs += input
255 | max_time_step = self.config.max_tgt_len
256 | soft_score = None
257 | mask = None
258 | for i in range(max_time_step):
259 | output, state, attn_weights = self.sample_one(inputs[i], soft_score, state, contexts, mask)
260 | if self.config.global_emb:
261 | soft_score = F.softmax(output)
262 | predicted = output.max(1)[1]
263 | inputs += [predicted]
264 | sample_ids += [predicted]
265 | outputs += [output]
266 | attns += [attn_weights]
267 | if self.config.mask:
268 | if mask is None:
269 | mask = predicted.unsqueeze(1).long()
270 | else:
271 | mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
272 |
273 | sample_ids = torch.stack(sample_ids)
274 | attns = torch.stack(attns)
275 | return sample_ids, (outputs, attns)
276 |
277 | def sample_one(self, input, input_dir, soft_score, soft_score_dir, state, state_dir, tmp_hiddens, tmp_hiddens_dir, contexts, mask,mask_dir):
278 | if self.config.global_emb:
279 | batch_size = contexts.size(0)
280 | a, b = self.embedding.weight.size()
281 | if soft_score is None:
282 | emb = self.embedding(input)
283 | else:
284 | emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b)))
285 | emb2 = self.embedding(input)
286 | gamma = F.sigmoid(self.gated1(emb1.squeeze()) + self.gated2(emb2.squeeze()))
287 | emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
288 |
289 | c, d = self.embedding_dir.weight.size()
290 | if soft_score_dir is None:
291 | emb_dir = self.embedding_dir(input_dir)
292 | else:
293 | emb1_dir = torch.bmm(soft_score_dir.unsqueeze(1), self.embedding_dir.weight.expand((batch_size, c, d)))
294 | emb2_dir = self.embedding_dir(input_dir)
295 | gamma_dir = F.sigmoid(self.gated1_dir(emb1_dir.squeeze()) + self.gated2_dir(emb2_dir.squeeze()))
296 | emb_dir = gamma_dir * emb1_dir.squeeze() + (1 - gamma_dir) * emb2_dir.squeeze()
297 | else:
298 | emb = self.embedding(input)
299 | emb_dir = self.embedding_dir(input_dir)
300 |
301 | output, state = self.rnn(emb, state)
302 | output_bk = output
303 | hidden, attn_weights = self.attention(output, contexts)
304 | if self.config.schmidt:
305 | hidden = models.schmidt(hidden, tmp_hiddens)
306 | output = self.compute_score(hidden)
307 | if self.config.mask:
308 | if mask is not None:
309 | output = output.scatter_(1, mask, -9999999999)
310 |
311 | output_dir, state_dir = self.rnn_dir(emb_dir, state_dir)
312 | output_dir_bk = output_dir
313 | hidden_dir, attn_weights_dir = self.attention_dir(output_dir, contexts)
314 | if self.config.schmidt:
315 | hidden_dir = models.schmidt(hidden_dir, tmp_hiddens_dir)
316 | output_dir = self.compute_score_dir(hidden_dir)
317 | if self.config.mask:
318 | if mask_dir is not None:
319 | output_dir = output_dir.scatter_(1, mask_dir, -9999999999)
320 |
321 | return output, output_dir, state, state_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_dir_bk
322 |
--------------------------------------------------------------------------------
/models/separation_dis.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import sys
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import random
9 |
10 | np.random.seed(1) # 设定种子
11 | torch.manual_seed(1)
12 | random.seed(1)
13 | torch.cuda.set_device(0)
14 | test_all_outputchannel = 0
15 |
16 |
17 | class ATTENTION(nn.Module):
18 | def __init__(self, hidden_size, query_size, align_hidden_size, mode='dot'):
19 | super(ATTENTION, self).__init__()
20 | # self.mix_emb_size=config.EMBEDDING_SIZE
21 | self.hidden_size = hidden_size
22 | self.query_size = query_size
23 | # self.align_hidden_size=hidden_size #align模式下的隐层大小,暂时取跟原来一致的
24 | self.align_hidden_size = align_hidden_size # align模式下的隐层大小,暂时取跟原来一致的
25 | self.mode = mode
26 | self.Linear_1 = nn.Linear(self.hidden_size, self.align_hidden_size, bias=False)
27 | # self.Linear_2=nn.Linear(hidden_sizedw,self.align_hidden_size,bias=False)
28 | self.Linear_2 = nn.Linear(self.query_size, self.align_hidden_size, bias=False)
29 | self.Linear_3 = nn.Linear(self.align_hidden_size, 1, bias=False)
30 |
31 | def forward(self, mix_hidden, query):
32 | # todo:这个要弄好,其实也可以直接抛弃memory来进行attention | DONE
33 | BATCH_SIZE = mix_hidden.size()[0]
34 | assert query.size() == (BATCH_SIZE, self.query_size)
35 | assert mix_hidden.size()[-1] == self.hidden_size
36 | # mix_hidden:bs,max_len,fre,hidden_size query:bs,hidden_size
37 | if self.mode == 'dot':
38 | # mix_hidden=mix_hidden.view(-1,1,self.hidden_size)
39 | mix_shape = mix_hidden.size()
40 | mix_hidden = mix_hidden.view(BATCH_SIZE, -1, self.hidden_size)
41 | query = query.view(-1, self.hidden_size, 1)
42 | # print '\n\n',mix_hidden.requires_grad,query.requires_grad,'\n\n'
43 | dot = torch.baddbmm(Variable(torch.zeros(1, 1)), mix_hidden, query)
44 | energy = dot.view(BATCH_SIZE, mix_shape[1], mix_shape[2])
45 | # TODO: 这里可以想想是不是能换成Relu之类的
46 | mask = F.sigmoid(energy)
47 | return mask
48 |
49 | elif self.mode == 'align':
50 | # mix_hidden=Variable(mix_hidden)
51 | # query=Variable(query)
52 | mix_shape = mix_hidden.size()
53 | mix_hidden = mix_hidden.view(-1, self.hidden_size)
54 | mix_hidden = self.Linear_1(mix_hidden).view(BATCH_SIZE, -1, self.align_hidden_size)
55 | query = self.Linear_2(query).view(-1, 1, self.align_hidden_size) # bs,1,hidden
56 | sum = F.tanh(mix_hidden + query)
57 | # TODO:从这里开始做起
58 | energy = self.Linear_3(sum.view(-1, self.align_hidden_size)).view(BATCH_SIZE, mix_shape[1], mix_shape[2])
59 | mask = F.sigmoid(energy)
60 | return mask
61 | else:
62 | print
63 | 'NO this attention methods.'
64 | raise IndexError
65 |
66 |
67 | class MIX_SPEECH_CNN(nn.Module):
68 | def __init__(self, config, input_fre, mix_speech_len):
69 | super(MIX_SPEECH_CNN, self).__init__()
70 | self.input_fre = input_fre
71 | self.mix_speech_len = mix_speech_len
72 | self.config = config
73 |
74 | self.cnn1 = nn.Conv2d(1, 96, (1, 7), stride=1, padding=(0, 3), dilation=(1, 1))
75 | self.cnn2 = nn.Conv2d(96, 96, (7, 1), stride=1, padding=(3, 0), dilation=(1, 1))
76 | self.cnn3 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(2, 2), dilation=(1, 1))
77 | self.cnn4 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(4, 2), dilation=(2, 1))
78 | self.cnn5 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(8, 2), dilation=(4, 1))
79 |
80 | self.cnn6 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(16, 2), dilation=(8, 1))
81 | self.cnn7 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(32, 2), dilation=(16, 1))
82 | self.cnn8 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(64, 2), dilation=(32, 1))
83 | self.cnn9 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(2, 2), dilation=(1, 1))
84 | self.cnn10 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(4, 4), dilation=(2, 2))
85 |
86 | self.cnn11 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(8, 8), dilation=(4, 4))
87 | self.cnn12 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(16, 16), dilation=(8, 8))
88 | self.cnn13 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(32, 32), dilation=(16, 16))
89 | self.cnn14 = nn.Conv2d(96, 96, (5, 5), stride=1, padding=(64, 64), dilation=(32, 32))
90 | self.cnn15 = nn.Conv2d(96, 8, (1, 1), stride=1, padding=(0, 0), dilation=(1, 1))
91 | self.num_cnns = 15
92 | self.bn1 = nn.BatchNorm2d(96)
93 | self.bn2 = nn.BatchNorm2d(96)
94 | self.bn3 = nn.BatchNorm2d(96)
95 | self.bn4 = nn.BatchNorm2d(96)
96 | self.bn5 = nn.BatchNorm2d(96)
97 | self.bn6 = nn.BatchNorm2d(96)
98 | self.bn7 = nn.BatchNorm2d(96)
99 | self.bn8 = nn.BatchNorm2d(96)
100 | self.bn9 = nn.BatchNorm2d(96)
101 | self.bn10 = nn.BatchNorm2d(96)
102 | self.bn11 = nn.BatchNorm2d(96)
103 | self.bn12 = nn.BatchNorm2d(96)
104 | self.bn13 = nn.BatchNorm2d(96)
105 | self.bn14 = nn.BatchNorm2d(96)
106 | self.bn15 = nn.BatchNorm2d(8)
107 |
108 | def forward(self, x):
109 | print
110 | 'speech input size:', x.size()
111 | assert len(x.size()) == 3
112 | x = x.unsqueeze(1)
113 | print
114 | '\nSpeech layer log:'
115 | x = x.contiguous()
116 | for idx in range(self.num_cnns):
117 | cnn_layer = eval('self.cnn{}'.format(idx + 1))
118 | bn_layer = eval('self.bn{}'.format(idx + 1))
119 | x = F.relu(cnn_layer(x))
120 | x = bn_layer(x)
121 | print
122 | 'speech shape after CNNs:', idx, '', x.size()
123 |
124 | out = x.transpose(1, 3).transpose(1, 2).contiguous()
125 | print
126 | 'speech output size:', out.size()
127 | return out, out
128 |
129 |
130 | class MIX_SPEECH(nn.Module):
131 | def __init__(self, config, input_fre, mix_speech_len):
132 | super(MIX_SPEECH, self).__init__()
133 | self.input_fre = input_fre
134 | self.mix_speech_len = mix_speech_len
135 | self.layer = nn.LSTM(
136 | input_size=input_fre,
137 | hidden_size=config.HIDDEN_UNITS,
138 | num_layers=4,
139 | batch_first=True,
140 | bidirectional=True
141 | )
142 | # self.batchnorm = nn.BatchNorm1d(self.input_fre*config.EMBEDDING_SIZE)
143 | self.Linear = nn.Linear(2 * config.HIDDEN_UNITS, self.input_fre * config.EMBEDDING_SIZE,bias=1)
144 | self.config = config
145 |
146 | def forward(self, x):
147 | x, hidden = self.layer(x)
148 | batch_size = x.size()[0]
149 | x = x.contiguous()
150 | xx = x
151 | x = x.view(batch_size * self.mix_speech_len, -1)
152 | # out=F.tanh(self.Linear(x))
153 | out = self.Linear(x)
154 | # out = self.batchnorm(out)
155 | out = F.tanh(out)
156 | # out=F.relu(out)
157 | out = out.view(batch_size, self.mix_speech_len, self.input_fre, -1)
158 | # print 'Mix speech output shape:',out.size()
159 | return out, xx
160 |
161 |
162 | class Discriminator(nn.Module):
163 | def __init__(self):
164 | super(Discriminator, self).__init__()
165 | self.cnn = nn.Conv2d(1, 64, (3, 3), stride=(2, 2), )
166 | self.cnn1 = nn.Conv2d(64, 64, (3, 3), stride=(2, 2), )
167 | self.cnn2 = nn.Conv2d(64, 64, (3, 3), stride=(2, 2), )
168 | # self.final=nn.Linear(36480,1)
169 | self.final = nn.Linear(73920, 1)
170 |
171 | def forward(self, spec):
172 | bs, topk, len, fre = spec.size()
173 | spec = spec.view(bs * topk, 1, len, fre)
174 | spec = F.relu(self.cnn(spec))
175 | spec = F.relu(self.cnn1(spec))
176 | spec = F.relu(self.cnn2(spec))
177 | spec = spec.view(bs * topk, -1)
178 | print
179 | 'size spec:', spec.size()
180 | score = F.sigmoid(self.final(spec))
181 | print
182 | 'size spec:', score.size()
183 | return score
184 |
185 |
186 | class SPEECH_EMBEDDING(nn.Module):
187 | def __init__(self, num_labels, embedding_size, max_num_channel):
188 | super(SPEECH_EMBEDDING, self).__init__()
189 | self.num_all = num_labels
190 | self.emb_size = embedding_size
191 | self.max_num_out = max_num_channel
192 | # self.layer=nn.Embedding(num_labels,embedding_size,padding_idx=-1)
193 | self.layer = nn.Embedding(num_labels, embedding_size)
194 |
195 | def forward(self, input, mask_idx):
196 | aim_matrix = torch.from_numpy(np.array(mask_idx))
197 | all = self.layer(Variable(aim_matrix)) # bs*num_labels(最多混合人个数)×Embedding的大小
198 | out = all
199 | return out
200 |
201 |
202 | class ADDJUST(nn.Module):
203 | # 这个模块是负责处理目标人的对应扰动的,进行一些偏移的调整
204 | def __init__(self, config, hidden_units, embedding_size):
205 | super(ADDJUST, self).__init__()
206 | self.config = config
207 | self.hidden_units = hidden_units
208 | self.emb_size = embedding_size
209 | self.layer = nn.Linear(hidden_units + embedding_size, embedding_size, bias=False)
210 |
211 | def forward(self, input_hidden, prob_emb):
212 | top_k_num = prob_emb.size()[1]
213 | x = torch.mean(input_hidden, 1).view(self.config.batch_size, 1, self.hidden_units).expand(
214 | self.config.batch_size, top_k_num, self.hidden_units)
215 | can = torch.cat([x, prob_emb], dim=2)
216 | all = self.layer(can) # bs*num_labels(最多混合人个数)×Embedding的大小
217 | out = all
218 | return out
219 |
220 |
221 | class SS(nn.Module):
222 | def __init__(self, config, speech_fre, mix_speech_len, num_labels):
223 | super(SS, self).__init__()
224 | self.config = config
225 | self.speech_fre = speech_fre
226 | self.mix_speech_len = mix_speech_len
227 | self.num_labels = num_labels
228 | print
229 | 'Begin to build the maim model for speech speration part.'
230 | if config.speech_cnn_net:
231 | self.mix_hidden_layer_3d = MIX_SPEECH_CNN(config, speech_fre, mix_speech_len)
232 | else:
233 | self.mix_hidden_layer_3d = MIX_SPEECH(config, speech_fre, mix_speech_len)
234 | # att_layer=ATTENTION(config.EMBEDDING_SIZE,'dot')
235 | self.att_speech_layer = ATTENTION(config.EMBEDDING_SIZE, config.SPK_EMB_SIZE, config.ATT_SIZE, 'align')
236 | if self.config.is_SelfTune:
237 | self.adjust_layer = ADDJUST(config, 2 * config.HIDDEN_UNITS, config.SPK_EMB_SIZE)
238 | print
239 | 'Adopt adjust layer.'
240 |
241 | def forward(self, mix_feas, hidden_outputs, targets, dict_spk2idx=None):
242 | '''
243 | :param targets:这个targets的大小是:topk,bs 注意后面要transpose
244 | 123 324 345
245 | 323 E E
246 | 这种样子的,所以要去找aim_list 应该找到的结果是,先transpose之后,然后flatten,然后取不是E的:[0 1 2 4 ]
247 |
248 | '''
249 |
250 | config = self.config
251 | top_k_max, batch_size = targets.size() # 这个top_k_max其实就是最多有几个说话人,应该是跟Max_MIX是保持一样的
252 | # assert top_k_max==config.MAX_MIX
253 | aim_list = (targets.transpose(0, 1).contiguous().view(-1) != dict_spk2idx['']).nonzero().squeeze()
254 | aim_list = aim_list.data.cpu().numpy()
255 |
256 | mix_speech_hidden, mix_tmp_hidden = self.mix_hidden_layer_3d(mix_feas)
257 | mix_speech_multiEmbs = torch.transpose(hidden_outputs, 0, 1).contiguous() # bs*num_labels(最多混合人个数)×Embedding的大小
258 | mix_speech_multiEmbs = mix_speech_multiEmbs.view(-1, config.SPK_EMB_SIZE) # bs*num_labels(最多混合人个数)×Embedding的大小
259 | # assert mix_speech_multiEmbs.size()[0]==targets.shape
260 | mix_speech_multiEmbs = mix_speech_multiEmbs[aim_list] # aim_num,embs
261 | # mix_speech_multiEmbs=mix_speech_multiEmbs[0] # aim_num,embs
262 | # print mix_speech_multiEmbs.shape
263 | if self.config.is_SelfTune:
264 | # TODO: 这里应该也是有问题的,暂时不用selfTune
265 | mix_adjust = self.adjust_layer(mix_tmp_hidden, mix_speech_multiEmbs)
266 | mix_speech_multiEmbs = mix_adjust + mix_speech_multiEmbs
267 | mix_speech_hidden_5d = mix_speech_hidden.view(batch_size, 1, self.mix_speech_len, self.speech_fre,
268 | config.EMBEDDING_SIZE)
269 | mix_speech_hidden_5d = mix_speech_hidden_5d.expand(batch_size, top_k_max, self.mix_speech_len, self.speech_fre,
270 | config.EMBEDDING_SIZE).contiguous()
271 | mix_speech_hidden_5d_last = mix_speech_hidden_5d.view(-1, self.mix_speech_len, self.speech_fre,
272 | config.EMBEDDING_SIZE)
273 | mix_speech_hidden_5d_last = mix_speech_hidden_5d_last[aim_list]
274 | att_multi_speech = self.att_speech_layer(mix_speech_hidden_5d_last,
275 | mix_speech_multiEmbs.view(-1, config.SPK_EMB_SIZE))
276 | att_multi_speech = att_multi_speech.view(-1, self.mix_speech_len, self.speech_fre) # bs,num_labels,len,fre这个东西
277 | multi_mask = att_multi_speech
278 | assert multi_mask.shape[0] == len(aim_list)
279 | return multi_mask
280 |
281 |
282 | def top_k_mask(batch_pro, alpha, top_k):
283 | 'batch_pro是 bs*n的概率分布,例如2×3的,每一行是一个概率分布\
284 | alpha是阈值,大于它的才可以取,可以跟Multi-label语音分离的ACC的alpha对应;\
285 | top_k是最多输出几个候选目标\
286 | 输出是与bs*n的一个mask,float型的'
287 | size = batch_pro.size()
288 | final = torch.zeros(size)
289 | sort_result, sort_index = torch.sort(batch_pro, 1, True) # 先排个序
290 | sort_index = sort_index[:, :top_k] # 选出每行的top_k的id
291 | sort_result = torch.sum(sort_result > alpha, 1)
292 | for line_idx in range(size[0]):
293 | line_top_k = sort_index[line_idx][:int(sort_result[line_idx].data.cpu().numpy())]
294 | line_top_k = line_top_k.data.cpu().numpy()
295 | for i in line_top_k:
296 | final[line_idx, i] = 1
297 | return final
298 |
299 |
300 |
--------------------------------------------------------------------------------
/models/separation_tasnet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | EPS = 1e-8
9 |
10 | def gcd(a, b):
11 | a, b = (a, b) if a >= b else (b, a)
12 | while b:
13 | a, b = b, a % b
14 | return a
15 |
16 | def overlap_and_add(signal, frame_step):
17 | """Reconstructs a signal from a framed representation.
18 |
19 | Adds potentially overlapping frames of a signal with shape
20 | `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
21 | The resulting tensor has shape `[..., output_size]` where
22 |
23 | output_size = (frames - 1) * frame_step + frame_length
24 |
25 | Args:
26 | signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
27 | frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
28 |
29 | Returns:
30 | A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
31 | output_size = (frames - 1) * frame_step + frame_length
32 |
33 | Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
34 | """
35 | outer_dimensions = signal.size()[:-2]
36 | frames, frame_length = signal.size()[-2:]
37 |
38 | subframe_length = gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
39 | subframe_step = frame_step // subframe_length
40 | subframes_per_frame = frame_length // subframe_length
41 | output_size = frame_step * (frames - 1) + frame_length
42 | output_subframes = output_size // subframe_length
43 |
44 | # subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
45 | subframe_signal = signal.view(outer_dimensions[0],outer_dimensions[1], -1, subframe_length)
46 |
47 | frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
48 | frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
49 | frame = frame.contiguous().view(-1)
50 |
51 | # result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
52 | result = signal.new_zeros(outer_dimensions[0],outer_dimensions[1], output_subframes, subframe_length)
53 | result.index_add_(-2, frame, subframe_signal)
54 | # result = result.view(*outer_dimensions, -1)
55 | result = result.view(outer_dimensions[0],outer_dimensions[1], -1)
56 | return result
57 |
58 |
59 | def remove_pad(inputs, inputs_lengths):
60 | """
61 | Args:
62 | inputs: torch.Tensor, [B, C, T] or [B, T], B is batch size
63 | inputs_lengths: torch.Tensor, [B]
64 | Returns:
65 | results: a list containing B items, each item is [C, T], T varies
66 | """
67 | results = []
68 | dim = inputs.dim()
69 | if dim == 3:
70 | C = inputs.size(1)
71 | for input, length in zip(inputs, inputs_lengths):
72 | if dim == 3: # [B, C, T]
73 | results.append(input[:,:length].view(C, -1).cpu().numpy())
74 | elif dim == 2: # [B, T]
75 | results.append(input[:length].view(-1).cpu().numpy())
76 | return results
77 |
78 | class ConvTasNet(nn.Module):
79 | def __init__(self, N=256, L=40, B=256, H=512, P=3, X=8, R=4, C=2, norm_type="gLN", causal=False,
80 | mask_nonlinear='sigmoid'):
81 | """
82 | Args:
83 | N: Number of filters in autoencoder
84 | L: Length of the filters (in samples)
85 | B: Number of channels in bottleneck 1 * 1-conv block
86 | H: Number of channels in convolutional blocks
87 | P: Kernel size in convolutional blocks
88 | X: Number of convolutional blocks in each repeat
89 | R: Number of repeats
90 | C: Number of speakers
91 | norm_type: BN, gLN, cLN
92 | causal: causal or non-causal
93 | mask_nonlinear: use which non-linear function to generate mask
94 | """
95 | super(ConvTasNet, self).__init__()
96 | # Hyper-parameter
97 | self.N, self.L, self.B, self.H, self.P, self.X, self.R, self.C = N, L, B, H, P, X, R, C
98 | self.norm_type = norm_type
99 | self.causal = causal
100 | self.mask_nonlinear = mask_nonlinear
101 | # Components
102 | self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
103 | # init
104 | #for p in self.parameters():
105 | # if p.dim() > 1:
106 | # nn.init.xavier_normal_(p)
107 |
108 | def forward(self, mixture, hidden_outputs):
109 | """
110 | Args:
111 | mixture: [M, T], M is batch size, T is #samples
112 | Returns:
113 | est_source: [M, C, T]
114 | """
115 | #mixture_w = self.encoder(mixture)
116 | est_mask = self.separator(mixture, hidden_outputs)
117 | #est_source = self.decoder(mixture_w, est_mask)
118 |
119 | # T changed after conv1d in encoder, fix it here
120 | #T_origin = mixture.size(-1)
121 | #T_conv = est_source.size(-1)
122 | #est_source = F.pad(est_source, (0, T_origin - T_conv))
123 | return est_mask
124 |
125 | @classmethod
126 | def load_model(cls, path):
127 | # Load to CPU
128 | package = torch.load(path, map_location=lambda storage, loc: storage)
129 | model = cls.load_model_from_package(package)
130 | return model
131 |
132 | @classmethod
133 | def load_model_from_package(cls, package):
134 | model = cls(package['N'], package['L'], package['B'], package['H'],
135 | package['P'], package['X'], package['R'], package['C'],
136 | norm_type=package['norm_type'], causal=package['causal'],
137 | mask_nonlinear=package['mask_nonlinear'])
138 | model.load_state_dict(package['state_dict'])
139 | return model
140 |
141 | @staticmethod
142 | def serialize(model, optimizer, epoch, tr_loss=None, cv_loss=None):
143 | package = {
144 | # hyper-parameter
145 | 'N': model.N, 'L': model.L, 'B': model.B, 'H': model.H,
146 | 'P': model.P, 'X': model.X, 'R': model.R, 'C': model.C,
147 | 'norm_type': model.norm_type, 'causal': model.causal,
148 | 'mask_nonlinear': model.mask_nonlinear,
149 | # state
150 | 'state_dict': model.state_dict(),
151 | 'optim_dict': optimizer.state_dict(),
152 | 'epoch': epoch
153 | }
154 | if tr_loss is not None:
155 | package['tr_loss'] = tr_loss
156 | package['cv_loss'] = cv_loss
157 | return package
158 |
159 |
160 | class TasNetEncoder(nn.Module):
161 | """Estimation of the nonnegative mixture weight by a 1-D conv layer.
162 | """
163 | def __init__(self, L=40, N=256, Ch = 32):
164 | super(TasNetEncoder, self).__init__()
165 | # Hyper-parameter
166 | self.L, self.N = L, N
167 | self.Ch = Ch
168 | # Components
169 | # 50% overlap
170 |
171 | self.conv1d_c2 = nn.Conv1d(2, N, kernel_size=L, stride=L // 2, bias=False)
172 |
173 | def forward(self, mixture):
174 | """
175 | Args:
176 | mixture: [M, T, C], M is batch size, T is #samples, C is channel number
177 | Returns:
178 | mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
179 | """
180 | mixture =mixture.transpose(1,2)
181 | mixture_w = F.relu(self.conv1d_c2(mixture))
182 | return mixture_w
183 |
184 |
185 | class TasNetDecoder(nn.Module):
186 | def __init__(self, N=256, L=40):
187 | super(TasNetDecoder, self).__init__()
188 | # Hyper-parameter
189 | self.N, self.L = N, L
190 | # Components
191 | self.basis_signals = nn.Linear(N, L, bias=False)
192 |
193 | self.conv1dTranspose = nn.ConvTranspose1d(N, 1, L, stride=L//2)
194 |
195 | def forward(self, mixture_w, est_mask):
196 | """
197 | Args:
198 | mixture_w: [M, N, K]
199 | est_mask: [M, C, N, K]
200 | Returns:
201 | est_source: [M, C, T]
202 | """
203 | # D = W * M
204 | source_w = torch.unsqueeze(mixture_w, 1) * est_mask # [M, C, N, K]
205 | #source_w = torch.transpose(source_w, 2, 3) # [M, C, K, N]
206 | M, C, N, K = source_w.size()
207 | source_w = source_w.view(-1, N, K).contiguous()
208 | est_source = self.conv1dTranspose(source_w).view(M, C, -1).contiguous()
209 | # S = DV
210 | #est_source = self.basis_signals(source_w) # [M, C, K, L]
211 | #est_source = overlap_and_add(est_source, self.L//2) # M x C x T
212 | return est_source
213 |
214 | class TemporalConvNet(nn.Module):
215 | def __init__(self, N, B, H, P, X, R, C, norm_type="gLN", causal=False,
216 | mask_nonlinear='relu'):
217 | """
218 | Args:
219 | N: Number of filters in autoencoder
220 | B: Number of channels in bottleneck 1 * 1-conv block
221 | H: Number of channels in convolutional blocks
222 | P: Kernel size in convolutional blocks
223 | X: Number of convolutional blocks in each repeat
224 | R: Number of repeats
225 | C: Number of speakers
226 | norm_type: BN, gLN, cLN
227 | causal: causal or non-causal
228 | mask_nonlinear: use which non-linear function to generate mask
229 | """
230 | super(TemporalConvNet, self).__init__()
231 | # Hyper-parameter
232 | self.C = C
233 | self.B = B
234 | self.mask_nonlinear = mask_nonlinear
235 | # Components
236 | # [M, N, K] -> [M, N, K]
237 | layer_norm = ChannelwiseLayerNorm(N)
238 | # [M, N, K] -> [M, B, K]
239 | bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False)
240 | # [M, B, K] -> [M, B, K]
241 | repeats = []
242 | for r in range(R):
243 | blocks = []
244 | for x in range(X):
245 | dilation = 2**x
246 | padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2
247 | blocks += [TemporalBlock(B, H, P, stride=1,
248 | padding=padding,
249 | dilation=dilation,
250 | norm_type=norm_type,
251 | causal=causal)]
252 | repeats += [nn.Sequential(*blocks)]
253 | temporal_conv_net = nn.Sequential(*repeats)
254 | # [M, B, K] -> [M, C*N, K]
255 | # self.mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False)
256 | self.mask_conv1x1 = nn.Conv1d(B, N, 1, bias=False)
257 | # 256 should keep consisten with SPK_EMB_SIZE in config
258 | # Put together
259 | self.network = nn.Sequential(layer_norm,
260 | bottleneck_conv1x1,
261 | temporal_conv_net,)
262 | # mask_conv1x1)
263 |
264 | def forward(self, mixture_w, hidden_outputs):
265 | """
266 | Keep this API same with TasNet
267 | Args:
268 | mixture_w: [M, N, K], M is batch size
269 | hidden_outputs: [M, C, D]
270 | returns:
271 | est_mask: [M, C, N, K]
272 | """
273 | B = self.B
274 | M, N, K = mixture_w.size()
275 | _, C, D = hidden_outputs.size()
276 | assert M==_
277 | original_sep= self.network(mixture_w).unsqueeze(1).expand(M,self.C,B,K) # [M, N, K] -> [M, C, B, K]
278 | hidden_outputs=hidden_outputs.unsqueeze(-1).expand(M,self.C, D, K)# [M,C,D,K]
279 | #original_sep=torch.cat((original_sep,hidden_outputs),dim=2).view(-1,B+D,K) #[M*C,(B+D),K]
280 | original_sep = (original_sep * hidden_outputs).contiguous().view(-1,B,K)
281 | score = self.mask_conv1x1(original_sep) # -> [M*C,N, K]
282 |
283 | score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K]
284 | if self.mask_nonlinear == 'softmax':
285 | est_mask = F.softmax(score, dim=1)
286 | elif self.mask_nonlinear == 'relu':
287 | est_mask = F.relu(score)
288 | elif self.mask_nonlinear == 'sigmoid':
289 | est_mask = F.sigmoid(score)
290 | else:
291 | raise ValueError("Unsupported mask non-linear function")
292 | return est_mask
293 |
294 |
295 | class TemporalBlock(nn.Module):
296 | def __init__(self, in_channels, out_channels, kernel_size,
297 | stride, padding, dilation, norm_type="gLN", causal=False):
298 | super(TemporalBlock, self).__init__()
299 | # [M, B, K] -> [M, H, K]
300 | conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False)
301 | prelu = nn.PReLU()
302 | norm = chose_norm(norm_type, out_channels)
303 | # [M, H, K] -> [M, B, K]
304 | dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size,
305 | stride, padding, dilation, norm_type,
306 | causal)
307 | # Put together
308 | self.net = nn.Sequential(conv1x1, prelu, norm, dsconv)
309 |
310 | def forward(self, x):
311 | """
312 | Args:
313 | x: [M, B, K]
314 | Returns:
315 | [M, B, K]
316 | """
317 | residual = x
318 | out = self.net(x)
319 | # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad?
320 | return out + residual # look like w/o F.relu is better than w/ F.relu
321 | # return F.relu(out + residual)
322 |
323 |
324 | class DepthwiseSeparableConv(nn.Module):
325 | def __init__(self, in_channels, out_channels, kernel_size,
326 | stride, padding, dilation, norm_type="gLN", causal=False):
327 | super(DepthwiseSeparableConv, self).__init__()
328 | # Use `groups` option to implement depthwise convolution
329 | # [M, H, K] -> [M, H, K]
330 | depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size,
331 | stride=stride, padding=padding,
332 | dilation=dilation, groups=in_channels,
333 | bias=False)
334 | if causal:
335 | chomp = Chomp1d(padding)
336 | prelu = nn.PReLU()
337 | norm = chose_norm(norm_type, in_channels)
338 | # [M, H, K] -> [M, B, K]
339 | pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False)
340 | # Put together
341 | if causal:
342 | self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm,
343 | pointwise_conv)
344 | else:
345 | self.net = nn.Sequential(depthwise_conv, prelu, norm,
346 | pointwise_conv)
347 |
348 | def forward(self, x):
349 | """
350 | Args:
351 | x: [M, H, K]
352 | Returns:
353 | result: [M, B, K]
354 | """
355 | return self.net(x)
356 |
357 |
358 | class Chomp1d(nn.Module):
359 | """To ensure the output length is the same as the input.
360 | """
361 | def __init__(self, chomp_size):
362 | super(Chomp1d, self).__init__()
363 | self.chomp_size = chomp_size
364 |
365 | def forward(self, x):
366 | """
367 | Args:
368 | x: [M, H, Kpad]
369 | Returns:
370 | [M, H, K]
371 | """
372 | return x[:, :, :-self.chomp_size].contiguous()
373 |
374 |
375 | def chose_norm(norm_type, channel_size):
376 | """The input of normlization will be (M, C, K), where M is batch size,
377 | C is channel size and K is sequence length.
378 | """
379 | if norm_type == "gLN":
380 | return GlobalLayerNorm(channel_size)
381 | elif norm_type == "cLN":
382 | return ChannelwiseLayerNorm(channel_size)
383 | else: # norm_type == "BN":
384 | # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics
385 | # along M and K, so this BN usage is right.
386 | return nn.BatchNorm1d(channel_size)
387 |
388 |
389 | # TODO: Use nn.LayerNorm to impl cLN to speed up
390 | class ChannelwiseLayerNorm(nn.Module):
391 | """Channel-wise Layer Normalization (cLN)"""
392 | def __init__(self, channel_size):
393 | super(ChannelwiseLayerNorm, self).__init__()
394 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
395 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
396 | self.reset_parameters()
397 |
398 | def reset_parameters(self):
399 | self.gamma.data.fill_(1)
400 | self.beta.data.zero_()
401 |
402 | def forward(self, y):
403 | """
404 | Args:
405 | y: [M, N, K], M is batch size, N is channel size, K is length
406 | Returns:
407 | cLN_y: [M, N, K]
408 | """
409 | mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K]
410 | var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K]
411 | cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
412 | return cLN_y
413 |
414 |
415 | class GlobalLayerNorm(nn.Module):
416 | """Global Layer Normalization (gLN)"""
417 | def __init__(self, channel_size):
418 | super(GlobalLayerNorm, self).__init__()
419 | self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1]
420 | self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1]
421 | self.reset_parameters()
422 |
423 | def reset_parameters(self):
424 | self.gamma.data.fill_(1)
425 | self.beta.data.zero_()
426 |
427 | def forward(self, y):
428 | """
429 | Args:
430 | y: [M, N, K], M is batch size, N is channel size, K is length
431 | Returns:
432 | gLN_y: [M, N, K]
433 | """
434 | # TODO: in torch 1.0, torch.mean() support dim list
435 | mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1]
436 | var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True)
437 | gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta
438 | return gLN_y
439 |
440 |
441 | if __name__ == "__main__":
442 | torch.manual_seed(123)
443 | M, N, L, T = 2, 3, 4, 12
444 | K = 2*T//L-1
445 | B, H, P, X, R, C, norm_type, causal = 2, 3, 3, 3, 2, 2, "gLN", False
446 | mixture = torch.randint(3, (M, T))
447 | # test Encoder
448 | encoder = Encoder(L, N)
449 | encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
450 | mixture_w = encoder(mixture)
451 | print('mixture', mixture)
452 | print('U', encoder.conv1d_U.weight)
453 | print('mixture_w', mixture_w)
454 | print('mixture_w size', mixture_w.size())
455 |
456 | # test TemporalConvNet
457 | separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type=norm_type, causal=causal)
458 | est_mask = separator(mixture_w)
459 | print('est_mask', est_mask)
460 | print('model', separator)
461 |
462 | # test Decoder
463 | decoder = Decoder(N, L)
464 | est_mask = torch.randint(2, (B, K, C, N))
465 | est_source = decoder(mixture_w, est_mask)
466 | print('est_source', est_source)
467 |
468 | # test Conv-TasNet
469 | conv_tasnet = ConvTasNet(N, L, B, H, P, X, R, C, norm_type=norm_type)
470 | est_source = conv_tasnet(mixture)
471 | print('est_source', est_source)
472 | print('est_source size', est_source.size())
473 |
474 |
--------------------------------------------------------------------------------
/models/seq2seq.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | # import data.dict as dict
7 | import models
8 | # from figure_hot import relitu_line
9 |
10 | import numpy as np
11 |
12 |
13 | class seq2seq(nn.Module):
14 |
15 | def __init__(self, config, input_emb_size, mix_speech_len, tgt_vocab_size, tgt_dir_vocab_size, use_cuda, pretrain=None, score_fn=''):
16 | super(seq2seq, self).__init__()
17 | if pretrain is not None:
18 | src_embedding = pretrain['src_emb']
19 | tgt_embedding = pretrain['tgt_emb']
20 | else:
21 | src_embedding = None
22 | tgt_embedding = None
23 |
24 | self.encoder = models.rnn_encoder(config, input_emb_size)
25 | ## TasNet Encoder
26 | self.TasNetEncoder = models.TasNetEncoder()
27 | self.TasNetDecoder = models.TasNetDecoder()
28 |
29 | if config.shared_vocab == False:
30 | self.decoder = models.rnn_decoder(config, tgt_vocab_size, tgt_dir_vocab_size, embedding=tgt_embedding, score_fn=score_fn)
31 | else:
32 | self.decoder = models.rnn_decoder(config, tgt_vocab_size, tgt_dir_vocab_size, embedding=self.encoder.embedding, score_fn=score_fn)
33 | self.use_cuda = use_cuda
34 | self.tgt_vocab_size = tgt_vocab_size
35 | self.tgt_dir_vocab_size = tgt_dir_vocab_size
36 | self.config = config
37 | self.criterion = models.criterion(tgt_vocab_size, use_cuda, config.loss)
38 | self.criterion_dir = models.criterion_dir(tgt_dir_vocab_size, use_cuda, config.loss)
39 | self.loss_for_ss = nn.MSELoss()
40 | self.loss_for_dir = nn.MSELoss()
41 | self.log_softmax = nn.LogSoftmax()
42 | self.softmax = nn.Softmax()
43 | self.linear_output = nn.Linear(tgt_dir_vocab_size, 1)
44 | self.wav_loss = models.WaveLoss(dBscale=1, nfft=config.FRAME_LENGTH, hop_size=config.FRAME_SHIFT)
45 |
46 | speech_fre = input_emb_size
47 | num_labels = tgt_vocab_size
48 | if config.use_tas:
49 | self.ss_model = models.ConvTasNet()
50 | else:
51 | self.ss_model = models.SS(config, speech_fre, mix_speech_len, num_labels)
52 |
53 | def compute_loss(self, hidden_outputs, hidden_outputs_dir, targets, targets_dir, memory_efficiency):
54 | if memory_efficiency:
55 | return models.memory_efficiency_cross_entropy_loss(hidden_outputs, self.decoder, targets, self.criterion, self.config)
56 | else:
57 | sgm_loss_speaker, num_total, num_correct = models.cross_entropy_loss(hidden_outputs, self.decoder, targets, self.criterion, self.config)
58 | #sgm_loss_direction = models.mmse_loss2(hidden_outputs_dir, self.decoder, targets_dir, self.loss_for_dir)
59 | sgm_loss_direction, num_total_dir, num_correct_dir = models.cross_entropy_loss_dir(hidden_outputs_dir, self.decoder, targets_dir, self.criterion_dir, self.config)
60 | print("sgm_loss_speaker:", sgm_loss_speaker)
61 | print("sgm_loss_direction:", sgm_loss_direction)
62 | sgm_loss = sgm_loss_speaker + sgm_loss_direction
63 | return sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir
64 |
65 | def separation_loss(self, x_input_map_multi, masks, y_multi_map, Var='NoItem'):
66 | if not self.config.MLMSE:
67 | return models.ss_loss(self.config, x_input_map_multi, masks, y_multi_map, self.loss_for_ss,self.wav_loss)
68 | else:
69 | return models.ss_loss_MLMSE(self.config, x_input_map_multi, masks, y_multi_map, self.loss_for_ss, Var)
70 |
71 | def separation_tas_loss(self,predict_wav, y_multi_wav,mix_lengths):
72 | return models.ss_tas_loss(self.config, predict_wav, y_multi_wav, mix_lengths,self.loss_for_ss)
73 |
74 | def update_var(self, x_input_map_multi, multi_masks, y_multi_map):
75 | predict_multi_map = torch.mean(multi_masks * x_input_map_multi, -2) # 在时间维度上平均
76 | y_multi_map = torch.mean(Variable(y_multi_map), -2) # 在时间维度上平均
77 | loss_vector = (y_multi_map - predict_multi_map).view(-1, self.config.speech_fre).unsqueeze(-1) # 应该是bs*1*fre
78 | Var = torch.bmm(loss_vector, loss_vector.transpose(1, 2))
79 | Var = torch.mean(Var, 0) # 在batch的维度上平均
80 | return Var.detach()
81 |
82 | def forward(self, src, tgt, tgt_dir):
83 | #lengths, indices = torch.sort(src_len.squeeze(0), dim=0, descending=True)
84 | # src = torch.index_select(src, dim=0, index=indices)
85 | # tgt = torch.index_select(tgt, dim=0, index=indices)
86 |
87 | mix_wav = src.transpose(0,1) # [batch, sample, channel]
88 | mix = self.TasNetEncoder(mix_wav) # [batch, featuremap, timeStep]
89 | mix_infer = mix.transpose(1,2) # [batch, timeStep, featuremap]
90 | _, lengths, _ = mix_infer.size()
91 | # 4 equals to the number of GPU
92 | lengths = Variable(torch.LongTensor(self.config.batch_size/4).zero_() + lengths).unsqueeze(0).cuda()
93 | lengths, indices = torch.sort(lengths.squeeze(0), dim=0, descending=True)
94 |
95 | contexts, state = self.encoder(mix_infer, lengths.data.tolist()) # context [max_len,batch_size,hidden_size×2]
96 | outputs, outputs_dir, final_state, global_embs = self.decoder(tgt[:-1], tgt_dir[:-1], state, contexts.transpose(0, 1))
97 |
98 | if self.config.use_tas:
99 | predicted_maps = self.ss_model(mix, global_embs.transpose(0,1))
100 |
101 | predicted_signal = self.TasNetDecoder(mix, predicted_maps) # [batch, spkN, timeStep]
102 |
103 | return outputs, outputs_dir, tgt[1:], tgt_dir[1:], predicted_signal.transpose(0,1)
104 |
105 | def sample(self, src, src_len):
106 | # src=src.squeeze()
107 | if self.use_cuda:
108 | src = src.cuda()
109 | src_len = src_len.cuda()
110 |
111 | lengths, indices = torch.sort(src_len, dim=0, descending=True)
112 | _, ind = torch.sort(indices)
113 | src = Variable(torch.index_select(src, dim=1, index=indices), volatile=True)
114 | bos = Variable(torch.ones(src.size(1)).long().fill_(dict.BOS), volatile=True)
115 |
116 | if self.use_cuda:
117 | bos = bos.cuda()
118 |
119 | contexts, state = self.encoder(src, lengths.tolist())
120 | sample_ids, final_outputs = self.decoder.sample([bos], state, contexts.transpose(0, 1))
121 | _, attns_weight = final_outputs
122 | alignments = attns_weight.max(2)[1]
123 | sample_ids = torch.index_select(sample_ids.data, dim=1, index=ind)
124 | alignments = torch.index_select(alignments.data, dim=1, index=ind)
125 | # targets = tgt[1:]
126 |
127 | return sample_ids.t(), alignments.t()
128 |
129 | def beam_sample(self, src, dict_spk2idx, dict_dir2idx, beam_size=1):
130 |
131 | mix_wav = src.transpose(0,1) # [batch, sample]
132 | mix = self.TasNetEncoder(mix_wav) # [batch, featuremap, timeStep]
133 |
134 | mix_infer = mix.transpose(1,2) # [batch, timeStep, featuremap]
135 | batch_size, lengths, _ = mix_infer.size()
136 | lengths = Variable(torch.LongTensor(self.config.batch_size).zero_() + lengths).unsqueeze(0).cuda()
137 | lengths, indices = torch.sort(lengths.squeeze(0), dim=0, descending=True)
138 |
139 | contexts, encState = self.encoder(mix_infer, lengths.data.tolist()) # context [max_len,batch_size,hidden_size×2]
140 |
141 | # (1b) Initialize for the decoder.
142 | def var(a):
143 | return Variable(a, volatile=True)
144 |
145 | def rvar(a):
146 | return var(a.repeat(1, beam_size, 1))
147 |
148 | def bottle(m):
149 | return m.view(batch_size * beam_size, -1)
150 |
151 | def unbottle(m):
152 | return m.view(beam_size, batch_size, -1)
153 |
154 | # Repeat everything beam_size times.
155 | contexts = rvar(contexts.data).transpose(0, 1)
156 | decState = (rvar(encState[0].data), rvar(encState[1].data))
157 | decState_dir = (rvar(encState[0].data), rvar(encState[1].data))
158 | # decState.repeat_beam_size_times(beam_size)
159 | beam = [models.Beam(beam_size, dict_spk2idx, n_best=1,
160 | cuda=self.use_cuda) for __ in range(batch_size)]
161 |
162 | beam_dir = [models.Beam(beam_size, dict_dir2idx, n_best=1,
163 | cuda=self.use_cuda) for __ in range(batch_size)]
164 | # (2) run the decoder to generate sentences, using beam search.
165 |
166 | mask = None
167 | mask_dir = None
168 | soft_score = None
169 | tmp_hiddens = []
170 | tmp_soft_score = []
171 |
172 | soft_score_dir = None
173 | tmp_hiddens_dir = []
174 | tmp_soft_score_dir = []
175 | output_list = []
176 | output_dir_list = []
177 | predicted_list = []
178 | predicted_dir_list = []
179 | output_bk_list = []
180 | output_bk_dir_list = []
181 | hidden_list = []
182 | hidden_dir_list = []
183 | emb_list = []
184 | emb_dir_list = []
185 | for i in range(self.config.max_tgt_len):
186 |
187 | if all((b.done() for b in beam)):
188 | break
189 | if all((b_dir.done() for b_dir in beam_dir)):
190 | break
191 |
192 | # Construct batch x beam_size nxt words.
193 | # Get all the pending current beam words and arrange for forward.
194 | inp = var(torch.stack([b.getCurrentState() for b in beam]).t().contiguous().view(-1))
195 | inp_dir = var(torch.stack([b_dir.getCurrentState() for b_dir in beam_dir]).t().contiguous().view(-1))
196 |
197 | # Run one step.
198 | output, output_dir, decState, decState_dir, attn_weights, attn_weights_dir, hidden, hidden_dir, emb, emb_dir, output_bk, output_bk_dir = self.decoder.sample_one(inp, inp_dir, soft_score, soft_score_dir, decState, decState_dir, tmp_hiddens, tmp_hiddens_dir,
199 | contexts, mask, mask_dir)
200 | soft_score = F.softmax(output)
201 | soft_score_dir = F.softmax(output_dir)
202 |
203 | predicted = output.max(1)[1]
204 | predicted_dir = output_dir.max(1)[1]
205 | if self.config.mask:
206 | if mask is None:
207 | mask = predicted.unsqueeze(1).long()
208 | mask_dir = predicted_dir.unsqueeze(1).long()
209 | else:
210 | mask = torch.cat((mask, predicted.unsqueeze(1)), 1)
211 | mask_dir = torch.cat((mask_dir, predicted_dir.unsqueeze(1)), 1)
212 | # decOut: beam x rnn_size
213 |
214 | # (b) Compute a vector of batch*beam word scores.
215 |
216 | output_list.append(output[0])
217 | output_dir_list.append(output_dir[0])
218 |
219 | output = unbottle(self.log_softmax(output))
220 | output_dir = unbottle(F.sigmoid(output_dir))
221 |
222 | attn = unbottle(attn_weights)
223 | hidden = unbottle(hidden)
224 | emb = unbottle(emb)
225 | attn_dir = unbottle(attn_weights_dir)
226 | hidden_dir = unbottle(hidden_dir)
227 | emb_dir = unbottle(emb_dir)
228 | # beam x tgt_vocab
229 |
230 | output_bk_list.append(output_bk[0])
231 | output_bk_dir_list.append(output_bk_dir[0])
232 | hidden_list.append(hidden[0])
233 | hidden_dir_list.append(hidden_dir[0])
234 | emb_list.append(emb[0])
235 | emb_dir_list.append(emb_dir[0])
236 |
237 | predicted_list.append(predicted)
238 | predicted_dir_list.append(predicted_dir)
239 |
240 | # (c) Advance each beam.
241 | # update state
242 |
243 | for j, b in enumerate(beam):
244 | b.advance(output.data[:, j], attn.data[:, j], hidden.data[:, j], emb.data[:, j])
245 | b.beam_update(decState, j) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
246 | if self.config.ct_recu:
247 | b.beam_update_context(contexts, j) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
248 | for i, a in enumerate(beam_dir):
249 | a.advance(output_dir.data[:, i], attn_dir.data[:, i], hidden_dir.data[:, i], emb_dir.data[:, i])
250 | a.beam_update(decState_dir, i) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
251 | if self.config.ct_recu:
252 | a.beam_update_context(contexts, i) # 这个函数更新了原来的decState,只不过不是用return,是直接赋值!
253 | # print "beam after decState:",decState[0].data.cpu().numpy().mean()
254 |
255 | # (3) Package everything up.
256 | allHyps,allHyps_dir, allScores, allAttn, allHiddens, allEmbs = [],[], [], [], [], []
257 |
258 | ind = range(batch_size)
259 | for j in ind:
260 | b = beam[j]
261 | c = beam_dir[j]
262 | n_best = 1
263 | scores, ks = b.sortFinished(minimum=n_best)
264 | hyps, hyps_dir, attn, hiddens, embs = [], [], [], [], []
265 | for i, (times, k) in enumerate(ks[:n_best]):
266 | hyp, att, hidden, emb = b.getHyp(times, k)
267 | hyp_dir, att_dir, hidden_dir, emb_dir = c.getHyp(times, k)
268 | if self.config.relitu:
269 | relitu_line(626, 1, att[0].cpu().numpy())
270 | relitu_line(626, 1, att[1].cpu().numpy())
271 | hyps.append(hyp)
272 | attn.append(att.max(1)[1])
273 | hiddens.append(hidden+hidden_dir)
274 | embs.append(emb+emb_dir)
275 | hyps_dir.append(hyp_dir)
276 | allHyps.append(hyps[0])
277 | allHyps_dir.append(hyps_dir[0])
278 | allScores.append(scores[0])
279 | allAttn.append(attn[0])
280 | allHiddens.append(hiddens[0])
281 | allEmbs.append(embs[0])
282 |
283 | ss_embs = Variable(torch.stack(allEmbs, 0).transpose(0, 1)) # to [decLen, bs, dim]
284 | if self.config.use_tas:
285 | predicted_maps = self.ss_model(mix, ss_embs[1:].transpose(0,1))
286 |
287 | predicted_signal = self.TasNetDecoder(mix, predicted_maps) # [batch, spkN, timeStep]
288 | return allHyps, allHyps_dir, allAttn, allHiddens, predicted_signal.transpose(0,1), output_list, output_dir_list, output_bk_list, output_bk_dir_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list
289 |
--------------------------------------------------------------------------------
/predata_WSJ_lcx.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import os
3 | import numpy as np
4 | import random
5 | import re
6 | import soundfile as sf
7 | import resampy
8 | import librosa
9 | import argparse
10 | import data.utils as utils
11 | import models
12 | from scipy.io import wavfile
13 | import scipy.signal
14 | # Add the config.
15 | parser = argparse.ArgumentParser(description='predata scripts.')
16 | parser.add_argument('-config', default='config_WSJ0_Tasnet.yaml', type=str, help="config file")
17 | opt = parser.parse_args()
18 | config = utils.read_config(opt.config)
19 |
20 | channel_first = config.channel_first
21 | np.random.seed(1)
22 | random.seed(1)
23 |
24 | data_path = '/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel'
25 |
26 | def pad_list(xs, pad_value):
27 | n_batch = len(xs)
28 | max_len = max(x.size(0) for x in xs)
29 | pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
30 | for i in range(n_batch):
31 | pad[i, :xs[i].size(0)] = xs[i]
32 | return pad
33 |
34 | def get_energy_order(multi_spk_fea_list):
35 | order=[]
36 | for one_line in multi_spk_fea_list:
37 | dd=sorted(one_line.items(),key= lambda d:d[1].sum(),reverse=True)
38 | dd=[d[0] for d in dd]
39 | order.append(dd)
40 | return order
41 |
42 | def get_spk_order(dir_tgt,raw_tgt):
43 | raw_tgt_dir=[]
44 | i=0
45 | for sample in raw_tgt:
46 | dd = [dir_tgt[i][spk] for spk in sample]
47 | raw_tgt_dir.append(dd)
48 | i=i+1
49 | return raw_tgt_dir
50 |
51 |
52 | def _collate_fn(mix_data,source_data,raw_tgt=None):
53 | """
54 | Args:
55 | batch: list, len(batch) = 1. See AudioDataset.__getitem__()
56 | Returns:
57 | mixtures_pad: B x T, torch.Tensor
58 | ilens : B, torch.Tentor
59 | sources_pad: B x C x T, torch.Tensor
60 | """
61 | mixtures, sources = mix_data,source_data
62 | if raw_tgt is None: #如果没有给定顺序
63 | raw_tgt = [sorted(spk.keys()) for spk in source_data]
64 | # sources= models.rank_feas(raw_tgt, source_data,out_type='numpy')
65 | sources=[]
66 | for each_feas, each_line in zip(source_data, raw_tgt):
67 | sources.append(np.stack([each_feas[spk] for spk in each_line]))
68 | sources=np.array(sources)
69 | mixtures=np.array(mixtures)
70 | # get batch of lengths of input sequences
71 | ilens = np.array([mix.shape[0] for mix in mixtures])
72 |
73 | # perform padding and convert to tensor
74 | pad_value = 0
75 | # mixtures_pad = pad_list([mix.float() for mix in mixtures], pad_value)
76 | ilens = ilens
77 | # sources_pad = pad_list([torch.from_numpy(s).float() for s in sources], pad_value)
78 | # N x T x C -> N x C x T
79 | # sources_pad = sources_pad.permute((0, 2, 1)).contiguous()
80 | return mixtures, ilens, sources
81 | # return mixtures_pad, ilens, sources_pad
82 |
83 | def prepare_data(mode, train_or_test, min=None, max=None):
84 |
85 | if min:
86 | config.MIN_MIX = min
87 | if max:
88 | config.MAX_MIX = max
89 |
90 | mix_speechs = []
91 | aim_fea = []
92 | aim_spkid = []
93 | aim_spkname = []
94 | query = []
95 | multi_spk_fea_list = []
96 | multi_spk_wav_list = []
97 | direction = []
98 |
99 | if config.MODE == 1:
100 | if config.DATASET == 'WSJ0':
101 | spk_file_tr = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/wav_spk.txt','r')
102 | all_spk_train = [i.replace("\n","") for i in spk_file_tr]
103 | all_spk_train = sorted(all_spk_train)
104 | print(all_spk_train)
105 |
106 | spk_file_tt = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/test/wav_spk.txt','r')
107 | all_spk_test = [i.replace("\n","") for i in spk_file_tt]
108 | all_spk_test = sorted(all_spk_test)
109 | print(all_spk_test)
110 | all_spk = all_spk_train + all_spk_test
111 | print(all_spk)
112 |
113 | all_dir = [i for i in range(1,20)]
114 | dicDirFile = open('/mnt/lustre/xushuang2/lcx/data/amcc-data/2channel/wav_dirLabel2.txt', 'r')#打开数据
115 | dirDict = {}
116 | while True:
117 | line = dicDirFile.readline()
118 | if line == '':
119 | break
120 | index = line.find(' ')
121 | key = line[:index]
122 | #print(key)
123 | value = line[index:]
124 | dirDict[key] = value.replace("\n","").replace(" ","")
125 | dicDirFile.close()
126 |
127 | spk_samples_list = {}
128 | batch_idx = 0
129 | list_path = '/mnt/lustre/xushuang2/lcx/data/create-speaker-mixtures/'
130 | all_samples_list = {}
131 | sample_idx = {}
132 | number_samples = {}
133 | batch_mix = {}
134 | mix_number_list = range(config.MIN_MIX, config.MAX_MIX + 1)
135 | number_samples_all = 0
136 | for mix_k in mix_number_list:
137 | if train_or_test == 'train':
138 | aim_list_path = list_path + 'mix_{}_spk_tr.txt'.format(mix_k)
139 | if train_or_test == 'valid':
140 | aim_list_path = list_path + 'mix_{}_spk_cv.txt'.format(mix_k)
141 | if train_or_test == 'test':
142 | aim_list_path = list_path + 'mix_{}_spk_tt.txt'.format(mix_k)
143 | config.batch_size = 1
144 |
145 | all_samples_list[mix_k] = open(aim_list_path).readlines() # [:31]
146 | number_samples[mix_k] = len(all_samples_list[mix_k])
147 | batch_mix[mix_k] = len(all_samples_list[mix_k]) / config.batch_size
148 | number_samples_all += len(all_samples_list[mix_k])
149 |
150 | sample_idx[mix_k] = 0
151 |
152 | if train_or_test == 'train' and config.SHUFFLE_BATCH:
153 | random.shuffle(all_samples_list[mix_k])
154 | print('shuffle success!', all_samples_list[mix_k][0])
155 |
156 | batch_total = number_samples_all / config.batch_size
157 |
158 | mix_k = random.sample(mix_number_list, 1)[0]
159 | # while True:
160 | for ___ in range(number_samples_all):
161 | if ___ == number_samples_all - 1:
162 | print('ends here.___')
163 | yield False
164 | mix_len = 0
165 | if sample_idx[mix_k] >= batch_mix[mix_k] * config.batch_size:
166 | mix_number_list.remove(mix_k)
167 | try:
168 | mix_k = random.sample(mix_number_list, 1)[0]
169 | except ValueError:
170 | print('seems there gets all over.')
171 | if len(mix_number_list) == 0:
172 | print('all mix number is over~!')
173 | yield False
174 |
175 | batch_idx = 0
176 | mix_speechs = np.zeros((config.batch_size, config.MAX_LEN))
177 | mix_feas = []
178 | mix_phase = []
179 | aim_fea = []
180 | aim_spkid = []
181 | aim_spkname = []
182 | query = []
183 | multi_spk_fea_list = []
184 | multi_spk_order_list=[]
185 | multi_spk_wav_list = []
186 | continue
187 |
188 | all_over = 1
189 | for kkkkk in mix_number_list:
190 | if not sample_idx[kkkkk] >= batch_mix[mix_k] * config.batch_size:
191 | all_over = 0
192 | break
193 | if all_over:
194 | print('all mix number is over~!')
195 | yield False
196 |
197 | # mix_k=random.sample(mix_number_list,1)[0]
198 | if train_or_test == 'train':
199 | aim_spk_k = random.sample(all_spk_train, mix_k)
200 | elif train_or_test == 'test':
201 | aim_spk_k = random.sample(all_spk_test, mix_k)
202 |
203 | aim_spk_k = re.findall('/([0-9][0-9].)/', all_samples_list[mix_k][sample_idx[mix_k]])
204 | aim_spk_db_k = [float(dd) for dd in re.findall(' (.*?) ', all_samples_list[mix_k][sample_idx[mix_k]])]
205 | aim_spk_samplename_k = re.findall('/(.{8})\.wav ', all_samples_list[mix_k][sample_idx[mix_k]])
206 | assert len(aim_spk_k) == mix_k == len(aim_spk_db_k) == len(aim_spk_samplename_k)
207 |
208 | multi_fea_dict_this_sample = {}
209 | multi_wav_dict_this_sample = {}
210 | multi_name_list_this_sample = []
211 | multi_db_dict_this_sample = {}
212 | direction_sample = {}
213 | for k, spk in enumerate(aim_spk_k):
214 |
215 | sample_name = aim_spk_samplename_k[k]
216 | if aim_spk_db_k[k] ==0:
217 | aim_spk_db_k[k] = int(aim_spk_db_k[k])
218 | if train_or_test != 'test':
219 | spk_speech_path = data_path + '/' + 'train' + '/' + sample_name + '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav'
220 | else:
221 | spk_speech_path = data_path + '/' + 'test' + '/' + sample_name + '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav'
222 |
223 | signal, rate = sf.read(spk_speech_path)
224 |
225 | wav_name = sample_name+ '_' +str(aim_spk_db_k[k])+ '_simu_nore.wav'
226 | direction_sample[spk] = dirDict[wav_name]
227 | if rate != config.FRAME_RATE:
228 | print("config.FRAME_RATE",config.FRAME_RATE)
229 | signal = signal.transpose()
230 | signal = resampy.resample(signal, rate, config.FRAME_RATE, filter='kaiser_best')
231 | signal = signal.transpose()
232 |
233 | if signal.shape[0] > config.MAX_LEN:
234 | signal = signal[:config.MAX_LEN,:]
235 |
236 | if signal.shape[0] > mix_len:
237 | mix_len = signal.shape[0]
238 |
239 | signal -= np.mean(signal)
240 | signal /= np.max(np.abs(signal))
241 |
242 | if signal.shape[0] < config.MAX_LEN:
243 | signal = np.r_[signal, np.zeros((config.MAX_LEN - signal.shape[0],signal.shape[1]))]
244 |
245 | if k == 0:
246 | ratio = 10 ** (aim_spk_db_k[k] / 20.0)
247 | signal = ratio * signal
248 | aim_spkname.append(aim_spk_k[0])
249 | aim_spk_speech = signal
250 | aim_spkid.append(aim_spkname)
251 | wav_mix = signal
252 | signal_c0 = signal[:,0]
253 | a,b,frq = scipy.signal.stft(signal_c0, fs=8000, nfft=config.FRAME_LENGTH, noverlap=config.FRAME_SHIFT)
254 | aim_fea_clean = np.transpose(np.abs(frq))
255 | aim_fea.append(aim_fea_clean)
256 | multi_fea_dict_this_sample[spk] = aim_fea_clean
257 | multi_wav_dict_this_sample[spk] = signal[:,0]
258 |
259 | else:
260 | ratio = 10 ** (aim_spk_db_k[k] / 20.0)
261 | signal = ratio * signal
262 | wav_mix = wav_mix + signal
263 | a,b,frq = scipy.signal.stft(signal[:,0], fs=8000, nfft=config.FRAME_LENGTH, noverlap=config.FRAME_SHIFT)
264 | some_fea_clean = np.transpose(np.abs(frq))
265 | multi_fea_dict_this_sample[spk] = some_fea_clean
266 | multi_wav_dict_this_sample[spk] = signal[:,0]
267 |
268 | multi_spk_fea_list.append(multi_fea_dict_this_sample)
269 | multi_spk_wav_list.append(multi_wav_dict_this_sample)
270 |
271 | mix_speechs.append(wav_mix)
272 | direction.append(direction_sample)
273 | batch_idx += 1
274 |
275 | if batch_idx == config.batch_size:
276 | mix_k = random.sample(mix_number_list, 1)[0]
277 | aim_fea = np.array(aim_fea)
278 | query = np.array(query)
279 | print('spk_list_from_this_gen:{}'.format(aim_spkname))
280 | print('aim spk list:', [one.keys() for one in multi_spk_fea_list])
281 | batch_ordre=get_energy_order(multi_spk_wav_list)
282 | direction = get_spk_order(direction, batch_ordre)
283 | if mode == 'global':
284 | all_spk = sorted(all_spk)
285 | all_spk = sorted(all_spk_train)
286 | all_spk.insert(0, '') # 添加两个结构符号,来标识开始或结束。
287 | all_spk.append('')
288 | all_dir = sorted(all_dir)
289 | all_dir.insert(0, '')
290 | all_dir.append('')
291 | all_spk_test = sorted(all_spk_test)
292 | dict_spk_to_idx = {spk: idx for idx, spk in enumerate(all_spk)}
293 | dict_idx_to_spk = {idx: spk for idx, spk in enumerate(all_spk)}
294 | dict_dir_to_idx = {dire: idx for idx, dire in enumerate(all_dir)}
295 | dict_idx_to_dir = {idx: dire for idx, dire in enumerate(all_dir)}
296 | yield {'all_spk': all_spk,
297 | 'dict_spk_to_idx': dict_spk_to_idx,
298 | 'dict_idx_to_spk': dict_idx_to_spk,
299 | 'all_dir': all_dir,
300 | 'dict_dir_to_idx': dict_dir_to_idx,
301 | 'dict_idx_to_dir': dict_idx_to_dir,
302 | 'num_fre': aim_fea.shape[2],
303 | 'num_frames': aim_fea.shape[1],
304 | 'total_spk_num': len(all_spk),
305 | 'total_batch_num': batch_total
306 | }
307 | elif mode == 'once':
308 | yield {'mix_wav': mix_speechs,
309 | 'aim_fea': aim_fea,
310 | 'aim_spkname': aim_spkname,
311 | 'direction': direction,
312 | 'query': query,
313 | 'num_all_spk': len(all_spk),
314 | 'multi_spk_fea_list': multi_spk_fea_list,
315 | 'multi_spk_wav_list': multi_spk_wav_list,
316 | 'batch_order': batch_ordre,
317 | 'batch_total': batch_total,
318 | 'tas_zip': _collate_fn(mix_speechs,multi_spk_wav_list,batch_ordre)
319 | }
320 | elif mode == 'tasnet':
321 | yield _collate_fn(mix_speechs,multi_spk_wav_list)
322 |
323 | batch_idx = 0
324 | mix_speechs = []
325 | aim_fea = []
326 | aim_spkid = []
327 | aim_spkname = []
328 | query = []
329 | multi_spk_fea_list = []
330 | multi_spk_wav_list = []
331 | direction = []
332 | sample_idx[mix_k] += 1
333 |
334 | else:
335 | raise ValueError('No such dataset:{} for Speech.'.format(config.DATASET))
336 | pass
337 |
338 | else:
339 | raise ValueError('No such Model:{}'.format(config.MODE))
340 |
341 | if __name__ == '__main__':
342 | train_len=[]
343 | train_data_gen = prepare_data('once', 'train')
344 | while True:
345 | train_data_gen.next()
346 | pass
347 | print(np.array(train_len).mean())
348 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | #TRAIN
4 | python train_WSJ0_SDNet.py
5 |
6 | #TEST
7 | python -u test_WSJ0_SDNet.py
8 |
9 |
--------------------------------------------------------------------------------
/separation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | Source separation algorithms attempt to extract recordings of individual
4 | sources from a recording of a mixture of sources. Evaluation methods for
5 | source separation compare the extracted sources from reference sources and
6 | attempt to measure the perceptual quality of the separation.
7 |
8 | See also the bss_eval MATLAB toolbox:
9 | http://bass-db.gforge.inria.fr/bss_eval/
10 |
11 | Conventions
12 | -----------
13 |
14 | An audio signal is expected to be in the format of a 1-dimensional array where
15 | the entries are the samples of the audio signal. When providing a group of
16 | estimated or reference sources, they should be provided in a 2-dimensional
17 | array, where the first dimension corresponds to the source number and the
18 | second corresponds to the samples.
19 |
20 | Metrics
21 | -------
22 |
23 | * :func:`mir_eval.separation.bss_eval_sources`: Computes the bss_eval_sources
24 | metrics from bss_eval, which optionally optimally match the estimated sources
25 | to the reference sources and measure the distortion and artifacts present in
26 | the estimated sources as well as the interference between them.
27 |
28 | * :func:`mir_eval.separation.bss_eval_sources_framewise`: Computes the
29 | bss_eval_sources metrics on a frame-by-frame basis.
30 |
31 | * :func:`mir_eval.separation.bss_eval_images`: Computes the bss_eval_images
32 | metrics from bss_eval, which includes the metrics in
33 | :func:`mir_eval.separation.bss_eval_sources` plus the image to spatial
34 | distortion ratio.
35 |
36 | * :func:`mir_eval.separation.bss_eval_images_framewise`: Computes the
37 | bss_eval_images metrics on a frame-by-frame basis.
38 |
39 | References
40 | ----------
41 | .. [#vincent2006performance] Emmanuel Vincent, Rémi Gribonval, and Cédric
42 | Févotte, "Performance measurement in blind audio source separation," IEEE
43 | Trans. on Audio, Speech and Language Processing, 14(4):1462-1469, 2006.
44 |
45 |
46 | '''
47 |
48 | import numpy as np
49 | import scipy.fftpack
50 | from scipy.linalg import toeplitz
51 | from scipy.signal import fftconvolve
52 | import collections
53 | import itertools
54 | import warnings
55 |
56 | # The maximum allowable number of sources (prevents insane computational load)
57 | MAX_SOURCES = 100
58 |
59 |
60 | def validate(reference_sources, estimated_sources):
61 | """Checks that the input data to a metric are valid, and throws helpful
62 | errors if not.
63 |
64 | Parameters
65 | ----------
66 | reference_sources : np.ndarray, shape=(nsrc, nsampl)
67 | matrix containing true sources
68 | estimated_sources : np.ndarray, shape=(nsrc, nsampl)
69 | matrix containing estimated sources
70 |
71 | """
72 |
73 | if reference_sources.shape != estimated_sources.shape:
74 | raise ValueError('The shape of estimated sources and the true '
75 | 'sources should match. reference_sources.shape '
76 | '= {}, estimated_sources.shape '
77 | '= {}'.format(reference_sources.shape,
78 | estimated_sources.shape))
79 |
80 | if reference_sources.ndim > 3 or estimated_sources.ndim > 3:
81 | raise ValueError('The number of dimensions is too high (must be less '
82 | 'than 3). reference_sources.ndim = {}, '
83 | 'estimated_sources.ndim '
84 | '= {}'.format(reference_sources.ndim,
85 | estimated_sources.ndim))
86 |
87 | if reference_sources.size == 0:
88 | warnings.warn("reference_sources is empty, should be of size "
89 | "(nsrc, nsample). sdr, sir, sar, and perm will all "
90 | "be empty np.ndarrays")
91 | elif _any_source_silent(reference_sources):
92 | raise ValueError('All the reference sources should be non-silent (not '
93 | 'all-zeros), but at least one of the reference '
94 | 'sources is all 0s, which introduces ambiguity to the'
95 | ' evaluation. (Otherwise we can add infinitely many '
96 | 'all-zero sources.)')
97 |
98 | if estimated_sources.size == 0:
99 | warnings.warn("estimated_sources is empty, should be of size "
100 | "(nsrc, nsample). sdr, sir, sar, and perm will all "
101 | "be empty np.ndarrays")
102 | elif _any_source_silent(estimated_sources):
103 | raise ValueError('All the estimated sources should be non-silent (not '
104 | 'all-zeros), but at least one of the estimated '
105 | 'sources is all 0s. Since we require each reference '
106 | 'source to be non-silent, having a silent estimated '
107 | 'source will result in an underdetermined system.')
108 |
109 | if (estimated_sources.shape[0] > MAX_SOURCES or
110 | reference_sources.shape[0] > MAX_SOURCES):
111 | raise ValueError('The supplied matrices should be of shape (nsrc,'
112 | ' nsampl) but reference_sources.shape[0] = {} and '
113 | 'estimated_sources.shape[0] = {} which is greater '
114 | 'than mir_eval.separation.MAX_SOURCES = {}. To '
115 | 'override this check, set '
116 | 'mir_eval.separation.MAX_SOURCES to a '
117 | 'larger value.'.format(reference_sources.shape[0],
118 | estimated_sources.shape[0],
119 | MAX_SOURCES))
120 |
121 |
122 | def _any_source_silent(sources):
123 | """Returns true if the parameter sources has any silent first dimensions"""
124 | return np.any(np.all(np.sum(
125 | sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1))
126 |
127 |
128 | def bss_eval_sources(reference_sources, estimated_sources,
129 | compute_permutation=True):
130 | """
131 | Ordering and measurement of the separation quality for estimated source
132 | signals in terms of filtered true source, interference and artifacts.
133 |
134 | The decomposition allows a time-invariant filter distortion of length
135 | 512, as described in Section III.B of [#vincent2006performance]_.
136 |
137 | Passing ``False`` for ``compute_permutation`` will improve the computation
138 | performance of the evaluation; however, it is not always appropriate and
139 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_sources.
140 |
141 | Examples
142 | --------
143 | >>> # reference_sources[n] should be an ndarray of samples of the
144 | >>> # n'th reference source
145 | >>> # estimated_sources[n] should be the same for the n'th estimated
146 | >>> # source
147 | >>> (sdr, sir, sar,
148 | ... perm) = mir_eval.separation.bss_eval_sources(reference_sources,
149 | ... estimated_sources)
150 |
151 | Parameters
152 | ----------
153 | reference_sources : np.ndarray, shape=(nsrc, nsampl)
154 | matrix containing true sources (must have same shape as
155 | estimated_sources)
156 | estimated_sources : np.ndarray, shape=(nsrc, nsampl)
157 | matrix containing estimated sources (must have same shape as
158 | reference_sources)
159 | compute_permutation : bool, optional
160 | compute permutation of estimate/source combinations (True by default)
161 |
162 | Returns
163 | -------
164 | sdr : np.ndarray, shape=(nsrc,)
165 | vector of Signal to Distortion Ratios (SDR)
166 | sir : np.ndarray, shape=(nsrc,)
167 | vector of Source to Interference Ratios (SIR)
168 | sar : np.ndarray, shape=(nsrc,)
169 | vector of Sources to Artifacts Ratios (SAR)
170 | perm : np.ndarray, shape=(nsrc,)
171 | vector containing the best ordering of estimated sources in
172 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to
173 | true source number ``j``). Note: ``perm`` will be ``[0, 1, ...,
174 | nsrc-1]`` if ``compute_permutation`` is ``False``.
175 |
176 | References
177 | ----------
178 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau
179 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik
180 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign
181 | (2007-2010): Achievements and remaining challenges", Signal Processing,
182 | 92, pp. 1928-1936, 2012.
183 |
184 | """
185 |
186 | # make sure the input is of shape (nsrc, nsampl)
187 | if estimated_sources.ndim == 1:
188 | estimated_sources = estimated_sources[np.newaxis, :]
189 | if reference_sources.ndim == 1:
190 | reference_sources = reference_sources[np.newaxis, :]
191 |
192 | validate(reference_sources, estimated_sources)
193 | # If empty matrices were supplied, return empty lists (special case)
194 | if reference_sources.size == 0 or estimated_sources.size == 0:
195 | return np.array([]), np.array([]), np.array([]), np.array([])
196 |
197 | nsrc = estimated_sources.shape[0]
198 |
199 | # does user desire permutations?
200 | if compute_permutation:
201 | # compute criteria for all possible pair matches
202 | sdr = np.empty((nsrc, nsrc))
203 | sir = np.empty((nsrc, nsrc))
204 | sar = np.empty((nsrc, nsrc))
205 | for jest in range(nsrc):
206 | for jtrue in range(nsrc):
207 | s_true, e_spat, e_interf, e_artif = \
208 | _bss_decomp_mtifilt(reference_sources,
209 | estimated_sources[jest],
210 | jtrue, 512)
211 | sdr[jest, jtrue], sir[jest, jtrue], sar[jest, jtrue] = \
212 | _bss_source_crit(s_true, e_spat, e_interf, e_artif)
213 |
214 | # select the best ordering
215 | perms = list(itertools.permutations(list(range(nsrc))))
216 | mean_sir = np.empty(len(perms))
217 | dum = np.arange(nsrc)
218 | for (i, perm) in enumerate(perms):
219 | mean_sir[i] = np.mean(sir[perm, dum])
220 | popt = perms[np.argmax(mean_sir)]
221 | idx = (popt, dum)
222 | return (sdr[idx], sir[idx], sar[idx], np.asarray(popt))
223 | else:
224 | # compute criteria for only the simple correspondence
225 | # (estimate 1 is estimate corresponding to reference source 1, etc.)
226 | sdr = np.empty(nsrc)
227 | sir = np.empty(nsrc)
228 | sar = np.empty(nsrc)
229 | for j in range(nsrc):
230 | s_true, e_spat, e_interf, e_artif = \
231 | _bss_decomp_mtifilt(reference_sources,
232 | estimated_sources[j],
233 | j, 512)
234 | sdr[j], sir[j], sar[j] = \
235 | _bss_source_crit(s_true, e_spat, e_interf, e_artif)
236 |
237 | # return the default permutation for compatibility
238 | popt = np.arange(nsrc)
239 | return (sdr, sir, sar, popt)
240 |
241 |
242 | def bss_eval_sources_framewise(reference_sources, estimated_sources,
243 | window=30 * 44100, hop=15 * 44100,
244 | compute_permutation=False):
245 | """Framewise computation of bss_eval_sources
246 |
247 | Please be aware that this function does not compute permutations (by
248 | default) on the possible relations between reference_sources and
249 | estimated_sources due to the dangers of a changing permutation. Therefore
250 | (by default), it assumes that ``reference_sources[i]`` corresponds to
251 | ``estimated_sources[i]``. To enable computing permutations please set
252 | ``compute_permutation`` to be ``True`` and check that the returned ``perm``
253 | is identical for all windows.
254 |
255 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated
256 | using only a single window or are shorter than the window length, the
257 | result of :func:`mir_eval.separation.bss_eval_sources` called on
258 | ``reference_sources`` and ``estimated_sources`` (with the
259 | ``compute_permutation`` parameter passed to
260 | :func:`mir_eval.separation.bss_eval_sources`) is returned.
261 |
262 | Examples
263 | --------
264 | >>> # reference_sources[n] should be an ndarray of samples of the
265 | >>> # n'th reference source
266 | >>> # estimated_sources[n] should be the same for the n'th estimated
267 | >>> # source
268 | >>> (sdr, sir, sar,
269 | ... perm) = mir_eval.separation.bss_eval_sources_framewise(
270 | reference_sources,
271 | ... estimated_sources)
272 |
273 | Parameters
274 | ----------
275 | reference_sources : np.ndarray, shape=(nsrc, nsampl)
276 | matrix containing true sources (must have the same shape as
277 | ``estimated_sources``)
278 | estimated_sources : np.ndarray, shape=(nsrc, nsampl)
279 | matrix containing estimated sources (must have the same shape as
280 | ``reference_sources``)
281 | window : int, optional
282 | Window length for framewise evaluation (default value is 30s at a
283 | sample rate of 44.1kHz)
284 | hop : int, optional
285 | Hop size for framewise evaluation (default value is 15s at a
286 | sample rate of 44.1kHz)
287 | compute_permutation : bool, optional
288 | compute permutation of estimate/source combinations for all windows
289 | (False by default)
290 |
291 | Returns
292 | -------
293 | sdr : np.ndarray, shape=(nsrc, nframes)
294 | vector of Signal to Distortion Ratios (SDR)
295 | sir : np.ndarray, shape=(nsrc, nframes)
296 | vector of Source to Interference Ratios (SIR)
297 | sar : np.ndarray, shape=(nsrc, nframes)
298 | vector of Sources to Artifacts Ratios (SAR)
299 | perm : np.ndarray, shape=(nsrc, nframes)
300 | vector containing the best ordering of estimated sources in
301 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to
302 | true source number ``j``). Note: ``perm`` will be ``range(nsrc)`` for
303 | all windows if ``compute_permutation`` is ``False``
304 |
305 | """
306 |
307 | # make sure the input is of shape (nsrc, nsampl)
308 | if estimated_sources.ndim == 1:
309 | estimated_sources = estimated_sources[np.newaxis, :]
310 | if reference_sources.ndim == 1:
311 | reference_sources = reference_sources[np.newaxis, :]
312 |
313 | validate(reference_sources, estimated_sources)
314 | # If empty matrices were supplied, return empty lists (special case)
315 | if reference_sources.size == 0 or estimated_sources.size == 0:
316 | return np.array([]), np.array([]), np.array([]), np.array([])
317 |
318 | nsrc = reference_sources.shape[0]
319 |
320 | nwin = int(
321 | np.floor((reference_sources.shape[1] - window + hop) / hop)
322 | )
323 | # if fewer than 2 windows would be evaluated, return the sources result
324 | if nwin < 2:
325 | result = bss_eval_sources(reference_sources,
326 | estimated_sources,
327 | compute_permutation)
328 | return [np.expand_dims(score, -1) for score in result]
329 |
330 | # compute the criteria across all windows
331 | sdr = np.empty((nsrc, nwin))
332 | sir = np.empty((nsrc, nwin))
333 | sar = np.empty((nsrc, nwin))
334 | perm = np.empty((nsrc, nwin))
335 |
336 | # k iterates across all the windows
337 | for k in range(nwin):
338 | win_slice = slice(k * hop, k * hop + window)
339 | ref_slice = reference_sources[:, win_slice]
340 | est_slice = estimated_sources[:, win_slice]
341 | # check for a silent frame
342 | if (not _any_source_silent(ref_slice) and
343 | not _any_source_silent(est_slice)):
344 | sdr[:, k], sir[:, k], sar[:, k], perm[:, k] = bss_eval_sources(
345 | ref_slice, est_slice, compute_permutation
346 | )
347 | else:
348 | # if we have a silent frame set results as np.nan
349 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan
350 |
351 | return sdr, sir, sar, perm
352 |
353 |
354 | def bss_eval_images(reference_sources, estimated_sources,
355 | compute_permutation=True):
356 | """Implementation of the bss_eval_images function from the
357 | BSS_EVAL Matlab toolbox.
358 |
359 | Ordering and measurement of the separation quality for estimated source
360 | signals in terms of filtered true source, interference and artifacts.
361 | This method also provides the ISR measure.
362 |
363 | The decomposition allows a time-invariant filter distortion of length
364 | 512, as described in Section III.B of [#vincent2006performance]_.
365 |
366 | Passing ``False`` for ``compute_permutation`` will improve the computation
367 | performance of the evaluation; however, it is not always appropriate and
368 | is not the way that the BSS_EVAL Matlab toolbox computes bss_eval_images.
369 |
370 | Examples
371 | --------
372 | >>> # reference_sources[n] should be an ndarray of samples of the
373 | >>> # n'th reference source
374 | >>> # estimated_sources[n] should be the same for the n'th estimated
375 | >>> # source
376 | >>> (sdr, isr, sir, sar,
377 | ... perm) = mir_eval.separation.bss_eval_images(reference_sources,
378 | ... estimated_sources)
379 |
380 | Parameters
381 | ----------
382 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
383 | matrix containing true sources
384 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
385 | matrix containing estimated sources
386 | compute_permutation : bool, optional
387 | compute permutation of estimate/source combinations (True by default)
388 |
389 | Returns
390 | -------
391 | sdr : np.ndarray, shape=(nsrc,)
392 | vector of Signal to Distortion Ratios (SDR)
393 | isr : np.ndarray, shape=(nsrc,)
394 | vector of source Image to Spatial distortion Ratios (ISR)
395 | sir : np.ndarray, shape=(nsrc,)
396 | vector of Source to Interference Ratios (SIR)
397 | sar : np.ndarray, shape=(nsrc,)
398 | vector of Sources to Artifacts Ratios (SAR)
399 | perm : np.ndarray, shape=(nsrc,)
400 | vector containing the best ordering of estimated sources in
401 | the mean SIR sense (estimated source number ``perm[j]`` corresponds to
402 | true source number ``j``). Note: ``perm`` will be ``(1,2,...,nsrc)``
403 | if ``compute_permutation`` is ``False``.
404 |
405 | References
406 | ----------
407 | .. [#] Emmanuel Vincent, Shoko Araki, Fabian J. Theis, Guido Nolte, Pau
408 | Bofill, Hiroshi Sawada, Alexey Ozerov, B. Vikrham Gowreesunker, Dominik
409 | Lutter and Ngoc Q.K. Duong, "The Signal Separation Evaluation Campaign
410 | (2007-2010): Achievements and remaining challenges", Signal Processing,
411 | 92, pp. 1928-1936, 2012.
412 |
413 | """
414 |
415 | # make sure the input has 3 dimensions
416 | # assuming input is in shape (nsampl) or (nsrc, nsampl)
417 | estimated_sources = np.atleast_3d(estimated_sources)
418 | reference_sources = np.atleast_3d(reference_sources)
419 | # we will ensure input doesn't have more than 3 dimensions in validate
420 |
421 | validate(reference_sources, estimated_sources)
422 | # If empty matrices were supplied, return empty lists (special case)
423 | if reference_sources.size == 0 or estimated_sources.size == 0:
424 | return np.array([]), np.array([]), np.array([]), \
425 | np.array([]), np.array([])
426 |
427 | # determine size parameters
428 | nsrc = estimated_sources.shape[0]
429 | nsampl = estimated_sources.shape[1]
430 | nchan = estimated_sources.shape[2]
431 |
432 | # does the user desire permutation?
433 | if compute_permutation:
434 | # compute criteria for all possible pair matches
435 | sdr = np.empty((nsrc, nsrc))
436 | isr = np.empty((nsrc, nsrc))
437 | sir = np.empty((nsrc, nsrc))
438 | sar = np.empty((nsrc, nsrc))
439 | for jest in range(nsrc):
440 | for jtrue in range(nsrc):
441 | s_true, e_spat, e_interf, e_artif = \
442 | _bss_decomp_mtifilt_images(
443 | reference_sources,
444 | np.reshape(
445 | estimated_sources[jest],
446 | (nsampl, nchan),
447 | order='F'
448 | ),
449 | jtrue,
450 | 512
451 | )
452 | sdr[jest, jtrue], isr[jest, jtrue], \
453 | sir[jest, jtrue], sar[jest, jtrue] = \
454 | _bss_image_crit(s_true, e_spat, e_interf, e_artif)
455 |
456 | # select the best ordering
457 | perms = list(itertools.permutations(range(nsrc)))
458 | mean_sir = np.empty(len(perms))
459 | dum = np.arange(nsrc)
460 | for (i, perm) in enumerate(perms):
461 | mean_sir[i] = np.mean(sir[perm, dum])
462 | popt = perms[np.argmax(mean_sir)]
463 | idx = (popt, dum)
464 | return (sdr[idx], isr[idx], sir[idx], sar[idx], np.asarray(popt))
465 | else:
466 | # compute criteria for only the simple correspondence
467 | # (estimate 1 is estimate corresponding to reference source 1, etc.)
468 | sdr = np.empty(nsrc)
469 | isr = np.empty(nsrc)
470 | sir = np.empty(nsrc)
471 | sar = np.empty(nsrc)
472 | Gj = [0] * nsrc # prepare G matrics with zeroes
473 | G = np.zeros(1)
474 | for j in range(nsrc):
475 | # save G matrix to avoid recomputing it every call
476 | s_true, e_spat, e_interf, e_artif, Gj_temp, G = \
477 | _bss_decomp_mtifilt_images(reference_sources,
478 | np.reshape(estimated_sources[j],
479 | (nsampl, nchan),
480 | order='F'),
481 | j, 512, Gj[j], G)
482 | Gj[j] = Gj_temp
483 | sdr[j], isr[j], sir[j], sar[j] = \
484 | _bss_image_crit(s_true, e_spat, e_interf, e_artif)
485 |
486 | # return the default permutation for compatibility
487 | popt = np.arange(nsrc)
488 | return (sdr, isr, sir, sar, popt)
489 |
490 |
491 | def bss_eval_images_framewise(reference_sources, estimated_sources,
492 | window=30 * 44100, hop=15 * 44100,
493 | compute_permutation=False):
494 | """Framewise computation of bss_eval_images
495 |
496 | Please be aware that this function does not compute permutations (by
497 | default) on the possible relations between ``reference_sources`` and
498 | ``estimated_sources`` due to the dangers of a changing permutation.
499 | Therefore (by default), it assumes that ``reference_sources[i]``
500 | corresponds to ``estimated_sources[i]``. To enable computing permutations
501 | please set ``compute_permutation`` to be ``True`` and check that the
502 | returned ``perm`` is identical for all windows.
503 |
504 | NOTE: if ``reference_sources`` and ``estimated_sources`` would be evaluated
505 | using only a single window or are shorter than the window length, the
506 | result of ``bss_eval_sources`` called on ``reference_sources`` and
507 | ``estimated_sources`` (with the ``compute_permutation`` parameter passed to
508 | ``bss_eval_images``) is returned
509 |
510 | Examples
511 | --------
512 | >>> # reference_sources[n] should be an ndarray of samples of the
513 | >>> # n'th reference source
514 | >>> # estimated_sources[n] should be the same for the n'th estimated
515 | >>> # source
516 | >>> (sdr, isr, sir, sar,
517 | ... perm) = mir_eval.separation.bss_eval_images_framewise(
518 | reference_sources,
519 | ... estimated_sources,
520 | window,
521 | .... hop)
522 |
523 | Parameters
524 | ----------
525 | reference_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
526 | matrix containing true sources (must have the same shape as
527 | ``estimated_sources``)
528 | estimated_sources : np.ndarray, shape=(nsrc, nsampl, nchan)
529 | matrix containing estimated sources (must have the same shape as
530 | ``reference_sources``)
531 | window : int
532 | Window length for framewise evaluation
533 | hop : int
534 | Hop size for framewise evaluation
535 | compute_permutation : bool, optional
536 | compute permutation of estimate/source combinations for all windows
537 | (False by default)
538 |
539 | Returns
540 | -------
541 | sdr : np.ndarray, shape=(nsrc, nframes)
542 | vector of Signal to Distortion Ratios (SDR)
543 | isr : np.ndarray, shape=(nsrc, nframes)
544 | vector of source Image to Spatial distortion Ratios (ISR)
545 | sir : np.ndarray, shape=(nsrc, nframes)
546 | vector of Source to Interference Ratios (SIR)
547 | sar : np.ndarray, shape=(nsrc, nframes)
548 | vector of Sources to Artifacts Ratios (SAR)
549 | perm : np.ndarray, shape=(nsrc, nframes)
550 | vector containing the best ordering of estimated sources in
551 | the mean SIR sense (estimated source number perm[j] corresponds to
552 | true source number j)
553 | Note: perm will be range(nsrc) for all windows if compute_permutation
554 | is False
555 |
556 | """
557 |
558 | # make sure the input has 3 dimensions
559 | # assuming input is in shape (nsampl) or (nsrc, nsampl)
560 | estimated_sources = np.atleast_3d(estimated_sources)
561 | reference_sources = np.atleast_3d(reference_sources)
562 | # we will ensure input doesn't have more than 3 dimensions in validate
563 |
564 | validate(reference_sources, estimated_sources)
565 | # If empty matrices were supplied, return empty lists (special case)
566 | if reference_sources.size == 0 or estimated_sources.size == 0:
567 | return np.array([]), np.array([]), np.array([]), np.array([])
568 |
569 | nsrc = reference_sources.shape[0]
570 |
571 | nwin = int(
572 | np.floor((reference_sources.shape[1] - window + hop) / hop)
573 | )
574 | # if fewer than 2 windows would be evaluated, return the images result
575 | if nwin < 2:
576 | result = bss_eval_images(reference_sources,
577 | estimated_sources,
578 | compute_permutation)
579 | return [np.expand_dims(score, -1) for score in result]
580 |
581 | # compute the criteria across all windows
582 | sdr = np.empty((nsrc, nwin))
583 | isr = np.empty((nsrc, nwin))
584 | sir = np.empty((nsrc, nwin))
585 | sar = np.empty((nsrc, nwin))
586 | perm = np.empty((nsrc, nwin))
587 |
588 | # k iterates across all the windows
589 | for k in range(nwin):
590 | win_slice = slice(k * hop, k * hop + window)
591 | ref_slice = reference_sources[:, win_slice, :]
592 | est_slice = estimated_sources[:, win_slice, :]
593 | # check for a silent frame
594 | if (not _any_source_silent(ref_slice) and
595 | not _any_source_silent(est_slice)):
596 | sdr[:, k], isr[:, k], sir[:, k], sar[:, k], perm[:, k] = \
597 | bss_eval_images(
598 | ref_slice, est_slice, compute_permutation
599 | )
600 | else:
601 | # if we have a silent frame set results as np.nan
602 | sdr[:, k] = sir[:, k] = sar[:, k] = perm[:, k] = np.nan
603 |
604 | return sdr, isr, sir, sar, perm
605 |
606 |
607 | def _bss_decomp_mtifilt(reference_sources, estimated_source, j, flen):
608 | """Decomposition of an estimated source image into four components
609 | representing respectively the true source image, spatial (or filtering)
610 | distortion, interference and artifacts, derived from the true source
611 | images using multichannel time-invariant filters.
612 | """
613 | nsampl = estimated_source.size
614 | # decomposition
615 | # true source image
616 | s_true = np.hstack((reference_sources[j], np.zeros(flen - 1)))
617 | # spatial (or filtering) distortion
618 | e_spat = _project(reference_sources[j, np.newaxis, :], estimated_source,
619 | flen) - s_true
620 | # interference
621 | e_interf = _project(reference_sources,
622 | estimated_source, flen) - s_true - e_spat
623 | # artifacts
624 | e_artif = -s_true - e_spat - e_interf
625 | e_artif[:nsampl] += estimated_source
626 | return (s_true, e_spat, e_interf, e_artif)
627 |
628 |
629 | def _bss_decomp_mtifilt_images(reference_sources, estimated_source, j, flen,
630 | Gj=None, G=None):
631 | """Decomposition of an estimated source image into four components
632 | representing respectively the true source image, spatial (or filtering)
633 | distortion, interference and artifacts, derived from the true source
634 | images using multichannel time-invariant filters.
635 | Adapted version to work with multichannel sources.
636 | Improved performance can be gained by passing Gj and G parameters initially
637 | as all zeros. These parameters store the results from the computation of
638 | the G matrix in _project_images and then return them for subsequent calls
639 | to this function. This only works when not computing permuations.
640 | """
641 | nsampl = np.shape(estimated_source)[0]
642 | nchan = np.shape(estimated_source)[1]
643 | # are we saving the Gj and G parameters?
644 | saveg = Gj is not None and G is not None
645 | # decomposition
646 | # true source image
647 | s_true = np.hstack((np.reshape(reference_sources[j],
648 | (nsampl, nchan),
649 | order="F").transpose(),
650 | np.zeros((nchan, flen - 1))))
651 | # spatial (or filtering) distortion
652 | if saveg:
653 | e_spat, Gj = _project_images(reference_sources[j, np.newaxis, :],
654 | estimated_source, flen, Gj)
655 | else:
656 | e_spat = _project_images(reference_sources[j, np.newaxis, :],
657 | estimated_source, flen)
658 | e_spat = e_spat - s_true
659 | # interference
660 | if saveg:
661 | e_interf, G = _project_images(reference_sources,
662 | estimated_source, flen, G)
663 | else:
664 | e_interf = _project_images(reference_sources,
665 | estimated_source, flen)
666 | e_interf = e_interf - s_true - e_spat
667 | # artifacts
668 | e_artif = -s_true - e_spat - e_interf
669 | e_artif[:, :nsampl] += estimated_source.transpose()
670 | # return Gj and G only if they were passed in
671 | if saveg:
672 | return (s_true, e_spat, e_interf, e_artif, Gj, G)
673 | else:
674 | return (s_true, e_spat, e_interf, e_artif)
675 |
676 |
677 | def _project(reference_sources, estimated_source, flen):
678 | """Least-squares projection of estimated source on the subspace spanned by
679 | delayed versions of reference sources, with delays between 0 and flen-1
680 | """
681 | nsrc = reference_sources.shape[0]
682 | nsampl = reference_sources.shape[1]
683 |
684 | # computing coefficients of least squares problem via FFT ##
685 | # zero padding and FFT of input data
686 | reference_sources = np.hstack((reference_sources,
687 | np.zeros((nsrc, flen - 1))))
688 | estimated_source = np.hstack((estimated_source, np.zeros(flen - 1)))
689 | n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.)))
690 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
691 | sef = scipy.fftpack.fft(estimated_source, n=n_fft)
692 | # inner products between delayed versions of reference_sources
693 | G = np.zeros((nsrc * flen, nsrc * flen))
694 | for i in range(nsrc):
695 | for j in range(nsrc):
696 | ssf = sf[i] * np.conj(sf[j])
697 | ssf = np.real(scipy.fftpack.ifft(ssf))
698 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
699 | r=ssf[:flen])
700 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss
701 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T
702 | # inner products between estimated_source and delayed versions of
703 | # reference_sources
704 | D = np.zeros(nsrc * flen)
705 | for i in range(nsrc):
706 | ssef = sf[i] * np.conj(sef)
707 | ssef = np.real(scipy.fftpack.ifft(ssef))
708 | D[i * flen: (i + 1) * flen] = np.hstack((ssef[0], ssef[-1:-flen:-1]))
709 |
710 | # Computing projection
711 | # Distortion filters
712 | try:
713 | C = np.linalg.solve(G, D).reshape(flen, nsrc, order='F')
714 | except np.linalg.linalg.LinAlgError:
715 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nsrc, order='F')
716 | # Filtering
717 | sproj = np.zeros(nsampl + flen - 1)
718 | for i in range(nsrc):
719 | sproj += fftconvolve(C[:, i], reference_sources[i])[:nsampl + flen - 1]
720 | return sproj
721 |
722 |
723 | def _project_images(reference_sources, estimated_source, flen, G=None):
724 | """Least-squares projection of estimated source on the subspace spanned by
725 | delayed versions of reference sources, with delays between 0 and flen-1.
726 | Passing G as all zeros will populate the G matrix and return it so it can
727 | be passed into the next call to avoid recomputing G (this will only works
728 | if not computing permutations).
729 | """
730 | nsrc = reference_sources.shape[0]
731 | nsampl = reference_sources.shape[1]
732 | nchan = reference_sources.shape[2]
733 | reference_sources = np.reshape(np.transpose(reference_sources, (2, 0, 1)),
734 | (nchan * nsrc, nsampl), order='F')
735 |
736 | # computing coefficients of least squares problem via FFT ##
737 | # zero padding and FFT of input data
738 | reference_sources = np.hstack((reference_sources,
739 | np.zeros((nchan * nsrc, flen - 1))))
740 | estimated_source = \
741 | np.hstack((estimated_source.transpose(), np.zeros((nchan, flen - 1))))
742 | n_fft = int(2 ** np.ceil(np.log2(nsampl + flen - 1.)))
743 | sf = scipy.fftpack.fft(reference_sources, n=n_fft, axis=1)
744 | sef = scipy.fftpack.fft(estimated_source, n=n_fft)
745 |
746 | # inner products between delayed versions of reference_sources
747 | if G is None:
748 | saveg = False
749 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen))
750 | for i in range(nchan * nsrc):
751 | for j in range(i + 1):
752 | ssf = sf[i] * np.conj(sf[j])
753 | ssf = np.real(scipy.fftpack.ifft(ssf))
754 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
755 | r=ssf[:flen])
756 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss
757 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T
758 | else: # avoid recomputing G (only works if no permutation is desired)
759 | saveg = True # return G
760 | if np.all(G == 0): # only compute G if passed as 0
761 | G = np.zeros((nchan * nsrc * flen, nchan * nsrc * flen))
762 | for i in range(nchan * nsrc):
763 | for j in range(i + 1):
764 | ssf = sf[i] * np.conj(sf[j])
765 | ssf = np.real(scipy.fftpack.ifft(ssf))
766 | ss = toeplitz(np.hstack((ssf[0], ssf[-1:-flen:-1])),
767 | r=ssf[:flen])
768 | G[i * flen: (i + 1) * flen, j * flen: (j + 1) * flen] = ss
769 | G[j * flen: (j + 1) * flen, i * flen: (i + 1) * flen] = ss.T
770 |
771 | # inner products between estimated_source and delayed versions of
772 | # reference_sources
773 | D = np.zeros((nchan * nsrc * flen, nchan))
774 | for k in range(nchan * nsrc):
775 | for i in range(nchan):
776 | ssef = sf[k] * np.conj(sef[i])
777 | ssef = np.real(scipy.fftpack.ifft(ssef))
778 | D[k * flen: (k + 1) * flen, i] = \
779 | np.hstack((ssef[0], ssef[-1:-flen:-1])).transpose()
780 |
781 | # Computing projection
782 | # Distortion filters
783 | try:
784 | C = np.linalg.solve(G, D).reshape(flen, nchan * nsrc, nchan, order='F')
785 | except np.linalg.linalg.LinAlgError:
786 | C = np.linalg.lstsq(G, D)[0].reshape(flen, nchan * nsrc, nchan,
787 | order='F')
788 | # Filtering
789 | sproj = np.zeros((nchan, nsampl + flen - 1))
790 | for k in range(nchan * nsrc):
791 | for i in range(nchan):
792 | sproj[i] += fftconvolve(C[:, k, i].transpose(),
793 | reference_sources[k])[:nsampl + flen - 1]
794 | # return G only if it was passed in
795 | if saveg:
796 | return sproj, G
797 | else:
798 | return sproj
799 |
800 |
801 | def _bss_source_crit(s_true, e_spat, e_interf, e_artif):
802 | """Measurement of the separation quality for a given source in terms of
803 | filtered true source, interference and artifacts.
804 | """
805 | # energy ratios
806 | s_filt = s_true + e_spat
807 | sdr = _safe_db(np.sum(s_filt ** 2), np.sum((e_interf + e_artif) ** 2))
808 | sir = _safe_db(np.sum(s_filt ** 2), np.sum(e_interf ** 2))
809 | sar = _safe_db(np.sum((s_filt + e_interf) ** 2), np.sum(e_artif ** 2))
810 | return (sdr, sir, sar)
811 |
812 |
813 | def _bss_image_crit(s_true, e_spat, e_interf, e_artif):
814 | """Measurement of the separation quality for a given image in terms of
815 | filtered true source, spatial error, interference and artifacts.
816 | """
817 | # energy ratios
818 | sdr = _safe_db(np.sum(s_true ** 2), np.sum((e_spat + e_interf + e_artif) ** 2))
819 | isr = _safe_db(np.sum(s_true ** 2), np.sum(e_spat ** 2))
820 | sir = _safe_db(np.sum((s_true + e_spat) ** 2), np.sum(e_interf ** 2))
821 | sar = _safe_db(np.sum((s_true + e_spat + e_interf) ** 2), np.sum(e_artif ** 2))
822 | return (sdr, isr, sir, sar)
823 |
824 |
825 | def _safe_db(num, den):
826 | """Properly handle the potential +Inf db SIR, instead of raising a
827 | RuntimeWarning. Only denominator is checked because the numerator can never
828 | be 0.
829 | """
830 | if den == 0:
831 | return np.Inf
832 | return 10 * np.log10(num / den)
833 |
834 |
835 | def evaluate(reference_sources, estimated_sources, **kwargs):
836 | """Compute all metrics for the given reference and estimated signals.
837 |
838 | NOTE: This will always compute :func:`mir_eval.separation.bss_eval_images`
839 | for any valid input and will additionally compute
840 | :func:`mir_eval.separation.bss_eval_sources` for valid input with fewer
841 | than 3 dimensions.
842 |
843 | Examples
844 | --------
845 | >>> # reference_sources[n] should be an ndarray of samples of the
846 | >>> # n'th reference source
847 | >>> # estimated_sources[n] should be the same for the n'th estimated source
848 | >>> scores = mir_eval.separation.evaluate(reference_sources,
849 | ... estimated_sources)
850 |
851 | Parameters
852 | ----------
853 | reference_sources : np.ndarray, shape=(nsrc, nsampl[, nchan])
854 | matrix containing true sources
855 | estimated_sources : np.ndarray, shape=(nsrc, nsampl[, nchan])
856 | matrix containing estimated sources
857 | kwargs
858 | Additional keyword arguments which will be passed to the
859 | appropriate metric or preprocessing functions.
860 |
861 | Returns
862 | -------
863 | scores : dict
864 | Dictionary of scores, where the key is the metric name (str) and
865 | the value is the (float) score achieved.
866 |
867 | """
868 | # Compute all the metrics
869 | scores = collections.OrderedDict()
870 |
871 | sdr, isr, sir, sar, perm = util.filter_kwargs(
872 | bss_eval_images,
873 | reference_sources,
874 | estimated_sources,
875 | **kwargs
876 | )
877 | scores['Images - Source to Distortion'] = sdr.tolist()
878 | scores['Images - Image to Spatial'] = isr.tolist()
879 | scores['Images - Source to Interference'] = sir.tolist()
880 | scores['Images - Source to Artifact'] = sar.tolist()
881 | scores['Images - Source permutation'] = perm.tolist()
882 |
883 | sdr, isr, sir, sar, perm = util.filter_kwargs(
884 | bss_eval_images_framewise,
885 | reference_sources,
886 | estimated_sources,
887 | **kwargs
888 | )
889 | scores['Images Frames - Source to Distortion'] = sdr.tolist()
890 | scores['Images Frames - Image to Spatial'] = isr.tolist()
891 | scores['Images Frames - Source to Interference'] = sir.tolist()
892 | scores['Images Frames - Source to Artifact'] = sar.tolist()
893 | scores['Images Frames - Source permutation'] = perm.tolist()
894 |
895 | # Verify we can compute sources on this input
896 | if reference_sources.ndim < 3 and estimated_sources.ndim < 3:
897 | sdr, sir, sar, perm = util.filter_kwargs(
898 | bss_eval_sources_framewise,
899 | reference_sources,
900 | estimated_sources,
901 | **kwargs
902 | )
903 | scores['Sources Frames - Source to Distortion'] = sdr.tolist()
904 | scores['Sources Frames - Source to Interference'] = sir.tolist()
905 | scores['Sources Frames - Source to Artifact'] = sar.tolist()
906 | scores['Sources Frames - Source permutation'] = perm.tolist()
907 |
908 | sdr, sir, sar, perm = util.filter_kwargs(
909 | bss_eval_sources,
910 | reference_sources,
911 | estimated_sources,
912 | **kwargs
913 | )
914 | scores['Sources - Source to Distortion'] = sdr.tolist()
915 | scores['Sources - Source to Interference'] = sir.tolist()
916 | scores['Sources - Source to Artifact'] = sar.tolist()
917 | scores['Sources - Source permutation'] = perm.tolist()
918 |
919 | return scores
920 |
--------------------------------------------------------------------------------
/test_WSJ0_SDNet.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import os
3 | import argparse
4 | import time
5 | import json
6 | import collections
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.autograd import Variable
11 | import torch.utils.data
12 | import numpy as np
13 |
14 | import models
15 | import data.utils as utils
16 | from optims import Optim
17 | import lr_scheduler as L
18 | from predata_WSJ_lcx import prepare_data
19 | import bss_test
20 | from models.loss import ss_tas_loss
21 | from scipy.io import wavfile
22 | # config
23 | parser = argparse.ArgumentParser(description='train_WSJ_tasnet.py')
24 |
25 | parser.add_argument('-config', default='config_WSJ0_SDNet.yaml', type=str, help="config file")
26 | parser.add_argument('-gpus', default=[3], nargs='+', type=int, help="Use CUDA on the listed devices.")
27 | parser.add_argument('-restore', default='', type=str, help="restore checkpoint")
28 | parser.add_argument('-seed', type=int, default=1234, help="Random seed")
29 | parser.add_argument('-model', default='seq2seq', type=str, help="Model selection")
30 | parser.add_argument('-score', default='', type=str, help="score_fn")
31 | parser.add_argument('-notrain', default=True, type=bool, help="train or not")
32 | parser.add_argument('-log', default='', type=str, help="log directory")
33 | parser.add_argument('-memory', default=False, type=bool, help="memory efficiency")
34 | parser.add_argument('-score_fc', default='', type=str, help="memory efficiency")
35 |
36 | opt = parser.parse_args()
37 | config = utils.read_config(opt.config)
38 | torch.manual_seed(opt.seed)
39 |
40 | # checkpoint
41 | if opt.restore:
42 | print('loading checkpoint...\n', opt.restore)
43 | checkpoints = torch.load(opt.restore,map_location={'cuda:2':'cuda:0'})
44 |
45 | # cuda
46 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0
47 | use_cuda = True
48 | if use_cuda:
49 | torch.cuda.set_device(opt.gpus[0])
50 | torch.cuda.manual_seed(opt.seed)
51 | print(use_cuda)
52 |
53 | # load the global statistic of the data
54 | print('loading data...\n')
55 | start_time = time.time()
56 |
57 | spk_global_gen = prepare_data(mode='global', train_or_test='train') # 数据中的一些统计参数的读取
58 | global_para = next(spk_global_gen)
59 | print(global_para)
60 |
61 | spk_all_list = global_para['all_spk'] # 所有说话人的列表
62 | dict_spk2idx = global_para['dict_spk_to_idx']
63 | dict_idx2spk = global_para['dict_idx_to_spk']
64 | direction_all_list = global_para['all_dir']
65 | dict_dir2idx = global_para['dict_dir_to_idx']
66 | dict_idx2dir = global_para['dict_idx_to_dir']
67 | speech_fre = global_para['num_fre'] # 语音频率总数
68 | total_frames = global_para['num_frames'] # 语音长度
69 | spk_num_total = global_para['total_spk_num'] # 总计说话人数目
70 | batch_total = global_para['total_batch_num'] # 一个epoch里多少个batch
71 |
72 | print(dict_idx2spk)
73 | print(dict_idx2dir)
74 |
75 | config.speech_fre = speech_fre
76 | mix_speech_len = total_frames
77 | config.mix_speech_len = total_frames
78 | num_labels = len(spk_all_list)
79 | num_dir_labels = len(direction_all_list)
80 |
81 | del spk_global_gen
82 | print('loading the global setting cost: %.3f' % (time.time() - start_time))
83 | print("num_dir_labels", num_dir_labels)
84 | print("num_labels", num_labels)
85 | # model
86 | print('building model...\n')
87 | model = getattr(models, opt.model)(config, 256, mix_speech_len, num_labels, num_dir_labels, use_cuda, None, opt.score_fc)
88 |
89 | if opt.restore:
90 | model.load_state_dict(checkpoints['model'])
91 | if use_cuda:
92 | model.cuda()
93 | if len(opt.gpus) > 1:
94 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
95 |
96 | # optimizer
97 | if 0 and opt.restore:
98 | optim = checkpoints['optim']
99 | else:
100 | optim = Optim(config.optim, config.learning_rate, config.max_grad_norm,
101 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at)
102 |
103 | optim.set_parameters(model.parameters())
104 |
105 | if config.schedule:
106 | # scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch)
107 | scheduler = L.StepLR(optim.optimizer, step_size=20, gamma=0.2)
108 |
109 | # total number of parameters
110 | param_count = 0
111 | for param in model.parameters():
112 | param_count += param.view(-1).size()[0]
113 |
114 | # logging modeule
115 | if not os.path.exists(config.log):
116 | os.mkdir(config.log)
117 | if opt.log == '':
118 | log_path = config.log + utils.format_time(time.localtime()) + '/'
119 | else:
120 | log_path = config.log + opt.log + '/'
121 | if not os.path.exists(log_path):
122 | os.mkdir(log_path)
123 | print('log_path:',log_path)
124 |
125 | logging = utils.logging(log_path + 'log.txt') # 单独写一个logging的函数,直接调用,既print,又记录到Log文件里。
126 | logging_csv = utils.logging_csv(log_path + 'record.csv')
127 | for k, v in config.items():
128 | logging("%s:\t%s\n" % (str(k), str(v)))
129 | logging("\n")
130 | logging(repr(model) + "\n\n")
131 |
132 | logging('total number of parameters: %d\n\n' % param_count)
133 | logging('score function is %s\n\n' % opt.score)
134 |
135 | if opt.restore:
136 | updates = checkpoints['updates']
137 | else:
138 | updates = 0
139 |
140 | total_loss, start_time = 0, time.time()
141 | total_loss_sgm, total_loss_ss = 0, 0
142 | report_total, report_correct = 0, 0
143 | report_vocab, report_tot_vocab = 0, 0
144 | scores = [[] for metric in config.metric]
145 | scores = collections.OrderedDict(zip(config.metric, scores))
146 | best_SDR = 0.0
147 | e=0
148 | loss_last_epoch = 1000000.0
149 |
150 | def eval(epoch):
151 | # config.batch_size=1
152 | model.eval()
153 | # print '\n\n测试的时候请设置config里的batch_size为1!!!please set the batch_size as 1'
154 | reference, candidate, source, alignments = [], [], [], []
155 | e = epoch
156 | test_or_valid = 'test'
157 | #test_or_valid = 'valid'
158 | print('Test or valid:', test_or_valid)
159 | eval_data_gen = prepare_data('once', test_or_valid, config.MIN_MIX, config.MAX_MIX)
160 | SDR_SUM = np.array([])
161 | SDRi_SUM = np.array([])
162 | SISNR_SUM = np.array([])
163 | SISNRI_SUM = np.array([])
164 | SS_SUM = np.array([])
165 | batch_idx = 0
166 | global best_SDR, Var
167 | f = open('./results/spk2.txt', 'a')
168 | f_dir = open('./results/dir2.txt', 'a')
169 | f_bk = open('./results/spk_bk.txt', 'a')
170 | f_bk_dir = open('./results/dir_bk.txt', 'a')
171 | f_emb = open('./results/spk_emb.txt', 'a')
172 | f_emb_dir = open('./results/dir_emb.txt', 'a')
173 | f_hidden = open('./results/spk_hidden.txt', 'a')
174 | f_hidden_dir = open('./results/dir_hidden.txt', 'a')
175 | while True:
176 | print('-' * 30)
177 | eval_data =next(eval_data_gen)
178 | if eval_data == False:
179 | print('SDR_aver_eval_epoch:', SDR_SUM.mean())
180 | print('SDRi_aver_eval_epoch:', SDRi_SUM.mean())
181 | print('SISNR_aver_eval_epoch:', SISNR_SUM.mean())
182 | print('SISNRI_aver_eval_epoch:', SISNRI_SUM.mean())
183 | print('SS_aver_eval_epoch:', SS_SUM.mean())
184 | break # 如果这个epoch的生成器没有数据了,直接进入下一个epoch
185 |
186 | raw_tgt= eval_data['batch_order']
187 |
188 | padded_mixture, mixture_lengths, padded_source = eval_data['tas_zip']
189 | padded_mixture=torch.from_numpy(padded_mixture).float()
190 | mixture_lengths=torch.from_numpy(mixture_lengths)
191 | padded_source=torch.from_numpy(padded_source).float()
192 |
193 | padded_mixture = padded_mixture.cuda().transpose(0,1)
194 | mixture_lengths = mixture_lengths.cuda()
195 | padded_source = padded_source.cuda()
196 |
197 | top_k = len(raw_tgt[0])
198 | tgt = Variable(torch.ones(top_k + 2, config.batch_size))
199 | src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
200 | tgt_len = Variable(torch.LongTensor([len(one_spk) for one_spk in eval_data['multi_spk_fea_list']])).unsqueeze(0)
201 |
202 | if use_cuda:
203 | tgt = tgt.cuda()
204 | src_len = src_len.cuda()
205 | tgt_len = tgt_len.cuda()
206 |
207 | if 1 and len(opt.gpus) > 1:
208 | samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.module.beam_sample(padded_mixture, dict_spk2idx, dict_dir2idx, config.beam_size)
209 | else:
210 | samples, samples_dir, alignment, hiddens, predicted_masks, output_list, output_dir_list, output_bk_list, output_dir_bk_list, hidden_list, hidden_dir_list, emb_list, emb_dir_list = model.beam_sample(padded_mixture, dict_spk2idx,dict_dir2idx, config.beam_size)
211 |
212 | predicted_masks = predicted_masks.transpose(0,1)
213 | predicted_masks = predicted_masks[:,0:top_k,:]
214 | mixture = torch.chunk(padded_mixture, 2, dim=-1)
215 | padded_mixture_c0 = mixture[0].squeeze()
216 |
217 | padded_source1= padded_source.data.cpu()
218 | predicted_masks1 = predicted_masks.data.cpu()
219 |
220 | padded_source= padded_source.squeeze().data.cpu().numpy()
221 | padded_mixture = padded_mixture.squeeze().data.cpu().numpy()
222 | predicted_masks = predicted_masks.squeeze().data.cpu().numpy()
223 | padded_mixture_c0 = padded_mixture_c0.squeeze().data.cpu().numpy()
224 | mixture_lengths = mixture_lengths.cpu()
225 |
226 | predicted_masks = predicted_masks - np.mean(predicted_masks)
227 | predicted_masks /= np.max(np.abs(predicted_masks))
228 |
229 | # '''''
230 | if batch_idx <= (3000 / config.batch_size): # only the former batches counts the SDR
231 |
232 | sisnr, sisnri = bss_test.cal_SISNRi_PIT(padded_source, predicted_masks,padded_mixture_c0)
233 | sdr, sdri = bss_test.cal_SDRi(padded_source,predicted_masks, padded_mixture_c0)
234 | loss = ss_tas_loss(config,predicted_masks1, padded_source1, mixture_lengths, True)
235 | loss = loss.numpy()
236 | try:
237 | #SDR_SUM,SDRi_SUM = np.append(SDR_SUM, bss_test.cal('batch_output1/'))
238 | SDR_SUM = np.append(SDR_SUM, sdr)
239 | SDRi_SUM = np.append(SDRi_SUM, sdri)
240 |
241 | SISNR_SUM = np.append(SISNR_SUM, sisnr)
242 | SISNRI_SUM = np.append(SISNRI_SUM, sisnri)
243 | SS_SUM = np.append(SS_SUM, loss)
244 | except:# AssertionError,wrong_info:
245 | print('Errors in calculating the SDR',wrong_info)
246 | print('SDR_aver_now:', SDR_SUM.mean())
247 | print('SDRi_aver_now:', SDRi_SUM.mean())
248 | print('SISNR_aver_now:', SISNR_SUM.mean())
249 | print('SISNRI_aver_now:', SISNRI_SUM.mean())
250 | print('SS_aver_now:', SS_SUM.mean())
251 |
252 | elif batch_idx == (3000 / config.batch_size) + 1 and SDR_SUM.mean() > best_SDR: # only record the best SDR once.
253 | print('Best SDR from {}---->{}'.format(best_SDR, SDR_SUM.mean()))
254 | best_SDR = SDR_SUM.mean()
255 |
256 | # '''
257 | candidate += [convertToLabels(dict_idx2spk, s, dict_spk2idx['']) for s in samples]
258 | # source += raw_src
259 | reference += raw_tgt
260 | print('samples:', samples)
261 | print('can:{}, \nref:{}'.format(candidate[-1 * config.batch_size:], reference[-1 * config.batch_size:]))
262 | alignments += [align for align in alignment]
263 | batch_idx += 1
264 | f.close()
265 | f_dir.close()
266 | score = {}
267 | result = utils.eval_metrics(reference, candidate, dict_spk2idx, log_path)
268 | logging_csv([e, updates, result['hamming_loss'], \
269 | result['micro_f1'], result['micro_precision'], result['micro_recall']])
270 | print('hamming_loss: %.8f | micro_f1: %.4f'
271 | % (result['hamming_loss'], result['micro_f1']))
272 | score['hamming_loss'] = result['hamming_loss']
273 | score['micro_f1'] = result['micro_f1']
274 | return score
275 |
276 |
277 | # Convert `idx` to labels. If index `stop` is reached, convert it and return.
278 | def convertToLabels(dict, idx, stop):
279 | labels = []
280 |
281 | for i in idx:
282 | i = int(i)
283 | if i == stop:
284 | break
285 | labels += [dict[i]]
286 |
287 | return labels
288 |
289 |
290 | def save_model(path):
291 | global updates
292 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
293 | checkpoints = {
294 | 'model': model_state_dict,
295 | 'config': config,
296 | 'optim': optim,
297 | 'updates': updates}
298 |
299 | torch.save(checkpoints, path)
300 |
301 |
302 | def main():
303 |
304 | eval(1)
305 | for metric in config.metric:
306 | logging("Best %s score: %.2f\n" % (metric, max(scores[metric])))
307 |
308 |
309 | if __name__ == '__main__':
310 | main()
311 |
--------------------------------------------------------------------------------
/train_WSJ0_SDNet.py:
--------------------------------------------------------------------------------
1 | # coding=utf8
2 | import os
3 | import argparse
4 | import time
5 | import json
6 | import collections
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.autograd import Variable
11 | import torch.utils.data
12 | import numpy as np
13 |
14 | import models
15 | import data.utils as utils
16 | from optims import Optim
17 | import lr_scheduler as L
18 | from predata_WSJ_lcx import prepare_data
19 | import bss_test
20 | # config
21 | parser = argparse.ArgumentParser(description='train_WSJ_tasnet.py')
22 |
23 | parser.add_argument('-config', default='config_WSJ0_SDNet.yaml', type=str, help="config file")
24 | parser.add_argument('-gpus', default=[0], nargs='+', type=int, help="Use CUDA on the listed devices.")
25 | parser.add_argument('-restore', default='', type=str, help="restore checkpoint")
26 | parser.add_argument('-seed', type=int, default=1234, help="Random seed")
27 | parser.add_argument('-model', default='seq2seq', type=str, help="Model selection")
28 | parser.add_argument('-score', default='', type=str, help="score_fn")
29 | parser.add_argument('-notrain', default=False, type=bool, help="train or not")
30 | parser.add_argument('-log', default='', type=str, help="log directory")
31 | parser.add_argument('-memory', default=False, type=bool, help="memory efficiency")
32 | parser.add_argument('-score_fc', default='', type=str, help="memory efficiency")
33 |
34 | opt = parser.parse_args()
35 | config = utils.read_config(opt.config)
36 | torch.manual_seed(opt.seed)
37 |
38 | # checkpoint
39 | if opt.restore:
40 | print('loading checkpoint...\n', opt.restore)
41 | checkpoints = torch.load(opt.restore,map_location={'cuda:2':'cuda:0'})
42 |
43 | # cuda
44 | use_cuda = torch.cuda.is_available() and len(opt.gpus) > 0
45 | use_cuda = True
46 | if use_cuda:
47 | torch.cuda.set_device(opt.gpus[0])
48 | torch.cuda.manual_seed(opt.seed)
49 | print(use_cuda)
50 |
51 | # load the global statistic of the data
52 | print('loading data...\n')
53 | start_time = time.time()
54 |
55 | spk_global_gen = prepare_data(mode='global', train_or_test='train') # 数据中的一些统计参数的读取
56 | global_para = next(spk_global_gen)
57 | print(global_para)
58 |
59 | spk_all_list = global_para['all_spk'] # 所有说话人的列表
60 | dict_spk2idx = global_para['dict_spk_to_idx']
61 | dict_idx2spk = global_para['dict_idx_to_spk']
62 | direction_all_list = global_para['all_dir']
63 | dict_dir2idx = global_para['dict_dir_to_idx']
64 | dict_idx2dir = global_para['dict_idx_to_dir']
65 | speech_fre = global_para['num_fre'] # 语音频率总数
66 | total_frames = global_para['num_frames'] # 语音长度
67 | spk_num_total = global_para['total_spk_num'] # 总计说话人数目
68 | batch_total = global_para['total_batch_num'] # 一个epoch里多少个batch
69 |
70 | print(dict_idx2spk)
71 | print(dict_idx2dir)
72 |
73 | config.speech_fre = speech_fre
74 | mix_speech_len = total_frames
75 | config.mix_speech_len = total_frames
76 | num_labels = len(spk_all_list)
77 | num_dir_labels = len(direction_all_list)
78 |
79 | del spk_global_gen
80 | print('loading the global setting cost: %.3f' % (time.time() - start_time))
81 | print("num_dir_labels", num_dir_labels)
82 | print("num_labels", num_labels)
83 | # model
84 | print('building model...\n')
85 | model = getattr(models, opt.model)(config, 256, mix_speech_len, num_labels, num_dir_labels, use_cuda, None, opt.score_fc)
86 |
87 | if opt.restore:
88 | model.load_state_dict(checkpoints['model'])
89 | if use_cuda:
90 | model.cuda()
91 | if len(opt.gpus) > 1:
92 | model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
93 |
94 | # optimizer
95 | if 0 and opt.restore:
96 | optim = checkpoints['optim']
97 | else:
98 | optim = Optim(config.optim, config.learning_rate, config.max_grad_norm,
99 | lr_decay=config.learning_rate_decay, start_decay_at=config.start_decay_at)
100 |
101 | optim.set_parameters(model.parameters())
102 |
103 | if config.schedule:
104 | # scheduler = L.CosineAnnealingLR(optim.optimizer, T_max=config.epoch)
105 | scheduler = L.StepLR(optim.optimizer, step_size=20, gamma=0.2)
106 |
107 | # total number of parameters
108 | param_count = 0
109 | for param in model.parameters():
110 | param_count += param.view(-1).size()[0]
111 |
112 | # logging modeule
113 | if not os.path.exists(config.log):
114 | os.mkdir(config.log)
115 | if opt.log == '':
116 | log_path = config.log + utils.format_time(time.localtime()) + '/'
117 | else:
118 | log_path = config.log + opt.log + '/'
119 | if not os.path.exists(log_path):
120 | os.mkdir(log_path)
121 | print('log_path:',log_path)
122 |
123 | logging = utils.logging(log_path + 'log.txt') # 单独写一个logging的函数,直接调用,既print,又记录到Log文件里。
124 | logging_csv = utils.logging_csv(log_path + 'record.csv')
125 | for k, v in config.items():
126 | logging("%s:\t%s\n" % (str(k), str(v)))
127 | logging("\n")
128 | logging(repr(model) + "\n\n")
129 |
130 | logging('total number of parameters: %d\n\n' % param_count)
131 | logging('score function is %s\n\n' % opt.score)
132 |
133 | if opt.restore:
134 | updates = checkpoints['updates']
135 | else:
136 | updates = 0
137 |
138 | total_loss, start_time = 0, time.time()
139 | total_loss_sgm, total_loss_ss = 0, 0
140 | report_total, report_correct = 0, 0
141 | report_vocab, report_tot_vocab = 0, 0
142 | scores = [[] for metric in config.metric]
143 | scores = collections.OrderedDict(zip(config.metric, scores))
144 | best_SDR = 0.0
145 | e=0
146 | loss_last_epoch = 1000000.0
147 |
148 | def train(epoch):
149 | global e, updates, total_loss, start_time, report_total,report_correct, total_loss_sgm, total_loss_ss,loss_last_epoch
150 | e = epoch
151 | model.train()
152 | SDR_SUM = np.array([])
153 | SDRi_SUM = np.array([])
154 | total_loss_final = 0
155 |
156 | if config.schedule and scheduler.get_lr()[0]>5e-5:
157 | scheduler.step()
158 | print("Decaying learning rate to %g" % scheduler.get_lr()[0])
159 |
160 | if opt.model == 'gated':
161 | model.current_epoch = epoch
162 |
163 | train_data_gen = prepare_data('once', 'train')
164 | while True:
165 | train_data = next(train_data_gen)
166 | if train_data == False:
167 | print("ss loss (SISNR) in trainset:", total_loss_final)
168 | break
169 |
170 | raw_tgt = train_data['batch_order']
171 | raw_tgt_dir = train_data['direction']
172 |
173 | padded_mixture, mixture_lengths, padded_source = train_data['tas_zip']
174 | padded_mixture = Variable(torch.from_numpy(padded_mixture).float())
175 | mixture_lengths = torch.from_numpy(mixture_lengths)
176 | padded_source = torch.from_numpy(padded_source).float()
177 |
178 |
179 | padded_mixture = padded_mixture.cuda().transpose(0,1)
180 | mixture_lengths = mixture_lengths.cuda()
181 | padded_source = padded_source.cuda()
182 |
183 | # 要保证底下这几个都是longTensor(长整数)
184 | tgt_max_len = config.MAX_MIX + 2 # with bos and eos.
185 | tgt = Variable(torch.from_numpy(np.array([[0] + [dict_spk2idx[spk] for spk in spks] + (tgt_max_len - len(spks) - 1) * [dict_spk2idx['']] for
186 | spks in raw_tgt], dtype=np.int))).transpose(0, 1) # 转换成数字,然后前后加开始和结束符号。
187 | tgt_dir = Variable(torch.from_numpy(np.array([[0] + [dict_dir2idx[int(dire)] for dire in directs] + (tgt_max_len - len(directs) - 1) * [dict_dir2idx['']] for
188 | directs in raw_tgt_dir], dtype=np.int))).transpose(0, 1) # 转换成数字,然后前后加开始和结束符号。
189 | src_len = Variable(torch.LongTensor(config.batch_size).zero_() + mix_speech_len).unsqueeze(0)
190 | tgt_len = Variable( torch.LongTensor([len(one_spk) for one_spk in train_data['multi_spk_fea_list']])).unsqueeze(0)
191 | if use_cuda:
192 | tgt = tgt.cuda()
193 | tgt_dir = tgt_dir.cuda()
194 | src_len = src_len.cuda()
195 | tgt_len = tgt_len.cuda()
196 |
197 | model.zero_grad()
198 |
199 | # aim_list 就是找到有正经说话人的地方的标号
200 | aim_list = (tgt[1:-1].transpose(0, 1).contiguous().view(-1) != dict_spk2idx['']).nonzero().squeeze()
201 | aim_list = aim_list.data.cpu().numpy()
202 |
203 | outputs, outputs_dir, targets, targets_dir, multi_mask = model(padded_mixture, tgt, tgt_dir) # 这里的outputs就是hidden_outputs,还没有进行最后分类的隐层,可以直接用
204 | multi_mask = multi_mask.transpose(0,1)
205 |
206 | if 1 and len(opt.gpus) > 1:
207 | sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir = model.module.compute_loss(outputs, outputs_dir, targets, targets_dir, opt.memory)
208 | else:
209 | sgm_loss, num_total, num_correct, num_total_dir, num_correct_dir = model.compute_loss(outputs, outputs_dir, targets, targets_dir, opt.memory)
210 | print('loss for SGM,this batch:', sgm_loss.item())
211 | print("num_total",num_total)
212 | print("num_correct",num_correct)
213 |
214 | if config.use_tas:
215 | if 1 and len(opt.gpus) > 1:
216 | ss_loss = model.module.separation_tas_loss(multi_mask, padded_source, mixture_lengths)
217 | else:
218 | ss_loss = model.separation_tas_loss(multi_mask, padded_source, mixture_lengths)
219 |
220 | print('loss for SS,this batch:', ss_loss.item())
221 |
222 | loss = 4*sgm_loss + ss_loss
223 |
224 | loss.backward()
225 | total_loss_sgm += sgm_loss.item()
226 | total_loss_ss += ss_loss.item()
227 | total_loss_final += ss_loss.item()
228 |
229 | report_correct += num_correct.item()
230 | report_total += num_total
231 | optim.step()
232 | updates += 1
233 |
234 | if updates % 30 == 0:
235 | logging(
236 | "time: %6.3f, epoch: %3d, updates: %8d, train loss this batch: %6.3f,sgm loss: %6.6f,ss loss: %6.6f,label acc: %6.6f\n"
237 | % (time.time() - start_time, epoch, updates, loss / num_total, total_loss_sgm / 30.0,
238 | total_loss_ss / 30.0, report_correct/report_total))
239 | total_loss_sgm, total_loss_ss = 0, 0
240 |
241 | if total_loss_final < loss_last_epoch:
242 | loss_last_epoch = total_loss_final
243 | save_model(log_path + 'DSNet_{}_{}.pt'.format(epoch,total_loss_final))
244 |
245 | # Convert `idx` to labels. If index `stop` is reached, convert it and return.
246 | def convertToLabels(dict, idx, stop):
247 | labels = []
248 |
249 | for i in idx:
250 | i = int(i)
251 | if i == stop:
252 | break
253 | labels += [dict[i]]
254 |
255 | return labels
256 |
257 |
258 | def save_model(path):
259 | global updates
260 | model_state_dict = model.module.state_dict() if len(opt.gpus) > 1 else model.state_dict()
261 | checkpoints = {
262 | 'model': model_state_dict,
263 | 'config': config,
264 | 'optim': optim,
265 | 'updates': updates}
266 |
267 | torch.save(checkpoints, path)
268 |
269 |
270 | def main():
271 | for i in range(1, config.epoch + 1):
272 | train(i)
273 | for metric in config.metric:
274 | logging("Best %s score: %.2f\n" % (metric, max(scores[metric])))
275 |
276 |
277 | if __name__ == '__main__':
278 | main()
279 |
--------------------------------------------------------------------------------