├── local ├── __init__.py ├── wsj0-train-spkrinfo.txt ├── makelists.py ├── reconstruct_spectrogram.py ├── prepare_spknet_data.py ├── gen_tfreords.py ├── utils.py └── convert_to_records.py ├── model ├── __init__.py ├── spknet.py └── blstm.py ├── .gitignore ├── io_funcs ├── __init__.py ├── wave_io.py ├── tfrecords_io_test.py ├── tfrecords_io.py ├── signal_processing.py └── kaldi_io.py ├── matlab ├── bss_eval_sources.m ├── spk2gender ├── writekaldifeatures.m ├── calc_gender_sdr.m ├── eval_sdr.m ├── extract_czt_fft_feats.m ├── enframe.m ├── create_wav_2speakers.m └── create_wav_3speakers.m ├── path.sh ├── README.md ├── run.sh └── run_lstm.py /local/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /io_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matlab/bss_eval_sources.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/snsun/pit-speech-separation/HEAD/matlab/bss_eval_sources.m -------------------------------------------------------------------------------- /matlab/spk2gender: -------------------------------------------------------------------------------- 1 | 050 0 2 | 051 1 3 | 052 1 4 | 053 0 5 | 22g 1 6 | 22h 1 7 | 420 0 8 | 421 0 9 | 422 1 10 | 423 1 11 | 440 1 12 | 441 0 13 | 442 1 14 | 443 1 15 | 444 0 16 | 445 0 17 | 446 1 18 | 447 1 19 | -------------------------------------------------------------------------------- /path.sh: -------------------------------------------------------------------------------- 1 | export KALDI_ROOT=/home/disk1/wangqing/kaldi/kaldi-trunk/ 2 | [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh 3 | export PATH=$PWD/utils/:$KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin/:$KALDI_ROOT/src/kwsbin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/:$KALDI_ROOT/src/nnet3bin/:$PWD:$PATH 4 | export LC_ALL=C 5 | -------------------------------------------------------------------------------- /matlab/writekaldifeatures.m: -------------------------------------------------------------------------------- 1 | function [fid] = writekaldifeatures(fid, utt_id, data) 2 | 3 | % WRITEKALDIFEATURES Writes a set of features in Kaldi format 4 | % 5 | % writekaldifeatures(features,filename) 6 | % 7 | % Inputs: 8 | % features: set of features in Matlab format (see readkaldifeatures for 9 | % detailed format specification) 10 | % filename: Kaldi feature filename (.ARK extension) 11 | % 12 | % Note: a .SCP file containing the location of the output .ARK file is also 13 | % created 14 | % 15 | % If you use this software in a publication, please cite 16 | % Emmanuel Vincent and Shinji Watanabe, Kaldi to Matlab conversion tools, 17 | % http://kaldi-to-matlab.gforge.inria.fr/, 2014. 18 | % 19 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 20 | % Copyright 2014 Emmanuel Vincent (Inria) and Shinji Watanabe (MERL) 21 | % This software is distributed under the terms of the GNU Public License 22 | % version 3 (http://www.gnu.org/licenses/gpl.txt) 23 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 24 | 25 | 26 | 27 | feature=data; 28 | fprintf(fid,'%s [\n ', utt_id); 29 | nfram=size(feature,1); 30 | for t=1:nfram, 31 | fprintf(fid,' %.7g', feature(t, :)); 32 | fprintf(fid,' \n '); 33 | end 34 | fprintf(fid,' ]\n'); 35 | 36 | return -------------------------------------------------------------------------------- /matlab/calc_gender_sdr.m: -------------------------------------------------------------------------------- 1 | model_name='ZoomFFT_BLSTM_3_496_10_26_def.mat' 2 | load(['sdr_' model_name], 'SDR', 'SAR', 'SIR', 'lists'); 3 | fprintf('The mean SDR is %f', mean(mean(SDR))) 4 | 5 | % Calculte different gender case 6 | [spk, gender] = textread('spk2gender', '%s%d'); 7 | cmm = 1; 8 | cmf = 1; 9 | cff = 1; 10 | for i = 1:size(SDR, 1) 11 | mix_name = lists(i+2).name; 12 | spk1 = mix_name(1:3); 13 | tmp = regexp(mix_name, '_'); 14 | spk2 = mix_name(tmp(2)+1:tmp(2)+3); 15 | for j = 1:length(spk) 16 | if spk1 == spk{j} 17 | break 18 | end 19 | end 20 | for k = 1:length(spk) 21 | if spk2 == spk{k} 22 | break 23 | end 24 | end 25 | 26 | if gender(k) == 0 & gender(j) == 0 27 | SDR_FF(cff,:) = SDR(i, :); 28 | lists_FF{cff} = lists(i).name; 29 | cff = cff +1; 30 | 31 | elseif gender(k) == 1 & gender(j) == 1 32 | SDR_MM(cmm,: )= SDR(i, :); 33 | lists_MM{cmm} = lists(i).name; 34 | cmm = cmm + 1; 35 | else 36 | SDR_MF(cmf, :) = SDR(i, :); 37 | lists_MF{cmf} = lists(i).name; 38 | cmf = cmf + 1; 39 | end 40 | end 41 | fprintf('The mean SDR for Male & Female is : %f', mean(mean(SDR_MF))); 42 | fprintf('The mean SDR for Female & Female is : %f', mean(mean(SDR_FF))); 43 | fprintf('The mean SDR for Male & Male is : %f', mean(mean(SDR_MM))); 44 | 45 | -------------------------------------------------------------------------------- /local/wsj0-train-spkrinfo.txt: -------------------------------------------------------------------------------- 1 | 001 M 2 | 002 F 3 | 00a F 4 | 00b M 5 | 00c M 6 | 00d M 7 | 00f F 8 | 010 M 9 | 011 F 10 | 012 M 11 | 013 M 12 | 014 F 13 | 015 M 14 | 016 F 15 | 017 F 16 | 018 F 17 | 019 F 18 | 01l M 19 | 01a F 20 | 01b F 21 | 01c F 22 | 01d F 23 | 01e M 24 | 01f F 25 | 01g M 26 | 01h F 27 | 01i M 28 | 01j F 29 | 01k F 30 | 01m F 31 | 01n F 32 | 01o F 33 | 01p F 34 | 01q F 35 | 01r M 36 | 01s M 37 | 01t M 38 | 01u F 39 | 01v F 40 | 01w M 41 | 01x F 42 | 01y M 43 | 01z M 44 | 020 M 45 | 021 M 46 | 022 F 47 | 023 F 48 | 024 M 49 | 025 M 50 | 026 M 51 | 027 F 52 | 028 F 53 | 029 M 54 | 02a F 55 | 02b M 56 | 02c F 57 | 02d F 58 | 02e F 59 | 02f F 60 | 050 F 61 | 051 M 62 | 052 M 63 | 053 F 64 | 200 M 65 | 201 M 66 | 202 F 67 | 203 F 68 | 204 F 69 | 205 F 70 | 206 F 71 | 207 M 72 | 208 M 73 | 209 F 74 | 20a F 75 | 20b F 76 | 20c M 77 | 20d F 78 | 20e F 79 | 20f M 80 | 20g M 81 | 20h F 82 | 20i M 83 | 20j M 84 | 20k M 85 | 20l M 86 | 20m M 87 | 20n M 88 | 20o M 89 | 20p F 90 | 20q M 91 | 20r M 92 | 20s M 93 | 20t F 94 | 20u M 95 | 20v M 96 | 22g M 97 | 22h M 98 | 400 M 99 | 401 F 100 | 403 M 101 | 404 F 102 | 405 M 103 | 406 M 104 | 407 F 105 | 408 M 106 | 409 F 107 | 40a M 108 | 40b M 109 | 40c M 110 | 40d F 111 | 40e F 112 | 40f M 113 | 40g F 114 | 40h F 115 | 40i M 116 | 40j M 117 | 40k M 118 | 40l F 119 | 40m F 120 | 40n M 121 | 40o F 122 | 40p F 123 | 420 F 124 | 421 F 125 | 422 M 126 | 423 M 127 | 430 F 128 | 431 M 129 | 432 F 130 | 050 F 131 | 051 M 132 | 052 M 133 | 053 F 134 | 22g M 135 | 22h M 136 | 423 M 137 | 440 M 138 | 441 F 139 | 442 F 140 | 443 M 141 | 444 F 142 | 445 F 143 | 446 M 144 | 447 M 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This project is used for PIT training of two speakers. 2 | 3 | We use Tensorflow(1.0) LSTM(BLSTM) to do PIT. 4 | 5 | Reference: 6 | 7 | Kolbæk, M., Yu, D., Tan, Z.-H., & Jensen, J. (2017). Multi-talker Speech Separation and Tracing with Permutation Invariant Training of Deep Recurrent Neural Networks, 1–10. Retrieved from http://arxiv.org/abs/1703.06284 8 | 9 | # How to prepare data 10 | ## Generate mixed speech and coresponding targets speech file. 11 | 12 | If you have WSJ0 data, you can use this code http://www.merl.com/demos/deep-clustering/create-speaker-mixtures.zip to create the mixed speech. 13 | 14 | Or you can also use you own data. 15 | 16 | ## Extract FFT spectrum feats for every utterance. 17 | 18 | For every utterance, you need to extract the mixed speech, speak1 and speaker2 feature matrix and use the function in 'io_funcs/tfrecords_io.py' make_sequence_example_two_labels(inputs,inputs_cmvn, labels1, labels2) to generate tensorflow examples. 19 | 20 | inputs: the mixed speech feats matrix with shape (num_frames, dim) 21 | 22 | inputs_cmvn: the mixed speech feats matrix after mean and variance normalization. I don't think this is necessary. You can 23 | use the same data with inputs. 24 | 25 | labels, labels2: spker1 and spker2's feats as targets. 26 | 27 |     28 | ## Generate tfrecord files list for training, cv and test sets. 29 | 30 | make a dir, named lists. Put all the training tfrecord files' path to 'lists/tr.lst' and the same for the 'lists/cv.lst', 'lists/tt.lst' 31 | 32 | ## Run run.sh 33 | 34 | Once you prapared all data list files for tr, cv and tt (test), you can run 'run.sh' from the step3--train RNN. Make sure you give the right list dir. 35 | 36 | 37 | -------------------------------------------------------------------------------- /local/makelists.py: -------------------------------------------------------------------------------- 1 | # This python script is used to generate a file list for {tr, tt, cv}. 2 | # In order to transform the Kaldi feats to tfrecords, we need a list to 3 | # specify the input and target kaldi file.Every list has the following form: 4 | # utt_id inputs_ark target1_ark target2_ark 5 | # Usage: 6 | # python makelists.py feats_dir mode list_dir 7 | # feats_dir: kaldi feats.ark, scp dir 8 | # mode: tr, cv, or tt 9 | # list_dir: where to store the generated list 10 | 11 | import sys 12 | import os 13 | 14 | usage = ''' 15 | Usage: 16 | python makelists.py feats_dir mode list_dir 17 | feats_dir: kaldi feats.ark, scp dir 18 | mode: tr, cv, or tt 19 | list_dir: where to store the generated list 20 | ''' 21 | if len(sys.argv) is not 4: 22 | print usage; 23 | exit(); 24 | 25 | feats_dir = sys.argv[1]; 26 | mode = sys.argv[2] 27 | list_dir = sys.argv[3] 28 | if not os.path.exists(list_dir): 29 | os.makedirs(list_dir) 30 | 31 | inputscp = feats_dir + '/' + mode + '_inputs/feats.scp' 32 | outputscp = feats_dir +'/' + mode + '_labels/feats.scp' 33 | lst=list_dir + '/' + mode + '_feats_mapping.lst' 34 | fid1 = open(inputscp, 'r') 35 | lines1 = fid1.readlines() 36 | fid2 = open(outputscp, 'r') 37 | lines2 = fid2.readlines() 38 | 39 | fid1.close() 40 | fid2.close() 41 | 42 | fid3 = open(lst, 'w') 43 | 44 | dict1 = {} 45 | dict2 = {} 46 | for line in lines2: 47 | l = line.rstrip('\n') 48 | strs = l.split(' ') 49 | dict1[strs[0]] = strs[1] 50 | 51 | for line in lines1: 52 | line = line.rstrip('\n') 53 | strs = line.split(' ') 54 | utt = strs[0] 55 | cont = strs[1] 56 | (names, ext) = os.path.splitext(utt) 57 | 58 | name1 = names +'_1.wav' 59 | name2 = names + '_2.wav' 60 | fid3.write(utt + ' ' + cont +' ' + dict1[name1] + ' ' + dict1[name2] + '\n') 61 | 62 | fid3.close() 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /io_funcs/wave_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #by wujian@2017.4.15 3 | 4 | import wave 5 | import numpy as np 6 | 7 | 8 | class WaveWrapper(object): 9 | 10 | def __init__(self, path, time_wnd = 25, time_off = 10): 11 | wave_src = wave.open(path, "rb") 12 | para_src = wave_src.getparams() 13 | self.rate = int(para_src[2]) 14 | self.cur_size = 0 15 | self.tot_size = int(para_src[3]) 16 | # default 400 160 17 | self.wnd_size = int(self.rate * 0.001 * time_wnd) 18 | self.wnd_rate = int(self.rate * 0.001 * time_off) 19 | self.ham = np.hamming(self.wnd_size+1) 20 | self.ham = np.sqrt(self.ham[0:self.wnd_size]) 21 | self.ham = self.ham / np.sqrt(np.sum(np.square(self.ham[range(0,self.wnd_size, self.wnd_rate)]))) 22 | self.data = np.fromstring(wave_src.readframes(wave_src.getnframes()), dtype=np.int16) 23 | self.upper_bound = np.max(np.abs(self.data)) 24 | 25 | def get_frames_num(self): 26 | return int((self.tot_size - self.wnd_size) / self.wnd_rate + 1) 27 | 28 | def get_wnd_size(self): 29 | return self.wnd_size 30 | 31 | def get_wnd_rate(self): 32 | return self.wnd_rate 33 | 34 | def get_sample_rate(self): 35 | return self.rate 36 | 37 | def get_upper_bound(self): 38 | return self.upper_bound 39 | 40 | def next_frame_phase(self,fft_len=512, pre_em=True): 41 | while self.cur_size + self.wnd_size <= self.tot_size: 42 | value = np.zeros(fft_len) 43 | value[: self.wnd_size] = np.array(self.data[self.cur_size: \ 44 | self.cur_size + self.wnd_size], dtype=np.float) 45 | if pre_em: 46 | value -= np.sum(value) / self.wnd_size 47 | value[1: ] -= value[: -1] * 0.97 48 | value[0] -= 0.97 * value[0] 49 | value[: self.wnd_size] *= self.ham 50 | angle = np.angle(np.fft.rfft(value)) 51 | yield np.cos(angle) + np.sin(angle) * 1.0j 52 | self.cur_size += self.wnd_rate 53 | -------------------------------------------------------------------------------- /local/reconstruct_spectrogram.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #by wujian@2017.4.15 3 | 4 | """transform spectrogram to waveform""" 5 | 6 | import sys 7 | import wave 8 | import numpy as np 9 | sys.path.append('./') 10 | from io_funcs.kaldi_io import ArkReader 11 | from io_funcs import wave_io 12 | 13 | 14 | if len(sys.argv) != 4: 15 | print "format error: %s [scp] [origin-wave] [reconst-wave]" % sys.argv[0] 16 | sys.exit(1) 17 | wnd_len = 32; #ms 18 | wnd_shift = 16; #ms 19 | fft_len = 256; # for 8k sample rate 20 | pre_em = False 21 | WAVE_WARPPER = wave_io.WaveWrapper(sys.argv[2],time_wnd=wnd_len, time_off=wnd_shift) 22 | WAVE_RECONST = wave.open(sys.argv[3], "wb") 23 | 24 | WND_SIZE = WAVE_WARPPER.get_wnd_size() 25 | WND_RATE = WAVE_WARPPER.get_wnd_rate() 26 | 27 | REAL_IFFT = np.fft.irfft 28 | 29 | HAM_WND = np.hamming(WND_SIZE+1) #simulate the matlab hamming(N, 'periodic') 30 | HAM_WND = np.sqrt(HAM_WND[0:-1]); 31 | stride = range(0,WND_SIZE,WND_RATE) 32 | HAM_WND = HAM_WND/np.sqrt(np.sum(HAM_WND[stride]*HAM_WND[stride])) #nomilize the window 33 | ark_name = sys.argv[1] 34 | kaldi_writer = ArkReader(ark_name) 35 | looped = False 36 | 37 | _, SPECT_ENHANCE, looped = kaldi_writer.read_next_utt() 38 | SPECT_ROWS, SPECT_COLS = SPECT_ENHANCE.shape 39 | assert WAVE_WARPPER.get_frames_num() == SPECT_ROWS 40 | INDEX = 0 41 | SPECT = np.zeros(SPECT_COLS) 42 | RECONST_POOL = np.zeros((SPECT_ROWS - 1) * WND_RATE + WND_SIZE) 43 | for phase in WAVE_WARPPER.next_frame_phase(fft_len=fft_len,pre_em=pre_em): 44 | # exclude energy 45 | #SPECT[1: ] = np.sqrt(np.exp(SPECT_ENHANCE[INDEX][1: ])) 46 | SPECT= SPECT_ENHANCE[INDEX] 47 | RECONST_POOL[INDEX * WND_RATE: INDEX * WND_RATE + WND_SIZE] += \ 48 | REAL_IFFT(SPECT * phase)[: WND_SIZE] * HAM_WND 49 | INDEX += 1 50 | # remove pre-emphasis 51 | if pre_em: 52 | for x in range(1, RECONST_POOL.size): 53 | RECONST_POOL[x] += 0.97 * RECONST_POOL[x - 1] 54 | RECONST_POOL = RECONST_POOL / np.max(np.abs(RECONST_POOL)) * WAVE_WARPPER.get_upper_bound() 55 | 56 | WAVE_RECONST.setnchannels(1) 57 | WAVE_RECONST.setnframes(RECONST_POOL.size) 58 | WAVE_RECONST.setsampwidth(2) 59 | WAVE_RECONST.setframerate(WAVE_WARPPER.get_sample_rate()) 60 | WAVE_RECONST.writeframes(np.array(RECONST_POOL, dtype=np.int16).tostring()) 61 | WAVE_RECONST.close() 62 | -------------------------------------------------------------------------------- /matlab/eval_sdr.m: -------------------------------------------------------------------------------- 1 | mixed_wav_dir = '/home/disk2/snsun/workspace/separation//data/wav/wav8k/min/tt/mix/'; 2 | spk1_dir = '/home/disk2/snsun/workspace/separation/data/wav/wav8k/min/tt/s1/'; 3 | spk2_dir = '/home/disk2/snsun/workspace/separation/data/wav/wav8k/min/tt//s2/'; 4 | model_name='StandPsmPIT_BLSTM_3_400_def'; 5 | rec_wav_dir = ['../data/separated/' model_name '/']; 6 | lists = dir(spk2_dir); 7 | len = length(lists) - 2; 8 | SDR = zeros(len, 2); 9 | SIR = SDR; 10 | SAR = SDR; 11 | SDR_Mix = SDR; 12 | SIR_Mix = SDR; 13 | SAR_Mix = SDR; 14 | for i = 3:len+2 15 | name = lists(i).name; 16 | part_name = name(1:end-4); 17 | rec_wav1 = audioread([rec_wav_dir part_name '_1.wav']); 18 | rec_wav2 = audioread([rec_wav_dir part_name '_2.wav']); 19 | rec_wav = [rec_wav1,rec_wav2]; 20 | 21 | ori_wav1 = audioread([spk1_dir part_name '.wav']); 22 | ori_wav2 = audioread([spk2_dir part_name '.wav']); 23 | ori_wav = [ori_wav1, ori_wav2]; 24 | 25 | mix_wav1 = audioread([mixed_wav_dir part_name '.wav']); 26 | mix_wav = [mix_wav1, mix_wav1]; 27 | 28 | min_len = min(size(ori_wav, 1), size(rec_wav, 1)); 29 | rec_wav = rec_wav(1:min_len, :); 30 | ori_wav = ori_wav(1:min_len, :); 31 | mix_wav = mix_wav(1:min_len, :); 32 | [SDR(i-2, :),SIR(i-2, :),SAR(i-2, :),perm]=bss_eval_sources(rec_wav',ori_wav'); 33 | if mod(i, 200) == 0 34 | i 35 | end 36 | end 37 | fprintf('The mean SDR is %f', mean(mean(SDR))) 38 | save(['sdr_' model_name], 'SDR', 'SAR', 'SIR', 'lists'); 39 | 40 | % Calculte different gender case 41 | [spk, gender] = textread('spk2gender', '%s%d'); 42 | cmm = 1; 43 | cmf = 1; 44 | cff = 1; 45 | for i = 1:size(SDR, 1) 46 | mix_name = lists(i+2).name; 47 | spk1 = mix_name(1:3); 48 | tmp = regexp(mix_name, '_'); 49 | spk2 = mix_name(tmp(2)+1:tmp(2)+3); 50 | for j = 1:length(spk) 51 | if spk1 == spk{j} 52 | break 53 | end 54 | end 55 | for k = 1:length(spk) 56 | if spk2 == spk{k} 57 | break 58 | end 59 | end 60 | 61 | if gender(k) == 0 & gender(j) == 0 62 | SDR_FF(cff,:) = SDR(i, :); 63 | lists_FF{cff} = lists(i).name; 64 | cff = cff +1; 65 | 66 | elseif gender(k) == 1 & gender(j) == 1 67 | SDR_MM(cmm,: )= SDR(i, :); 68 | lists_MM{cmm} = lists(i).name; 69 | cmm = cmm + 1; 70 | else 71 | SDR_MF(cmf, :) = SDR(i, :); 72 | lists_MF{cmf} = lists(i).name; 73 | cmf = cmf + 1; 74 | end 75 | end 76 | fprintf('The mean SDR for Male & Female is : %f', mean(mean(SDR_MF))); 77 | fprintf('The mean SDR for Female & Female is : %f', mean(mean(SDR_FF))); 78 | fprintf('The mean SDR for Male & Male is : %f', mean(mean(SDR_MM))); 79 | 80 | -------------------------------------------------------------------------------- /matlab/extract_czt_fft_feats.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %% This script is used to extract CZT+FFT features. 3 | %% 1. We use the 128-point CZT to 50-1000Hz as additional features 4 | %% to improve the frequncy resolution. 5 | %% 2. Our final feature is 128 + 129 = 257 dim; 6 | %% 3. Note: I comment the part which is used to extract FFT features 7 | %% for target because we have had the feature. If you don't have 8 | %% targets features, please uncomment this part. 9 | 10 | mode1 = {'tr'}; % in order to run parallelly, extract 'tr', 'cv' and 'tt' separately 11 | mode_len = length(mode1); 12 | data_dir = '/home/disk2/snsun/workspace/separation/data/wav/wav8k/min/'; %CHANGE THE DIR TO YOUR DATA 13 | feats_dir = '../data/feats/50_1000_128_zoomfft/feats_8k_czt_psm2/'; %CHANGE THE DIR TO WHERE YOU WANT TO STORE THE FEATURES 14 | for idx=1:mode_len 15 | mode = mode1{idx}; 16 | input_dir = [data_dir mode '/']; 17 | mix_dir = [input_dir 'mix/']; 18 | s1_dir = [input_dir 's1/']; 19 | s2_dir = [input_dir 's2/']; 20 | 21 | output_feats_dir = [feats_dir mode '_inputs/']; 22 | output_labels_dir = [feats_dir mode '_labels/']; 23 | mkdir(output_feats_dir); 24 | mkdir(output_labels_dir); 25 | 26 | 27 | fid_feats = fopen([output_feats_dir 'feats.txt'], 'w'); 28 | fid_labels = fopen([output_labels_dir 'feats.txt'], 'w'); 29 | 30 | % FFT and CZT configuration 31 | fs = 8000; 32 | fft_len = 256; 33 | dim = 129; 34 | frame_len = 256; 35 | frame_shift = 128; 36 | 37 | f1 = 50; %in Hz, CZT start freq 38 | f2 = 1000; %in Hz, CZT end freq 39 | M = 128; % CZT poits 40 | w=exp(-j*2*pi*(f2-f1)/(M*fs));% for CZT 41 | a=exp(j*2*pi*f1/fs);% for CZT 42 | 43 | Win=sqrt(hamming(fft_len,'periodic')); 44 | Win=Win/sqrt(sum(Win(1:frame_shift:fft_len).^2)); 45 | lists = dir(mix_dir); 46 | for i = 3:length(lists) 47 | utt_id = lists(i).name; 48 | filename = [mix_dir utt_id]; 49 | wav = audioread(filename); 50 | frames = enframe(wav, Win, frame_shift); 51 | Xn = fft(frames, fft_len, 2); 52 | Y = abs(Xn(:, 1:dim)); 53 | 54 | %CZT 55 | 56 | Y_c = abs(czt(frames', M, w, a)); 57 | Y_c = Y_c'; 58 | feats = [Y_c, Y]; 59 | 60 | writekaldifeatures(fid_feats, utt_id, feats); 61 | filename1 = [s1_dir utt_id]; 62 | wav1 = audioread(filename1); 63 | frames = enframe(wav1, Win, frame_shift); 64 | X = fft(frames , fft_len, 2); 65 | Y = abs(X(:, 1:dim)); 66 | theta = angle(X(:, 1:dim)./Xn(:, 1:dim)); 67 | Y = Y.*cos(theta); 68 | writekaldifeatures(fid_labels, [utt_id(1:end-4) '_1.wav'], Y); 69 | filename2 = [s2_dir utt_id]; 70 | wav1 = audioread(filename2); 71 | frames = enframe(wav1, Win, frame_shift); 72 | X = fft(frames , fft_len, 2); 73 | Y = abs(X(:, 1:dim)); 74 | theta = angle(X(:, 1:dim)./Xn(:, 1:dim)); 75 | Y = Y.*cos(theta); 76 | writekaldifeatures(fid_labels, [utt_id(1:end-4) '_2.wav'], Y); 77 | if mod(i, 100) == 0 78 | i 79 | end 80 | end 81 | fclose(fid_feats); 82 | fclose(fid_labels); 83 | 84 | end 85 | -------------------------------------------------------------------------------- /io_funcs/tfrecords_io_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Ke Wang Xiaomi 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import argparse 11 | import os.path 12 | import sys 13 | import time 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from tfrecords_io import get_padded_batch 19 | 20 | tf.logging.set_verbosity(tf.logging.INFO) 21 | 22 | 23 | class TfrecordsIoTest(tf.test.TestCase): 24 | 25 | def testReadTfrecords(self): 26 | tfrecords_lst="../list/train_8k.lst" 27 | with tf.Graph().as_default(): 28 | mixed,inputs, labels1,labels2, lengths = get_padded_batch( 29 | tfrecords_lst, FLAGS.batch_size, FLAGS.input_dim, 30 | FLAGS.output_dim, num_enqueuing_threads=FLAGS.num_threads, 31 | num_epochs=FLAGS.num_epochs) 32 | 33 | init = tf.group(tf.global_variables_initializer(), 34 | tf.local_variables_initializer()) 35 | 36 | sess = tf.Session() 37 | 38 | sess.run(init) 39 | 40 | coord = tf.train.Coordinator() 41 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 42 | 43 | try: 44 | time_start = time.time() 45 | while not coord.should_stop(): 46 | # Print an overview fairly often. 47 | tr_inputs, tr_labels, tr_lengths = sess.run([ 48 | inputs, labels1, lengths]) 49 | tf.logging.info('inputs shape : '+ str(tr_inputs.shape)) 50 | tf.logging.info('labels shape : ' + str(tr_labels.shape)) 51 | tf.logging.info('actual lengths : ' + str(tr_lengths)) 52 | except tf.errors.OutOfRangeError: 53 | tf.logging.info('Done training -- epoch limit reached') 54 | finally: 55 | # When done, ask the threads to stop. 56 | coord.request_stop() 57 | 58 | # Wait for threads to finish. 59 | coord.join(threads) 60 | sess.close() 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | '--batch_size', 67 | type=int, 68 | default=1, 69 | help='Mini-batch size.' 70 | ) 71 | parser.add_argument( 72 | '--input_dim', 73 | type=int, 74 | default=145, 75 | help='The dimension of inputs.' 76 | ) 77 | parser.add_argument( 78 | '--output_dim', 79 | type=int, 80 | default=51, 81 | help='The dimension of outputs.' 82 | ) 83 | parser.add_argument( 84 | '--num_threads', 85 | type=int, 86 | default=8, 87 | help='The num of threads to read tfrecords files.' 88 | ) 89 | parser.add_argument( 90 | '--num_epochs', 91 | type=int, 92 | default=1, 93 | help='The num of epochs to read tfrecords files.' 94 | ) 95 | parser.add_argument( 96 | '--data_dir', 97 | type=str, 98 | default='data/tfrecords/', 99 | help='Directory of train, val and test data.' 100 | ) 101 | parser.add_argument( 102 | '--config_dir', 103 | type=str, 104 | default='list/', 105 | help='Directory to load train, val and test lists.' 106 | ) 107 | FLAGS, unparsed = parser.parse_known_args() 108 | tf.test.main() 109 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Author: Sining Sun (Northwestern Polytechnical University, China) 3 | # This recipe is used to do NN-PIT (LSTM, DNN or BLSTM) 4 | 5 | 6 | 7 | 8 | step=0 9 | 10 | lists_dir=./lists/ #lists_dir is used to store some necessary files lists 11 | mkdir -p $lists_dir 12 | num_threads=12 13 | 14 | tfrecords_dir=data/tfrecords/ 15 | gpu_id='0' 16 | TF_CPP_MIN_LOG_LEVEL=1 17 | rnn_num_layers=3 18 | tr_batch_size=32 19 | 20 | tt_batch_size=1 21 | input_size=129 22 | output_size=129 23 | 24 | rnn_size=496 25 | keep_prob=0.8 26 | learning_rate=0.0005 27 | halving_factor=0.7 28 | decode=0 29 | model_type=BLSTM 30 | 31 | prefix=StandPsmPIT 32 | assignment=def 33 | name=${prefix}_${model_type}_${rnn_num_layers}_${rnn_size}_ReLU 34 | save_dir=exp/$name/ 35 | data_dir=data/separated/${name}_${assignment}/ 36 | resume_training=false 37 | 38 | # note: we want to use gender information, but we didn't use in this version. 39 | # but when we prepared our data, we stored the gender information (maybe useful in the future). 40 | # wsj-train-spkrinfo.txt: https://catalog.ldc.upenn.edu/docs/LDC93S6A/wsj0-train-spkrinfo.txt 41 | 42 | # tfrecords are stored in data/tfrecords/{tr, cv, tt}_psm/ 43 | 44 | if [ $step -le 0 ]; then 45 | for x in tr cv tt; do 46 | python -u local/gen_tfreords.py --gender_list local/wsj0-train-spkrinfo.txt data/wav/wav8k/min/$x/ lists/${x}_wav.lst data/tfrecords/${x}_psm/ & 47 | 48 | done 49 | wait 50 | fi 51 | ##################################################################################################### 52 | # NOTE for STEP 1: ### 53 | # 1. Make sure that you configure the RNN/data_dir/model_dir/ all rights ### 54 | ##################################################################################################### 55 | 56 | if [ $step -le 1 ]; then 57 | 58 | echo "Start Traing RNN(LSTM or BLSTM) model." 59 | decode=0 60 | batch_size=25 61 | # Here, we made tfrecords list file for tr, cv and tt data. 62 | # Make sure you have generated tfrecords files in $tfrecords_dir/{tr, cv, tt}_psm/ 63 | # The list files name must be tr_tf.lst, cv_tf.lst and tt_tf.lst. We fixed them in run_lstm.py 64 | for x in tr tt cv; do 65 | find $tfrecords_dir/${x}_psm/ -iname "*.tfrecords" > $lists_dir/${x}_tf.lst 66 | done 67 | 68 | tr_cmd="python -u run_lstm.py \ 69 | --lists_dir=$lists_dir --rnn_num_layers=$rnn_num_layers --batch_size=$batch_size --rnn_size=$rnn_size \ 70 | --decode=$decode --learning_rate=$learning_rate --save_dir=$save_dir --data_dir=$data_dir --keep_prob=$keep_prob \ 71 | --input_size=$input_size --output_size=$output_size --assign=$assignment --resume_training=$resume_training \ 72 | --model_type=$model_type --halving_factor=$halving_factor " 73 | 74 | echo $tr_cmd 75 | CUDA_VISIBLE_DEVICES=$gpu_id TF_CPP_MIN_LOG_LEVEL=$TF_CPP_MIN_LOG_LEVEL $tr_cmd 76 | fi 77 | ##################################################################################################### 78 | # NOTE for STEP 2: ### 79 | # 1. Make sure that you configure the RNN/data_dir/model_dir/ all rights ### 80 | ##################################################################################################### 81 | 82 | if [ $step -le 2 ]; then 83 | 84 | echo "Start Decoding." 85 | decode=1 86 | batch_size=30 87 | tr_cmd="python -u run_lstm.py --lists_dir=$lists_dir --rnn_num_layers=$rnn_num_layers --batch_size=$batch_size --rnn_size=$rnn_size \ 88 | --decode=$decode --learning_rate=$learning_rate --save_dir=$save_dir --data_dir=$data_dir --keep_prob=$keep_prob \ 89 | --input_size=$input_size --output_size=$output_size --assign=$assignment --resume_training=$resume_training \ 90 | --model_type=$model_type --czt_dim=128" 91 | 92 | echo $tr_cmd 93 | CUDA_VISIBLE_DEVICES=$gpu_id TF_CPP_MIN_LOG_LEVEL=$TF_CPP_MIN_LOG_LEVEL $tr_cmd 94 | fi 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /local/prepare_spknet_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Sining Sun 5 | 6 | """Converts data to TFRecords file format with Example protos.""" 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import os 13 | import struct 14 | import sys 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | sys.path.append('./') 20 | from io_funcs.tfrecords_io import get_padded_batch_v2 21 | 22 | tf.logging.set_verbosity(tf.logging.INFO) 23 | 24 | 25 | def make_sequence_example(inputs,labels): 26 | """Returns a SequenceExample for the given inputs and labels(optional). 27 | """ 28 | input_features = [ 29 | tf.train.Feature(float_list=tf.train.FloatList(value=input_)) 30 | for input_ in inputs] 31 | if labels is not None : 32 | label_features = [ 33 | tf.train.Feature(float_list=tf.train.FloatList(value=label)) 34 | for label in labels] 35 | 36 | feature_list = { 37 | 'inputs': tf.train.FeatureList(feature=input_features), 38 | 'labels': tf.train.FeatureList(feature=label_features) 39 | } 40 | else: 41 | feature_list = { 42 | 'inputs': tf.train.FeatureList(feature=input_features) 43 | } 44 | feature_lists = tf.train.FeatureLists(feature_list=feature_list) 45 | return tf.train.SequenceExample(feature_lists=feature_lists) 46 | 47 | 48 | def main(_): 49 | fid = open(FLAGS.input_list, 'r') 50 | lines = fid.readlines() 51 | fid.close() 52 | 53 | fid = open(FLAGS.spk_list, 'r') 54 | spkers = fid.readlines() 55 | fid.close() 56 | num_spkers = len(spkers) 57 | spker_dict={} 58 | i=0 59 | for spker in spkers: 60 | spker_dict[spker.strip('\n')] = i 61 | i = i + 1 62 | 63 | file_list = [line.strip('\n') for line in lines] 64 | _, _, label1, label2, length = get_padded_batch_v2( 65 | file_list, 1, 257, 129, 1, 1, False) 66 | sess = tf.Session() 67 | sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()]) 68 | coord = tf.train.Coordinator() 69 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 70 | for filename in file_list: 71 | if coord.should_stop(): 72 | break 73 | tmpname = filename.split('/')[-1] 74 | file_name = os.path.splitext(tmpname)[0] 75 | spker1 = file_name.split('_')[0][0:3] 76 | spker2 = file_name.split('_')[2][0:3] 77 | target = np.zeros([1, num_spkers]) 78 | target[0, spker_dict[spker1]] = 1 79 | target[0, spker_dict[spker2]] = 1 80 | feats_spk1, feats_spk2 = sess.run([label1, label2]) 81 | feats1 = np.concatenate((feats_spk1[0, :, :], feats_spk2[0, :, :]), axis=1) 82 | feats2 = np.concatenate((feats_spk2[0, :, :], feats_spk1[0, :, :]), axis=1) 83 | name1 = FLAGS.output_dir + '/'+file_name+'_1.tfrecords' 84 | name2 = FLAGS.output_dir + '/'+file_name+'_2.tfrecords' 85 | with tf.python_io.TFRecordWriter(name1) as writer: 86 | ex = make_sequence_example(feats1, target) 87 | writer.write(ex.SerializeToString()) 88 | with tf.python_io.TFRecordWriter(name2) as writer: 89 | ex = make_sequence_example(feats2, target) 90 | writer.write(ex.SerializeToString()) 91 | 92 | 93 | if __name__ == '__main__': 94 | parser= argparse.ArgumentParser() 95 | parser.add_argument( 96 | '--input_list', 97 | type=str, 98 | default='', 99 | help='The original tfrecords file list' 100 | ) 101 | parser.add_argument( 102 | '--output_dir', 103 | type=str, 104 | default='', 105 | help='The output data dir' 106 | ) 107 | parser.add_argument( 108 | '--spk_list', 109 | type=str, 110 | default='', 111 | help='The spker id list' 112 | ) 113 | 114 | FLAGS, unparsed = parser.parse_known_args() 115 | 116 | if not os.path.exists(FLAGS.output_dir): 117 | os.makedirs(FLAGS.output_dir) 118 | 119 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 120 | 121 | 122 | -------------------------------------------------------------------------------- /local/gen_tfreords.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os,sys 3 | sys.path.append('.') 4 | 5 | import multiprocessing 6 | from io_funcs.signal_processing import audiowrite, stft, audioread 7 | from local.utils import mkdir_p 8 | import tensorflow as tf 9 | import numpy as np 10 | parser = argparse.ArgumentParser(description='Generate TFRecords files') 11 | parser.add_argument('wavdir', 12 | help='The parent dir of mix/s1/s2') 13 | parser.add_argument('namelist', 14 | help='wav files list, per wav name per line') 15 | 16 | parser.add_argument('tfdir', 17 | help='TFRecords files dir') 18 | parser.add_argument('--gender_list','-g', default='', type=str, 19 | help='The speekers gender list') 20 | 21 | """ 22 | This file is used to generate tfrecords for gender-sensitive PIT 23 | speech seperation. Every tfrecords file contains: 24 | inputs: [mix_speech_abs, max_speech_phase], shape:T*(fft_len*2) 25 | labels: [spker1_speech_abs, apker2_speech_abs], shape:T*(fft_len*2) 26 | gender: [spker1_gender, spker2_gender], shape:1*2 27 | """ 28 | 29 | args = parser.parse_args() 30 | 31 | wavdir = args.wavdir 32 | tfdir = args.tfdir 33 | namelist = args.namelist 34 | mkdir_p(tfdir) 35 | if args.gender_list is not '': 36 | apply_gender_info=True; 37 | gender_dict = {} 38 | fid = open(args.gender_list, 'r') 39 | lines = fid.readlines() 40 | fid.close() 41 | for line in lines: 42 | spk = line.strip('\n').split(' ')[0] 43 | gender = line.strip('\n').split(' ')[1] 44 | if gender.lower() == 'm': 45 | gender_dict[spk] = 1; 46 | else: 47 | gender_dict[spk] = 0 48 | 49 | def make_sequence_example(inputs, labels, genders): 50 | """Returns a SequenceExample for the given inputs, labels and genders 51 | Args: 52 | inputs: A list of input vectors. Each input vector is a list of floats. 53 | labels: A list of label vectors. Each label vector is a list of floats. 54 | genders: A 1*2 vector [0, 1], [0,1], [1,1], [0, 0] 55 | Returns: 56 | A tf.train.SequenceExample containing inputs and labels(optional). 57 | """ 58 | input_features = [ 59 | tf.train.Feature(float_list=tf.train.FloatList(value=input_)) 60 | for input_ in inputs] 61 | label_features = [ 62 | tf.train.Feature(float_list=tf.train.FloatList(value=label)) 63 | for label in labels] 64 | gender_features = [ tf.train.Feature(float_list=tf.train.FloatList(value=genders)) ] 65 | feature_list = { 66 | 'inputs': tf.train.FeatureList(feature=input_features), 67 | 'labels': tf.train.FeatureList(feature=label_features), 68 | 'genders': tf.train.FeatureList(feature=gender_features) 69 | } 70 | feature_lists = tf.train.FeatureLists(feature_list=feature_list) 71 | return tf.train.SequenceExample(feature_lists=feature_lists) 72 | 73 | 74 | 75 | def gen_feats(wav_name): 76 | mix_wav_name = wavdir + '/mix/'+ wav_name 77 | s1_wav_name = wavdir + '/s1/' + wav_name 78 | s2_wav_name = wavdir + '/s2/' + wav_name 79 | 80 | mix_wav = audioread(mix_wav_name, offset=0.0, duration=None, sample_rate=8000) 81 | s1_wav = audioread(s1_wav_name, offset=0.0, duration=None, sample_rate=8000) 82 | s2_wav = audioread(s2_wav_name, offset=0.0, duration=None, sample_rate=8000) 83 | 84 | mix_stft = stft(mix_wav, time_dim=0, size=256, shift=128) 85 | s1_stft = stft(s1_wav, time_dim=0, size=256, shift=128) 86 | s2_stft = stft(s2_wav, time_dim=0, size=256, shift=128) 87 | 88 | s1_gender = gender_dict[wav_name.split('_')[0][0:3]] 89 | s2_gender = gender_dict[wav_name.split('_')[2][0:3]] 90 | part_name = os.path.splitext(wav_name)[0] 91 | tfrecords_name = tfdir + '/' + part_name + '.tfrecords' 92 | print(tfrecords_name) 93 | with tf.python_io.TFRecordWriter(tfrecords_name) as writer: 94 | tf.logging.info( 95 | "Writing utterance %s" %tfrecords_name) 96 | mix_abs = np.abs(mix_stft) 97 | mix_angle = np.angle(mix_stft); 98 | s1_abs = np.abs(s1_stft); 99 | s1_angle = np.angle(s1_stft) 100 | s2_abs = np.abs(s2_stft); 101 | s2_angle = np.angle(s2_stft) 102 | inputs = np.concatenate((mix_abs, mix_angle), axis=1) 103 | labels = np.concatenate((s1_abs*np.cos(mix_angle-s1_angle), s2_abs*np.cos(mix_angle-s2_angle)), axis = 1) 104 | gender = [s1_gender, s2_gender] 105 | ex = make_sequence_example(inputs, labels, gender) 106 | writer.write(ex.SerializeToString()) 107 | 108 | 109 | 110 | pool = multiprocessing.Pool(8) 111 | workers= [] 112 | fid = open(namelist, 'r') 113 | lines = fid.readlines() 114 | fid.close() 115 | for name in lines: 116 | name = name.strip('\n') 117 | workers.append(pool.apply_async(gen_feats, (name))) 118 | #gen_feats(name) 119 | pool.close() 120 | pool.join() 121 | -------------------------------------------------------------------------------- /local/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Sining Sun 5 | 6 | from __future__ import absolute_import 7 | 8 | import sys, os, time 9 | import pprint 10 | import numpy as np 11 | 12 | import tensorflow as tf 13 | import tensorflow.contrib.slim as slim 14 | 15 | pp = pprint.PrettyPrinter() 16 | 17 | def show_all_variables(): 18 | model_vars = tf.trainable_variables() 19 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 20 | sys.stdout.flush() 21 | 22 | 23 | def mkdir_p(path): 24 | """ Creates a path recursively without throwing an error if it already exists 25 | 26 | :param path: path to create 27 | :return: None 28 | """ 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | 32 | 33 | """ 34 | From http://wiki.scipy.org/Cookbook/SegmentAxis 35 | """ 36 | 37 | 38 | def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0): 39 | """Generate a new array that chops the given array along the given axis into overlapping frames. 40 | 41 | example: 42 | >>> segment_axis(np.arange(10), 4, 2) 43 | array([[0, 1, 2, 3], 44 | [2, 3, 4, 5], 45 | [4, 5, 6, 7], 46 | [6, 7, 8, 9]]) 47 | 48 | arguments: 49 | a The array to segment 50 | length The length of each frame 51 | overlap The number of array elements by which the frames should overlap 52 | axis The axis to operate on; if None, act on the flattened array 53 | end What to do with the last frame, if the array is not evenly 54 | divisible into pieces. Options are: 55 | 56 | 'cut' Simply discard the extra values 57 | 'wrap' Copy values from the beginning of the array 58 | 'pad' Pad with a constant value 59 | 60 | endvalue The value to use for end='pad' 61 | 62 | The array is not copied unless necessary (either because it is 63 | unevenly strided and being flattened or because end is set to 64 | 'pad' or 'wrap'). 65 | """ 66 | 67 | if axis is None: 68 | a = np.ravel(a) # may copy 69 | axis = 0 70 | 71 | l = a.shape[axis] 72 | 73 | if overlap >= length: raise ValueError( 74 | "frames cannot overlap by more than 100%") 75 | if overlap < 0 or length <= 0: raise ValueError( 76 | "overlap must be nonnegative and length must be positive") 77 | 78 | if l < length or (l - length) % (length - overlap): 79 | if l > length: 80 | roundup = length + (1 + (l - length) // (length - overlap)) * ( 81 | length - overlap) 82 | rounddown = length + ((l - length) // (length - overlap)) * ( 83 | length - overlap) 84 | else: 85 | roundup = length 86 | rounddown = 0 87 | assert rounddown < l < roundup 88 | assert roundup == rounddown + (length - overlap) or ( 89 | roundup == length and rounddown == 0) 90 | a = a.swapaxes(-1, axis) 91 | 92 | if end == 'cut': 93 | a = a[..., :rounddown] 94 | elif end in ['pad', 'wrap']: # copying will be necessary 95 | s = list(a.shape) 96 | s[-1] = roundup 97 | b = np.empty(s, dtype=a.dtype) 98 | b[..., :l] = a 99 | if end == 'pad': 100 | b[..., l:] = endvalue 101 | elif end == 'wrap': 102 | b[..., l:] = a[..., :roundup - l] 103 | a = b 104 | 105 | a = a.swapaxes(-1, axis) 106 | 107 | l = a.shape[axis] 108 | if l == 0: raise ValueError( 109 | "Not enough data points to segment array in 'cut' mode; " 110 | "try 'pad' or 'wrap'") 111 | assert l >= length 112 | assert (l - length) % (length - overlap) == 0 113 | n = 1 + (l - length) // (length - overlap) 114 | s = a.strides[axis] 115 | newshape = a.shape[:axis] + (n, length) + a.shape[axis + 1:] 116 | newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[ 117 | axis + 1:] 118 | 119 | if not a.flags.contiguous: 120 | a = a.copy() 121 | newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[ 122 | axis + 1:] 123 | return np.ndarray.__new__(np.ndarray, strides=newstrides, 124 | shape=newshape, buffer=a, dtype=a.dtype) 125 | 126 | try: 127 | return np.ndarray.__new__(np.ndarray, strides=newstrides, 128 | shape=newshape, buffer=a, dtype=a.dtype) 129 | except TypeError or ValueError: 130 | warnings.warn("Problem with ndarray creation forces copy.") 131 | a = a.copy() 132 | # Shape doesn't change but strides does 133 | newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[ 134 | axis + 1:] 135 | return np.ndarray.__new__(np.ndarray, strides=newstrides, 136 | shape=newshape, buffer=a, dtype=a.dtype) 137 | 138 | 139 | -------------------------------------------------------------------------------- /io_funcs/tfrecords_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Sining Sun (NPU) 5 | 6 | 7 | """Utility functions for working with tf.train.SequenceExamples.""" 8 | 9 | import tensorflow as tf 10 | 11 | 12 | 13 | def get_padded_batch(file_list, batch_size, input_size, output_size, 14 | num_enqueuing_threads=4, num_epochs=1, shuffle=True): 15 | """Reads batches of SequenceExamples from TFRecords and pads them. 16 | Can deal with variable length SequenceExamples by padding each batch to the 17 | length of the longest sequence with zeros. 18 | Args: 19 | file_list: A list of paths to TFRecord files containing SequenceExamples. 20 | batch_size: The number of SequenceExamples to include in each batch. 21 | input_size: The size of each input vector. The returned batch of inputs 22 | will have a shape [batch_size, num_steps, input_size]. 23 | num_enqueuing_threads: The number of threads to use for enqueuing 24 | SequenceExamples. 25 | Returns: 26 | inputs: A tensor of shape [batch_size, num_steps, input_size] of floats32s. 27 | labels: A tensor of shape [batch_size, num_steps] of float32s. 28 | lengths: A tensor of shape [batch_size] of int32s. The lengths of each 29 | SequenceExample before padding. 30 | """ 31 | file_queue = tf.train.string_input_producer( 32 | file_list, num_epochs=num_epochs, shuffle=shuffle) 33 | reader = tf.TFRecordReader() 34 | _, serialized_example = reader.read(file_queue) 35 | 36 | sequence_features = { 37 | 'inputs': tf.FixedLenSequenceFeature(shape=[input_size], 38 | dtype=tf.float32), 39 | 'labels': tf.FixedLenSequenceFeature(shape=[output_size], 40 | dtype=tf.float32), 41 | 'genders': tf.FixedLenSequenceFeature(shape=[2], dtype=tf.float32)} 42 | 43 | _, sequence = tf.parse_single_sequence_example( 44 | serialized_example, sequence_features=sequence_features) 45 | 46 | length = tf.shape(sequence['inputs'])[0] 47 | 48 | capacity = 1000 + (num_enqueuing_threads + 1) * batch_size 49 | queue = tf.PaddingFIFOQueue( 50 | capacity=capacity, 51 | dtypes=[tf.float32, tf.float32, tf.float32, tf.int32], 52 | shapes=[(None, input_size), (None, output_size),(1,2), ()]) 53 | 54 | enqueue_ops = [queue.enqueue([sequence['inputs'], 55 | sequence['labels'], 56 | sequence['genders'], 57 | length])] * num_enqueuing_threads 58 | 59 | tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops)) 60 | return queue.dequeue_many(batch_size) 61 | 62 | def get_padded_batch_v2(file_list, batch_size, input_size, output_size, 63 | num_enqueuing_threads=4, num_epochs=1, shuffle=True): 64 | """Reads batches of SequenceExamples from TFRecords and pads them. 65 | Can deal with variable length SequenceExamples by padding each batch to the 66 | length of the longest sequence with zeros. 67 | Args: 68 | file_list: A list of paths to TFRecord files containing SequenceExamples. 69 | batch_size: The number of SequenceExamples to include in each batch. 70 | input_size: The size of each input vector. The returned batch of inputs 71 | will have a shape [batch_size, num_steps, input_size]. 72 | num_enqueuing_threads: The number of threads to use for enqueuing 73 | SequenceExamples. 74 | Returns: 75 | inputs: A tensor of shape [batch_size, num_steps, input_size] of floats32s. 76 | labels: A tensor of shape [batch_size, num_steps] of float32s. 77 | lengths: A tensor of shape [batch_size] of int32s. The lengths of each 78 | SequenceExample before padding. 79 | """ 80 | file_queue = tf.train.string_input_producer( 81 | file_list, num_epochs=num_epochs, shuffle=shuffle) 82 | reader = tf.TFRecordReader() 83 | _, serialized_example = reader.read(file_queue) 84 | 85 | 86 | sequence_features = { 87 | 'inputs': tf.FixedLenSequenceFeature(shape=[input_size],dtype=tf.float32), 88 | 'inputs_cmvn': tf.FixedLenSequenceFeature(shape=[input_size],dtype=tf.float32), 89 | 'labels1': tf.FixedLenSequenceFeature(shape=[output_size],dtype=tf.float32), 90 | 'labels2': tf.FixedLenSequenceFeature(shape=[output_size],dtype=tf.float32), 91 | } 92 | 93 | _, sequence = tf.parse_single_sequence_example( 94 | serialized_example, sequence_features=sequence_features) 95 | 96 | length = tf.shape(sequence['inputs'])[0] 97 | 98 | capacity = 1000 + (num_enqueuing_threads + 1) * batch_size 99 | queue = tf.PaddingFIFOQueue( 100 | capacity=capacity, 101 | dtypes=[tf.float32, tf.float32,tf.float32, tf.float32, tf.int32], 102 | shapes=[(None, input_size),(None, input_size),(None, output_size), (None, output_size), ()]) 103 | 104 | enqueue_ops = [queue.enqueue([sequence['inputs'], 105 | sequence['inputs_cmvn'], 106 | sequence['labels1'], 107 | sequence['labels2'], 108 | length])] * num_enqueuing_threads 109 | 110 | tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops)) 111 | return queue.dequeue_many(batch_size) 112 | -------------------------------------------------------------------------------- /matlab/enframe.m: -------------------------------------------------------------------------------- 1 | function [f,t,w]=enframe(x,win,hop,m,fs) 2 | %ENFRAME split signal up into (overlapping) frames: one per row. [F,T]=(X,WIN,HOP) 3 | % 4 | % Usage: (1) f=enframe(x,n) % split into frames of length n 5 | % (2) f=enframe(x,hamming(n,'periodic'),n/4) % use a 75% overlapped Hamming window of length n 6 | % (3) calculate spectrogram in units of power per Hz 7 | % 8 | % W=hamming(NW); % analysis window (NW = fft length) 9 | % P=enframe(S,W,HOP,'sdp',FS); % computer first half of PSD (HOP = frame increment in samples) 10 | % 11 | % (3) frequency domain frame-based processing: 12 | % 13 | % S=...; % input signal 14 | % OV=2; % overlap factor of 2 (4 is also often used) 15 | % NW=160; % DFT window length 16 | % W=sqrt(hamming(NW,'periodic')); % omit sqrt if OV=4 17 | % [F,T,WS]=enframe(S,W,1/OV,'fa'); % do STFT: one row per time frame, +ve frequencies only 18 | % ... process frames ... 19 | % X=overlapadd(irfft(F,NW,2),WS,HOP); % reconstitute the time waveform with scaled window (omit "X=" to plot waveform) 20 | % 21 | % Inputs: x input signal 22 | % win window or window length in samples 23 | % hop frame increment or hop in samples or fraction of window [window length] 24 | % m mode input: 25 | % 'z' zero pad to fill up final frame 26 | % 'r' reflect last few samples for final frame 27 | % 'A' calculate the t output as the centre of mass 28 | % 'E' calculate the t output as the centre of energy 29 | % 'f' perform a 1-sided dft on each frame (like rfft) 30 | % 'F' perform a 2-sided dft on each frame using fft 31 | % 'p' calculate the 1-sided power/energy spectrum of each frame 32 | % 'P' calculate the 2-sided power/energy spectrum of each frame 33 | % 'a' scale window to give unity gain with overlap-add 34 | % 's' scale window so that power is preserved: sum(mean(enframe(x,win,hop,'sp'),1))=mean(x.^2) 35 | % 'S' scale window so that total energy is preserved: sum(sum(enframe(x,win,hop,'Sp')))=sum(x.^2) 36 | % 'd' make options 's' and 'S' give power/energy per Hz: sum(mean(enframe(x,win,hop,'sp'),1))*fs/length(win)=mean(x.^2) 37 | % fs sample frequency (only needed for 'd' option) [1] 38 | % 39 | % Outputs: f enframed data - one frame per row 40 | % t fractional time in samples at the centre of each frame 41 | % with the first sample being 1. 42 | % w window function used 43 | % 44 | % By default, the number of frames will be rounded down to the nearest 45 | % integer and the last few samples of x() will be ignored unless its length 46 | % is lw more than a multiple of hop. If the 'z' or 'r' options are given, 47 | % the number of frame will instead be rounded up and no samples will be ignored. 48 | % 49 | 50 | % Bugs/Suggestions: 51 | % (1) Possible additional mode options: 52 | % 'u' modify window for first and last few frames to ensure WOLA 53 | % 'a' normalize window to give a mean of unity after overlaps 54 | % 'e' normalize window to give an energy of unity after overlaps 55 | % 'wm' use Hamming window 56 | % 'wn' use Hanning window 57 | % 'x' hoplude all frames that hoplude any of the x samples 58 | 59 | % Copyright (C) Mike Brookes 1997-2014 60 | % Version: $Id: enframe.m 9529 2017-02-25 19:08:56Z dmb $ 61 | % 62 | % VOICEBOX is a MATLAB toolbox for speech processing. 63 | % Home page: http://www.ee.ic.ac.uk/hp/staff/dmb/voicebox/voicebox.html 64 | % 65 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 66 | % This program is free software; you can redistribute it and/or modify 67 | % it under the terms of the GNU General Public License as published by 68 | % the Free Software Foundation; either version 2 of the License, or 69 | % (at your option) any later version. 70 | % 71 | % This program is distributed in the hope that it will be useful, 72 | % but WITHOUT ANY WARRANTY; without even the implied warranty of 73 | % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 74 | % GNU General Public License for more details. 75 | % 76 | % You can obtain a copy of the GNU General Public License from 77 | % http://www.gnu.org/copyleft/gpl.html or by writing to 78 | % Free Software Foundation, Inc.,675 Mass Ave, Cambridge, MA 02139, USA. 79 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 80 | 81 | nx=length(x(:)); 82 | if nargin<2 || isempty(win) 83 | win=nx; 84 | end 85 | if nargin<4 || isempty(m) 86 | m=''; 87 | end 88 | nwin=length(win); 89 | if nwin == 1 90 | lw = win; 91 | w = ones(1,lw); 92 | else 93 | lw = nwin; 94 | w = win(:).'; 95 | end 96 | if (nargin < 3) || isempty(hop) 97 | hop = lw; % if no hop given, make non-overlapping 98 | elseif hop<1 99 | hop=lw*hop; 100 | end 101 | if any(m=='a') 102 | w=w*sqrt(hop/sum(w.^2)); % scale to give unity gain for overlap-add 103 | elseif any(m=='s') 104 | w=w/sqrt(w*w'*lw); 105 | elseif any(m=='S') 106 | w=w/sqrt(w*w'*lw/hop); 107 | end 108 | if any(m=='d') % scale to give power/energy densities 109 | if nargin<5 || isempty(fs) 110 | w=w*sqrt(lw); 111 | else 112 | w=w*sqrt(lw/fs); 113 | end 114 | end 115 | nli=nx-lw+hop; 116 | nf = max(fix(nli/hop),0); % number of full frames 117 | na=nli-hop*nf+(nf==0)*(lw-hop); % number of samples left over 118 | fx=nargin>3 && (any(m=='z') || any(m=='r')) && na>0; % need an extra row 119 | f=zeros(nf+fx,lw); 120 | indf= hop*(0:(nf-1)).'; 121 | inds = (1:lw); 122 | if fx 123 | f(1:nf,:) = x(indf(:,ones(1,lw))+inds(ones(nf,1),:)); 124 | if any(m=='r') 125 | ix=1+mod(nf*hop:nf*hop+lw-1,2*nx); 126 | f(nf+1,:)=x(ix+(ix>nx).*(2*nx+1-2*ix)); 127 | else 128 | f(nf+1,1:nx-nf*hop)=x(1+nf*hop:nx); 129 | end 130 | nf=size(f,1); 131 | else 132 | f(:) = x(indf(:,ones(1,lw))+inds(ones(nf,1),:)); 133 | end 134 | if (nwin > 1) % if we have a non-unity window 135 | f = f .* w(ones(nf,1),:); 136 | end 137 | if any(lower(m)=='p') % 'pP' = calculate the power spectrum 138 | f=fft(f,[],2); 139 | f=real(f.*conj(f)); 140 | if any(m=='p') 141 | imx=fix((lw+1)/2); % highest replicated frequency 142 | f(:,2:imx)=f(:,2:imx)+f(:,lw:-1:lw-imx+2); 143 | f=f(:,1:fix(lw/2)+1); 144 | end 145 | elseif any(lower(m)=='f') % 'fF' = take the DFT 146 | f=fft(f,[],2); 147 | if any(m=='f') 148 | f=f(:,1:fix(lw/2)+1); 149 | end 150 | end 151 | if nargout>1 152 | if any(m=='E') 153 | t0=sum((1:lw).*w.^2)/sum(w.^2); 154 | elseif any(m=='A') 155 | t0=sum((1:lw).*w)/sum(w); 156 | else 157 | t0=(1+lw)/2; 158 | end 159 | t=t0+hop*(0:(nf-1)).'; 160 | end 161 | 162 | 163 | -------------------------------------------------------------------------------- /local/convert_to_records.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Sining Sun 5 | 6 | """Converts data to TFRecords file format with Example protos.""" 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import os 13 | import struct 14 | import sys 15 | import multiprocessing 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | sys.path.append('./') 21 | from io_funcs.tfrecords_io import make_sequence_example_two_labels 22 | 23 | tf.logging.set_verbosity(tf.logging.INFO) 24 | 25 | def convert_cmvn_to_numpy(inputs_cmvn, labels_cmvn): 26 | if FLAGS.labels_cmvn !='': 27 | """Convert global binary ark cmvn to numpy format.""" 28 | tf.logging.info("Convert %s and %s to numpy format" % ( 29 | inputs_cmvn, labels_cmvn)) 30 | inputs_filename = FLAGS.inputs_cmvn 31 | labels_filename = FLAGS.labels_cmvn 32 | 33 | inputs = read_binary_file(inputs_filename, 0) 34 | labels = read_binary_file(labels_filename, 0) 35 | 36 | inputs_frame = inputs[0][-1] 37 | labels_frame = labels[0][-1] 38 | 39 | #assert inputs_frame == labels_frame 40 | 41 | cmvn_inputs = np.hsplit(inputs, [inputs.shape[1]-1])[0] 42 | cmvn_labels = np.hsplit(labels, [labels.shape[1]-1])[0] 43 | 44 | mean_inputs = cmvn_inputs[0] / inputs_frame 45 | stddev_inputs = np.sqrt(cmvn_inputs[1] / inputs_frame - mean_inputs ** 2) 46 | mean_labels = cmvn_labels[0] / labels_frame 47 | stddev_labels = np.sqrt(cmvn_labels[1] / labels_frame - mean_labels ** 2) 48 | 49 | cmvn_name = os.path.join(FLAGS.output_dir, "train_cmvn.npz") 50 | np.savez(cmvn_name, 51 | mean_inputs=mean_inputs, 52 | stddev_inputs=stddev_inputs, 53 | mean_labels=mean_labels, 54 | stddev_labels=stddev_labels) 55 | 56 | tf.logging.info("Write to %s" % cmvn_name) 57 | else : 58 | """Convert global binary ark cmvn to numpy format.""" 59 | tf.logging.info("Convert %s to numpy format" % ( 60 | inputs_cmvn)) 61 | inputs_filename = FLAGS.inputs_cmvn 62 | 63 | inputs = read_binary_file(inputs_filename, 0) 64 | 65 | inputs_frame = inputs[0][-1] 66 | 67 | 68 | cmvn_inputs = np.hsplit(inputs, [inputs.shape[1]-1])[0] 69 | 70 | mean_inputs = cmvn_inputs[0] / inputs_frame 71 | stddev_inputs = np.sqrt(cmvn_inputs[1] / inputs_frame - mean_inputs ** 2) 72 | 73 | cmvn_name = os.path.join(FLAGS.output_dir, "train_cmvn.npz") 74 | np.savez(cmvn_name, 75 | mean_inputs=mean_inputs, 76 | stddev_inputs=stddev_inputs) 77 | 78 | tf.logging.info("Write to %s" % cmvn_name) 79 | 80 | 81 | def read_binary_file(filename, offset=0): 82 | """Read data from matlab binary file (row, col and matrix). 83 | 84 | Returns: 85 | A numpy matrix containing data of the given binary file. 86 | """ 87 | read_buffer = open(filename, 'rb') 88 | read_buffer.seek(int(offset), 0) 89 | header = struct.unpack(' 38 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s2/']); %#ok 39 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/mix/']); %#ok 40 | % status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s1/']); %#ok 41 | % status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s2/']); %#ok 42 | % status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/mix/']); 43 | 44 | TaskFile = ['mix_2_spk_' data_type{i_type} '.txt']; 45 | fid=fopen(TaskFile,'r'); 46 | C=textscan(fid,'%s %f %s %f'); 47 | 48 | Source1File = ['mix_2_spk_' min_max{i_mm} '_' data_type{i_type} '_1']; 49 | Source2File = ['mix_2_spk_' min_max{i_mm} '_' data_type{i_type} '_2']; 50 | MixFile = ['mix_2_spk_' min_max{i_mm} '_' data_type{i_type} '_mix']; 51 | fid_s1 = fopen(Source1File,'w'); 52 | fid_s2 = fopen(Source2File,'w'); 53 | fid_m = fopen(MixFile,'w'); 54 | 55 | num_files = length(C{1}); 56 | fs8k=8000; 57 | 58 | % scaling_16k = zeros(num_files,2); 59 | scaling_8k = zeros(num_files,2); 60 | % scaling16bit_16k = zeros(num_files,1); 61 | scaling16bit_8k = zeros(num_files,1); 62 | fprintf(1,'%s\n',[min_max{i_mm} '_' data_type{i_type}]); 63 | for i = 1:num_files 64 | [inwav1_dir,invwav1_name,inwav1_ext] = fileparts(C{1}{i}); 65 | [inwav2_dir,invwav2_name,inwav2_ext] = fileparts(C{3}{i}); 66 | fprintf(fid_s1,'%s\n',C{1}{i}); 67 | fprintf(fid_s2,'%s\n',C{3}{i}); 68 | inwav1_snr = C{2}(i); 69 | inwav2_snr = C{4}(i); 70 | mix_name = [invwav1_name,'_',num2str(inwav1_snr),'_',invwav2_name,'_',num2str(inwav2_snr)]; 71 | fprintf(fid_m,'%s\n',mix_name); 72 | 73 | % get input wavs 74 | [s1, fs] = wavread( C{1}{i}); 75 | s2 = wavread(C{3}{i}); 76 | 77 | % resample, normalize 8 kHz file, save scaling factor 78 | s1_8k=resample(s1,fs8k,fs); 79 | [s1_8k,lev1]=activlev(s1_8k,fs8k,'n'); % y_norm = y /sqrt(lev); 80 | s2_8k=resample(s2,fs8k,fs); 81 | [s2_8k,lev2]=activlev(s2_8k,fs8k,'n'); 82 | 83 | weight_1=10^(inwav1_snr/20); 84 | weight_2=10^(inwav2_snr/20); 85 | 86 | s1_8k = weight_1 * s1_8k; 87 | s2_8k = weight_2 * s2_8k; 88 | 89 | switch min_max{i_mm} 90 | case 'max' 91 | mix_8k_length = max(length(s1_8k),length(s2_8k)); 92 | s1_8k = cat(1,s1_8k,zeros(mix_8k_length - length(s1_8k),1)); 93 | s2_8k = cat(1,s2_8k,zeros(mix_8k_length - length(s2_8k),1)); 94 | case 'min' 95 | mix_8k_length = min(length(s1_8k),length(s2_8k)); 96 | s1_8k = s1_8k(1:mix_8k_length); 97 | s2_8k = s2_8k(1:mix_8k_length); 98 | end 99 | mix_8k = s1_8k + s2_8k; 100 | 101 | max_amp_8k = max(cat(1,abs(mix_8k(:)),abs(s1_8k(:)),abs(s2_8k(:)))); 102 | mix_scaling_8k = 1/max_amp_8k*0.9; 103 | s1_8k = mix_scaling_8k * s1_8k; 104 | s2_8k = mix_scaling_8k * s2_8k; 105 | mix_8k = mix_scaling_8k * mix_8k; 106 | 107 | % apply same gain to 16 kHz file 108 | %s1_8k=resample(s1,fs8k,fs); 109 | %[s1_8k,lev1]=activlev(s1_8k,fs8k,'n'); % y_norm = y /sqrt(lev); 110 | %s2_8k=resample(s2,fs8k,fs); 111 | %[s2_8k,lev2]=activlev(s2_8k,fs8k,'n'); 112 | 113 | %s1_16k = weight_1 * s1 / sqrt(lev1); 114 | %s2_16k = weight_2 * s2 / sqrt(lev2); 115 | 116 | %switch min_max{i_mm} 117 | % case 'max' 118 | % mix_16k_length = max(length(s1_16k),length(s2_16k)); 119 | % s1_16k = cat(1,s1_16k,zeros(mix_16k_length - length(s1_16k),1)); 120 | % s2_16k = cat(1,s2_16k,zeros(mix_16k_length - length(s2_16k),1)); 121 | % case 'min' 122 | % mix_16k_length = min(length(s1_16k),length(s2_16k)); 123 | % s1_16k = s1_16k(1:mix_16k_length); 124 | % s2_16k = s2_16k(1:mix_16k_length); 125 | %end 126 | %mix_16k = s1_16k + s2_16k; 127 | % 128 | %max_amp_16k = max(cat(1,abs(mix_16k(:)),abs(s1_16k(:)),abs(s2_16k(:)))); 129 | %mix_scaling_16k = 1/max_amp_16k*0.9; 130 | %s1_16k = mix_scaling_16k * s1_16k; 131 | %s2_16k = mix_scaling_16k * s2_16k; 132 | %mix_16k = mix_scaling_16k * mix_16k; 133 | 134 | % save 8 kHz and 16 kHz mixtures, as well as 135 | % necessary scaling factors 136 | 137 | %scaling_16k(i,1) = weight_1 * mix_scaling_16k/ sqrt(lev1); 138 | %scaling_16k(i,2) = weight_2 * mix_scaling_16k/ sqrt(lev2); 139 | scaling_8k(i,1) = weight_1 * mix_scaling_8k/ sqrt(lev1); 140 | scaling_8k(i,2) = weight_2 * mix_scaling_8k/ sqrt(lev2); 141 | 142 | % scaling16bit_16k(i) = mix_scaling_16k; 143 | scaling16bit_8k(i) = mix_scaling_8k; 144 | 145 | wavwrite(s1_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s1/' mix_name '.wav']); 146 | % wavwrite(s1_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s1/' mix_name '.wav']); 147 | wavwrite(s2_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s2/' mix_name '.wav']); 148 | % wavwrite(s2_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s2/' mix_name '.wav']); 149 | wavwrite(mix_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/mix/' mix_name '.wav']); 150 | % wavwrite(mix_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/mix/' mix_name '.wav']); 151 | 152 | if mod(i,10)==0 153 | fprintf(1,'.'); 154 | if mod(i,200)==0 155 | fprintf(1,'\n'); 156 | end 157 | end 158 | 159 | end 160 | save([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/scaling.mat'],'scaling_8k','scaling16bit_8k'); 161 | % save([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/scaling.mat'],'scaling_16k','scaling16bit_16k'); 162 | 163 | fclose(fid); 164 | fclose(fid_s1); 165 | fclose(fid_s2); 166 | fclose(fid_m); 167 | end 168 | end 169 | -------------------------------------------------------------------------------- /model/spknet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | """ 6 | Build the LSTM(BLSTM) neural networks for speaker recognition. 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import sys 15 | import time 16 | 17 | import tensorflow as tf 18 | from tensorflow.contrib.rnn.python.ops import rnn 19 | import numpy as np 20 | 21 | class LSTM(object): 22 | """Build BLSTM or LSTM model for speaker recognition. 23 | If you use this module to train your module, make sure that 24 | your prepare the right format data! 25 | 26 | Attributes: 27 | config: Used to config our model 28 | config.input_size: feature (input) size; 29 | config.output_size: the final layer(output layer) size; 30 | config.rnn_size: the rnn cells' number 31 | config.batch_size: the batch_size for training 32 | config.rnn_num_layers: the rnn layers numbers 33 | config.keep_prob: the dropout rate 34 | inputs: [A,B], a T*(2D) matrix 35 | labels: "two hot" target label 36 | lengths: the length of every utterance 37 | infer: bool, if training(false) or test (true) 38 | """ 39 | 40 | def __init__(self, config, inputs, labels, lengths, infer=False): 41 | self._inputs = inputs 42 | self._labels = labels 43 | self._lengths = lengths 44 | self._model_type = config.model_type 45 | if infer: # if infer, we prefer to run one utterance one time. 46 | config.batch_size = 1 47 | outputs = self._inputs 48 | ## This first layer-- feed forward layer 49 | ## Transform the input to the right size before feed into RNN 50 | 51 | with tf.variable_scope('forward1'): 52 | outputs = tf.reshape(outputs, [-1, config.input_size]) 53 | outputs = tf.layers.dense(outputs, units=config.rnn_size, 54 | activation=tf.nn.tanh, 55 | reuse=tf.get_variable_scope().reuse) 56 | outputs = tf.reshape( 57 | outputs, [config.batch_size,-1, config.rnn_size]) 58 | 59 | ## Configure the LSTM or BLSTM model 60 | ## For BLSTM, we use the BasicLSTMCell.For LSTM, we use LSTMCell. 61 | ## You can change them and test the performance... 62 | 63 | if config.model_type.lower() == 'blstm': 64 | with tf.variable_scope('blstm'): 65 | cell = tf.contrib.rnn.BasicLSTMCell(config.rnn_size) 66 | if not infer and config.keep_prob < 1.0: 67 | cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=config.keep_prob) 68 | 69 | lstm_fw_cell = tf.contrib.rnn.MultiRNNCell([cell] * config.rnn_num_layers) 70 | lstm_bw_cell = tf.contrib.rnn.MultiRNNCell([cell] * config.rnn_num_layers) 71 | lstm_fw_cell = _unpack_cell(lstm_fw_cell) 72 | lstm_bw_cell = _unpack_cell(lstm_bw_cell) 73 | result = rnn.stack_bidirectional_dynamic_rnn( 74 | cells_fw = lstm_fw_cell, 75 | cells_bw = lstm_bw_cell, 76 | inputs=outputs, 77 | dtype=tf.float32, 78 | sequence_length=self._lengths) 79 | outputs, fw_final_states, bw_final_states = result 80 | if config.model_type.lower() == 'lstm': 81 | with tf.variable_scope('lstm'): 82 | def lstm_cell(): 83 | return tf.contrib.rnn.LSTMCell( 84 | config.rnn_size, forget_bias=1.0, use_peepholes=True, 85 | initializer=tf.contrib.layers.xavier_initializer(), 86 | state_is_tuple=True, activation=tf.tanh) 87 | attn_cell = lstm_cell 88 | if not infer and config.keep_prob < 1.0: 89 | def attn_cell(): 90 | return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=config.keep_prob) 91 | cell = tf.contrib.rnn.MultiRNNCell( 92 | [attn_cell() for _ in range(config.rnn_num_layers)], 93 | state_is_tuple=True) 94 | self._initial_state = cell.zero_state(config.batch_size, tf.float32) 95 | state = self.initial_state 96 | outputs, state = tf.nn.dynamic_rnn( 97 | cell, outputs, 98 | dtype=tf.float32, 99 | sequence_length=self._lengths, 100 | initial_state=self.initial_state) 101 | self._final_state = state 102 | 103 | ## Feed forward layer. Transform the RNN output to the right output size 104 | 105 | with tf.variable_scope('forward2'): 106 | if config.embedding_option == 0: #no embedding , frame by frame 107 | if self._model_type.lower() == 'blstm': 108 | outputs = tf.reshape(outputs, [-1, 2*config.rnn_size]) 109 | in_size=2*config.rnn_size 110 | else: 111 | outputs = tf.reshape(outputs, [-1, config.rnn_size]) 112 | in_size = config.rnn_size 113 | 114 | else: 115 | if self._model_type.lower() == 'blstm': 116 | outputs = tf.reshape(outputs, [config.batch_size,-1, 2*config.rnn_size]) 117 | in_size=2*config.rnn_size 118 | else: 119 | outputs = tf.reshape(outputs, [config.batch_size,-1, config.rnn_size]) 120 | in_size = config.rnn_size 121 | 122 | if config.embedding_option == 1: #last frame embedding 123 | #http://sqrtf.com/fetch-rnn-encoder-last-output-using-tf-gather_nd/ 124 | ind = tf.subtract(self._lengths, tf.constant(1)) 125 | batch_range = tf.range(config.batch_size) 126 | indices = tf.stack([batch_range, ind], axis=1) 127 | 128 | outputs = tf.gather_nd(outputs, indices) 129 | self._labels = tf.reduce_mean(self._labels, 1) 130 | elif config.embedding_option == 2: # mean pooing 131 | outputs = tf.reduce_mean(outputs,1) 132 | self._labels = tf.reduce_mean(self._labels, 1) 133 | out_size = config.output_size 134 | weights1 = tf.get_variable('weights1', [in_size, out_size], 135 | initializer=tf.random_normal_initializer(stddev=0.01)) 136 | biases1 = tf.get_variable('biases1', [out_size], 137 | initializer=tf.constant_initializer(0.0)) 138 | outputs = tf.matmul(outputs, weights1) + biases1 139 | if config.embedding_option == 0: 140 | outputs = tf.reshape(outputs, [config.batch_size, -1, out_size]) 141 | self._outputs = tf.nn.sigmoid(outputs) 142 | # Ability to save the model 143 | self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) 144 | 145 | if infer: return 146 | 147 | 148 | # Compute loss(CE) 149 | self._loss=tf.losses.sigmoid_cross_entropy(self._labels, outputs) 150 | if tf.get_variable_scope().reuse: return 151 | 152 | self._lr = tf.Variable(0.0, trainable=False) 153 | tvars = tf.trainable_variables() 154 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), 155 | config.max_grad_norm) 156 | optimizer = tf.train.AdamOptimizer(self.lr) 157 | #optimizer = tf.train.GradientDescentOptimizer(self.lr) 158 | self._train_op = optimizer.apply_gradients(zip(grads, tvars)) 159 | 160 | self._new_lr = tf.placeholder( 161 | tf.float32, shape=[], name='new_learning_rate') 162 | self._lr_update = tf.assign(self._lr, self._new_lr) 163 | 164 | def assign_lr(self, session, lr_value): 165 | session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) 166 | 167 | @property 168 | def inputs(self): 169 | return self._inputs_spk1,self._inputs_spk2 170 | 171 | @property 172 | def labels(self): 173 | return self._labels 174 | 175 | @property 176 | def initial_state(self): 177 | return self._initial_state 178 | 179 | @property 180 | def final_state(self): 181 | return self._final_state 182 | 183 | @property 184 | def lr(self): 185 | return self._lr 186 | 187 | @property 188 | def loss(self): 189 | return self._loss 190 | 191 | @property 192 | def train_op(self): 193 | return self._train_op 194 | @property 195 | def outputs(self): 196 | return self._outputs 197 | 198 | 199 | @staticmethod 200 | def _weight_and_bias(in_size, out_size): 201 | # Create variable named "weights". 202 | weights = tf.get_variable('weights', [in_size, out_size], 203 | initializer=tf.random_normal_initializer(stddev=0.01)) 204 | # Create variabel named "biases". 205 | biases = tf.get_variable('biases', [out_size], 206 | initializer=tf.constant_initializer(0.0)) 207 | return weights, biases 208 | def _unpack_cell(cell): 209 | if isinstance(cell,tf.contrib.rnn.MultiRNNCell): 210 | return cell._cells 211 | else: 212 | return [cell] 213 | -------------------------------------------------------------------------------- /io_funcs/signal_processing.py: -------------------------------------------------------------------------------- 1 | import string 2 | import threading, sys 3 | 4 | import librosa 5 | import numpy as np 6 | import scipy 7 | from numpy.fft import rfft, irfft 8 | from scipy import signal 9 | from scipy.io.wavfile import write as wav_write 10 | sys.path.append('.') 11 | from local.utils import segment_axis 12 | 13 | 14 | def _samples_to_stft_frames(samples, size, shift): 15 | """ 16 | Calculates STFT frames from samples in time domain. 17 | :param samples: Number of samples in time domain. 18 | :param size: FFT size. 19 | :param shift: Hop in samples. 20 | :return: Number of STFT frames. 21 | """ 22 | 23 | return np.ceil((float(samples) - size + shift) / shift).astype(np.int) 24 | 25 | 26 | def _stft_frames_to_samples(frames, size, shift): 27 | """ 28 | Calculates samples in time domain from STFT frames 29 | :param frames: Number of STFT frames. 30 | :param size: FFT size. 31 | :param shift: Hop in samples. 32 | :return: Number of samples in time domain. 33 | """ 34 | return frames * shift + size - shift 35 | 36 | 37 | def _biorthogonal_window_loopy(analysis_window, shift): 38 | """ 39 | This version of the synthesis calculation is as close as possible to the 40 | Matlab impelementation in terms of variable names. 41 | 42 | The results are equal. 43 | 44 | The implementation follows equation A.92 in 45 | Krueger, A. Modellbasierte Merkmalsverbesserung zur robusten automatischen 46 | Spracherkennung in Gegenwart von Nachhall und Hintergrundstoerungen 47 | Paderborn, Universitaet Paderborn, Diss., 2011, 2011 48 | """ 49 | fft_size = len(analysis_window) 50 | assert np.mod(fft_size, shift) == 0 51 | number_of_shifts = len(analysis_window) // shift 52 | 53 | sum_of_squares = np.zeros(shift) 54 | for synthesis_index in range(0, shift): 55 | for sample_index in range(0, number_of_shifts + 1): 56 | analysis_index = synthesis_index + sample_index * shift 57 | 58 | if analysis_index + 1 < fft_size: 59 | sum_of_squares[synthesis_index] \ 60 | += analysis_window[analysis_index] ** 2 61 | 62 | sum_of_squares = np.kron(np.ones(number_of_shifts), sum_of_squares) 63 | synthesis_window = analysis_window / sum_of_squares / fft_size 64 | return synthesis_window 65 | 66 | 67 | def audioread(path, offset=0.0, duration=None, sample_rate=16000): 68 | """ 69 | Reads a wav file, converts it to 32 bit float values and reshapes accoring 70 | to the number of channels. 71 | Now, this is a wrapper of librosa with our common defaults. 72 | 73 | :param path: Absolute or relative file path to audio file. 74 | :type: String. 75 | :param offset: Begin of loaded audio. 76 | :type: Scalar in seconds. 77 | :param duration: Duration of loaded audio. 78 | :type: Scalar in seconds. 79 | :param sample_rate: Sample rate of audio 80 | :type: scalar in number of samples per second 81 | :return: 82 | 83 | .. admonition:: Example 84 | Only path provided: 85 | 86 | >>> path = '/net/speechdb/timit/pcm/train/dr1/fcjf0/sa1.wav' 87 | >>> signal = audioread(path) 88 | 89 | Say you load audio examples from a very long audio, you can provide a 90 | start position and a duration in seconds. 91 | 92 | >>> path = '/net/speechdb/timit/pcm/train/dr1/fcjf0/sa1.wav' 93 | >>> signal = audioread(path, offset=0, duration=1) 94 | """ 95 | signal = librosa.load(path, 96 | sr=sample_rate, 97 | mono=False, 98 | offset=offset, 99 | duration=duration) 100 | return signal[0] 101 | 102 | 103 | def stft(time_signal, time_dim=None, size=1024, shift=256, 104 | window=signal.blackman, fading=True, window_length=None): 105 | """ 106 | Calculates the short time Fourier transformation of a multi channel multi 107 | speaker time signal. It is able to add additional zeros for fade-in and 108 | fade out and should yield an STFT signal which allows perfect 109 | reconstruction. 110 | 111 | :param time_signal: multi channel time signal. 112 | :param time_dim: Scalar dim of time. 113 | Default: None means the biggest dimension 114 | :param size: Scalar FFT-size. 115 | :param shift: Scalar FFT-shift. Typically shift is a fraction of size. 116 | :param window: Window function handle. 117 | :param fading: Pads the signal with zeros for better reconstruction. 118 | :param window_length: Sometimes one desires to use a shorter window than 119 | the fft size. In that case, the window is padded with zeros. 120 | The default is to use the fft-size as a window size. 121 | :return: Single channel complex STFT signal 122 | with dimensions frames times size/2+1. 123 | """ 124 | if time_dim is None: 125 | time_dim = np.argmax(time_signal.shape) 126 | 127 | # Pad with zeros to have enough samples for the window function to fade. 128 | if fading: 129 | pad = [(0, 0)] * time_signal.ndim 130 | pad[time_dim] = [size - shift, size - shift] 131 | time_signal = np.pad(time_signal, pad, mode='constant') 132 | 133 | # Pad with trailing zeros, to have an integral number of frames. 134 | frames = _samples_to_stft_frames(time_signal.shape[time_dim], size, shift) 135 | samples = _stft_frames_to_samples(frames, size, shift) 136 | pad = [(0, 0)] * time_signal.ndim 137 | pad[time_dim] = [0, samples - time_signal.shape[time_dim]] 138 | time_signal = np.pad(time_signal, pad, mode='constant') 139 | 140 | if window_length is None: 141 | window = window(size) 142 | else: 143 | window = window(window_length) 144 | window = np.pad(window, (0, size - window_length), mode='constant') 145 | 146 | time_signal_seg = segment_axis(time_signal, size, 147 | size - shift, axis=time_dim) 148 | 149 | letters = string.ascii_lowercase 150 | mapping = letters[:time_signal_seg.ndim] + ',' + letters[time_dim + 1] \ 151 | + '->' + letters[:time_signal_seg.ndim] 152 | 153 | return rfft(np.einsum(mapping, time_signal_seg, window), 154 | axis=time_dim + 1) 155 | 156 | 157 | def istft(stft_signal, size=1024, shift=256, 158 | window=signal.blackman, fading=True, window_length=None): 159 | """ 160 | Calculated the inverse short time Fourier transform to exactly reconstruct 161 | the time signal. 162 | 163 | :param stft_signal: Single channel complex STFT signal 164 | with dimensions frames times size/2+1. 165 | :param size: Scalar FFT-size. 166 | :param shift: Scalar FFT-shift. Typically shift is a fraction of size. 167 | :param window: Window function handle. 168 | :param fading: Removes the additional padding, if done during STFT. 169 | :param window_length: Sometimes one desires to use a shorter window than 170 | the fft size. In that case, the window is padded with zeros. 171 | The default is to use the fft-size as a window size. 172 | :return: Single channel complex STFT signal 173 | :return: Single channel time signal. 174 | """ 175 | assert stft_signal.shape[1] == size // 2 + 1 176 | 177 | if window_length is None: 178 | window = window(size) 179 | else: 180 | window = window(window_length) 181 | window = np.pad(window, (0, size - window_length), mode='constant') 182 | 183 | window = _biorthogonal_window_loopy(window, shift) 184 | 185 | # Why? Line created by Hai, Lukas does not know, why it exists. 186 | window *= size 187 | 188 | time_signal = scipy.zeros(stft_signal.shape[0] * shift + size - shift) 189 | 190 | for j, i in enumerate(range(0, len(time_signal) - size + shift, shift)): 191 | time_signal[i:i + size] += window * np.real(irfft(stft_signal[j])) 192 | 193 | # Compensate fade-in and fade-out 194 | if fading: 195 | time_signal = time_signal[ 196 | size - shift:len(time_signal) - (size - shift)] 197 | 198 | return time_signal 199 | 200 | 201 | def audiowrite(data, path, samplerate=16000, normalize=False, threaded=True): 202 | """ Write the audio data ``data`` to the wav file ``path`` 203 | 204 | The file can be written in a threaded mode. In this case, the writing 205 | process will be started at a separate thread. Consequently, the file will 206 | not be written when this function exits. 207 | 208 | :param data: A numpy array with the audio data 209 | :param path: The wav file the data should be written to 210 | :param samplerate: Samplerate of the audio data 211 | :param normalize: Normalize the audio first so that the values are within 212 | the range of [INTMIN, INTMAX]. E.g. no clipping occurs 213 | :param threaded: If true, the write process will be started as a separate 214 | thread 215 | :return: The number of clipped samples 216 | """ 217 | data = data.copy() 218 | int16_max = np.iinfo(np.int16).max 219 | int16_min = np.iinfo(np.int16).min 220 | 221 | if normalize: 222 | if not data.dtype.kind == 'f': 223 | data = data.astype(np.float) 224 | data /= np.max(np.abs(data)) 225 | 226 | if data.dtype.kind == 'f': 227 | data *= int16_max 228 | 229 | sample_to_clip = np.sum(data > int16_max) 230 | if sample_to_clip > 0: 231 | print('Warning, clipping {} samples'.format(sample_to_clip)) 232 | data = np.clip(data, int16_min, int16_max) 233 | data = data.astype(np.int16) 234 | 235 | if threaded: 236 | threading.Thread(target=wav_write, 237 | args=(path, samplerate, data)).start() 238 | else: 239 | wav_write(path, samplerate, data) 240 | 241 | return sample_to_clip 242 | -------------------------------------------------------------------------------- /matlab/create_wav_3speakers.m: -------------------------------------------------------------------------------- 1 | % create_wav_3_speakers.m 2 | % 3 | % Create 3-speaker mixtures 4 | % 5 | % This script assumes that WSJ0's wv1 sphere files have already 6 | % been converted to wav files, using the original folder structure 7 | % under wsj0/, e.g., 8 | % 11-1.1/wsj0/si_tr_s/01t/01to030v.wv1 is converted to wav and 9 | % stored in YOUR_PATH/wsj0/si_tr_s/01t/01to030v.wav, and 10 | % 11-6.1/wsj0/si_dt_05/050/050a0501.wv1 is converted to wav and 11 | % stored in YOUR_PATH/wsj0/si_dt_05/050/050a0501.wav. 12 | % Relevant data from all disks are assumed merged under YOUR_PATH/wsj0/ 13 | % 14 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 15 | % Copyright (C) 2016 Mitsubishi Electric Research Labs 16 | % (Jonathan Le Roux, John R. Hershey, Zhuo Chen) 17 | % Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 18 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 19 | 20 | addpath('./voicebox') 21 | data_type = {'tr','cv','tt'}; 22 | wsj0root = './'; % YOUR_PATH/, the folder containing wsj0/ 23 | output_dir16k='./data/3speakers/wav16k'; 24 | output_dir8k='./data/3speakers/wav8k'; 25 | 26 | min_max = {'min'}; %{'min','max'}; 27 | 28 | for i_mm = 1:length(min_max) 29 | for i_type = 1:length(data_type) 30 | if ~exist([output_dir16k '/' min_max{i_mm} '/' data_type{i_type}],'dir') 31 | mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type}]); 32 | end 33 | if ~exist([output_dir8k '/' min_max{i_mm} '/' data_type{i_type}],'dir') 34 | mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type}]); 35 | end 36 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s1/']); %#ok 37 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s2/']); %#ok 38 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s3/']); %#ok 39 | status = mkdir([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/mix/']); %#ok 40 | status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s1/']); %#ok 41 | status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s2/']); %#ok 42 | status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s3/']); %#ok 43 | status = mkdir([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/mix/']); 44 | 45 | TaskFile = ['mix_3_spk_' data_type{i_type} '.txt']; 46 | fid=fopen(TaskFile,'r'); 47 | C=textscan(fid,'%s %f %s %f %s %f'); 48 | 49 | Source1File = ['mix_3_spk_' min_max{i_mm} '_' data_type{i_type} '_1']; 50 | Source2File = ['mix_3_spk_' min_max{i_mm} '_' data_type{i_type} '_2']; 51 | Source3File = ['mix_3_spk_' min_max{i_mm} '_' data_type{i_type} '_3']; 52 | MixFile = ['mix_3_spk_' min_max{i_mm} '_' data_type{i_type} '_mix']; 53 | fid_s1 = fopen(Source1File,'w'); 54 | fid_s2 = fopen(Source2File,'w'); 55 | fid_s3 = fopen(Source3File,'w'); 56 | fid_m = fopen(MixFile,'w'); 57 | 58 | num_files = length(C{1}); 59 | fs8k=8000; 60 | 61 | scaling_16k = zeros(num_files,3); 62 | scaling_8k = zeros(num_files,3); 63 | scaling16bit_16k = zeros(num_files,1); 64 | scaling16bit_8k = zeros(num_files,1); 65 | fprintf(1,'%s\n',[min_max{i_mm} '_' data_type{i_type}]); 66 | for i = 1:num_files 67 | [inwav1_dir,invwav1_name,inwav1_ext] = fileparts(C{1}{i}); 68 | [inwav2_dir,invwav2_name,inwav2_ext] = fileparts(C{3}{i}); 69 | [inwav3_dir,invwav3_name,inwav3_ext] = fileparts(C{5}{i}); 70 | fprintf(fid_s1,'%s\n',C{1}{i});%[inwav1_dir,'/',invwav1_name,inwav1_ext]); 71 | fprintf(fid_s2,'%s\n',C{3}{i});%[inwav2_dir,'/',invwav2_name,inwav2_ext]); 72 | fprintf(fid_s3,'%s\n',C{5}{i});%[inwav3_dir,'/',invwav3_name,inwav3_ext]); 73 | inwav1_snr = C{2}(i); 74 | inwav2_snr = C{4}(i); 75 | inwav3_snr = C{6}(i); 76 | mix_name = [invwav1_name,'_',num2str(inwav1_snr),... 77 | '_',invwav2_name,'_',num2str(inwav2_snr),... 78 | '_',invwav3_name,'_',num2str(inwav3_snr)]; 79 | fprintf(fid_m,'%s\n',mix_name); 80 | 81 | % get input wavs 82 | [s1, fs] = wavread([wsj0root C{1}{i}]); 83 | s2 = wavread([wsj0root C{3}{i}]); 84 | s3 = wavread([wsj0root C{5}{i}]); 85 | 86 | % resample, normalize 8 kHz file, save scaling factor 87 | s1_8k=resample(s1,fs8k,fs); 88 | [s1_8k,lev1]=activlev(s1_8k,fs8k,'n'); % y_norm = y /sqrt(lev); 89 | s2_8k=resample(s2,fs8k,fs); 90 | [s2_8k,lev2]=activlev(s2_8k,fs8k,'n'); 91 | s3_8k=resample(s3,fs8k,fs); 92 | [s3_8k,lev3]=activlev(s3_8k,fs8k,'n'); 93 | 94 | weight_1=10^(inwav1_snr/20); 95 | weight_2=10^(inwav2_snr/20); 96 | weight_3=10^(inwav3_snr/20); 97 | 98 | s1_8k = weight_1 * s1_8k; 99 | s2_8k = weight_2 * s2_8k; 100 | s3_8k = weight_3 * s3_8k; 101 | 102 | switch min_max{i_mm} 103 | case 'max' 104 | mix_8k_length = max([length(s1_8k),length(s2_8k),length(s3_8k)]); 105 | s1_8k = cat(1,s1_8k,zeros(mix_8k_length - length(s1_8k),1)); 106 | s2_8k = cat(1,s2_8k,zeros(mix_8k_length - length(s2_8k),1)); 107 | s3_8k = cat(1,s3_8k,zeros(mix_8k_length - length(s3_8k),1)); 108 | case 'min' 109 | mix_8k_length = min([length(s1_8k),length(s2_8k),length(s3_8k)]); 110 | s1_8k = s1_8k(1:mix_8k_length); 111 | s2_8k = s2_8k(1:mix_8k_length); 112 | s3_8k = s3_8k(1:mix_8k_length); 113 | end 114 | mix_8k = s1_8k + s2_8k + s3_8k; 115 | 116 | max_amp_8k = max(cat(1,abs(mix_8k(:)),abs(s1_8k(:)),abs(s2_8k(:)),abs(s3_8k(:)))); 117 | mix_scaling_8k = 1/max_amp_8k*0.9; 118 | s1_8k = mix_scaling_8k * s1_8k; 119 | s2_8k = mix_scaling_8k * s2_8k; 120 | s3_8k = mix_scaling_8k * s3_8k; 121 | mix_8k = mix_scaling_8k * mix_8k; 122 | 123 | % apply same gain to 16 kHz file 124 | s1_16k = weight_1 * s1 / sqrt(lev1); 125 | s2_16k = weight_2 * s2 / sqrt(lev2); 126 | s3_16k = weight_3 * s3 / sqrt(lev3); 127 | 128 | switch min_max{i_mm} 129 | case 'max' 130 | mix_16k_length = max([length(s1_16k),length(s2_16k),length(s3_16k)]); 131 | s1_16k = cat(1,s1_16k,zeros(mix_16k_length - length(s1_16k),1)); 132 | s2_16k = cat(1,s2_16k,zeros(mix_16k_length - length(s2_16k),1)); 133 | s3_16k = cat(1,s3_16k,zeros(mix_16k_length - length(s3_16k),1)); 134 | case 'min' 135 | mix_16k_length = min([length(s1_16k),length(s2_16k),length(s3_16k)]); 136 | s1_16k = s1_16k(1:mix_16k_length); 137 | s2_16k = s2_16k(1:mix_16k_length); 138 | s3_16k = s3_16k(1:mix_16k_length); 139 | end 140 | mix_16k = s1_16k + s2_16k + s3_16k; 141 | 142 | max_amp_16k = max(cat(1,abs(mix_16k(:)),abs(s1_16k(:)),abs(s2_16k(:)),abs(s3_16k(:)))); 143 | mix_scaling_16k = 1/max_amp_16k*0.9; 144 | s1_16k = mix_scaling_16k * s1_16k; 145 | s2_16k = mix_scaling_16k * s2_16k; 146 | s3_16k = mix_scaling_16k * s3_16k; 147 | mix_16k = mix_scaling_16k * mix_16k; 148 | 149 | % save 8 kHz and 16 kHz mixtures, as well as 150 | % necessary scaling factors 151 | 152 | scaling_16k(i,1) = weight_1 * mix_scaling_16k/ sqrt(lev1); 153 | scaling_16k(i,2) = weight_2 * mix_scaling_16k/ sqrt(lev2); 154 | scaling_16k(i,3) = weight_3 * mix_scaling_16k/ sqrt(lev3); 155 | scaling_8k(i,1) = weight_1 * mix_scaling_8k/ sqrt(lev1); 156 | scaling_8k(i,2) = weight_2 * mix_scaling_8k/ sqrt(lev2); 157 | scaling_8k(i,3) = weight_3 * mix_scaling_8k/ sqrt(lev3); 158 | 159 | scaling16bit_16k(i) = mix_scaling_16k; 160 | scaling16bit_8k(i) = mix_scaling_8k; 161 | 162 | wavwrite(s1_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s1/' mix_name '.wav']); 163 | wavwrite(s1_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s1/' mix_name '.wav']); 164 | wavwrite(s2_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s2/' mix_name '.wav']); 165 | wavwrite(s2_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s2/' mix_name '.wav']); 166 | wavwrite(s3_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/s3/' mix_name '.wav']); 167 | wavwrite(s3_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/s3/' mix_name '.wav']); 168 | wavwrite(mix_8k,fs8k,[output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/mix/' mix_name '.wav']); 169 | wavwrite(mix_16k,fs,[output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/mix/' mix_name '.wav']); 170 | 171 | if mod(i,10)==0 172 | fprintf(1,'.'); 173 | if mod(i,200)==0 174 | fprintf(1,'\n'); 175 | end 176 | end 177 | 178 | end 179 | save([output_dir8k '/' min_max{i_mm} '/' data_type{i_type} '/scaling.mat'],'scaling_8k','scaling16bit_8k'); 180 | save([output_dir16k '/' min_max{i_mm} '/' data_type{i_type} '/scaling.mat'],'scaling_16k','scaling16bit_16k'); 181 | 182 | fclose(fid); 183 | fclose(fid_s1); 184 | fclose(fid_s2); 185 | fclose(fid_s3); 186 | fclose(fid_m); 187 | end 188 | end 189 | -------------------------------------------------------------------------------- /model/blstm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 Sining Sun (Northwestern Polytechnical University, China) 5 | # Jiaqiang Liu (Northwestern Polytechnical University, China) 6 | 7 | """ 8 | Build the LSTM(BLSTM) neural networks for PIT speech separation. 9 | 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import sys 17 | import time 18 | 19 | import tensorflow as tf 20 | from tensorflow.contrib.rnn.python.ops import rnn 21 | import numpy as np 22 | 23 | class LSTM(object): 24 | """Build BLSTM or LSTM model with PIT loss functions. 25 | If you use this module to train your module, make sure that 26 | your prepare the right format data! 27 | 28 | Attributes: 29 | config: Used to config our model 30 | config.input_size: feature (input) size; 31 | config.output_size: the final layer(output layer) size; 32 | config.rnn_size: the rnn cells' number 33 | config.batch_size: the batch_size for training 34 | config.rnn_num_layers: the rnn layers numbers 35 | config.keep_prob: the dropout rate 36 | inputs: the mixed speech feature without cmvn 37 | inputs_cmvn: the mixed speech feature with cmvn as the inputs of model(LSTM or BLSTM) 38 | labels1: the spk1's feature, as targets to train the model 39 | labels2: the spk2's feature, as targets to train the model 40 | infer: bool, if training(false) or test (true) 41 | """ 42 | 43 | def __init__(self, config, inputs, labels, lengths, genders, infer=False): 44 | self._inputs = inputs 45 | self._mixed = inputs 46 | self._labels1 = tf.slice(labels, [0,0,0], [-1,-1, config.output_size]) 47 | self._labels2 = tf.slice(labels, [0,0,config.output_size], [-1,-1, -1]) 48 | self._lengths = lengths 49 | self._genders = genders 50 | self._model_type = config.model_type 51 | 52 | outputs = self._inputs 53 | ## This first layer-- feed forward layer 54 | ## Transform the input to the right size before feed into RNN 55 | 56 | with tf.variable_scope('forward1'): 57 | outputs = tf.reshape(outputs, [-1, config.input_size]) 58 | outputs = tf.layers.dense(outputs, units=config.rnn_size, 59 | activation=tf.nn.tanh, 60 | reuse=tf.get_variable_scope().reuse) 61 | outputs = tf.reshape( 62 | outputs, [config.batch_size,-1, config.rnn_size]) 63 | 64 | ## Configure the LSTM or BLSTM model 65 | ## For BLSTM, we use the BasicLSTMCell.For LSTM, we use LSTMCell. 66 | ## You can change them and test the performance... 67 | def lstm_cell(): 68 | return tf.contrib.rnn.LSTMCell( 69 | config.rnn_size, forget_bias=1.0, use_peepholes=True, 70 | initializer=tf.contrib.layers.xavier_initializer(), 71 | state_is_tuple=True, activation=tf.tanh) 72 | attn_cell = lstm_cell 73 | if not infer and config.keep_prob < 1.0: 74 | def attn_cell(): 75 | return tf.contrib.rnn.DropoutWrapper(lstm_cell(), output_keep_prob=config.keep_prob) 76 | 77 | if config.model_type.lower() == 'blstm': 78 | with tf.variable_scope('blstm'): 79 | 80 | lstm_fw_cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(config.rnn_num_layers)],state_is_tuple=True) 81 | lstm_bw_cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(config.rnn_num_layers)],state_is_tuple=True) 82 | 83 | lstm_fw_cell = _unpack_cell(lstm_fw_cell) 84 | lstm_bw_cell = _unpack_cell(lstm_bw_cell) 85 | result = rnn.stack_bidirectional_dynamic_rnn( 86 | cells_fw = lstm_fw_cell, 87 | cells_bw = lstm_bw_cell, 88 | inputs=outputs, 89 | dtype=tf.float32, 90 | sequence_length=self._lengths) 91 | outputs, fw_final_states, bw_final_states = result 92 | if config.model_type.lower() == 'lstm': 93 | with tf.variable_scope('lstm'): 94 | cell = tf.contrib.rnn.MultiRNNCell( 95 | [attn_cell() for _ in range(config.rnn_num_layers)], 96 | state_is_tuple=True) 97 | self._initial_state = cell.zero_state(config.batch_size, tf.float32) 98 | state = self.initial_state 99 | outputs, state = tf.nn.dynamic_rnn( 100 | cell, outputs, 101 | dtype=tf.float32, 102 | sequence_length=self._lengths, 103 | initial_state=self.initial_state) 104 | self._final_state = state 105 | 106 | ## Feed forward layer. Transform the RNN output to the right output size 107 | 108 | with tf.variable_scope('forward2'): 109 | if self._model_type.lower() == 'blstm': 110 | outputs = tf.reshape(outputs, [-1, 2*config.rnn_size]) 111 | in_size=2*config.rnn_size 112 | else: 113 | outputs = tf.reshape(outputs, [-1, config.rnn_size]) 114 | in_size = config.rnn_size 115 | #w1,b1 =self. _weight_and_bias("L_1",in_size,256) 116 | #outputs1 = tf.nn.relu(tf.matmul(outputs,w1)+b1) 117 | #w2,b2 = self._weight_and_bias("L_2",256,256) 118 | #outputs2 = tf.nn.relu(tf.matmul(outputs1,w2)+b2+outputs1) 119 | out_size = config.output_size 120 | #in_size=256 121 | weights1 = tf.get_variable('weights1', [in_size, out_size], 122 | initializer=tf.random_normal_initializer(stddev=0.01)) 123 | biases1 = tf.get_variable('biases1', [out_size], 124 | initializer=tf.constant_initializer(0.0)) 125 | weights2 = tf.get_variable('weights2', [in_size, out_size], 126 | initializer=tf.random_normal_initializer(stddev=0.01)) 127 | biases2 = tf.get_variable('biases2', [out_size], 128 | initializer=tf.constant_initializer(0.0)) 129 | mask1 = tf.nn.relu(tf.matmul(outputs, weights1) + biases1) 130 | mask2 = tf.nn.relu(tf.matmul(outputs, weights2) + biases2) 131 | self._activations1 = tf.reshape( 132 | mask1, [config.batch_size, -1, config.output_size]) 133 | self._activations2 = tf.reshape( 134 | mask2, [config.batch_size, -1, config.output_size]) 135 | # in general, config.czt_dim == 0; However, we found that if we concatenate 136 | # 128 dim chrip-z transform feats to FFT feats, we got better SDR performance 137 | # for the same gender case. 138 | 139 | # so , if you don't use czt feats (just the fft feats), config.czt_dim=0 140 | self._cleaned1 = self._activations1*self._mixed 141 | self._cleaned2 = self._activations2*self._mixed 142 | # Ability to save the model 143 | self.saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=30) 144 | 145 | if infer: return 146 | 147 | 148 | # Compute loss(Mse) 149 | cost1 = tf.reduce_mean( tf.reduce_sum(tf.pow(self._cleaned1-self._labels1,2),1) 150 | +tf.reduce_sum(tf.pow(self._cleaned2-self._labels2,2),1) 151 | ,1) 152 | cost2 = tf.reduce_mean( tf.reduce_sum(tf.pow(self._cleaned2-self._labels1,2),1) 153 | +tf.reduce_sum(tf.pow(self._cleaned1-self._labels2,2),1) 154 | ,1) 155 | 156 | idx = tf.cast(cost1>cost2,tf.float32) 157 | self._loss = tf.reduce_sum(idx*cost2+(1-idx)*cost1) 158 | if tf.get_variable_scope().reuse: return 159 | 160 | self._lr = tf.Variable(0.0, trainable=False) 161 | tvars = tf.trainable_variables() 162 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), 163 | config.max_grad_norm) 164 | optimizer = tf.train.AdamOptimizer(self.lr) 165 | #optimizer = tf.train.GradientDescentOptimizer(self.lr) 166 | self._train_op = optimizer.apply_gradients(zip(grads, tvars)) 167 | 168 | self._new_lr = tf.placeholder( 169 | tf.float32, shape=[], name='new_learning_rate') 170 | self._lr_update = tf.assign(self._lr, self._new_lr) 171 | 172 | def assign_lr(self, session, lr_value): 173 | session.run(self._lr_update, feed_dict={self._new_lr: lr_value}) 174 | def get_opt_output(self): 175 | ''' 176 | This function is just for the PIT testing with optimal assignment 177 | ''' 178 | 179 | cost1 = tf.reduce_sum(tf.pow(self._cleaned1-self._labels1,2),2)+tf.reduce_sum(tf.pow(self._cleaned2-self._labels2,2),2) 180 | cost2 = tf.reduce_sum(tf.pow(self._cleaned2-self._labels1,2),2)+tf.reduce_sum(tf.pow(self._cleaned1-self._labels2,2),2) 181 | idx = tf.slice(cost1, [0, 0], [1, -1]) > tf.slice(cost2, [0, 0], [1, -1]) 182 | idx = tf.cast(idx, tf.float32) 183 | idx = tf.reduce_mean(idx,reduction_indices=0) 184 | idx = tf.reshape(idx, [tf.shape(idx)[0], 1]) 185 | x1 = self._cleaned1[0,:,:] * (1-idx) + self._cleaned2[0,:, :]*idx 186 | x2 = self._cleaned1[0,:,:]*idx + self._cleaned2[0,:,:]*(1-idx) 187 | row = tf.shape(x1)[0] 188 | col = tf.shape(x1)[1] 189 | x1 = tf.reshape(x1, [1, row, col]) 190 | x2 = tf.reshape(x2, [1, row, col]) 191 | return x1, x2 192 | 193 | @property 194 | def inputs(self): 195 | return self._inputs 196 | 197 | @property 198 | def labels(self): 199 | return self._labels1,self._labels2 200 | 201 | @property 202 | def initial_state(self): 203 | return self._initial_state 204 | 205 | @property 206 | def final_state(self): 207 | return self._final_state 208 | 209 | @property 210 | def lr(self): 211 | return self._lr 212 | 213 | @property 214 | def activations(self): 215 | return self._activations 216 | 217 | @property 218 | def loss(self): 219 | return self._loss 220 | 221 | @property 222 | def train_op(self): 223 | return self._train_op 224 | 225 | @staticmethod 226 | def _weight_and_bias(name,in_size, out_size): 227 | # Create variable named "weights". 228 | weights = tf.get_variable(name+"_w", [in_size, out_size], 229 | initializer=tf.random_normal_initializer(stddev=0.01)) 230 | # Create variabel named "biases". 231 | biases = tf.get_variable(name+"_b", [out_size], 232 | initializer=tf.constant_initializer(0.0)) 233 | return weights, biases 234 | def _unpack_cell(cell): 235 | if isinstance(cell,tf.contrib.rnn.MultiRNNCell): 236 | return cell._cells 237 | else: 238 | return [cell] 239 | -------------------------------------------------------------------------------- /io_funcs/kaldi_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2017 5 | # Ke Wang 6 | # Sining Sun 7 | # Yuchao Zhang 8 | 9 | """IO classes for reading and writing kaldi .ark 10 | 11 | This module provides io interfaces for reading and writing kaldi .ark files. 12 | Currently, this module only supports binary-formatted .ark files. Text .ark 13 | files are not supported. 14 | 15 | To use this module, you need to provide kaldi .scp files only. The .ark 16 | locations with corresponding offsets can be retrieved from .scp files. 17 | """ 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import sys 23 | import struct 24 | import random 25 | import numpy as np 26 | 27 | class GlobalHeader(object): 28 | """ Compress ark format header. """ 29 | def __init__(self, format, header): 30 | self.format = format 31 | self.min_value = header[0] 32 | self.range = header[1] 33 | self.num_rows = header[2] 34 | self.num_cols = header[3] 35 | 36 | class PerColHeader(object): 37 | """ Compress ark format per column header. """ 38 | def __init__(self, header): 39 | self.percentile_0 = header[0] 40 | self.percentile_25 = header[1] 41 | self.percentile_75 = header[2] 42 | self.percentile_100 = header[3] 43 | 44 | class ArkReader(object): 45 | """ Class to read Kaldi ark format. 46 | 47 | Each time, it reads one line of the .scp file and reads in the 48 | corresponding features into a numpy matrix. It only supports 49 | binary-formatted .ark files. Text files are not supported. 50 | 51 | Attributes: 52 | utt_ids: A list saving utterance identities. 53 | scp_data: A list saving .ark path and offset for items in utt_ids. 54 | scp_position: An integer indicating which utt_id and correspoding 55 | scp_data will be read next. 56 | """ 57 | 58 | def __init__(self, scp_path): 59 | """Init utt_ids along with scp_data according to .scp file.""" 60 | self.scp_position = 0 61 | fin = open(scp_path,"r") 62 | self.utt_ids = [] 63 | self.scp_data = [] 64 | line = fin.readline() 65 | while line != '' and line != None: 66 | utt_id, path_pos = line.replace('\n','').split(' ') 67 | path, pos = path_pos.split(':') 68 | self.utt_ids.append(utt_id) 69 | self.scp_data.append((path, pos)) 70 | line = fin.readline() 71 | 72 | fin.close() 73 | 74 | def shuffle(self): 75 | """Shuffle utt_ids along with scp_data and reset scp_position.""" 76 | zipped = zip(self.utt_ids, self.scp_data) 77 | random.shuffle(zipped) 78 | self.utt_ids, self.scp_data = zip(*zipped) # unzip and assign 79 | self.scp_position = 0 80 | 81 | @staticmethod 82 | def read_ark(self, ark_file, ark_offset=0): 83 | """Read data from the archive (.ark from kaldi). 84 | 85 | Returns: 86 | A numpy matrix containing data of ark_file. 87 | """ 88 | ark_read_buffer = open(ark_file, 'rb') 89 | ark_read_buffer.seek(int(ark_offset), 0) 90 | header = struct.unpack('= len(self.scp_data): #if at end of file loop around 176 | looped = True 177 | self.scp_position = 0 178 | else: 179 | looped = False 180 | 181 | self.scp_position += 1 182 | 183 | utt_ids = self.utt_ids[self.scp_position-1] 184 | utt_data = self.read_utt_data_from_index(self.scp_position-1) 185 | 186 | return utt_ids, utt_data, looped 187 | 188 | def read_next_scp(self): 189 | """Read the next utterance ID but don't read the data. 190 | 191 | Returns: 192 | The utterance ID of the utterance that was read. 193 | """ 194 | if self.scp_position >= len(self.scp_data): #if at end of file loop around 195 | self.scp_position = 0 196 | 197 | self.scp_position += 1 198 | 199 | return self.utt_ids[self.scp_position-1] 200 | 201 | def read_previous_scp(self): 202 | """Read the previous utterance ID but don't read the data. 203 | 204 | Returns: 205 | The utterance ID of the utterance that was read. 206 | """ 207 | if self.scp_position < 0: #if at beginning of file loop around 208 | self.scp_position = len(self.scp_data) - 1 209 | 210 | self.scp_position -= 1 211 | 212 | return self.utt_ids[self.scp_position+1] 213 | 214 | def read_utt_data_from_id(self, utt_id): 215 | """Read the data of a certain utterance ID. 216 | 217 | Args: 218 | utt_id: A string indicating a certain utterance ID. 219 | 220 | Returns: 221 | A numpy array containing the utterance data corresponding to the ID. 222 | """ 223 | utt_mat = self.read_utt_data_from_index(self.utt_ids.index(utt_id)) 224 | 225 | return utt_mat 226 | 227 | def read_utt_data_from_index(self, index): 228 | """Read the data of a certain index. 229 | 230 | Args: 231 | index: A integer index corresponding to a certain utterance ID. 232 | 233 | Returns: 234 | A numpy array containing the utterance data corresponding to the 235 | index. 236 | """ 237 | return self.read_ark(self, self.scp_data[index][0], self.scp_data[index][1]) 238 | 239 | def split(self): 240 | """Split of the data that was read so far.""" 241 | self.scp_data = self.scp_data[self.scp_position:-1] 242 | self.utt_ids = self.utt_ids[self.scp_position:-1] 243 | 244 | 245 | class ArkWriter(object): 246 | """Class to write numpy matrices into Kaldi .ark file and create the 247 | corresponding .scp file. It only supports binary-formatted .ark files. 248 | Text and compressed .ark files are not supported. 249 | 250 | Attributes: 251 | scp_path: The path to the .scp file that will be written. 252 | scp_file_write: The file object corresponds to scp_path. 253 | 254 | """ 255 | 256 | def __init__(self, scp_path): 257 | """Arkwriter constructor.""" 258 | self.scp_path = scp_path 259 | self.scp_file_write = open(self.scp_path, "w") 260 | 261 | def write_next_utt(self, ark_path, utt_id, utt_mat): 262 | """Read an utterance to the archive. 263 | 264 | Args: 265 | ark_path: Path to the .ark file that will be used for writing. 266 | utt_id: The utterance ID. 267 | utt_mat: A numpy array containing the utterance data. 268 | """ 269 | ark_file_write = open(ark_path,"ab") 270 | utt_mat = np.asarray(utt_mat, dtype=np.float32) 271 | rows, cols = utt_mat.shape 272 | ark_file_write.write(struct.pack('<%ds'%(len(utt_id)), utt_id)) 273 | pos = ark_file_write.tell() 274 | ark_file_write.write(struct.pack('