├── .gitignore ├── README.md ├── compute_cmvn.py ├── conf └── 1.config.yaml ├── dataset.py ├── dcnet.py ├── requirements.txt ├── scripts ├── run_demo.sh ├── sdr_eval.sh ├── sdr_eval_2spk.m ├── spk2gender └── train.sh ├── separate.py ├── train_dcnet.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | *.mat 3 | egs.py 4 | __pycache__/ 5 | data/ 6 | .vscode/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep clustering for single-channel speech separation 2 | 3 | Implement of "Deep Clustering Discriminative Embeddings for Segmentation and Separation" 4 | 5 | ### Requirements 6 | 7 | see [requirements.txt](requirements.txt) 8 | 9 | ### Usage 10 | 11 | 1. Configure experiments in .yaml files, for example: `train.yaml` 12 | 13 | 2. Training: 14 | 15 | ```shell 16 | python ./train_dcnet.py --config conf/train.yaml --num-epoches 20 > train.log 2>&1 & 17 | ``` 18 | 19 | 3. Inference: 20 | ``` 21 | python ./separate.py --num-spks 2 $mdl_dir/train.yaml $mdl_dir/final.pkl egs.scp 22 | ``` 23 | 24 | ### Experiments 25 | 26 | | Configure | Epoch | FM | FF | MM | FF/MM | AVG | 27 | | :-------: | :---: | :---: | :--: | :--: | :---: | :--: | 28 | | [config-1](conf/1.config.yaml) | 25 | 11.42 | 6.85 | 7.88 | 7.36 | 9.54 | 29 | 30 | ### Q & A 31 | 32 | 1. The format of the `.scp` file? 33 | 34 | The format of the `wav.scp` file follows the definition in kaldi toolkit. Each line contains a `key value` pair, where key is a unique string to index audio file and the value is the path of the file. For example 35 | ``` 36 | mix-utt-00001 /home/data/train/mix-utt-00001.wav 37 | ... 38 | mix-utt-XXXXX /home/data/train/mix-utt-XXXXX.wav 39 | ``` 40 | 41 | 2. How to prepare training dataset? 42 | 43 | Original paper use MATLAB scripts from [create-speaker-mixtures.zip](http://www.merl.com/demos/deep-clustering/create-speaker-mixtures.zip) to simulate two- and three-speaker dataset. You can use you own data source (egs: Librispeech, TIMIT) and create mixtures, keeping clean sources at meanwhile. 44 | 45 | 46 | ### Reference 47 | 48 | 1. Hershey J R, Chen Z, Le Roux J, et al. Deep clustering: Discriminative embeddings for segmentation and separation[C]//Acoustics, Speech and Signal Processing (ICASSP), 2016 IEEE International Conference on. IEEE, 2016: 31-35. 49 | 2. Isik Y, Roux J L, Chen Z, et al. Single-channel multi-speaker separation using deep clustering[J]. arXiv preprint arXiv:1607.02173, 2016. 50 | -------------------------------------------------------------------------------- /compute_cmvn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | # wujian@2018 5 | 6 | import argparse 7 | import pickle 8 | import tqdm 9 | import numpy as np 10 | 11 | from dataset import SpectrogramReader 12 | from utils import parse_yaml 13 | 14 | def run(args): 15 | num_bins, conf_dict = parse_yaml(args.train_conf) 16 | reader = SpectrogramReader(args.wave_scp, **conf_dict["spectrogram_reader"]) 17 | mean = np.zeros(num_bins) 18 | std = np.zeros(num_bins) 19 | num_frames = 0 20 | # D(X) = E(X^2) - E(X)^2 21 | for _, spectrogram in tqdm.tqdm(reader): 22 | num_frames += spectrogram.shape[0] 23 | mean += np.sum(spectrogram, 0) 24 | std += np.sum(spectrogram**2, 0) 25 | mean = mean / num_frames 26 | std = np.sqrt(std / num_frames - mean**2) 27 | with open(args.cmvn_dst, "wb") as f: 28 | cmvn_dict = {"mean": mean, "std": std} 29 | pickle.dump(cmvn_dict, f) 30 | print("Totally processed {} frames".format(num_frames)) 31 | print("Global mean: {}".format(mean)) 32 | print("Global std: {}".format(std)) 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser( 37 | description="Command to compute global cmvn stats") 38 | parser.add_argument( 39 | "wave_scp", type=str, help="Location of mixture wave scripts") 40 | parser.add_argument( 41 | "train_conf", type=str, help="Location of training configure files") 42 | parser.add_argument( 43 | "cmvn_dst", type=str, help="Location to dump cmvn stats") 44 | args = parser.parse_args() 45 | run(args) -------------------------------------------------------------------------------- /conf/1.config.yaml: -------------------------------------------------------------------------------- 1 | # config for training 2 | 3 | trainer: 4 | checkpoint: "./tune/2spk_dcnet_d" 5 | optimizer: "rmsprop" 6 | lr: 1e-5 7 | momentum: 0.9 8 | weight_decay: 0 9 | clip_norm: 200 10 | num_spks: 2 11 | 12 | dcnet: 13 | rnn: "lstm" 14 | embedding_dim: 20 15 | num_layers: 2 16 | hidden_size: 600 17 | dropout: 0.5 18 | non_linear: "tanh" 19 | bidirectional: true 20 | 21 | spectrogram_reader: 22 | frame_shift: 64 23 | frame_length: 256 24 | window: "hann" 25 | transpose: true 26 | apply_log: true 27 | apply_abs: true 28 | 29 | train_scp_conf: 30 | mixture: "./data/2spk/train/mix.scp" 31 | spk1: "./data/2spk/train/spk1.scp" 32 | spk2: "./data/2spk/train/spk2.scp" 33 | 34 | valid_scp_conf: 35 | mixture: "./data/2spk/dev/mix.scp" 36 | spk1: "./data/2spk/dev/spk1.scp" 37 | spk2: "./data/2spk/dev/spk2.scp" 38 | 39 | debug_scp_conf: 40 | mixture: "./data/tune/mix.scp" 41 | spk1: "./data/tune/spk1.scp" 42 | spk2: "./data/tune/spk2.scp" 43 | 44 | dataloader: 45 | shuffle: true 46 | batch_size: 1 47 | drop_last: false 48 | vad_threshold: 40 49 | mvn_dict: "data/cmvn.dict" 50 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import os 6 | import random 7 | import logging 8 | import pickle 9 | 10 | import numpy as np 11 | import torch as th 12 | 13 | from torch.nn.utils.rnn import pack_sequence, pad_sequence 14 | 15 | from utils import parse_scps, stft, compute_vad_mask, apply_cmvn 16 | 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.INFO) 19 | handler = logging.StreamHandler() 20 | handler.setLevel(logging.INFO) 21 | formatter = logging.Formatter( 22 | "%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s") 23 | handler.setFormatter(formatter) 24 | logger.addHandler(handler) 25 | 26 | 27 | class SpectrogramReader(object): 28 | """ 29 | Wrapper for short-time fourier transform of dataset 30 | """ 31 | 32 | def __init__(self, wave_scp, **kwargs): 33 | if not os.path.exists(wave_scp): 34 | raise FileNotFoundError("Could not find file {}".format(wave_scp)) 35 | self.stft_kwargs = kwargs 36 | self.wave_dict = parse_scps(wave_scp) 37 | self.wave_keys = [key for key in self.wave_dict.keys()] 38 | logger.info( 39 | "Create SpectrogramReader for {} with {} utterances".format( 40 | wave_scp, len(self.wave_dict))) 41 | 42 | def __len__(self): 43 | return len(self.wave_dict) 44 | 45 | def __contains__(self, key): 46 | return key in self.wave_dict 47 | 48 | # stft 49 | def _load(self, key): 50 | return stft(self.wave_dict[key], **self.stft_kwargs) 51 | 52 | # sequential index 53 | def __iter__(self): 54 | for key in self.wave_dict: 55 | yield key, self._load(key) 56 | 57 | # random index 58 | def __getitem__(self, key): 59 | if key not in self.wave_dict: 60 | raise KeyError("Could not find utterance {}".format(key)) 61 | return self._load(key) 62 | 63 | 64 | class Dataset(object): 65 | def __init__(self, mixture_reader, targets_reader_list): 66 | self.mixture_reader = mixture_reader 67 | self.keys_list = mixture_reader.wave_keys 68 | self.targets_reader_list = targets_reader_list 69 | 70 | def __len__(self): 71 | return len(self.keys_list) 72 | 73 | def _has_target(self, key): 74 | for targets_reader in self.targets_reader_list: 75 | if key not in targets_reader: 76 | return False 77 | return True 78 | 79 | def _index_by_key(self, key): 80 | """ 81 | Return a tuple like (matrix, [matrix, ...]) 82 | """ 83 | if key not in self.mixture_reader or not self._has_target(key): 84 | raise KeyError("Missing targets or mixture") 85 | target_list = [reader[key] for reader in self.targets_reader_list] 86 | return (self.mixture_reader[key], target_list) 87 | 88 | def _index_by_num(self, num): 89 | """ 90 | Return a tuple like (matrix, [matrix, ...]) 91 | """ 92 | if num >= len(self.keys_list): 93 | raise IndexError("Index out of dataset, {} vs {}".format( 94 | num, len(self.keys_list))) 95 | key = self.keys_list[num] 96 | return self._index_by_key(key) 97 | 98 | def _index_by_list(self, list_idx): 99 | """ 100 | Returns a list of tuple like [ 101 | (matrix, [matrix, ...]), 102 | (matrix, [matrix, ...]), 103 | ... 104 | ] 105 | """ 106 | if max(list_idx) >= len(self.keys_list): 107 | raise IndexError("Index list contains index out of dataset") 108 | return [self._index_by_num(index) for index in list_idx] 109 | 110 | def __getitem__(self, index): 111 | """ 112 | Implement to support multi-type index: by key, number or list 113 | """ 114 | if type(index) == int: 115 | return self._index_by_num(index) 116 | elif type(index) == str: 117 | return self._index_by_key(index) 118 | elif type(index) == list: 119 | return self._index_by_list(index) 120 | else: 121 | raise KeyError("Unsupported index type(int/str/list)") 122 | 123 | 124 | class BatchSampler(object): 125 | def __init__(self, 126 | sampler_size, 127 | batch_size=16, 128 | shuffle=True, 129 | drop_last=False): 130 | if batch_size <= 0: 131 | raise ValueError( 132 | "Illegal batch_size(= {}) detected".format(batch_size)) 133 | self.batch_size = batch_size 134 | self.drop_last = drop_last 135 | self.sampler_index = list(range(sampler_size)) 136 | self.sampler_size = sampler_size 137 | if shuffle: 138 | random.shuffle(self.sampler_index) 139 | 140 | def __len__(self): 141 | return self.sampler_size 142 | 143 | def __iter__(self): 144 | base = 0 145 | step = self.batch_size 146 | while True: 147 | if base + step > self.sampler_size: 148 | break 149 | yield (self.sampler_index[base:base + step] 150 | if step != 1 else self.sampler_index[base]) 151 | base += step 152 | if not self.drop_last and base < self.sampler_size: 153 | yield self.sampler_index[base:] 154 | 155 | 156 | class DataLoader(object): 157 | """ 158 | Multi/Per utterance loader for DCNet training 159 | """ 160 | 161 | def __init__(self, 162 | dataset, 163 | shuffle=True, 164 | batch_size=16, 165 | drop_last=False, 166 | vad_threshold=40, 167 | mvn_dict=None): 168 | self.dataset = dataset 169 | self.vad_threshold = vad_threshold 170 | self.mvn_dict = mvn_dict 171 | self.batch_size = batch_size 172 | self.drop_last = drop_last 173 | self.shuffle = shuffle 174 | if mvn_dict: 175 | logger.info("Using cmvn dictionary from {}".format(mvn_dict)) 176 | with open(mvn_dict, "rb") as f: 177 | self.mvn_dict = pickle.load(f) 178 | 179 | def __len__(self): 180 | remain = len(self.dataset) % self.batch_size 181 | if self.drop_last or not remain: 182 | return len(self.dataset) // self.batch_size 183 | else: 184 | return len(self.dataset) // self.batch_size + 1 185 | 186 | def _transform(self, mixture_specs, targets_specs_list): 187 | """ 188 | Transform from numpy/list to torch types 189 | """ 190 | # compute vad mask before cmvn 191 | vad_mask = compute_vad_mask( 192 | mixture_specs, self.vad_threshold, apply_exp=True) 193 | # apply cmvn 194 | if self.mvn_dict: 195 | mixture_specs = apply_cmvn(mixture_specs, self.mvn_dict) 196 | # compute target embedding index 197 | target_attr = np.argmax(np.array(targets_specs_list), 0) 198 | return { 199 | "num_frames": mixture_specs.shape[0], 200 | "spectrogram": th.tensor(mixture_specs, dtype=th.float32), 201 | "target_attr": th.tensor(target_attr, dtype=th.int64), 202 | "silent_mask": th.tensor(vad_mask, dtype=th.float32) 203 | } 204 | 205 | def _process(self, index): 206 | if type(index) is list: 207 | dict_list = sorted( 208 | [self._transform(s, t) for s, t in self.dataset[index]], 209 | key=lambda x: x["num_frames"], 210 | reverse=True) 211 | spectrogram = pack_sequence([d["spectrogram"] for d in dict_list]) 212 | target_attr = pad_sequence( 213 | [d["target_attr"] for d in dict_list], batch_first=True) 214 | silent_mask = pad_sequence( 215 | [d["silent_mask"] for d in dict_list], batch_first=True) 216 | return spectrogram, target_attr, silent_mask 217 | elif type(index) is int: 218 | s, t = self.dataset[index] 219 | data_dict = self._transform(s, t) 220 | return data_dict["spectrogram"], \ 221 | data_dict["target_attr"], \ 222 | data_dict["silent_mask"] 223 | else: 224 | raise ValueError("Unsupported index type({})".format(type(index))) 225 | 226 | def __iter__(self): 227 | sampler = BatchSampler( 228 | len(self.dataset), 229 | batch_size=self.batch_size, 230 | shuffle=self.shuffle, 231 | drop_last=self.drop_last) 232 | num_utts = 0 233 | log_period = 2000 // self.batch_size 234 | for e, index in enumerate(sampler): 235 | num_utts += (len(index) if type(index) is list else 1) 236 | if not (e + 1) % log_period: 237 | logger.info("Processed {} batches, {} utterances".format( 238 | e + 1, num_utts)) 239 | yield self._process(index) 240 | logger.info("Processed {} utterances in total".format(num_utts)) 241 | -------------------------------------------------------------------------------- /dcnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import torch as th 6 | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence 7 | 8 | 9 | def l2_loss(x): 10 | norm = th.norm(x, 2) 11 | return norm**2 12 | 13 | 14 | def l2_normalize(x, dim=0, eps=1e-12): 15 | assert (dim < x.dim()) 16 | norm = th.norm(x, 2, dim, keepdim=True) 17 | return x / (norm + eps) 18 | 19 | 20 | class DCNet(th.nn.Module): 21 | def __init__(self, 22 | num_bins, 23 | rnn="lstm", 24 | embedding_dim=20, 25 | num_layers=2, 26 | hidden_size=600, 27 | dropout=0.0, 28 | non_linear="tanh", 29 | bidirectional=True): 30 | super(DCNet, self).__init__() 31 | if non_linear not in ['tanh', 'sigmoid']: 32 | raise ValueError( 33 | "Unsupported non-linear type: {}".format(non_linear)) 34 | rnn = rnn.upper() 35 | if rnn not in ['RNN', 'LSTM', 'GRU']: 36 | raise ValueError("Unsupported rnn type: {}".format(rnn)) 37 | self.rnn = getattr(th.nn, rnn)( 38 | num_bins, 39 | hidden_size, 40 | num_layers, 41 | batch_first=True, 42 | dropout=dropout, 43 | bidirectional=bidirectional) 44 | self.drops = th.nn.Dropout(p=dropout) 45 | self.embed = th.nn.Linear( 46 | hidden_size * 2 47 | if bidirectional else hidden_size, num_bins * embedding_dim) 48 | self.non_linear = { 49 | "tanh": th.nn.functional.tanh, 50 | "sigmoid": th.nn.functional.sigmoid 51 | }[non_linear] 52 | self.embedding_dim = embedding_dim 53 | 54 | def forward(self, x, train=True): 55 | is_packed = isinstance(x, PackedSequence) 56 | if not is_packed and x.dim() != 3: 57 | x = th.unsqueeze(x, 0) 58 | x, _ = self.rnn(x) 59 | if is_packed: 60 | x, _ = pad_packed_sequence(x, batch_first=True) 61 | N = x.size(0) 62 | # N x T x H 63 | x = self.drops(x) 64 | # N x T x FD 65 | x = self.embed(x) 66 | x = self.non_linear(x) 67 | 68 | if train: 69 | # N x T x FD => N x TF x D 70 | x = x.view(N, -1, self.embedding_dim) 71 | else: 72 | # for inference 73 | # N x T x FD => NTF x D 74 | x = x.view(-1, self.embedding_dim) 75 | x = l2_normalize(x, -1) 76 | return x 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.13.3 2 | torch==0.4.0 3 | scipy==1.0.0 4 | librosa==0.5.1 5 | tqdm==4.19.4 6 | config==0.3.9 7 | scikit_learn==0.19.1 8 | PyYAML==5.1 9 | -------------------------------------------------------------------------------- /scripts/run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # wujian@2018 3 | 4 | mix_scp=./data/2spk/test/mix.scp 5 | mdl_dir=./tune/2spk_dcnet_a 6 | 7 | set -eu 8 | 9 | [ -d ./cache ] && rm -rf cache 10 | 11 | mkdir cache 12 | 13 | shuf $mix_scp | head -n30 > egs.scp 14 | 15 | ./separate.py --dump-pca --num-spks 2 $mdl_dir/train.yaml $mdl_dir/final.pkl egs.scp 16 | 17 | rm -f egs.scp 18 | -------------------------------------------------------------------------------- /scripts/sdr_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | set -eu 5 | 6 | [ $# -ne 1 ] && echo "format error: $0 " && exit 1 7 | 8 | for x in spk2gender sdr.scp; do 9 | [ ! -f $x ] && echo "$0: missing $x" && exit 1 10 | done 11 | 12 | src_scp=$1 13 | 14 | awk '{print $1}' $src_scp | \ 15 | awk 'BEGIN{ 16 | while (getline < "spk2gender") { 17 | if ($2 == 0) 18 | spk2gender[$1] = "F"; 19 | else 20 | spk2gender[$1] = "M"; 21 | } 22 | } { 23 | split($1, t, "_"); 24 | s1 = spk2gender[substr(t[1], 0, 3)] 25 | s2 = spk2gender[substr(t[3], 0, 3)] 26 | printf("%s\t%s%s\n", $1, s1, s2) 27 | }' | \ 28 | awk 'BEGIN{ 29 | while (getline < "sdr.scp") { 30 | mix2sdr[$1] = $2; 31 | } 32 | FF = 0; nFF = 0; 33 | FM = 0; nFM = 0; 34 | MM = 0; nMM = 0; 35 | } { 36 | if ($2 == "FM" || $2 == "MF") { 37 | FM += mix2sdr[$1]; 38 | nFM += 1; 39 | } else if ($2 == "MM") { 40 | MM += mix2sdr[$1]; 41 | nMM += 1; 42 | } else { 43 | FF += mix2sdr[$1]; 44 | nFF += 1; 45 | } 46 | } END { 47 | printf("FF sdr/num-utts = %.2f/%d\n", FF / nFF, nFF); 48 | printf("MM sdr/num-utts = %.2f/%d\n", MM / nMM, nMM); 49 | printf("FM sdr/num-utts = %.2f/%d\n", FM / nFM, nFM); 50 | printf("-- sdr/num-utts = %.2f/%d\n", (FF + FM + MM) / (nFF + nMM + nFM), nFF + nMM + nFM); 51 | }' 52 | 53 | 54 | -------------------------------------------------------------------------------- /scripts/sdr_eval_2spk.m: -------------------------------------------------------------------------------- 1 | % requires bss_eval_sources.m from bss_eval tools 2 | function sdr_eval_2spk(gt_spk1, gt_spk2, ev_spk1, ev_spk2, sdr_file) 3 | gt_spk1 = fopen(gt_spk1); 4 | gt_spk2 = fopen(gt_spk2); 5 | ev_spk1 = fopen(ev_spk1); 6 | ev_spk2 = fopen(ev_spk2); 7 | sdr_out = fopen(sdr_file, 'w'); 8 | 9 | gt_spk1_cell = textscan(gt_spk1, '%s %s'); 10 | gt_spk2_cell = textscan(gt_spk2, '%s %s'); 11 | ev_spk1_cell = textscan(ev_spk1, '%s %s'); 12 | ev_spk2_cell = textscan(ev_spk2, '%s %s'); 13 | 14 | num_utts = length(gt_spk1_cell{1}); 15 | sdr_tot = 0; 16 | fprintf('Evaluate %d utterances...\n', num_utts); 17 | 18 | for uid = 1: num_utts 19 | if mod(uid, 100) == 0 20 | fprintf('Processed %d utterance...\n', uid); 21 | end 22 | gt_spk1_utt = audioread(gt_spk1_cell{2}{uid}); 23 | gt_spk2_utt = audioread(gt_spk2_cell{2}{uid}); 24 | ev_spk1_utt = audioread(ev_spk1_cell{2}{uid}); 25 | ev_spk2_utt = audioread(ev_spk2_cell{2}{uid}); 26 | 27 | gt = [gt_spk1_utt, gt_spk2_utt]; 28 | ev = [ev_spk1_utt, ev_spk2_utt]; 29 | [sdr, ~, ~, ~] = bss_eval_sources(ev', gt'); 30 | fprintf(sdr_out, '%s\t%f\n', gt_spk1_cell{1}{uid}, mean(sdr)); 31 | sdr_tot = sdr_tot + mean(sdr); 32 | end 33 | 34 | fprintf('Average SDR: %f\n', sdr_tot / num_utts); 35 | end 36 | -------------------------------------------------------------------------------- /scripts/spk2gender: -------------------------------------------------------------------------------- 1 | 050 0 2 | 051 1 3 | 052 1 4 | 053 0 5 | 22g 1 6 | 22h 1 7 | 420 0 8 | 421 0 9 | 422 1 10 | 423 1 11 | 440 1 12 | 441 0 13 | 442 1 14 | 443 1 15 | 444 0 16 | 445 0 17 | 446 1 18 | 447 1 19 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # wujian@2018 3 | 4 | set -eu 5 | 6 | # [ $# -ne 1 ] && echo "format error: $0 " && exit 1 7 | 8 | conf=conf/train.yaml 9 | 10 | checkpoint=$(grep checkpoint $conf | awk '{print $2}' | sed 's:"::g') 11 | 12 | mkdir -p $checkpoint 13 | 14 | echo "start training --> $checkpoint ..." 15 | 16 | cp $conf $checkpoint/train.yaml 17 | 18 | CUDA_VISIBLE_DEVICES=0 python ./train_dcnet.py --config $conf --num-epoches 20 > $checkpoint/train.log 2>&1 19 | 20 | echo "done" 21 | -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import argparse 6 | import os 7 | import pickle 8 | import sklearn 9 | 10 | import numpy as np 11 | import torch as th 12 | import scipy.io as sio 13 | 14 | from utils import stft, istft, parse_scps, compute_vad_mask, apply_cmvn, parse_yaml, EPSILON 15 | from dcnet import DCNet 16 | 17 | class DeepCluster(object): 18 | def __init__(self, dcnet, dcnet_state, num_spks, pca=False, cuda=False): 19 | if not os.path.exists(dcnet_state): 20 | raise RuntimeError( 21 | "Could not find state file {}".format(dcnet_state)) 22 | self.dcnet = dcnet 23 | 24 | self.location = "cuda" if args.cuda else "cpu" 25 | self.dcnet.load_state_dict( 26 | th.load(dcnet_state, map_location='cpu')) 27 | self.dcnet.to(self.location) 28 | self.dcnet.eval() 29 | self.kmeans = sklearn.cluster.KMeans(n_clusters=num_spks) 30 | self.pca = sklearn.decomposition.PCA(n_components=3) if pca else None 31 | self.num_spks = num_spks 32 | 33 | def _cluster(self, spectra, vad_mask): 34 | """ 35 | Arguments 36 | spectra: log-magnitude spectrogram(real numbers) 37 | vad_mask: binary mask for non-silence bins(if non-sil: 1) 38 | return 39 | pca_embed: PCA embedding vector(dim 3) 40 | spk_masks: binary masks for each speaker 41 | """ 42 | # TF x D 43 | net_embed = self.dcnet( 44 | th.tensor(spectra, dtype=th.float32, device=self.location), 45 | train=False).cpu().data.numpy() 46 | # filter silence embeddings: TF x D => N x D 47 | active_embed = net_embed[vad_mask.reshape(-1)] 48 | # classes: N x D 49 | # pca_mat: N x 3 50 | classes = self.kmeans.fit_predict(active_embed) 51 | 52 | pca_mat = None 53 | if self.pca: 54 | pca_mat = self.pca.fit_transform(active_embed) 55 | 56 | def form_mask(classes, spkid, vad_mask): 57 | mask = ~vad_mask 58 | # mask = np.zeros_like(vad_mask) 59 | mask[vad_mask] = (classes == spkid) 60 | return mask 61 | 62 | return pca_mat, [ 63 | form_mask(classes, spk, vad_mask) for spk in range(self.num_spks) 64 | ] 65 | 66 | def seperate(self, spectra, cmvn=None): 67 | """ 68 | spectra: stft complex results T x F 69 | cmvn: python dict contains global mean/std 70 | """ 71 | if not np.iscomplexobj(spectra): 72 | raise ValueError("Input must be matrix in complex value") 73 | # compute log-magnitude spectrogram 74 | log_spectra = np.log(np.maximum(np.abs(spectra), EPSILON)) 75 | # compute vad mask before do mvn 76 | vad_mask = compute_vad_mask( 77 | log_spectra, threshold_db=40).astype(np.bool) 78 | 79 | # print("Keep {} bins out of {}".format(np.sum(vad_mask), vad_mask.size)) 80 | pca_mat, spk_masks = self._cluster( 81 | apply_cmvn(log_spectra, cmvn) if cmvn else log_spectra, vad_mask) 82 | 83 | return pca_mat, spk_masks, [ 84 | spectra * spk_mask for spk_mask in spk_masks 85 | ] 86 | 87 | 88 | def run(args): 89 | num_bins, config_dict = parse_yaml(args.config) 90 | # Load cmvn 91 | dict_mvn = config_dict["dataloader"]["mvn_dict"] 92 | if dict_mvn: 93 | if not os.path.exists(dict_mvn): 94 | raise FileNotFoundError("Could not find mvn files") 95 | with open(dict_mvn, "rb") as f: 96 | dict_mvn = pickle.load(f) 97 | 98 | dcnet = DCNet(num_bins, **config_dict["dcnet"]) 99 | 100 | frame_length = config_dict["spectrogram_reader"]["frame_length"] 101 | frame_shift = config_dict["spectrogram_reader"]["frame_shift"] 102 | window = config_dict["spectrogram_reader"]["window"] 103 | 104 | cluster = DeepCluster( 105 | dcnet, 106 | args.dcnet_state, 107 | args.num_spks, 108 | pca=args.dump_pca, 109 | cuda=args.cuda) 110 | 111 | utt_dict = parse_scps(args.wave_scp) 112 | num_utts = 0 113 | for key, utt in utt_dict.items(): 114 | try: 115 | samps, stft_mat = stft( 116 | utt, 117 | frame_length=frame_length, 118 | frame_shift=frame_shift, 119 | window=window, 120 | center=True, 121 | return_samps=True) 122 | except FileNotFoundError: 123 | print("Skip utterance {}... not found".format(key)) 124 | continue 125 | print("Processing utterance {}".format(key)) 126 | num_utts += 1 127 | norm = np.linalg.norm(samps, np.inf) 128 | pca_mat, spk_mask, spk_spectrogram = cluster.seperate( 129 | stft_mat, cmvn=dict_mvn) 130 | 131 | for index, stft_mat in enumerate(spk_spectrogram): 132 | istft( 133 | os.path.join(args.dump_dir, '{}.spk{}.wav'.format( 134 | key, index + 1)), 135 | stft_mat, 136 | frame_length=frame_length, 137 | frame_shift=frame_shift, 138 | window=window, 139 | center=True, 140 | norm=norm, 141 | fs=8000, 142 | nsamps=samps.size) 143 | if args.dump_mask: 144 | sio.savemat( 145 | os.path.join(args.dump_dir, '{}.spk{}.mat'.format( 146 | key, index + 1)), {"mask": spk_mask[index]}) 147 | if args.dump_pca: 148 | sio.savemat( 149 | os.path.join(args.dump_dir, '{}.mat'.format(key)), 150 | {"pca_matrix": pca_mat}) 151 | print("Processed {} utterance!".format(num_utts)) 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser( 156 | description= 157 | "Command to seperate single-channel speech using masks clustered on embeddings of DCNet" 158 | ) 159 | parser.add_argument( 160 | "config", type=str, help="Location of training configure files") 161 | parser.add_argument( 162 | "dcnet_state", type=str, help="Location of networks state file") 163 | parser.add_argument( 164 | "wave_scp", 165 | type=str, 166 | help="Location of input wave scripts in kaldi format") 167 | parser.add_argument( 168 | "--cuda", 169 | default=False, 170 | action="store_true", 171 | dest="cuda", 172 | help="If true, inference on GPUs") 173 | parser.add_argument( 174 | "--num-spks", 175 | type=int, 176 | default=2, 177 | dest="num_spks", 178 | help="Number of speakers to be seperated") 179 | parser.add_argument( 180 | "--dump-dir", 181 | type=str, 182 | default="cache", 183 | dest="dump_dir", 184 | help="Location to dump seperated speakers") 185 | parser.add_argument( 186 | "--dump-pca", 187 | default=False, 188 | action="store_true", 189 | dest="dump_pca", 190 | help="If true, dump pca matrix") 191 | parser.add_argument( 192 | "--dump-mask", 193 | default=False, 194 | action="store_true", 195 | dest="dump_mask", 196 | help="If true, dump binary mask matrix") 197 | args = parser.parse_args() 198 | run(args) 199 | -------------------------------------------------------------------------------- /train_dcnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | # wujian@2018 5 | 6 | import argparse 7 | import os 8 | 9 | from trainer import Trainer 10 | from dataset import SpectrogramReader, Dataset, DataLoader, logger 11 | from dcnet import DCNet 12 | from utils import nfft, parse_yaml 13 | 14 | 15 | def uttloader(scp_config, reader_kwargs, loader_kwargs, train=True): 16 | mix_reader = SpectrogramReader(scp_config['mixture'], **reader_kwargs) 17 | target_reader = [ 18 | SpectrogramReader(scp_config[spk_key], **reader_kwargs) 19 | for spk_key in scp_config if spk_key[:3] == 'spk' 20 | ] 21 | dataset = Dataset(mix_reader, target_reader) 22 | # modify shuffle status 23 | loader_kwargs["shuffle"] = train 24 | # validate perutt if needed 25 | # if not train: 26 | # loader_kwargs["batch_size"] = 1 27 | # if validate, do not shuffle 28 | utt_loader = DataLoader(dataset, **loader_kwargs) 29 | return utt_loader 30 | 31 | 32 | def train(args): 33 | debug = args.debug 34 | logger.info( 35 | "Start training in {} model".format('debug' if debug else 'normal')) 36 | num_bins, config_dict = parse_yaml(args.config) 37 | reader_conf = config_dict["spectrogram_reader"] 38 | loader_conf = config_dict["dataloader"] 39 | dcnnet_conf = config_dict["dcnet"] 40 | 41 | batch_size = loader_conf["batch_size"] 42 | logger.info( 43 | "Training in {}".format("per utterance" if batch_size == 1 else 44 | '{} utterance per batch'.format(batch_size))) 45 | 46 | train_loader = uttloader( 47 | config_dict["train_scp_conf"] 48 | if not debug else config_dict["debug_scp_conf"], 49 | reader_conf, 50 | loader_conf, 51 | train=True) 52 | valid_loader = uttloader( 53 | config_dict["valid_scp_conf"] 54 | if not debug else config_dict["debug_scp_conf"], 55 | reader_conf, 56 | loader_conf, 57 | train=False) 58 | checkpoint = config_dict["trainer"]["checkpoint"] 59 | logger.info("Training for {} epoches -> {}...".format( 60 | args.num_epoches, "default checkpoint" 61 | if checkpoint is None else checkpoint)) 62 | 63 | dcnet = DCNet(num_bins, **dcnnet_conf) 64 | trainer = Trainer(dcnet, **config_dict["trainer"]) 65 | trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches) 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser( 70 | description="Command to train DCNet, configured by .yaml files") 71 | parser.add_argument( 72 | "--config", 73 | type=str, 74 | default="train.yaml", 75 | dest="config", 76 | help="Location of .yaml configure files for training") 77 | parser.add_argument( 78 | "--debug", 79 | default=False, 80 | action="store_true", 81 | dest="debug", 82 | help="If true, start training in debug data") 83 | parser.add_argument( 84 | "--num-epoches", 85 | type=int, 86 | default=20, 87 | dest="num_epoches", 88 | help="Number of epoches to train dcnet") 89 | args = parser.parse_args() 90 | train(args) 91 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import os 6 | import time 7 | import warnings 8 | 9 | import torch as th 10 | from torch.nn.utils.rnn import PackedSequence 11 | 12 | from dcnet import l2_loss 13 | from dataset import logger 14 | 15 | device = th.device("cuda:0" if th.cuda.is_available() else "cpu") 16 | 17 | 18 | def create_optimizer(optimizer, params, **kwargs): 19 | supported_optimizer = { 20 | 'sgd': th.optim.SGD, # momentum, weight_decay, lr 21 | 'rmsprop': th.optim.RMSprop, # momentum, weight_decay, lr 22 | 'adam': th.optim.Adam # weight_decay, lr 23 | # ... 24 | } 25 | if optimizer not in supported_optimizer: 26 | raise ValueError('Unsupported optimizer {}'.format(optimizer)) 27 | if optimizer == 'adam': 28 | del kwargs['momentum'] 29 | opt = supported_optimizer[optimizer](params, **kwargs) 30 | logger.info('Create optimizer {}({})'.format(optimizer, kwargs)) 31 | return opt 32 | 33 | 34 | class Trainer(object): 35 | def __init__(self, 36 | dcnet, 37 | checkpoint="checkpoint", 38 | optimizer="adam", 39 | lr=1e-5, 40 | momentum=0.9, 41 | weight_decay=0, 42 | clip_norm=None, 43 | num_spks=2): 44 | self.nnet = dcnet 45 | logger.info("DCNet:\n{}".format(self.nnet)) 46 | self.optimizer = create_optimizer( 47 | optimizer, 48 | self.nnet.parameters(), 49 | lr=lr, 50 | momentum=momentum, 51 | weight_decay=weight_decay) 52 | self.nnet.to(device) 53 | self.checkpoint = checkpoint 54 | self.num_spks = num_spks 55 | self.clip_norm = clip_norm 56 | if self.clip_norm: 57 | logger.info("Clip gradient by 2-norm {}".format(clip_norm)) 58 | if not os.path.exists(checkpoint): 59 | os.makedirs(checkpoint) 60 | 61 | def train(self, dataloader): 62 | self.nnet.train() 63 | logger.info("Training...") 64 | tot_loss = 0 65 | num_batches = len(dataloader) 66 | for mix_spect, tgt_index, vad_masks in dataloader: 67 | self.optimizer.zero_grad() 68 | mix_spect = mix_spect.cuda() if isinstance( 69 | mix_spect, PackedSequence) else mix_spect.to(device) 70 | tgt_index = tgt_index.to(device) 71 | vad_masks = vad_masks.to(device) 72 | # mix_spect = mix_spect * vad_masks 73 | net_embed = self.nnet(mix_spect) 74 | cur_loss = self.loss(net_embed, tgt_index, vad_masks) 75 | tot_loss += cur_loss.item() 76 | cur_loss.backward() 77 | if self.clip_norm: 78 | th.nn.utils.clip_grad_norm_(self.nnet.parameters(), 79 | self.clip_norm) 80 | self.optimizer.step() 81 | return tot_loss / num_batches, num_batches 82 | 83 | def validate(self, dataloader): 84 | self.nnet.eval() 85 | logger.info("Evaluating...") 86 | tot_loss = 0 87 | num_batches = len(dataloader) 88 | # do not need to keep gradient 89 | with th.no_grad(): 90 | for mix_spect, tgt_index, vad_masks in dataloader: 91 | mix_spect = mix_spect.cuda() if isinstance( 92 | mix_spect, PackedSequence) else mix_spect.to(device) 93 | tgt_index = tgt_index.to(device) 94 | vad_masks = vad_masks.to(device) 95 | # mix_spect = mix_spect * vad_masks 96 | net_embed = self.nnet(mix_spect) 97 | cur_loss = self.loss(net_embed, tgt_index, vad_masks) 98 | tot_loss += cur_loss.item() 99 | return tot_loss / num_batches, num_batches 100 | 101 | def run(self, train_set, dev_set, num_epoches=20): 102 | init_loss, _ = self.validate(dev_set) 103 | logger.info("Start training for {} epoches".format(num_epoches)) 104 | logger.info("Epoch {:2d}: dev = {:.4e}".format(0, init_loss)) 105 | th.save(self.nnet.state_dict(), 106 | os.path.join(self.checkpoint, 'dcnet.0.pkl')) 107 | for epoch in range(1, num_epoches + 1): 108 | on_train_start = time.time() 109 | train_loss, train_num_batch = self.train(train_set) 110 | on_valid_start = time.time() 111 | valid_loss, valid_num_batch = self.validate(dev_set) 112 | on_valid_end = time.time() 113 | logger.info( 114 | "Loss(time/num-utts) - Epoch {:2d}: train = {:.4e}({:.2f}s/{:d}) |" 115 | " dev = {:.4e}({:.2f}s/{:d})".format( 116 | epoch, train_loss, on_valid_start - on_train_start, 117 | train_num_batch, valid_loss, on_valid_end - on_valid_start, 118 | valid_num_batch)) 119 | save_path = os.path.join(self.checkpoint, 120 | 'dcnet.{:d}.pkl'.format(epoch)) 121 | th.save(self.nnet.state_dict(), save_path) 122 | logger.info("Training for {} epoches done!".format(num_epoches)) 123 | 124 | def loss(self, net_embed, tgt_index, binary_mask): 125 | """ 126 | Arguments: 127 | net_embed N x TF x D 128 | tgt_embed N x T x F 129 | binary_mask N x T x F 130 | """ 131 | if tgt_index.shape != binary_mask.shape: 132 | raise ValueError("Dimension mismatch {} vs {}".format( 133 | tgt_index.shape, binary_mask.shape)) 134 | if th.max(tgt_index) != self.num_spks - 1: 135 | warnings.warn( 136 | "Maybe something wrong with target embeddings computing") 137 | 138 | if tgt_index.dim() == 2: 139 | tgt_index = th.unsqueeze(tgt_index, 0) 140 | binary_mask = th.unsqueeze(binary_mask, 0) 141 | 142 | N, T, F = tgt_index.shape 143 | # shape binary_mask: N x TF x 1 144 | binary_mask = binary_mask.view(N, T * F, 1) 145 | 146 | # encode one-hot 147 | tgt_embed = th.zeros([N, T * F, self.num_spks], device=device) 148 | tgt_embed.scatter_(2, tgt_index.view(N, T * F, 1), 1) 149 | 150 | # net_embed: N x TF x D 151 | # tgt_embed: N x TF x S 152 | net_embed = net_embed * binary_mask 153 | tgt_embed = tgt_embed * binary_mask 154 | 155 | loss = l2_loss(th.bmm(th.transpose(net_embed, 1, 2), net_embed)) + \ 156 | l2_loss(th.bmm(th.transpose(tgt_embed, 1, 2), tgt_embed)) - \ 157 | l2_loss(th.bmm(th.transpose(net_embed, 1, 2), tgt_embed)) * 2 158 | 159 | return loss / th.sum(binary_mask) 160 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import os 6 | import warnings 7 | import yaml 8 | 9 | import librosa as audio_lib 10 | import numpy as np 11 | 12 | MAX_INT16 = np.iinfo(np.int16).max 13 | EPSILON = np.finfo(np.float32).eps 14 | 15 | config_keys = [ 16 | "trainer", "dcnet", "spectrogram_reader", "dataloader", "train_scp_conf", 17 | "valid_scp_conf", "debug_scp_conf" 18 | ] 19 | 20 | 21 | def nfft(window_size): 22 | return int(2**np.ceil(int(np.log2(window_size)))) 23 | 24 | 25 | # return F x T or T x F 26 | def stft(file, 27 | frame_length=1024, 28 | frame_shift=256, 29 | center=False, 30 | window="hann", 31 | return_samps=False, 32 | apply_abs=False, 33 | apply_log=False, 34 | apply_pow=False, 35 | transpose=True): 36 | if not os.path.exists(file): 37 | raise FileNotFoundError("Input file {} do not exists!".format(file)) 38 | if apply_log and not apply_abs: 39 | apply_abs = True 40 | warnings.warn( 41 | "Ignore apply_abs=False cause function return real values") 42 | samps, _ = audio_lib.load(file, sr=None) 43 | stft_mat = audio_lib.stft( 44 | samps, 45 | nfft(frame_length), 46 | frame_shift, 47 | frame_length, 48 | window=window, 49 | center=center) 50 | if apply_abs: 51 | stft_mat = np.abs(stft_mat) 52 | if apply_pow: 53 | stft_mat = np.power(stft_mat, 2) 54 | if apply_log: 55 | stft_mat = np.log(np.maximum(stft_mat, EPSILON)) 56 | if transpose: 57 | stft_mat = np.transpose(stft_mat) 58 | return stft_mat if not return_samps else (samps, stft_mat) 59 | 60 | 61 | def istft(file, 62 | stft_mat, 63 | frame_length=1024, 64 | frame_shift=256, 65 | center=False, 66 | window="hann", 67 | transpose=True, 68 | norm=None, 69 | fs=16000, 70 | nsamps=None): 71 | if transpose: 72 | stft_mat = np.transpose(stft_mat) 73 | samps = audio_lib.istft( 74 | stft_mat, 75 | frame_shift, 76 | frame_length, 77 | window=window, 78 | center=center, 79 | length=nsamps) 80 | samps_norm = np.linalg.norm(samps, np.inf) 81 | # renorm if needed 82 | if not norm: 83 | samps = samps * norm / samps_norm 84 | samps_int16 = (samps * MAX_INT16).astype(np.int16) 85 | fdir = os.path.dirname(file) 86 | if fdir and not os.path.exists(fdir): 87 | os.makedirs(fdir) 88 | audio_lib.output.write_wav(file, samps_int16, fs) 89 | 90 | 91 | def compute_vad_mask(spectra, threshold_db=40, apply_exp=True): 92 | # to linear first if needed 93 | if apply_exp: 94 | spectra = np.exp(spectra) 95 | # to dB 96 | spectra_db = 20 * np.log10(spectra) 97 | max_magnitude_db = np.max(spectra_db) 98 | threshold = 10**((max_magnitude_db - threshold_db) / 20) 99 | mask = np.array(spectra > threshold, dtype=np.float32) 100 | return mask 101 | 102 | 103 | def apply_cmvn(feats, cmvn_dict): 104 | if type(cmvn_dict) != dict: 105 | raise TypeError("Input must be a python dictionary") 106 | if 'mean' in cmvn_dict: 107 | feats = feats - cmvn_dict['mean'] 108 | if 'std' in cmvn_dict: 109 | feats = feats / cmvn_dict['std'] 110 | return feats 111 | 112 | 113 | def parse_scps(scp_path): 114 | assert os.path.exists(scp_path) 115 | scp_dict = dict() 116 | with open(scp_path, 'r') as f: 117 | for scp in f: 118 | scp_tokens = scp.strip().split() 119 | if len(scp_tokens) != 2: 120 | raise RuntimeError( 121 | "Error format of context \'{}\'".format(scp)) 122 | key, addr = scp_tokens 123 | if key in scp_dict: 124 | raise ValueError("Duplicate key \'{}\' exists!".format(key)) 125 | scp_dict[key] = addr 126 | return scp_dict 127 | 128 | 129 | def filekey(path): 130 | fname = os.path.basename(path) 131 | if not fname: 132 | raise ValueError("{}(Is directory path?)".format(path)) 133 | token = fname.split(".") 134 | if len(token) == 1: 135 | return token[0] 136 | else: 137 | return '.'.join(token[:-1]) 138 | 139 | 140 | def parse_yaml(yaml_conf): 141 | if not os.path.exists(yaml_conf): 142 | raise FileNotFoundError( 143 | "Could not find configure files...{}".format(yaml_conf)) 144 | with open(yaml_conf, 'r') as f: 145 | config_dict = yaml.load(f) 146 | 147 | for key in config_keys: 148 | if key not in config_dict: 149 | raise KeyError("Missing {} configs in yaml".format(key)) 150 | batch_size = config_dict["dataloader"]["batch_size"] 151 | if batch_size <= 0: 152 | raise ValueError("Invalid batch_size: {}".format(batch_size)) 153 | 154 | num_frames = config_dict["spectrogram_reader"]["frame_length"] 155 | num_bins = nfft(num_frames) // 2 + 1 156 | if len(config_dict["train_scp_conf"]) != len( 157 | config_dict["valid_scp_conf"]): 158 | raise ValueError("Check configures in train_scp_conf/valid_scp_conf") 159 | num_spks = 0 160 | for key in config_dict["train_scp_conf"]: 161 | if key[:3] == "spk": 162 | num_spks += 1 163 | if num_spks != config_dict["trainer"]["num_spks"]: 164 | warnings.warn( 165 | "Number of speakers configured in trainer do not match *_scp_conf, " 166 | " correct to {}".format(num_spks)) 167 | config_dict["trainer"]["num_spks"] = num_spks 168 | return num_bins, config_dict 169 | --------------------------------------------------------------------------------