├── LICENSE ├── README.md ├── data ├── tr_cmvn.npz ├── uniq_target_ref_dur.txt ├── wsj0_2mix │ ├── cv │ │ ├── aux.scp │ │ ├── mix.scp │ │ └── ref.scp │ ├── tr │ │ ├── aux.scp │ │ ├── mix.scp │ │ └── ref.scp │ └── tt │ │ ├── aux.scp │ │ ├── mix.scp │ │ └── ref.scp └── wsj0_2mix_extr_tr.spk ├── decode.sh ├── nnet ├── __pycache__ │ ├── conf.cpython-36.pyc │ ├── conf.cpython-38.pyc │ ├── conv_tas_net.cpython-36.pyc │ ├── conv_tas_net.cpython-38.pyc │ ├── conv_tas_net_decode.cpython-36.pyc │ ├── conv_tas_net_decode.cpython-38.pyc │ └── sincnet.cpython-36.pyc ├── compute_si_snr.py ├── conf.py ├── conv_tas_net.py ├── conv_tas_net_decode.py ├── libs │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── audio.cpython-36.pyc │ │ ├── audio.cpython-38.pyc │ │ ├── dataset.cpython-36.pyc │ │ ├── dataset.cpython-38.pyc │ │ ├── kaldi_io.cpython-36.pyc │ │ ├── kaldi_io.cpython-38.pyc │ │ ├── trainer.cpython-36.pyc │ │ ├── trainer.cpython-38.pyc │ │ ├── utils.cpython-36.pyc │ │ └── utils.cpython-38.pyc │ ├── audio.py │ ├── dataset.py │ ├── kaldi_io.py │ ├── metric.py │ ├── trainer.py │ └── utils.py ├── separate.py └── train.py ├── pretrain_model └── link.txt ├── requirements.txt └── train.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Meng Ge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpEx_Plus 2 | -------------------------------------------------------------------------------- /data/tr_cmvn.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/data/tr_cmvn.npz -------------------------------------------------------------------------------- /data/wsj0_2mix_extr_tr.spk: -------------------------------------------------------------------------------- 1 | 011 2 | 012 3 | 013 4 | 014 5 | 015 6 | 016 7 | 017 8 | 018 9 | 019 10 | 01a 11 | 01b 12 | 01c 13 | 01d 14 | 01e 15 | 01f 16 | 01g 17 | 01i 18 | 01j 19 | 01k 20 | 01l 21 | 01m 22 | 01n 23 | 01o 24 | 01p 25 | 01q 26 | 01r 27 | 01s 28 | 01t 29 | 01u 30 | 01v 31 | 01w 32 | 01x 33 | 01y 34 | 01z 35 | 020 36 | 021 37 | 022 38 | 023 39 | 024 40 | 025 41 | 026 42 | 027 43 | 028 44 | 029 45 | 02a 46 | 02b 47 | 02c 48 | 02d 49 | 02e 50 | 204 51 | 205 52 | 206 53 | 207 54 | 208 55 | 209 56 | 20a 57 | 20b 58 | 20c 59 | 20d 60 | 20e 61 | 20f 62 | 20g 63 | 20h 64 | 20i 65 | 20j 66 | 20k 67 | 20l 68 | 20m 69 | 20n 70 | 20o 71 | 20p 72 | 20q 73 | 20r 74 | 20s 75 | 20t 76 | 20u 77 | 20v 78 | 401 79 | 403 80 | 404 81 | 405 82 | 406 83 | 407 84 | 408 85 | 409 86 | 40a 87 | 40b 88 | 40c 89 | 40d 90 | 40e 91 | 40f 92 | 40g 93 | 40h 94 | 40i 95 | 40j 96 | 40k 97 | 40l 98 | 40m 99 | 40n 100 | 40o 101 | 40p 102 | -------------------------------------------------------------------------------- /decode.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | #cpt_dir=exp/conv_tasnet 6 | #epochs=100 7 | # constrainted by GPU number & memory 8 | #batch_size=20 9 | #cache_size=16 10 | 11 | #[ $# -ne 2 ] && echo "Script error: $0 " && exit 1 12 | 13 | #./nnet/train.py --gpu "1,2,3,4,5,7" --epochs $epochs --batch-size $batch_size --checkpoint $cpt_dir/conv-net 14 | ./nnet/separate.py 15 | -------------------------------------------------------------------------------- /nnet/__pycache__/conf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conf.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conf.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conv_tas_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conv_tas_net.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conv_tas_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conv_tas_net.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conv_tas_net_decode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conv_tas_net_decode.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/conv_tas_net_decode.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/conv_tas_net_decode.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/sincnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/__pycache__/sincnet.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/compute_si_snr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Compute SI-SDR as the evaluation metric 4 | """ 5 | 6 | import argparse 7 | 8 | from tqdm import tqdm 9 | 10 | from collections import defaultdict 11 | from libs.metric import si_snr, permute_si_snr 12 | from libs.audio import WaveReader, Reader 13 | 14 | 15 | class SpeakersReader(object): 16 | def __init__(self, scps): 17 | split_scps = scps.split(",") 18 | if len(split_scps) == 1: 19 | raise RuntimeError( 20 | "Construct SpeakersReader need more than one script, got {}". 21 | format(scps)) 22 | self.readers = [WaveReader(scp) for scp in split_scps] 23 | 24 | def __len__(self): 25 | first_reader = self.readers[0] 26 | return len(first_reader) 27 | 28 | def __getitem__(self, key): 29 | return [reader[key] for reader in self.readers] 30 | 31 | def __iter__(self): 32 | first_reader = self.readers[0] 33 | for key in first_reader.index_keys: 34 | yield key, self[key] 35 | 36 | 37 | class Report(object): 38 | def __init__(self, spk2gender=None): 39 | self.s2g = Reader(spk2gender) if spk2gender else None 40 | self.snr = defaultdict(float) 41 | self.cnt = defaultdict(int) 42 | 43 | def add(self, key, val): 44 | gender = "NG" 45 | if self.s2g: 46 | gender = self.s2g[key] 47 | self.snr[gender] += val 48 | self.cnt[gender] += 1 49 | 50 | def report(self): 51 | print("SI-SDR(dB) Report: ") 52 | for gender in self.snr: 53 | tot_snrs = self.snr[gender] 54 | num_utts = self.cnt[gender] 55 | print("{}: {:d}/{:.3f}".format(gender, num_utts, 56 | tot_snrs / num_utts)) 57 | 58 | 59 | def run(args): 60 | single_speaker = len(args.sep_scp.split(",")) == 1 61 | reporter = Report(args.spk2gender) 62 | 63 | if single_speaker: 64 | sep_reader = WaveReader(args.sep_scp) 65 | ref_reader = WaveReader(args.ref_scp) 66 | for key, sep in tqdm(sep_reader): 67 | ref = ref_reader[key] 68 | if sep.size != ref.size: 69 | end = min(sep.size, ref.size) 70 | sep = sep[:end] 71 | ref = ref[:end] 72 | snr = si_snr(sep, ref) 73 | reporter.add(key, snr) 74 | else: 75 | sep_reader = SpeakersReader(args.sep_scp) 76 | ref_reader = SpeakersReader(args.ref_scp) 77 | for key, sep_list in tqdm(sep_reader): 78 | ref_list = ref_reader[key] 79 | if sep_list[0].size != ref_list[0].size: 80 | end = min(sep_list[0].size, ref_list[0].size) 81 | sep_list = [s[:end] for s in sep_list] 82 | ref_list = [s[:end] for s in ref_list] 83 | snr = permute_si_snr(sep_list, ref_list) 84 | reporter.add(key, snr) 85 | reporter.report() 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser( 90 | description= 91 | "Command to compute SI-SDR, as metric of the separation quality", 92 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 93 | parser.add_argument( 94 | "sep_scp", 95 | type=str, 96 | help="Separated speech scripts, waiting for measure" 97 | "(support multi-speaker, egs: spk1.scp,spk2.scp)") 98 | parser.add_argument( 99 | "ref_scp", 100 | type=str, 101 | help="Reference speech scripts, as ground truth for" 102 | " SI-SDR computation") 103 | parser.add_argument( 104 | "--spk2gender", 105 | type=str, 106 | default="", 107 | help="If assigned, report results per gender") 108 | args = parser.parse_args() 109 | run(args) 110 | -------------------------------------------------------------------------------- /nnet/conf.py: -------------------------------------------------------------------------------- 1 | fs = 8000 2 | chunk_len = 4 # (s) 3 | chunk_size = chunk_len * fs 4 | num_spks = 1 5 | 6 | # network configure 7 | nnet_conf = { 8 | "L": 20, 9 | "N": 256, 10 | "X": 8, 11 | "R": 4, 12 | "B": 256, 13 | "H": 512, 14 | "P": 3, 15 | "norm": "gLN", 16 | "num_spks": num_spks, 17 | "non_linear": "relu" 18 | } 19 | 20 | # data configure: 21 | train_dir = "data/wsj0_2mix/tr/" 22 | dev_dir = "data/wsj0_2mix/cv/" 23 | spk_list = "data/wsj0_2mix_extr_tr.spk" 24 | 25 | train_data = { 26 | "mix_scp": 27 | train_dir + "mix.scp", 28 | "ref_scp": 29 | train_dir + "ref.scp", 30 | "aux_scp": 31 | train_dir + "aux.scp", 32 | "spk_list": spk_list, 33 | "sample_rate": 34 | fs, 35 | } 36 | 37 | dev_data = { 38 | "mix_scp": 39 | dev_dir + "mix.scp", 40 | "ref_scp": 41 | dev_dir + "ref.scp", 42 | "aux_scp": 43 | dev_dir + "aux.scp", 44 | "spk_list": spk_list, 45 | "sample_rate": fs, 46 | } 47 | 48 | # trainer config 49 | adam_kwargs = { 50 | "lr": 1e-3, 51 | "weight_decay": 1e-5, 52 | } 53 | 54 | trainer_conf = { 55 | "optimizer": "adam", 56 | "optimizer_kwargs": adam_kwargs, 57 | "min_lr": 1e-8, 58 | "patience": 2, 59 | "factor": 0.5, 60 | "logging_period": 200 # batch number 61 | } 62 | -------------------------------------------------------------------------------- /nnet/conv_tas_net.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def param(nnet, Mb=True): 7 | """ 8 | Return number parameters(not bytes) in nnet 9 | """ 10 | neles = sum([param.nelement() for param in nnet.parameters()]) 11 | return neles / 10**6 if Mb else neles 12 | 13 | 14 | class ChannelWiseLayerNorm(nn.LayerNorm): 15 | """ 16 | Channel wise layer normalization 17 | """ 18 | 19 | def __init__(self, *args, **kwargs): 20 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 21 | 22 | def forward(self, x): 23 | """ 24 | x: N x C x T 25 | """ 26 | if x.dim() != 3: 27 | raise RuntimeError("{} accept 3D tensor as input".format( 28 | self.__name__)) 29 | # N x C x T => N x T x C 30 | x = th.transpose(x, 1, 2) 31 | # LN 32 | x = super().forward(x) 33 | # N x C x T => N x T x C 34 | x = th.transpose(x, 1, 2) 35 | return x 36 | 37 | 38 | class GlobalChannelLayerNorm(nn.Module): 39 | """ 40 | Global channel layer normalization 41 | """ 42 | 43 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 44 | super(GlobalChannelLayerNorm, self).__init__() 45 | self.eps = eps 46 | self.normalized_dim = dim 47 | self.elementwise_affine = elementwise_affine 48 | if elementwise_affine: 49 | self.beta = nn.Parameter(th.zeros(dim, 1)) 50 | self.gamma = nn.Parameter(th.ones(dim, 1)) 51 | else: 52 | self.register_parameter("weight", None) 53 | self.register_parameter("bias", None) 54 | 55 | def forward(self, x): 56 | """ 57 | x: N x C x T 58 | """ 59 | if x.dim() != 3: 60 | raise RuntimeError("{} accept 3D tensor as input".format( 61 | self.__name__)) 62 | # N x 1 x 1 63 | mean = th.mean(x, (1, 2), keepdim=True) 64 | var = th.mean((x - mean)**2, (1, 2), keepdim=True) 65 | # N x T x C 66 | if self.elementwise_affine: 67 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta 68 | else: 69 | x = (x - mean) / th.sqrt(var + self.eps) 70 | return x 71 | 72 | def extra_repr(self): 73 | return "{normalized_dim}, eps={eps}, " \ 74 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 75 | 76 | 77 | def build_norm(norm, dim): 78 | """ 79 | Build normalize layer 80 | LN cost more memory than BN 81 | """ 82 | if norm not in ["cLN", "gLN", "BN"]: 83 | raise RuntimeError("Unsupported normalize layer: {}".format(norm)) 84 | if norm == "cLN": 85 | return ChannelWiseLayerNorm(dim, elementwise_affine=True) 86 | elif norm == "BN": 87 | return nn.BatchNorm1d(dim) 88 | else: 89 | return GlobalChannelLayerNorm(dim, elementwise_affine=True) 90 | 91 | 92 | class Conv1D(nn.Conv1d): 93 | """ 94 | 1D conv in ConvTasNet 95 | """ 96 | 97 | def __init__(self, *args, **kwargs): 98 | super(Conv1D, self).__init__(*args, **kwargs) 99 | 100 | def forward(self, x, squeeze=False): 101 | """ 102 | x: N x L or N x C x L 103 | """ 104 | if x.dim() not in [2, 3]: 105 | raise RuntimeError("{} accept 2/3D tensor as input".format( 106 | self.__name__)) 107 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 108 | if squeeze: 109 | x = th.squeeze(x) 110 | return x 111 | 112 | 113 | class ConvTrans1D(nn.ConvTranspose1d): 114 | """ 115 | 1D conv transpose in ConvTasNet 116 | """ 117 | 118 | def __init__(self, *args, **kwargs): 119 | super(ConvTrans1D, self).__init__(*args, **kwargs) 120 | 121 | def forward(self, x, squeeze=False): 122 | """ 123 | x: N x L or N x C x L 124 | """ 125 | if x.dim() not in [2, 3]: 126 | raise RuntimeError("{} accept 2/3D tensor as input".format( 127 | self.__name__)) 128 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 129 | if squeeze: 130 | x = th.squeeze(x) 131 | return x 132 | 133 | 134 | class Conv1DBlock(nn.Module): 135 | """ 136 | 1D convolutional block: 137 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 138 | """ 139 | 140 | def __init__(self, 141 | in_channels=256, 142 | conv_channels=512, 143 | kernel_size=3, 144 | dilation=1, 145 | norm="cLN", 146 | causal=False): 147 | super(Conv1DBlock, self).__init__() 148 | # 1x1 conv 149 | self.conv1x1 = Conv1D(in_channels, conv_channels, 1) 150 | self.prelu1 = nn.PReLU() 151 | self.lnorm1 = build_norm(norm, conv_channels) 152 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 153 | dilation * (kernel_size - 1)) 154 | # depthwise conv 155 | self.dconv = nn.Conv1d( 156 | conv_channels, 157 | conv_channels, 158 | kernel_size, 159 | groups=conv_channels, 160 | padding=dconv_pad, 161 | dilation=dilation, 162 | bias=True) 163 | self.prelu2 = nn.PReLU() 164 | self.lnorm2 = build_norm(norm, conv_channels) 165 | # 1x1 conv cross channel 166 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 167 | # different padding way 168 | self.causal = causal 169 | self.dconv_pad = dconv_pad 170 | 171 | def forward(self, x): 172 | y = self.conv1x1(x) 173 | y = self.lnorm1(self.prelu1(y)) 174 | y = self.dconv(y) 175 | if self.causal: 176 | y = y[:, :, :-self.dconv_pad] 177 | y = self.lnorm2(self.prelu2(y)) 178 | y = self.sconv(y) 179 | x = x + y 180 | return x 181 | 182 | class Conv1DBlock_v2(nn.Module): 183 | """ 184 | 1D convolutional block: 185 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 186 | """ 187 | 188 | def __init__(self, 189 | in_channels=256, 190 | spk_embed_dim=100, 191 | conv_channels=512, 192 | kernel_size=3, 193 | dilation=1, 194 | norm="cLN", 195 | causal=False): 196 | super(Conv1DBlock_v2, self).__init__() 197 | # 1x1 conv 198 | self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1) 199 | self.prelu1 = nn.PReLU() 200 | self.lnorm1 = build_norm(norm, conv_channels) 201 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 202 | dilation * (kernel_size - 1)) 203 | # depthwise conv 204 | self.dconv = nn.Conv1d( 205 | conv_channels, 206 | conv_channels, 207 | kernel_size, 208 | groups=conv_channels, 209 | padding=dconv_pad, 210 | dilation=dilation, 211 | bias=True) 212 | self.prelu2 = nn.PReLU() 213 | self.lnorm2 = build_norm(norm, conv_channels) 214 | # 1x1 conv cross channel 215 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 216 | # different padding way 217 | self.causal = causal 218 | self.dconv_pad = dconv_pad 219 | 220 | def forward(self, x, aux): 221 | #print(x.shape) 222 | T = x.shape[-1] 223 | #print(aux.shape) 224 | aux = th.unsqueeze(aux, -1) 225 | #print(aux.shape) 226 | aux = aux.repeat(1,1,T) 227 | y = th.cat([x, aux], 1) 228 | y = self.conv1x1(y) 229 | y = self.lnorm1(self.prelu1(y)) 230 | y = self.dconv(y) 231 | if self.causal: 232 | y = y[:, :, :-self.dconv_pad] 233 | y = self.lnorm2(self.prelu2(y)) 234 | y = self.sconv(y) 235 | x = x + y 236 | return x 237 | 238 | class ResBlock(nn.Module): 239 | """ 240 | ref to 241 | https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py 242 | and 243 | https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py 244 | """ 245 | def __init__(self, in_dims, out_dims): 246 | super().__init__() 247 | self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) 248 | self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False) 249 | self.batch_norm1 = nn.BatchNorm1d(out_dims) 250 | self.batch_norm2 = nn.BatchNorm1d(out_dims) 251 | self.prelu1 = nn.PReLU() 252 | self.prelu2 = nn.PReLU() 253 | self.mp = nn.MaxPool1d(3) 254 | if in_dims != out_dims: 255 | self.downsample = True 256 | self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) 257 | else: 258 | self.downsample = False 259 | 260 | def forward(self, x): 261 | residual = x 262 | x = self.conv1(x) 263 | x = self.batch_norm1(x) 264 | x = self.prelu1(x) 265 | x = self.conv2(x) 266 | x = self.batch_norm2(x) 267 | if self.downsample: 268 | residual = self.conv_downsample(residual) 269 | x = x + residual 270 | x = self.prelu2(x) 271 | return self.mp(x) 272 | 273 | class ConvTasNet(nn.Module): 274 | def __init__(self, 275 | L=20, 276 | N=256, 277 | X=8, 278 | R=4, 279 | B=256, 280 | H=512, 281 | P=3, 282 | norm="cLN", 283 | num_spks=1, 284 | non_linear="relu", 285 | causal=False): 286 | super(ConvTasNet, self).__init__() 287 | supported_nonlinear = { 288 | "relu": F.relu, 289 | "sigmoid": th.sigmoid, 290 | "softmax": F.softmax 291 | } 292 | if non_linear not in supported_nonlinear: 293 | raise RuntimeError("Unsupported non-linear function: {}", 294 | format(non_linear)) 295 | self.non_linear_type = non_linear 296 | self.non_linear = supported_nonlinear[non_linear] 297 | 298 | # Multi-scale Encoder 299 | # n x S => n x N x T, S = 4s*8000 = 32000 300 | self.L1 = L 301 | self.L2 = 80 302 | self.L3 = 160 303 | self.encoder_1d_short = Conv1D(1, N, L, stride=L // 2, padding=0) 304 | self.encoder_1d_middle = Conv1D(1, N, 80, stride=L // 2, padding=0) 305 | self.encoder_1d_long = Conv1D(1, N, 160, stride=L // 2, padding=0) 306 | # keep T not change 307 | # T = int((xlen - L) / (L // 2)) + 1 308 | # before repeat blocks, always cLN 309 | self.ln = ChannelWiseLayerNorm(3*N) 310 | # n x N x T => n x B x T 311 | self.proj = Conv1D(3*N, B, 1) 312 | 313 | # Repeat Conv Blocks 314 | # n x B x T => n x B x T 315 | self.conv_block_1 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 316 | self.conv_block_1_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 317 | self.conv_block_2 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 318 | self.conv_block_2_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 319 | self.conv_block_3 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 320 | self.conv_block_3_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 321 | self.conv_block_4 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 322 | self.conv_block_4_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 323 | 324 | # Multi-scale Decoder 325 | # output 1x1 conv 326 | # n x B x T => n x N x T 327 | # NOTE: using ModuleList not python list 328 | # self.conv1x1_2 = th.nn.ModuleList( 329 | # [Conv1D(B, N, 1) for _ in range(num_spks)]) 330 | # n x B x T => n x 2N x T 331 | self.mask1 = Conv1D(B, N, 1) 332 | self.mask2 = Conv1D(B, N, 1) 333 | self.mask3 = Conv1D(B, N, 1) 334 | 335 | # using ConvTrans1D: n x N x T => n x 1 x To 336 | # To = (T - 1) * L // 2 + L 337 | self.decoder_1d_1 = ConvTrans1D(N, 1, kernel_size=L, stride=L // 2, bias=True) 338 | self.decoder_1d_2 = ConvTrans1D(N, 1, kernel_size=80, stride=L // 2, bias=True) 339 | self.decoder_1d_3 = ConvTrans1D(N, 1, kernel_size=160, stride=L // 2, bias=True) 340 | #self.num_spks = num_spks 341 | 342 | # Speaker Encoder 343 | self.aux_enc3 = nn.Sequential( 344 | ChannelWiseLayerNorm(3*256), 345 | Conv1D(3*256, 256, 1), 346 | ResBlock(256, 256), 347 | ResBlock(256, 512), 348 | ResBlock(512, 512), 349 | Conv1D(512, 256, 1), 350 | ) 351 | self.pred_linear = nn.Linear(256,101) 352 | 353 | def flatten_parameters(self): 354 | self.lstm.flatten_parameters() 355 | 356 | def _build_blocks(self, num_blocks, **block_kwargs): 357 | """ 358 | Build Conv1D block 359 | """ 360 | blocks = [ 361 | Conv1DBlock(**block_kwargs, dilation=(2**b)) 362 | for b in range(1,num_blocks) 363 | ] 364 | return nn.Sequential(*blocks) 365 | 366 | def _build_repeats(self, num_repeats, num_blocks, **block_kwargs): 367 | """ 368 | Build Conv1D block repeats 369 | """ 370 | repeats = [ 371 | self._build_blocks(num_blocks, **block_kwargs) 372 | for r in range(num_repeats) 373 | ] 374 | return nn.Sequential(*repeats) 375 | 376 | def forward(self, x, aux, aux_len): 377 | if x.dim() >= 3: 378 | raise RuntimeError( 379 | "{} accept 1/2D tensor as input, but got {:d}".format( 380 | self.__name__, x.dim())) 381 | # when inference, only one utt 382 | if x.dim() == 1: 383 | x = th.unsqueeze(x, 0) 384 | 385 | # Multi-scale Encoder (Mixture audio input) 386 | w1 = F.relu(self.encoder_1d_short(x)) 387 | T = w1.shape[-1] 388 | xlen1 = x.shape[-1] 389 | xlen2 = (T - 1) * (self.L1 // 2) + self.L2 390 | xlen3 = (T - 1) * (self.L1 // 2) + self.L3 391 | w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) 392 | w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) 393 | # n x 3N x T 394 | y = self.ln(th.cat([w1, w2, w3], 1)) 395 | # n x B x T 396 | y = self.proj(y) 397 | 398 | # Multi-scale Encoder (Reference audio input) 399 | aux_w1 = F.relu(self.encoder_1d_short(aux)) 400 | aux_T_shape = aux_w1.shape[-1] 401 | aux_len1 = aux.shape[-1] 402 | aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2 403 | aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3 404 | aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0))) 405 | aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0))) 406 | 407 | # Speaker Encoder 408 | aux = self.aux_enc3(th.cat([aux_w1, aux_w2, aux_w3], 1)) 409 | aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1 410 | aux_T = ((aux_T // 3) // 3) // 3 411 | aux = th.sum(aux, -1)/aux_T.view(-1,1).float() 412 | 413 | # Speaker Extractor 414 | y = self.conv_block_1(y, aux) 415 | y = self.conv_block_1_other(y) 416 | y = self.conv_block_2(y, aux) 417 | y = self.conv_block_2_other(y) 418 | y = self.conv_block_3(y, aux) 419 | y = self.conv_block_3_other(y) 420 | y = self.conv_block_4(y, aux) 421 | y = self.conv_block_4_other(y) 422 | 423 | # Multi-scale Decoder 424 | m1 = self.non_linear(self.mask1(y)) 425 | m2 = self.non_linear(self.mask2(y)) 426 | m3 = self.non_linear(self.mask3(y)) 427 | s1 = w1 * m1 428 | s2 = w2 * m2 429 | s3 = w3 * m3 430 | 431 | return self.decoder_1d_1(s1, squeeze=True), self.decoder_1d_2(s2, squeeze=True)[:, :xlen1], self.decoder_1d_3(s3, squeeze=True)[:, :xlen1], self.pred_linear(aux) 432 | #return self.decoder_1d_1(s1, squeeze=True).unsqueeze(0), self.decoder_1d_2(s2, squeeze=True).unsqueeze(0)[:, :xlen1], self.decoder_1d_3(s3, squeeze=True).unsqueeze(0)[:, :xlen1], self.pred_linear(aux) 433 | 434 | def foo_conv1d_block(): 435 | nnet = Conv1DBlock(256, 512, 3, 20) 436 | print(param(nnet)) 437 | 438 | def foo_layernorm(): 439 | C, T = 256, 20 440 | nnet1 = nn.LayerNorm([C, T], elementwise_affine=True) 441 | print(param(nnet1, Mb=False)) 442 | nnet2 = nn.LayerNorm([C, T], elementwise_affine=False) 443 | print(param(nnet2, Mb=False)) 444 | 445 | def foo_conv_tas_net(): 446 | x = th.rand(4, 1000) 447 | nnet = ConvTasNet(norm="cLN", causal=False) 448 | print("ConvTasNet #param: {:.2f}".format(param(nnet))) 449 | x = nnet(x) 450 | s1 = x[0] 451 | print(s1.shape) 452 | 453 | if __name__ == "__main__": 454 | foo_conv_tas_net() 455 | -------------------------------------------------------------------------------- /nnet/conv_tas_net_decode.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def param(nnet, Mb=True): 7 | """ 8 | Return number parameters(not bytes) in nnet 9 | """ 10 | neles = sum([param.nelement() for param in nnet.parameters()]) 11 | return neles / 10**6 if Mb else neles 12 | 13 | 14 | class ChannelWiseLayerNorm(nn.LayerNorm): 15 | """ 16 | Channel wise layer normalization 17 | """ 18 | 19 | def __init__(self, *args, **kwargs): 20 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 21 | 22 | def forward(self, x): 23 | """ 24 | x: N x C x T 25 | """ 26 | if x.dim() != 3: 27 | raise RuntimeError("{} accept 3D tensor as input".format( 28 | self.__name__)) 29 | # N x C x T => N x T x C 30 | x = th.transpose(x, 1, 2) 31 | # LN 32 | x = super().forward(x) 33 | # N x C x T => N x T x C 34 | x = th.transpose(x, 1, 2) 35 | return x 36 | 37 | 38 | class GlobalChannelLayerNorm(nn.Module): 39 | """ 40 | Global channel layer normalization 41 | """ 42 | 43 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 44 | super(GlobalChannelLayerNorm, self).__init__() 45 | self.eps = eps 46 | self.normalized_dim = dim 47 | self.elementwise_affine = elementwise_affine 48 | if elementwise_affine: 49 | self.beta = nn.Parameter(th.zeros(dim, 1)) 50 | self.gamma = nn.Parameter(th.ones(dim, 1)) 51 | else: 52 | self.register_parameter("weight", None) 53 | self.register_parameter("bias", None) 54 | 55 | def forward(self, x): 56 | """ 57 | x: N x C x T 58 | """ 59 | if x.dim() != 3: 60 | raise RuntimeError("{} accept 3D tensor as input".format( 61 | self.__name__)) 62 | # N x 1 x 1 63 | mean = th.mean(x, (1, 2), keepdim=True) 64 | var = th.mean((x - mean)**2, (1, 2), keepdim=True) 65 | # N x T x C 66 | if self.elementwise_affine: 67 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta 68 | else: 69 | x = (x - mean) / th.sqrt(var + self.eps) 70 | return x 71 | 72 | def extra_repr(self): 73 | return "{normalized_dim}, eps={eps}, " \ 74 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 75 | 76 | 77 | def build_norm(norm, dim): 78 | """ 79 | Build normalize layer 80 | LN cost more memory than BN 81 | """ 82 | if norm not in ["cLN", "gLN", "BN"]: 83 | raise RuntimeError("Unsupported normalize layer: {}".format(norm)) 84 | if norm == "cLN": 85 | return ChannelWiseLayerNorm(dim, elementwise_affine=True) 86 | elif norm == "BN": 87 | return nn.BatchNorm1d(dim) 88 | else: 89 | return GlobalChannelLayerNorm(dim, elementwise_affine=True) 90 | 91 | 92 | class Conv1D(nn.Conv1d): 93 | """ 94 | 1D conv in ConvTasNet 95 | """ 96 | 97 | def __init__(self, *args, **kwargs): 98 | super(Conv1D, self).__init__(*args, **kwargs) 99 | 100 | def forward(self, x, squeeze=False): 101 | """ 102 | x: N x L or N x C x L 103 | """ 104 | if x.dim() not in [2, 3]: 105 | raise RuntimeError("{} accept 2/3D tensor as input".format( 106 | self.__name__)) 107 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 108 | if squeeze: 109 | x = th.squeeze(x) 110 | return x 111 | 112 | 113 | class ConvTrans1D(nn.ConvTranspose1d): 114 | """ 115 | 1D conv transpose in ConvTasNet 116 | """ 117 | 118 | def __init__(self, *args, **kwargs): 119 | super(ConvTrans1D, self).__init__(*args, **kwargs) 120 | 121 | def forward(self, x, squeeze=False): 122 | """ 123 | x: N x L or N x C x L 124 | """ 125 | if x.dim() not in [2, 3]: 126 | raise RuntimeError("{} accept 2/3D tensor as input".format( 127 | self.__name__)) 128 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 129 | if squeeze: 130 | x = th.squeeze(x) 131 | return x 132 | 133 | 134 | class Conv1DBlock(nn.Module): 135 | """ 136 | 1D convolutional block: 137 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 138 | """ 139 | 140 | def __init__(self, 141 | in_channels=256, 142 | conv_channels=512, 143 | kernel_size=3, 144 | dilation=1, 145 | norm="cLN", 146 | causal=False): 147 | super(Conv1DBlock, self).__init__() 148 | # 1x1 conv 149 | self.conv1x1 = Conv1D(in_channels, conv_channels, 1) 150 | self.prelu1 = nn.PReLU() 151 | self.lnorm1 = build_norm(norm, conv_channels) 152 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 153 | dilation * (kernel_size - 1)) 154 | # depthwise conv 155 | self.dconv = nn.Conv1d( 156 | conv_channels, 157 | conv_channels, 158 | kernel_size, 159 | groups=conv_channels, 160 | padding=dconv_pad, 161 | dilation=dilation, 162 | bias=True) 163 | self.prelu2 = nn.PReLU() 164 | self.lnorm2 = build_norm(norm, conv_channels) 165 | # 1x1 conv cross channel 166 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 167 | # different padding way 168 | self.causal = causal 169 | self.dconv_pad = dconv_pad 170 | 171 | def forward(self, x): 172 | y = self.conv1x1(x) 173 | y = self.lnorm1(self.prelu1(y)) 174 | y = self.dconv(y) 175 | if self.causal: 176 | y = y[:, :, :-self.dconv_pad] 177 | y = self.lnorm2(self.prelu2(y)) 178 | y = self.sconv(y) 179 | x = x + y 180 | return x 181 | 182 | class Conv1DBlock_v2(nn.Module): 183 | """ 184 | 1D convolutional block: 185 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 186 | """ 187 | 188 | def __init__(self, 189 | in_channels=256, 190 | spk_embed_dim=100, 191 | conv_channels=512, 192 | kernel_size=3, 193 | dilation=1, 194 | norm="cLN", 195 | causal=False): 196 | super(Conv1DBlock_v2, self).__init__() 197 | # 1x1 conv 198 | self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1) 199 | self.prelu1 = nn.PReLU() 200 | self.lnorm1 = build_norm(norm, conv_channels) 201 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else ( 202 | dilation * (kernel_size - 1)) 203 | # depthwise conv 204 | self.dconv = nn.Conv1d( 205 | conv_channels, 206 | conv_channels, 207 | kernel_size, 208 | groups=conv_channels, 209 | padding=dconv_pad, 210 | dilation=dilation, 211 | bias=True) 212 | self.prelu2 = nn.PReLU() 213 | self.lnorm2 = build_norm(norm, conv_channels) 214 | # 1x1 conv cross channel 215 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 216 | # different padding way 217 | self.causal = causal 218 | self.dconv_pad = dconv_pad 219 | 220 | def forward(self, x, aux): 221 | #print(x.shape) 222 | T = x.shape[-1] 223 | #print(aux.shape) 224 | aux = th.unsqueeze(aux, -1) 225 | #print(aux.shape) 226 | aux = aux.repeat(1,1,T) 227 | y = th.cat([x, aux], 1) 228 | y = self.conv1x1(y) 229 | y = self.lnorm1(self.prelu1(y)) 230 | y = self.dconv(y) 231 | if self.causal: 232 | y = y[:, :, :-self.dconv_pad] 233 | y = self.lnorm2(self.prelu2(y)) 234 | y = self.sconv(y) 235 | x = x + y 236 | return x 237 | 238 | class ResBlock(nn.Module): 239 | """ 240 | ref to 241 | https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py 242 | and 243 | https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py 244 | """ 245 | def __init__(self, in_dims, out_dims): 246 | super().__init__() 247 | self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) 248 | self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False) 249 | self.batch_norm1 = nn.BatchNorm1d(out_dims) 250 | self.batch_norm2 = nn.BatchNorm1d(out_dims) 251 | self.prelu1 = nn.PReLU() 252 | self.prelu2 = nn.PReLU() 253 | self.mp = nn.MaxPool1d(3) 254 | if in_dims != out_dims: 255 | self.downsample = True 256 | self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False) 257 | else: 258 | self.downsample = False 259 | 260 | def forward(self, x): 261 | residual = x 262 | x = self.conv1(x) 263 | x = self.batch_norm1(x) 264 | x = self.prelu1(x) 265 | x = self.conv2(x) 266 | x = self.batch_norm2(x) 267 | if self.downsample: 268 | residual = self.conv_downsample(residual) 269 | x = x + residual 270 | x = self.prelu2(x) 271 | return self.mp(x) 272 | 273 | class ConvTasNet(nn.Module): 274 | def __init__(self, 275 | L=20, 276 | N=256, 277 | X=8, 278 | R=4, 279 | B=256, 280 | H=512, 281 | P=3, 282 | norm="cLN", 283 | num_spks=1, 284 | non_linear="relu", 285 | causal=False): 286 | super(ConvTasNet, self).__init__() 287 | supported_nonlinear = { 288 | "relu": F.relu, 289 | "sigmoid": th.sigmoid, 290 | "softmax": F.softmax 291 | } 292 | if non_linear not in supported_nonlinear: 293 | raise RuntimeError("Unsupported non-linear function: {}", 294 | format(non_linear)) 295 | self.non_linear_type = non_linear 296 | self.non_linear = supported_nonlinear[non_linear] 297 | 298 | # Multi-scale Encoder 299 | # n x S => n x N x T, S = 4s*8000 = 32000 300 | self.L1 = L 301 | self.L2 = 80 302 | self.L3 = 160 303 | self.encoder_1d_short = Conv1D(1, N, L, stride=L // 2, padding=0) 304 | self.encoder_1d_middle = Conv1D(1, N, 80, stride=L // 2, padding=0) 305 | self.encoder_1d_long = Conv1D(1, N, 160, stride=L // 2, padding=0) 306 | # keep T not change 307 | # T = int((xlen - L) / (L // 2)) + 1 308 | # before repeat blocks, always cLN 309 | self.ln = ChannelWiseLayerNorm(3*N) 310 | # n x N x T => n x B x T 311 | self.proj = Conv1D(3*N, B, 1) 312 | 313 | # Repeat Conv Blocks 314 | # n x B x T => n x B x T 315 | self.conv_block_1 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 316 | self.conv_block_1_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 317 | self.conv_block_2 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 318 | self.conv_block_2_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 319 | self.conv_block_3 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 320 | self.conv_block_3_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 321 | self.conv_block_4 = Conv1DBlock_v2(spk_embed_dim=256, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal, dilation=1) 322 | self.conv_block_4_other = self._build_blocks(num_blocks=X, in_channels=B, conv_channels=H, kernel_size=P, norm=norm, causal=causal) 323 | 324 | # Multi-scale Decoder 325 | # output 1x1 conv 326 | # n x B x T => n x N x T 327 | # NOTE: using ModuleList not python list 328 | # self.conv1x1_2 = th.nn.ModuleList( 329 | # [Conv1D(B, N, 1) for _ in range(num_spks)]) 330 | # n x B x T => n x 2N x T 331 | self.mask1 = Conv1D(B, N, 1) 332 | self.mask2 = Conv1D(B, N, 1) 333 | self.mask3 = Conv1D(B, N, 1) 334 | 335 | # using ConvTrans1D: n x N x T => n x 1 x To 336 | # To = (T - 1) * L // 2 + L 337 | self.decoder_1d_1 = ConvTrans1D(N, 1, kernel_size=L, stride=L // 2, bias=True) 338 | self.decoder_1d_2 = ConvTrans1D(N, 1, kernel_size=80, stride=L // 2, bias=True) 339 | self.decoder_1d_3 = ConvTrans1D(N, 1, kernel_size=160, stride=L // 2, bias=True) 340 | #self.num_spks = num_spks 341 | 342 | # Speaker Encoder 343 | self.aux_enc3 = nn.Sequential( 344 | ChannelWiseLayerNorm(3*256), 345 | Conv1D(3*256, 256, 1), 346 | ResBlock(256, 256), 347 | ResBlock(256, 512), 348 | ResBlock(512, 512), 349 | Conv1D(512, 256, 1), 350 | ) 351 | self.pred_linear = nn.Linear(256,101) 352 | 353 | def flatten_parameters(self): 354 | self.lstm.flatten_parameters() 355 | 356 | def _build_blocks(self, num_blocks, **block_kwargs): 357 | """ 358 | Build Conv1D block 359 | """ 360 | blocks = [ 361 | Conv1DBlock(**block_kwargs, dilation=(2**b)) 362 | for b in range(1,num_blocks) 363 | ] 364 | return nn.Sequential(*blocks) 365 | 366 | def _build_repeats(self, num_repeats, num_blocks, **block_kwargs): 367 | """ 368 | Build Conv1D block repeats 369 | """ 370 | repeats = [ 371 | self._build_blocks(num_blocks, **block_kwargs) 372 | for r in range(num_repeats) 373 | ] 374 | return nn.Sequential(*repeats) 375 | 376 | def forward(self, x, aux, aux_len): 377 | if x.dim() >= 3: 378 | raise RuntimeError( 379 | "{} accept 1/2D tensor as input, but got {:d}".format( 380 | self.__name__, x.dim())) 381 | # when inference, only one utt 382 | if x.dim() == 1: 383 | x = th.unsqueeze(x, 0) 384 | 385 | # Multi-scale Encoder (Mixture audio input) 386 | w1 = F.relu(self.encoder_1d_short(x)) 387 | T = w1.shape[-1] 388 | xlen1 = x.shape[-1] 389 | xlen2 = (T - 1) * (self.L1 // 2) + self.L2 390 | xlen3 = (T - 1) * (self.L1 // 2) + self.L3 391 | w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0))) 392 | w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0))) 393 | # n x 3N x T 394 | y = self.ln(th.cat([w1, w2, w3], 1)) 395 | # n x B x T 396 | y = self.proj(y) 397 | 398 | # Multi-scale Encoder (Reference audio input) 399 | aux_w1 = F.relu(self.encoder_1d_short(aux)) 400 | aux_T_shape = aux_w1.shape[-1] 401 | aux_len1 = aux.shape[-1] 402 | aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2 403 | aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3 404 | aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0))) 405 | aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0))) 406 | 407 | # Speaker Encoder 408 | aux = self.aux_enc3(th.cat([aux_w1, aux_w2, aux_w3], 1)) 409 | aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1 410 | aux_T = ((aux_T // 3) // 3) // 3 411 | aux = th.sum(aux, -1)/aux_T.view(-1,1).float() 412 | 413 | # Speaker Extractor 414 | y = self.conv_block_1(y, aux) 415 | y = self.conv_block_1_other(y) 416 | y = self.conv_block_2(y, aux) 417 | y = self.conv_block_2_other(y) 418 | y = self.conv_block_3(y, aux) 419 | y = self.conv_block_3_other(y) 420 | y = self.conv_block_4(y, aux) 421 | y = self.conv_block_4_other(y) 422 | 423 | # Multi-scale Decoder 424 | m1 = self.non_linear(self.mask1(y)) 425 | m2 = self.non_linear(self.mask2(y)) 426 | m3 = self.non_linear(self.mask3(y)) 427 | s1 = w1 * m1 428 | s2 = w2 * m2 429 | s3 = w3 * m3 430 | 431 | #return self.decoder_1d_1(s1, squeeze=True), self.decoder_1d_2(s2, squeeze=True)[:, :xlen1], self.decoder_1d_3(s3, squeeze=True)[:, :xlen1], self.pred_linear(aux) 432 | return self.decoder_1d_1(s1, squeeze=True).unsqueeze(0), self.decoder_1d_2(s2, squeeze=True).unsqueeze(0)[:, :xlen1], self.decoder_1d_3(s3, squeeze=True).unsqueeze(0)[:, :xlen1], self.pred_linear(aux) 433 | 434 | def foo_conv1d_block(): 435 | nnet = Conv1DBlock(256, 512, 3, 20) 436 | print(param(nnet)) 437 | 438 | def foo_layernorm(): 439 | C, T = 256, 20 440 | nnet1 = nn.LayerNorm([C, T], elementwise_affine=True) 441 | print(param(nnet1, Mb=False)) 442 | nnet2 = nn.LayerNorm([C, T], elementwise_affine=False) 443 | print(param(nnet2, Mb=False)) 444 | 445 | def foo_conv_tas_net(): 446 | x = th.rand(4, 1000) 447 | nnet = ConvTasNet(norm="cLN", causal=False) 448 | print("ConvTasNet #param: {:.2f}".format(param(nnet))) 449 | x = nnet(x) 450 | s1 = x[0] 451 | print(s1.shape) 452 | 453 | if __name__ == "__main__": 454 | foo_conv_tas_net() 455 | -------------------------------------------------------------------------------- /nnet/libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__init__.py -------------------------------------------------------------------------------- /nnet/libs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/audio.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/audio.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/audio.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/audio.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/kaldi_io.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/kaldi_io.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/kaldi_io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/kaldi_io.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /nnet/libs/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gemengtju/SpEx_Plus/9fe15e1483989f97cd22f9a7ed4fed56738e4e9c/nnet/libs/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /nnet/libs/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io.wavfile as wf 4 | import soundfile as sf 5 | 6 | MAX_INT16 = np.iinfo(np.int16).max 7 | 8 | 9 | def write_wav(fname, samps, fs=16000, normalize=True): 10 | """ 11 | Write wav files in int16, support single/multi-channel 12 | """ 13 | #if normalize: 14 | # samps = samps * MAX_INT16 15 | ## scipy.io.wavfile.write could write single/multi-channel files 16 | ## for multi-channel, accept ndarray [Nsamples, Nchannels] 17 | #if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 18 | # samps = np.transpose(samps) 19 | # samps = np.squeeze(samps) 20 | ## same as MATLAB and kaldi 21 | #samps_int16 = samps.astype(np.int16) 22 | #fdir = os.path.dirname(fname) 23 | #if fdir and not os.path.exists(fdir): 24 | # os.makedirs(fdir) 25 | ## NOTE: librosa 0.6.0 seems could not write non-float narray 26 | ## so use scipy.io.wavfile instead 27 | #wf.write(fname, fs, samps_int16) 28 | 29 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16 30 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float 31 | fdir = os.path.dirname(fname) 32 | if fdir and not os.path.exists(fdir): 33 | os.makedirs(fdir) 34 | sf.write(fname, samps, fs, subtype='FLOAT') 35 | 36 | 37 | def read_wav(fname, normalize=True, return_rate=False): 38 | """ 39 | Read wave files using scipy.io.wavfile(support multi-channel) 40 | """ 41 | # samps_int16: N x C or N 42 | # N: number of samples 43 | # C: number of channels 44 | #samp_rate, samps_int16 = wf.read(fname) 45 | ## N x C => C x N 46 | #samps = samps_int16.astype(np.float) 47 | ## tranpose because I used to put channel axis first 48 | #if samps.ndim != 1: 49 | # samps = np.transpose(samps) 50 | ## normalize like MATLAB and librosa 51 | #if normalize: 52 | # samps = samps / MAX_INT16 53 | #if return_rate: 54 | # return samp_rate, samps 55 | #return samps 56 | 57 | # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16 58 | # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float 59 | samps, samp_rate = sf.read(fname) 60 | if return_rate: 61 | return samp_rate, samps 62 | return samps 63 | 64 | 65 | def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2): 66 | """ 67 | Parse kaldi's script(.scp) file 68 | If num_tokens >= 2, function will check token number 69 | """ 70 | scp_dict = dict() 71 | line = 0 72 | with open(scp_path, "r") as f: 73 | for raw_line in f: 74 | scp_tokens = raw_line.strip().split() 75 | line += 1 76 | if num_tokens >= 2 and len(scp_tokens) != num_tokens or len( 77 | scp_tokens) < 2: 78 | raise RuntimeError( 79 | "For {}, format error in line[{:d}]: {}".format( 80 | scp_path, line, raw_line)) 81 | if num_tokens == 2: 82 | key, value = scp_tokens 83 | else: 84 | key, value = scp_tokens[0], scp_tokens[1:] 85 | if key in scp_dict: 86 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 87 | key, scp_path)) 88 | scp_dict[key] = value_processor(value) 89 | return scp_dict 90 | 91 | 92 | class Reader(object): 93 | """ 94 | Basic Reader Class 95 | """ 96 | 97 | def __init__(self, scp_path, value_processor=lambda x: x): 98 | self.index_dict = parse_scripts( 99 | scp_path, value_processor=value_processor, num_tokens=2) 100 | self.index_keys = list(self.index_dict.keys()) 101 | 102 | def _load(self, key): 103 | # return path 104 | return self.index_dict[key] 105 | 106 | # number of utterance 107 | def __len__(self): 108 | return len(self.index_dict) 109 | 110 | # avoid key error 111 | def __contains__(self, key): 112 | return key in self.index_dict 113 | 114 | # sequential index 115 | def __iter__(self): 116 | for key in self.index_keys: 117 | yield key, self._load(key) 118 | 119 | # random index, support str/int as index 120 | def __getitem__(self, index): 121 | if type(index) not in [int, str]: 122 | raise IndexError("Unsupported index type: {}".format(type(index))) 123 | if type(index) == int: 124 | # from int index to key 125 | num_utts = len(self.index_keys) 126 | if index >= num_utts or index < 0: 127 | raise KeyError( 128 | "Interger index out of range, {:d} vs {:d}".format( 129 | index, num_utts)) 130 | index = self.index_keys[index] 131 | if index not in self.index_dict: 132 | raise KeyError("Missing utterance {}!".format(index)) 133 | return self._load(index) 134 | 135 | 136 | class WaveReader(Reader): 137 | """ 138 | Sequential/Random Reader for single channel wave 139 | Format of wav.scp follows Kaldi's definition: 140 | key1 /path/to/wav 141 | ... 142 | """ 143 | 144 | def __init__(self, wav_scp, sample_rate=None, normalize=True): 145 | super(WaveReader, self).__init__(wav_scp) 146 | self.samp_rate = sample_rate 147 | self.normalize = normalize 148 | 149 | def _load(self, key): 150 | # return C x N or N 151 | samp_rate, samps = read_wav( 152 | self.index_dict[key], normalize=self.normalize, return_rate=True) 153 | # if given samp_rate, check it 154 | if self.samp_rate is not None and samp_rate != self.samp_rate: 155 | raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format( 156 | samp_rate, self.samp_rate)) 157 | return samps 158 | -------------------------------------------------------------------------------- /nnet/libs/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch as th 3 | import numpy as np 4 | 5 | from torch.utils.data.dataloader import default_collate 6 | import torch.utils.data as dat 7 | from torch.nn.utils.rnn import pad_sequence 8 | from .kaldi_io import read_vec_flt, read_mat 9 | from .audio import WaveReader 10 | 11 | def make_dataloader(train=True, 12 | data_kwargs=None, 13 | num_workers=4, 14 | chunk_size=32000, 15 | batch_size=16): 16 | dataset = Dataset(**data_kwargs) 17 | return DataLoader(dataset, 18 | train=train, 19 | chunk_size=chunk_size, 20 | batch_size=batch_size, 21 | num_workers=num_workers) 22 | 23 | class Dataset(object): 24 | """ 25 | Per Utterance Loader 26 | """ 27 | def __init__(self, mix_scp="", ref_scp=None, aux_scp=None, spk_list=None, sample_rate=8000): 28 | self.mix = WaveReader(mix_scp, sample_rate=sample_rate) 29 | self.ref = WaveReader(ref_scp, sample_rate=sample_rate) 30 | self.aux = WaveReader(aux_scp, sample_rate=sample_rate) 31 | # If use WSJ0-2mix data (min version), don't need this part 32 | self.ref_dur = WaveReader("data/uniq_target_ref_dur.txt", sample_rate=sample_rate) 33 | self.sample_rate = sample_rate 34 | self.spk_list = self._load_spk(spk_list) 35 | print(self.spk_list) 36 | 37 | def _load_spk(self, spk_list_path): 38 | if spk_list_path is None: 39 | return [] 40 | lines = open(spk_list_path).readlines() 41 | new_lines = [] 42 | for line in lines: 43 | new_lines.append(line.strip()) 44 | 45 | return new_lines 46 | 47 | def __len__(self): 48 | return len(self.mix) 49 | 50 | def __getitem__(self, index): 51 | key = self.mix.index_keys[index] 52 | mix = self.mix[key] 53 | ref = self.ref[key] 54 | aux = self.aux[key] 55 | 56 | target1_dur = self.ref_dur.index_dict[key.split('_')[0]] 57 | target2_dur = self.ref_dur.index_dict[key.split('_')[2]] 58 | end_idx = int(min(float(target1_dur), float(target2_dur)) * self.sample_rate) 59 | 60 | spk_idx = self.spk_list.index(key.split('_')[-1][0:3]) 61 | return { 62 | "mix": mix[:end_idx].astype(np.float32), 63 | "ref": ref[:end_idx].astype(np.float32), 64 | "aux": aux.astype(np.float32), 65 | "aux_len": len(aux), 66 | "spk_idx": spk_idx 67 | } 68 | 69 | 70 | class ChunkSplitter(object): 71 | """ 72 | Split utterance into small chunks 73 | """ 74 | def __init__(self, chunk_size, train=True, least=16000): 75 | self.chunk_size = chunk_size 76 | self.least = least 77 | self.train = train 78 | 79 | def _make_chunk(self, eg, s): 80 | """ 81 | Make a chunk instance, which contains: 82 | "mix": ndarray, 83 | "ref": [ndarray...] 84 | """ 85 | chunk = dict() 86 | chunk["mix"] = eg["mix"][s:s + self.chunk_size] 87 | chunk["ref"] = eg["ref"][s:s + self.chunk_size] 88 | chunk["aux"] = eg["aux"] 89 | chunk["aux_len"] = eg["aux_len"] 90 | chunk["valid_len"] = int(self.chunk_size) 91 | chunk["spk_idx"] = eg["spk_idx"] 92 | return chunk 93 | 94 | def split(self, eg): 95 | N = eg["mix"].size 96 | # too short, throw away 97 | if N < self.least: 98 | return [] 99 | chunks = [] 100 | # padding zeros 101 | if N < self.chunk_size: 102 | P = self.chunk_size - N 103 | chunk = dict() 104 | chunk["mix"] = np.pad(eg["mix"], (0, P), "constant") 105 | chunk["ref"] = np.pad(eg["ref"], (0, P), "constant") 106 | chunk["aux"] = eg["aux"] 107 | chunk["aux_len"] = eg["aux_len"] 108 | chunk["valid_len"] = int(N) 109 | chunk["spk_idx"] = eg["spk_idx"] 110 | chunks.append(chunk) 111 | else: 112 | # random select start point for training 113 | s = random.randint(0, N % self.least) if self.train else 0 114 | while True: 115 | if s + self.chunk_size > N: 116 | break 117 | chunk = self._make_chunk(eg, s) 118 | chunks.append(chunk) 119 | s += self.least 120 | return chunks 121 | 122 | 123 | class DataLoader(object): 124 | """ 125 | Online dataloader for chunk-level PIT 126 | """ 127 | def __init__(self, 128 | dataset, 129 | num_workers=4, 130 | chunk_size=32000, 131 | batch_size=16, 132 | train=True): 133 | self.batch_size = batch_size 134 | self.train = train 135 | self.splitter = ChunkSplitter(chunk_size, 136 | train=train, 137 | least=chunk_size // 2) 138 | # just return batch of egs, support multiple workers 139 | self.eg_loader = dat.DataLoader(dataset, 140 | batch_size=batch_size // 2, 141 | num_workers=num_workers, 142 | shuffle=train, 143 | collate_fn=self._collate) 144 | 145 | def _collate(self, batch): 146 | """ 147 | Online split utterances 148 | """ 149 | chunk = [] 150 | for eg in batch: 151 | chunk += self.splitter.split(eg) 152 | return chunk 153 | 154 | def _pad_aux(self, chunk_list): 155 | lens_list = [] 156 | for chunk_item in chunk_list: 157 | #lens_list.append(chunk_item['aux_mfcc_len']) 158 | lens_list.append(chunk_item['aux_len']) 159 | max_len = np.max(lens_list) 160 | 161 | 162 | for idx in range(len(chunk_list)): 163 | P = max_len - len(chunk_list[idx]["aux"]) 164 | chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant") 165 | 166 | return chunk_list 167 | 168 | def _merge(self, chunk_list): 169 | """ 170 | Merge chunk list into mini-batch 171 | """ 172 | N = len(chunk_list) 173 | if self.train: 174 | random.shuffle(chunk_list) 175 | blist = [] 176 | for s in range(0, N - self.batch_size + 1, self.batch_size): 177 | # padding aux info 178 | #self._pad_aux(chunk_list[s:s + self.batch_size]) 179 | batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size])) 180 | blist.append(batch) 181 | rn = N % self.batch_size 182 | return blist, chunk_list[-rn:] if rn else [] 183 | 184 | def __iter__(self): 185 | chunk_list = [] 186 | for chunks in self.eg_loader: 187 | chunk_list += chunks 188 | batch, chunk_list = self._merge(chunk_list) 189 | for obj in batch: 190 | yield obj 191 | -------------------------------------------------------------------------------- /nnet/libs/kaldi_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2014-2019 Brno University of Technology (author: Karel Vesely) 5 | # Licensed under the Apache License, Version 2.0 (the "License") 6 | 7 | from __future__ import print_function 8 | from __future__ import division 9 | 10 | import numpy as np 11 | import sys, os, re, gzip, struct 12 | 13 | ################################################# 14 | # Adding kaldi tools to shell path, 15 | 16 | # Select kaldi, 17 | if not 'KALDI_ROOT' in os.environ: 18 | # Default! To change run python with 'export KALDI_ROOT=/some_dir python' 19 | os.environ['KALDI_ROOT']='/mnt/matylda5/iveselyk/Tools/kaldi-trunk' 20 | 21 | # Add kaldi tools to path, 22 | os.environ['PATH'] = os.popen('echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/').readline().strip() + ':' + os.environ['PATH'] 23 | 24 | 25 | ################################################# 26 | # Define all custom exceptions, 27 | class UnsupportedDataType(Exception): pass 28 | class UnknownVectorHeader(Exception): pass 29 | class UnknownMatrixHeader(Exception): pass 30 | 31 | class BadSampleSize(Exception): pass 32 | class BadInputFormat(Exception): pass 33 | 34 | class SubprocessFailed(Exception): pass 35 | 36 | ################################################# 37 | # Data-type independent helper functions, 38 | 39 | def open_or_fd(file, mode='rb'): 40 | """ fd = open_or_fd(file) 41 | Open file, gzipped file, pipe, or forward the file-descriptor. 42 | Eventually seeks in the 'file' argument contains ':offset' suffix. 43 | """ 44 | offset = None 45 | try: 46 | # strip 'ark:' prefix from r{x,w}filename (optional), 47 | if re.search('^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:', file): 48 | (prefix,file) = file.split(':',1) 49 | # separate offset from filename (optional), 50 | if re.search(':[0-9]+$', file): 51 | (file,offset) = file.rsplit(':',1) 52 | # input pipe? 53 | if file[-1] == '|': 54 | fd = popen(file[:-1], 'rb') # custom, 55 | # output pipe? 56 | elif file[0] == '|': 57 | fd = popen(file[1:], 'wb') # custom, 58 | # is it gzipped? 59 | elif file.split('.')[-1] == 'gz': 60 | fd = gzip.open(file, mode) 61 | # a normal file... 62 | else: 63 | fd = open(file, mode) 64 | except TypeError: 65 | # 'file' is opened file descriptor, 66 | fd = file 67 | # Eventually seek to offset, 68 | if offset != None: fd.seek(int(offset)) 69 | return fd 70 | 71 | # based on '/usr/local/lib/python3.6/os.py' 72 | def popen(cmd, mode="rb"): 73 | if not isinstance(cmd, str): 74 | raise TypeError("invalid cmd type (%s, expected string)" % type(cmd)) 75 | 76 | import subprocess, io, threading 77 | 78 | # cleanup function for subprocesses, 79 | def cleanup(proc, cmd): 80 | ret = proc.wait() 81 | if ret > 0: 82 | raise SubprocessFailed('cmd %s returned %d !' % (cmd,ret)) 83 | return 84 | 85 | # text-mode, 86 | if mode == "r": 87 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=sys.stderr) 88 | threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, 89 | return io.TextIOWrapper(proc.stdout) 90 | elif mode == "w": 91 | proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stderr=sys.stderr) 92 | threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, 93 | return io.TextIOWrapper(proc.stdin) 94 | # binary, 95 | elif mode == "rb": 96 | proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=sys.stderr) 97 | threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, 98 | return proc.stdout 99 | elif mode == "wb": 100 | proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stderr=sys.stderr) 101 | threading.Thread(target=cleanup,args=(proc,cmd)).start() # clean-up thread, 102 | return proc.stdin 103 | # sanity, 104 | else: 105 | raise ValueError("invalid mode %s" % mode) 106 | 107 | 108 | def read_key(fd): 109 | """ [key] = read_key(fd) 110 | Read the utterance-key from the opened ark/stream descriptor 'fd'. 111 | """ 112 | key = '' 113 | while 1: 114 | char = fd.read(1).decode("latin1") 115 | if char == '' : break 116 | if char == ' ' : break 117 | key += char 118 | key = key.strip() 119 | if key == '': return None # end of file, 120 | assert(re.match('^\S+$',key) != None) # check format (no whitespace!) 121 | return key 122 | 123 | 124 | ################################################# 125 | # Integer vectors (alignments, ...), 126 | 127 | def read_ali_ark(file_or_fd): 128 | """ Alias to 'read_vec_int_ark()' """ 129 | return read_vec_int_ark(file_or_fd) 130 | 131 | def read_vec_int_ark(file_or_fd): 132 | """ generator(key,vec) = read_vec_int_ark(file_or_fd) 133 | Create generator of (key,vector) tuples, which reads from the ark file/stream. 134 | file_or_fd : ark, gzipped ark, pipe or opened file descriptor. 135 | 136 | Read ark to a 'dictionary': 137 | d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) } 138 | """ 139 | fd = open_or_fd(file_or_fd) 140 | try: 141 | key = read_key(fd) 142 | while key: 143 | ali = read_vec_int(fd) 144 | yield key, ali 145 | key = read_key(fd) 146 | finally: 147 | if fd is not file_or_fd: fd.close() 148 | 149 | def read_vec_int(file_or_fd): 150 | """ [int-vec] = read_vec_int(file_or_fd) 151 | Read kaldi integer vector, ascii or binary input, 152 | """ 153 | fd = open_or_fd(file_or_fd) 154 | binary = fd.read(2).decode() 155 | if binary == '\0B': # binary flag 156 | assert(fd.read(1).decode() == '\4'); # int-size 157 | vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim 158 | if vec_size == 0: 159 | return np.array([], dtype='int32') 160 | # Elements from int32 vector are sored in tuples: (sizeof(int32), value), 161 | vec = np.frombuffer(fd.read(vec_size*5), dtype=[('size','int8'),('value','int32')], count=vec_size) 162 | assert(vec[0]['size'] == 4) # int32 size, 163 | ans = vec[:]['value'] # values are in 2nd column, 164 | else: # ascii, 165 | arr = (binary + fd.readline().decode()).strip().split() 166 | try: 167 | arr.remove('['); arr.remove(']') # optionally 168 | except ValueError: 169 | pass 170 | ans = np.array(arr, dtype=int) 171 | if fd is not file_or_fd : fd.close() # cleanup 172 | return ans 173 | 174 | # Writing, 175 | def write_vec_int(file_or_fd, v, key=''): 176 | """ write_vec_int(f, v, key='') 177 | Write a binary kaldi integer vector to filename or stream. 178 | Arguments: 179 | file_or_fd : filename or opened file descriptor for writing, 180 | v : the vector to be stored, 181 | key (optional) : used for writing ark-file, the utterance-id gets written before the vector. 182 | 183 | Example of writing single vector: 184 | kaldi_io.write_vec_int(filename, vec) 185 | 186 | Example of writing arkfile: 187 | with open(ark_file,'w') as f: 188 | for key,vec in dict.iteritems(): 189 | kaldi_io.write_vec_flt(f, vec, key=key) 190 | """ 191 | fd = open_or_fd(file_or_fd, mode='wb') 192 | if sys.version_info[0] == 3: assert(fd.mode == 'wb') 193 | try: 194 | if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), 195 | fd.write('\0B'.encode()) # we write binary! 196 | # dim, 197 | fd.write('\4'.encode()) # int32 type, 198 | fd.write(struct.pack(np.dtype('int32').char, v.shape[0])) 199 | # data, 200 | for i in range(len(v)): 201 | fd.write('\4'.encode()) # int32 type, 202 | fd.write(struct.pack(np.dtype('int32').char, v[i])) # binary, 203 | finally: 204 | if fd is not file_or_fd : fd.close() 205 | 206 | 207 | ################################################# 208 | # Float vectors (confidences, ivectors, ...), 209 | 210 | # Reading, 211 | def read_vec_flt_scp(file_or_fd): 212 | """ generator(key,mat) = read_vec_flt_scp(file_or_fd) 213 | Returns generator of (key,vector) tuples, read according to kaldi scp. 214 | file_or_fd : scp, gzipped scp, pipe or opened file descriptor. 215 | 216 | Iterate the scp: 217 | for key,vec in kaldi_io.read_vec_flt_scp(file): 218 | ... 219 | 220 | Read scp to a 'dictionary': 221 | d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } 222 | """ 223 | fd = open_or_fd(file_or_fd) 224 | try: 225 | for line in fd: 226 | (key,rxfile) = line.decode().split(' ') 227 | vec = read_vec_flt(rxfile) 228 | yield key, vec 229 | finally: 230 | if fd is not file_or_fd : fd.close() 231 | 232 | def read_vec_flt_ark(file_or_fd): 233 | """ generator(key,vec) = read_vec_flt_ark(file_or_fd) 234 | Create generator of (key,vector) tuples, reading from an ark file/stream. 235 | file_or_fd : ark, gzipped ark, pipe or opened file descriptor. 236 | 237 | Read ark to a 'dictionary': 238 | d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) } 239 | """ 240 | fd = open_or_fd(file_or_fd) 241 | try: 242 | key = read_key(fd) 243 | while key: 244 | ali = read_vec_flt(fd) 245 | yield key, ali 246 | key = read_key(fd) 247 | finally: 248 | if fd is not file_or_fd : fd.close() 249 | 250 | def read_vec_flt(file_or_fd): 251 | """ [flt-vec] = read_vec_flt(file_or_fd) 252 | Read kaldi float vector, ascii or binary input, 253 | """ 254 | fd = open_or_fd(file_or_fd) 255 | binary = fd.read(2).decode() 256 | if binary == '\0B': # binary flag 257 | ans = _read_vec_flt_binary(fd) 258 | else: # ascii, 259 | arr = (binary + fd.readline().decode()).strip().split() 260 | try: 261 | arr.remove('['); arr.remove(']') # optionally 262 | except ValueError: 263 | pass 264 | ans = np.array(arr, dtype=float) 265 | if fd is not file_or_fd : fd.close() # cleanup 266 | return ans 267 | 268 | def _read_vec_flt_binary(fd): 269 | header = fd.read(3).decode() 270 | if header == 'FV ' : sample_size = 4 # floats 271 | elif header == 'DV ' : sample_size = 8 # doubles 272 | else : raise UnknownVectorHeader("The header contained '%s'" % header) 273 | assert (sample_size > 0) 274 | # Dimension, 275 | assert (fd.read(1).decode() == '\4'); # int-size 276 | vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # vector dim 277 | if vec_size == 0: 278 | return np.array([], dtype='float32') 279 | # Read whole vector, 280 | buf = fd.read(vec_size * sample_size) 281 | if sample_size == 4 : ans = np.frombuffer(buf, dtype='float32') 282 | elif sample_size == 8 : ans = np.frombuffer(buf, dtype='float64') 283 | else : raise BadSampleSize 284 | return ans 285 | 286 | 287 | # Writing, 288 | def write_vec_flt(file_or_fd, v, key=''): 289 | """ write_vec_flt(f, v, key='') 290 | Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats. 291 | Arguments: 292 | file_or_fd : filename or opened file descriptor for writing, 293 | v : the vector to be stored, 294 | key (optional) : used for writing ark-file, the utterance-id gets written before the vector. 295 | 296 | Example of writing single vector: 297 | kaldi_io.write_vec_flt(filename, vec) 298 | 299 | Example of writing arkfile: 300 | with open(ark_file,'w') as f: 301 | for key,vec in dict.iteritems(): 302 | kaldi_io.write_vec_flt(f, vec, key=key) 303 | """ 304 | fd = open_or_fd(file_or_fd, mode='wb') 305 | if sys.version_info[0] == 3: assert(fd.mode == 'wb') 306 | try: 307 | if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), 308 | fd.write('\0B'.encode()) # we write binary! 309 | # Data-type, 310 | if v.dtype == 'float32': fd.write('FV '.encode()) 311 | elif v.dtype == 'float64': fd.write('DV '.encode()) 312 | else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % v.dtype) 313 | # Dim, 314 | fd.write('\04'.encode()) 315 | fd.write(struct.pack(np.dtype('uint32').char, v.shape[0])) # dim 316 | # Data, 317 | fd.write(v.tobytes()) 318 | finally: 319 | if fd is not file_or_fd : fd.close() 320 | 321 | 322 | ################################################# 323 | # Float matrices (features, transformations, ...), 324 | 325 | # Reading, 326 | def read_mat_scp(file_or_fd): 327 | """ generator(key,mat) = read_mat_scp(file_or_fd) 328 | Returns generator of (key,matrix) tuples, read according to kaldi scp. 329 | file_or_fd : scp, gzipped scp, pipe or opened file descriptor. 330 | 331 | Iterate the scp: 332 | for key,mat in kaldi_io.read_mat_scp(file): 333 | ... 334 | 335 | Read scp to a 'dictionary': 336 | d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) } 337 | """ 338 | fd = open_or_fd(file_or_fd) 339 | try: 340 | for line in fd: 341 | (key,rxfile) = line.decode().split(' ') 342 | mat = read_mat(rxfile) 343 | yield key, mat 344 | finally: 345 | if fd is not file_or_fd : fd.close() 346 | 347 | def read_mat_ark(file_or_fd): 348 | """ generator(key,mat) = read_mat_ark(file_or_fd) 349 | Returns generator of (key,matrix) tuples, read from ark file/stream. 350 | file_or_fd : scp, gzipped scp, pipe or opened file descriptor. 351 | 352 | Iterate the ark: 353 | for key,mat in kaldi_io.read_mat_ark(file): 354 | ... 355 | 356 | Read ark to a 'dictionary': 357 | d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) } 358 | """ 359 | fd = open_or_fd(file_or_fd) 360 | try: 361 | key = read_key(fd) 362 | while key: 363 | mat = read_mat(fd) 364 | yield key, mat 365 | key = read_key(fd) 366 | finally: 367 | if fd is not file_or_fd : fd.close() 368 | 369 | def read_mat(file_or_fd): 370 | """ [mat] = read_mat(file_or_fd) 371 | Reads single kaldi matrix, supports ascii and binary. 372 | file_or_fd : file, gzipped file, pipe or opened file descriptor. 373 | """ 374 | fd = open_or_fd(file_or_fd) 375 | try: 376 | binary = fd.read(2).decode() 377 | if binary == '\0B' : 378 | mat = _read_mat_binary(fd) 379 | else: 380 | assert(binary == ' [') 381 | mat = _read_mat_ascii(fd) 382 | finally: 383 | if fd is not file_or_fd: fd.close() 384 | return mat 385 | 386 | def _read_mat_binary(fd): 387 | # Data type 388 | header = fd.read(3).decode() 389 | # 'CM', 'CM2', 'CM3' are possible values, 390 | if header.startswith('CM'): return _read_compressed_mat(fd, header) 391 | elif header == 'FM ': sample_size = 4 # floats 392 | elif header == 'DM ': sample_size = 8 # doubles 393 | else: raise UnknownMatrixHeader("The header contained '%s'" % header) 394 | assert(sample_size > 0) 395 | # Dimensions 396 | s1, rows, s2, cols = np.frombuffer(fd.read(10), dtype='int8,int32,int8,int32', count=1)[0] 397 | # Read whole matrix 398 | buf = fd.read(rows * cols * sample_size) 399 | if sample_size == 4 : vec = np.frombuffer(buf, dtype='float32') 400 | elif sample_size == 8 : vec = np.frombuffer(buf, dtype='float64') 401 | else : raise BadSampleSize 402 | mat = np.reshape(vec,(rows,cols)) 403 | return mat 404 | 405 | def _read_mat_ascii(fd): 406 | rows = [] 407 | while 1: 408 | line = fd.readline().decode() 409 | if (len(line) == 0) : raise BadInputFormat # eof, should not happen! 410 | if len(line.strip()) == 0 : continue # skip empty line 411 | arr = line.strip().split() 412 | if arr[-1] != ']': 413 | rows.append(np.array(arr,dtype='float32')) # not last line 414 | else: 415 | rows.append(np.array(arr[:-1],dtype='float32')) # last line 416 | mat = np.vstack(rows) 417 | return mat 418 | 419 | 420 | def _read_compressed_mat(fd, format): 421 | """ Read a compressed matrix, 422 | see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h 423 | methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...), 424 | """ 425 | assert(format == 'CM ') # The formats CM2, CM3 are not supported... 426 | 427 | # Format of header 'struct', 428 | global_header = np.dtype([('minvalue','float32'),('range','float32'),('num_rows','int32'),('num_cols','int32')]) # member '.format' is not written, 429 | per_col_header = np.dtype([('percentile_0','uint16'),('percentile_25','uint16'),('percentile_75','uint16'),('percentile_100','uint16')]) 430 | 431 | # Read global header, 432 | globmin, globrange, rows, cols = np.frombuffer(fd.read(16), dtype=global_header, count=1)[0] 433 | 434 | # The data is structed as [Colheader, ... , Colheader, Data, Data , .... ] 435 | # { cols }{ size } 436 | col_headers = np.frombuffer(fd.read(cols*8), dtype=per_col_header, count=cols) 437 | col_headers = np.array([np.array([x for x in y]) * globrange * 1.52590218966964e-05 + globmin for y in col_headers], dtype=np.float32) 438 | data = np.reshape(np.frombuffer(fd.read(cols*rows), dtype='uint8', count=cols*rows), newshape=(cols,rows)) # stored as col-major, 439 | 440 | mat = np.zeros((cols,rows), dtype='float32') 441 | p0 = col_headers[:, 0].reshape(-1, 1) 442 | p25 = col_headers[:, 1].reshape(-1, 1) 443 | p75 = col_headers[:, 2].reshape(-1, 1) 444 | p100 = col_headers[:, 3].reshape(-1, 1) 445 | mask_0_64 = (data <= 64) 446 | mask_193_255 = (data > 192) 447 | mask_65_192 = (~(mask_0_64 | mask_193_255)) 448 | 449 | mat += (p0 + (p25 - p0) / 64. * data) * mask_0_64.astype(np.float32) 450 | mat += (p25 + (p75 - p25) / 128. * (data - 64)) * mask_65_192.astype(np.float32) 451 | mat += (p75 + (p100 - p75) / 63. * (data - 192)) * mask_193_255.astype(np.float32) 452 | 453 | return mat.T # transpose! col-major -> row-major, 454 | 455 | 456 | # Writing, 457 | def write_mat(file_or_fd, m, key=''): 458 | """ write_mat(f, m, key='') 459 | Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats. 460 | Arguments: 461 | file_or_fd : filename of opened file descriptor for writing, 462 | m : the matrix to be stored, 463 | key (optional) : used for writing ark-file, the utterance-id gets written before the matrix. 464 | 465 | Example of writing single matrix: 466 | kaldi_io.write_mat(filename, mat) 467 | 468 | Example of writing arkfile: 469 | with open(ark_file,'w') as f: 470 | for key,mat in dict.iteritems(): 471 | kaldi_io.write_mat(f, mat, key=key) 472 | """ 473 | fd = open_or_fd(file_or_fd, mode='wb') 474 | if sys.version_info[0] == 3: assert(fd.mode == 'wb') 475 | try: 476 | if key != '' : fd.write((key+' ').encode("latin1")) # ark-files have keys (utterance-id), 477 | fd.write('\0B'.encode()) # we write binary! 478 | # Data-type, 479 | if m.dtype == 'float32': fd.write('FM '.encode()) 480 | elif m.dtype == 'float64': fd.write('DM '.encode()) 481 | else: raise UnsupportedDataType("'%s', please use 'float32' or 'float64'" % m.dtype) 482 | # Dims, 483 | fd.write('\04'.encode()) 484 | fd.write(struct.pack(np.dtype('uint32').char, m.shape[0])) # rows 485 | fd.write('\04'.encode()) 486 | fd.write(struct.pack(np.dtype('uint32').char, m.shape[1])) # cols 487 | # Data, 488 | fd.write(m.tobytes()) 489 | finally: 490 | if fd is not file_or_fd : fd.close() 491 | 492 | 493 | ################################################# 494 | # 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...) 495 | # Corresponds to: vector > > 496 | # - outer vector: time axis 497 | # - inner vector: records at the time 498 | # - tuple: int = index, float = value 499 | # 500 | 501 | def read_cnet_ark(file_or_fd): 502 | """ Alias of function 'read_post_ark()', 'cnet' = confusion network """ 503 | return read_post_ark(file_or_fd) 504 | 505 | def read_post_rxspec(file_): 506 | """ adaptor to read both 'ark:...' and 'scp:...' inputs of posteriors, 507 | """ 508 | if file_.startswith("ark:"): 509 | return read_post_ark(file_) 510 | elif file_.startswith("scp:"): 511 | return read_post_scp(file_) 512 | else: 513 | print("unsupported intput type: %s" % file_) 514 | print("it should begint with 'ark:' or 'scp:'") 515 | sys.exit(1) 516 | 517 | def read_post_scp(file_or_fd): 518 | """ generator(key,post) = read_post_scp(file_or_fd) 519 | Returns generator of (key,post) tuples, read according to kaldi scp. 520 | file_or_fd : scp, gzipped scp, pipe or opened file descriptor. 521 | 522 | Iterate the scp: 523 | for key,post in kaldi_io.read_post_scp(file): 524 | ... 525 | 526 | Read scp to a 'dictionary': 527 | d = { key:post for key,post in kaldi_io.read_post_scp(file) } 528 | """ 529 | fd = open_or_fd(file_or_fd) 530 | try: 531 | for line in fd: 532 | (key,rxfile) = line.decode().split(' ') 533 | post = read_post(rxfile) 534 | yield key, post 535 | finally: 536 | if fd is not file_or_fd : fd.close() 537 | 538 | def read_post_ark(file_or_fd): 539 | """ generator(key,vec>) = read_post_ark(file) 540 | Returns generator of (key,posterior) tuples, read from ark file. 541 | file_or_fd : ark, gzipped ark, pipe or opened file descriptor. 542 | 543 | Iterate the ark: 544 | for key,post in kaldi_io.read_post_ark(file): 545 | ... 546 | 547 | Read ark to a 'dictionary': 548 | d = { key:post for key,post in kaldi_io.read_post_ark(file) } 549 | """ 550 | fd = open_or_fd(file_or_fd) 551 | try: 552 | key = read_key(fd) 553 | while key: 554 | post = read_post(fd) 555 | yield key, post 556 | key = read_key(fd) 557 | finally: 558 | if fd is not file_or_fd: fd.close() 559 | 560 | def read_post(file_or_fd): 561 | """ [post] = read_post(file_or_fd) 562 | Reads single kaldi 'Posterior' in binary format. 563 | 564 | The 'Posterior' is C++ type 'vector > >', 565 | the outer-vector is usually time axis, inner-vector are the records 566 | at given time, and the tuple is composed of an 'index' (integer) 567 | and a 'float-value'. The 'float-value' can represent a probability 568 | or any other numeric value. 569 | 570 | Returns vector of vectors of tuples. 571 | """ 572 | fd = open_or_fd(file_or_fd) 573 | ans=[] 574 | binary = fd.read(2).decode(); assert(binary == '\0B'); # binary flag 575 | assert(fd.read(1).decode() == '\4'); # int-size 576 | outer_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) 577 | 578 | # Loop over 'outer-vector', 579 | for i in range(outer_vec_size): 580 | assert(fd.read(1).decode() == '\4'); # int-size 581 | inner_vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of records for frame (or bin) 582 | data = np.frombuffer(fd.read(inner_vec_size*10), dtype=[('size_idx','int8'),('idx','int32'),('size_post','int8'),('post','float32')], count=inner_vec_size) 583 | assert(data[0]['size_idx'] == 4) 584 | assert(data[0]['size_post'] == 4) 585 | ans.append(data[['idx','post']].tolist()) 586 | 587 | if fd is not file_or_fd: fd.close() 588 | return ans 589 | 590 | 591 | ################################################# 592 | # Kaldi Confusion Network bin begin/end times, 593 | # (kaldi stores CNs time info separately from the Posterior). 594 | # 595 | 596 | def read_cntime_ark(file_or_fd): 597 | """ generator(key,vec>) = read_cntime_ark(file_or_fd) 598 | Returns generator of (key,cntime) tuples, read from ark file. 599 | file_or_fd : file, gzipped file, pipe or opened file descriptor. 600 | 601 | Iterate the ark: 602 | for key,time in kaldi_io.read_cntime_ark(file): 603 | ... 604 | 605 | Read ark to a 'dictionary': 606 | d = { key:time for key,time in kaldi_io.read_post_ark(file) } 607 | """ 608 | fd = open_or_fd(file_or_fd) 609 | try: 610 | key = read_key(fd) 611 | while key: 612 | cntime = read_cntime(fd) 613 | yield key, cntime 614 | key = read_key(fd) 615 | finally: 616 | if fd is not file_or_fd : fd.close() 617 | 618 | def read_cntime(file_or_fd): 619 | """ [cntime] = read_cntime(file_or_fd) 620 | Reads single kaldi 'Confusion Network time info', in binary format: 621 | C++ type: vector >. 622 | (begin/end times of bins at the confusion network). 623 | 624 | Binary layout is ' ...' 625 | 626 | file_or_fd : file, gzipped file, pipe or opened file descriptor. 627 | 628 | Returns vector of tuples. 629 | """ 630 | fd = open_or_fd(file_or_fd) 631 | binary = fd.read(2).decode(); assert(binary == '\0B'); # assuming it's binary 632 | 633 | assert(fd.read(1).decode() == '\4'); # int-size 634 | vec_size = np.frombuffer(fd.read(4), dtype='int32', count=1)[0] # number of frames (or bins) 635 | 636 | data = np.frombuffer(fd.read(vec_size*10), dtype=[('size_beg','int8'),('t_beg','float32'),('size_end','int8'),('t_end','float32')], count=vec_size) 637 | assert(data[0]['size_beg'] == 4) 638 | assert(data[0]['size_end'] == 4) 639 | ans = data[['t_beg','t_end']].tolist() # Return vector of tuples (t_beg,t_end), 640 | 641 | if fd is not file_or_fd : fd.close() 642 | return ans 643 | 644 | 645 | ################################################# 646 | # Segments related, 647 | # 648 | 649 | # Segments as 'Bool vectors' can be handy, 650 | # - for 'superposing' the segmentations, 651 | # - for frame-selection in Speaker-ID experiments, 652 | def read_segments_as_bool_vec(segments_file): 653 | """ [ bool_vec ] = read_segments_as_bool_vec(segments_file) 654 | using kaldi 'segments' file for 1 wav, format : ' ' 655 | - t-beg, t-end is in seconds, 656 | - assumed 100 frames/second, 657 | """ 658 | segs = np.loadtxt(segments_file, dtype='object,object,f,f', ndmin=1) 659 | # Sanity checks, 660 | assert(len(segs) > 0) # empty segmentation is an error, 661 | assert(len(np.unique([rec[1] for rec in segs ])) == 1) # segments with only 1 wav-file, 662 | # Convert time to frame-indexes, 663 | start = np.rint([100 * rec[2] for rec in segs]).astype(int) 664 | end = np.rint([100 * rec[3] for rec in segs]).astype(int) 665 | # Taken from 'read_lab_to_bool_vec', htk.py, 666 | frms = np.repeat(np.r_[np.tile([False,True], len(end)), False], 667 | np.r_[np.c_[start - np.r_[0, end[:-1]], end-start].flat, 0]) 668 | assert np.sum(end-start) == np.sum(frms) 669 | return frms 670 | 671 | 672 | -------------------------------------------------------------------------------- /nnet/libs/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | SI-SNR(scale-invariant SNR/SDR) measure of speech separation 3 | """ 4 | 5 | import numpy as np 6 | 7 | from itertools import permutations 8 | 9 | 10 | def si_snr(x, s, remove_dc=True): 11 | """ 12 | Compute SI-SNR 13 | Arguments: 14 | x: vector, enhanced/separated signal 15 | s: vector, reference signal(ground truth) 16 | """ 17 | 18 | def vec_l2norm(x): 19 | return np.linalg.norm(x, 2) 20 | 21 | # zero mean, seems do not hurt results 22 | if remove_dc: 23 | x_zm = x - np.mean(x) 24 | s_zm = s - np.mean(s) 25 | t = np.inner(x_zm, s_zm) * s_zm / vec_l2norm(s_zm)**2 26 | n = x_zm - t 27 | else: 28 | t = np.inner(x, s) * s / vec_l2norm(s)**2 29 | n = x - t 30 | return 20 * np.log10(vec_l2norm(t) / vec_l2norm(n)) 31 | 32 | 33 | def permute_si_snr(xlist, slist): 34 | """ 35 | Compute SI-SNR between N pairs 36 | Arguments: 37 | x: list[vector], enhanced/separated signal 38 | s: list[vector], reference signal(ground truth) 39 | """ 40 | 41 | def si_snr_avg(xlist, slist): 42 | return sum([si_snr(x, s) for x, s in zip(xlist, slist)]) / len(xlist) 43 | 44 | N = len(xlist) 45 | if N != len(slist): 46 | raise RuntimeError( 47 | "size do not match between xlist and slist: {:d} vs {:d}".format( 48 | N, len(slist))) 49 | si_snrs = [] 50 | for order in permutations(range(N)): 51 | si_snrs.append(si_snr_avg(xlist, [slist[n] for n in order])) 52 | return max(si_snrs) 53 | -------------------------------------------------------------------------------- /nnet/libs/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | from itertools import permutations 6 | from collections import defaultdict 7 | 8 | import torch as th 9 | import torch.nn.functional as F 10 | from torch.optim.lr_scheduler import ReduceLROnPlateau 11 | from torch.nn.utils import clip_grad_norm_ 12 | 13 | from .utils import get_logger 14 | 15 | #from pudb import set_trace 16 | #set_trace() 17 | 18 | 19 | def load_obj(obj, device): 20 | """ 21 | Offload tensor object in obj to cuda device 22 | """ 23 | 24 | def cuda(obj): 25 | return obj.to(device) if isinstance(obj, th.Tensor) else obj 26 | 27 | if isinstance(obj, dict): 28 | return {key: load_obj(obj[key], device) for key in obj} 29 | elif isinstance(obj, list): 30 | return [load_obj(val, device) for val in obj] 31 | else: 32 | return cuda(obj) 33 | 34 | 35 | class SimpleTimer(object): 36 | """ 37 | A simple timer 38 | """ 39 | 40 | def __init__(self): 41 | self.reset() 42 | 43 | def reset(self): 44 | self.start = time.time() 45 | 46 | def elapsed(self): 47 | return (time.time() - self.start) / 60 48 | 49 | 50 | class ProgressReporter(object): 51 | """ 52 | A simple progress reporter 53 | """ 54 | 55 | def __init__(self, logger, period=100): 56 | self.period = period 57 | self.logger = logger 58 | self.loss = [] 59 | self.timer = SimpleTimer() 60 | 61 | def add(self, loss): 62 | self.loss.append(loss) 63 | N = len(self.loss) 64 | if not N % self.period: 65 | avg = sum(self.loss[-self.period:]) / self.period 66 | self.logger.info("Processed {:d} batches" 67 | "(loss = {:+.2f})...".format(N, avg)) 68 | 69 | def report(self, details=False): 70 | N = len(self.loss) 71 | if details: 72 | sstr = ",".join(map(lambda f: "{:.2f}".format(f), self.loss)) 73 | self.logger.info("Loss on {:d} batches: {}".format(N, sstr)) 74 | return { 75 | "loss": sum(self.loss) / N, 76 | "batches": N, 77 | "cost": self.timer.elapsed() 78 | } 79 | 80 | 81 | class Trainer(object): 82 | def __init__(self, 83 | nnet, 84 | checkpoint="checkpoint", 85 | optimizer="adam", 86 | gpuid=0, 87 | optimizer_kwargs=None, 88 | clip_norm=None, 89 | min_lr=0, 90 | patience=0, 91 | factor=0.5, 92 | logging_period=100, 93 | resume=None, 94 | no_impr=6): 95 | if not th.cuda.is_available(): 96 | raise RuntimeError("CUDA device unavailable...exist") 97 | if not isinstance(gpuid, tuple): 98 | gpuid = (gpuid, ) 99 | self.device = th.device("cuda:{}".format(gpuid[0])) 100 | self.gpuid = gpuid 101 | if checkpoint and not os.path.exists(checkpoint): 102 | os.makedirs(checkpoint) 103 | self.checkpoint = checkpoint 104 | self.logger = get_logger( 105 | os.path.join(checkpoint, "trainer.log"), file=True) 106 | 107 | self.clip_norm = clip_norm 108 | self.logging_period = logging_period 109 | self.cur_epoch = 0 # zero based 110 | self.no_impr = no_impr 111 | 112 | if resume: 113 | if not os.path.exists(resume): 114 | raise FileNotFoundError( 115 | "Could not find resume checkpoint: {}".format(resume)) 116 | cpt = th.load(resume, map_location="cpu") 117 | self.cur_epoch = cpt["epoch"] 118 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format( 119 | resume, self.cur_epoch)) 120 | # load nnet 121 | nnet.load_state_dict(cpt["model_state_dict"]) 122 | self.nnet = nnet.to(self.device) 123 | self.optimizer = self.create_optimizer( 124 | optimizer, optimizer_kwargs, state=cpt["optim_state_dict"]) 125 | else: 126 | self.nnet = nnet.to(self.device) 127 | self.optimizer = self.create_optimizer(optimizer, optimizer_kwargs) 128 | self.scheduler = ReduceLROnPlateau( 129 | self.optimizer, 130 | mode="min", 131 | factor=factor, 132 | patience=patience, 133 | min_lr=min_lr, 134 | verbose=True) 135 | self.num_params = sum( 136 | [param.nelement() for param in nnet.parameters()]) / 10.0**6 137 | 138 | # logging 139 | self.logger.info("Model summary:\n{}".format(nnet)) 140 | self.logger.info("Loading model to GPUs:{}, #param: {:.2f}M".format( 141 | gpuid, self.num_params)) 142 | if clip_norm: 143 | self.logger.info( 144 | "Gradient clipping by {}, default L2".format(clip_norm)) 145 | 146 | def save_checkpoint(self, best=True): 147 | cpt = { 148 | "epoch": self.cur_epoch, 149 | "model_state_dict": self.nnet.state_dict(), 150 | "optim_state_dict": self.optimizer.state_dict() 151 | } 152 | th.save( 153 | cpt, 154 | os.path.join(self.checkpoint, 155 | "{0}.pt.tar".format("best" if best else "last"))) 156 | 157 | def create_optimizer(self, optimizer, kwargs, state=None): 158 | supported_optimizer = { 159 | "sgd": th.optim.SGD, # momentum, weight_decay, lr 160 | "rmsprop": th.optim.RMSprop, # momentum, weight_decay, lr 161 | "adam": th.optim.Adam, # weight_decay, lr 162 | "adadelta": th.optim.Adadelta, # weight_decay, lr 163 | "adagrad": th.optim.Adagrad, # lr, lr_decay, weight_decay 164 | "adamax": th.optim.Adamax # lr, weight_decay 165 | # ... 166 | } 167 | if optimizer not in supported_optimizer: 168 | raise ValueError("Now only support optimizer {}".format(optimizer)) 169 | opt = supported_optimizer[optimizer](self.nnet.parameters(), **kwargs) 170 | self.logger.info("Create optimizer {0}: {1}".format(optimizer, kwargs)) 171 | if state is not None: 172 | opt.load_state_dict(state) 173 | self.logger.info("Load optimizer state dict from checkpoint") 174 | return opt 175 | 176 | def compute_loss(self, egs): 177 | raise NotImplementedError 178 | 179 | def train(self, data_loader): 180 | self.logger.info("Set train mode...") 181 | self.nnet.train() 182 | reporter = ProgressReporter(self.logger, period=self.logging_period) 183 | 184 | for egs in data_loader: 185 | ### 186 | #egs['aux_mfcc'] = list(egs['aux_mfcc']) 187 | ### 188 | 189 | # load to gpu 190 | egs = load_obj(egs, self.device) 191 | 192 | self.optimizer.zero_grad() 193 | loss = self.compute_loss(egs) 194 | loss.backward() 195 | if self.clip_norm: 196 | clip_grad_norm_(self.nnet.parameters(), self.clip_norm) 197 | self.optimizer.step() 198 | 199 | reporter.add(loss.item()) 200 | return reporter.report() 201 | 202 | def eval(self, data_loader): 203 | self.logger.info("Set eval mode...") 204 | self.nnet.eval() 205 | reporter = ProgressReporter(self.logger, period=self.logging_period) 206 | 207 | with th.no_grad(): 208 | for egs in data_loader: 209 | ### 210 | #egs['aux_mfcc'] = list(egs['aux_mfcc']) 211 | ### 212 | egs = load_obj(egs, self.device) 213 | loss = self.compute_loss(egs) 214 | reporter.add(loss.item()) 215 | return reporter.report(details=True) 216 | 217 | def run(self, train_loader, dev_loader, num_epochs=50): 218 | # avoid alloc memory from gpu0 219 | with th.cuda.device(self.gpuid[0]): 220 | stats = dict() 221 | # check if save is OK 222 | self.save_checkpoint(best=False) 223 | cv = self.eval(dev_loader) 224 | best_loss = cv["loss"] 225 | self.logger.info("START FROM EPOCH {:d}, LOSS = {:.4f}".format( 226 | self.cur_epoch, best_loss)) 227 | no_impr = 0 228 | # make sure not inf 229 | self.scheduler.best = best_loss 230 | while self.cur_epoch < num_epochs: 231 | self.cur_epoch += 1 232 | cur_lr = self.optimizer.param_groups[0]["lr"] 233 | stats[ 234 | "title"] = "Loss(time/N, lr={:.3e}) - Epoch {:2d}:".format( 235 | cur_lr, self.cur_epoch) 236 | tr = self.train(train_loader) 237 | stats["tr"] = "train = {:+.4f}({:.2f}m/{:d})".format( 238 | tr["loss"], tr["cost"], tr["batches"]) 239 | cv = self.eval(dev_loader) 240 | stats["cv"] = "dev = {:+.4f}({:.2f}m/{:d})".format( 241 | cv["loss"], cv["cost"], cv["batches"]) 242 | stats["scheduler"] = "" 243 | if cv["loss"] > best_loss: 244 | no_impr += 1 245 | stats["scheduler"] = "| no impr, best = {:.4f}".format( 246 | self.scheduler.best) 247 | else: 248 | best_loss = cv["loss"] 249 | no_impr = 0 250 | self.save_checkpoint(best=True) 251 | self.logger.info( 252 | "{title} {tr} | {cv} {scheduler}".format(**stats)) 253 | # schedule here 254 | self.scheduler.step(cv["loss"]) 255 | # flush scheduler info 256 | sys.stdout.flush() 257 | # save last checkpoint 258 | self.save_checkpoint(best=False) 259 | if no_impr == self.no_impr: 260 | self.logger.info( 261 | "Stop training cause no impr for {:d} epochs".format( 262 | no_impr)) 263 | break 264 | self.logger.info("Training for {:d}/{:d} epoches done!".format( 265 | self.cur_epoch, num_epochs)) 266 | 267 | 268 | class SiSnrTrainer(Trainer): 269 | def __init__(self, *args, **kwargs): 270 | super(SiSnrTrainer, self).__init__(*args, **kwargs) 271 | 272 | def sisnr(self, x, s, eps=1e-8): 273 | """ 274 | Arguments: 275 | x: separated signal, N x S tensor 276 | s: reference signal, N x S tensor 277 | Return: 278 | sisnr: N tensor 279 | """ 280 | 281 | def l2norm(mat, keepdim=False): 282 | return th.norm(mat, dim=-1, keepdim=keepdim) 283 | 284 | if x.shape != s.shape: 285 | raise RuntimeError( 286 | "Dimention mismatch when calculate si-snr, {} vs {}".format( 287 | x.shape, s.shape)) 288 | x_zm = x - th.mean(x, dim=-1, keepdim=True) 289 | s_zm = s - th.mean(s, dim=-1, keepdim=True) 290 | t = th.sum( 291 | x_zm * s_zm, dim=-1, 292 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) 293 | return 20 * th.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 294 | 295 | def mask_by_length(self, xs, lengths, fill=0): 296 | """Mask tensor according to length. 297 | 298 | Args: 299 | xs (Tensor): Batch of input tensor (B, `*`). 300 | lengths (LongTensor or List): Batch of lengths (B,). 301 | fill (int or float): Value to fill masked part. 302 | 303 | Returns: 304 | Tensor: Batch of masked input tensor (B, `*`). 305 | 306 | Examples: 307 | >>> x = torch.arange(5).repeat(3, 1) + 1 308 | >>> x 309 | tensor([[1, 2, 3, 4, 5], 310 | [1, 2, 3, 4, 5], 311 | [1, 2, 3, 4, 5]]) 312 | >>> lengths = [5, 3, 2] 313 | >>> mask_by_length(x, lengths) 314 | tensor([[1, 2, 3, 4, 5], 315 | [1, 2, 3, 0, 0], 316 | [1, 2, 0, 0, 0]]) 317 | 318 | """ 319 | assert xs.size(0) == len(lengths) 320 | ret = xs.data.new(*xs.size()).fill_(fill) 321 | for i, l in enumerate(lengths): 322 | ret[i, :l] = xs[i, :l] 323 | return ret 324 | 325 | def compute_loss(self, egs): 326 | ests, ests2, ests3, spk_pred = th.nn.parallel.data_parallel( 327 | self.nnet, (egs["mix"], egs["aux"], egs["aux_len"]), device_ids=self.gpuid) 328 | refs = egs["ref"] 329 | 330 | ## P x N 331 | N = egs["mix"].size(0) 332 | valid_len = egs["valid_len"] 333 | ests = self.mask_by_length(ests, valid_len) 334 | ests2 = self.mask_by_length(ests2, valid_len) 335 | ests3 = self.mask_by_length(ests3, valid_len) 336 | refs = self.mask_by_length(refs, valid_len) 337 | 338 | snr1 = self.sisnr(ests, refs) 339 | snr2 = self.sisnr(ests2, refs) 340 | snr3 = self.sisnr(ests3, refs) 341 | snr_loss = (-0.8*th.sum(snr1)-0.1*th.sum(snr2)-0.1*th.sum(snr3)) / N 342 | 343 | ce = th.nn.CrossEntropyLoss() 344 | ce_loss = ce(spk_pred, egs["spk_idx"]) 345 | #return snr_loss + 0.5 * ce_loss 346 | # PS: we found we can get similar result when we set the scale param to 0.5 or 1.0 347 | return snr_loss + 0.5 * ce_loss 348 | -------------------------------------------------------------------------------- /nnet/libs/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | 6 | def get_logger( 7 | name, 8 | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", 9 | date_format="%Y-%m-%d %H:%M:%S", 10 | file=False): 11 | """ 12 | Get python logger instance 13 | """ 14 | logger = logging.getLogger(name) 15 | logger.setLevel(logging.INFO) 16 | # file or console 17 | handler = logging.StreamHandler() if not file else logging.FileHandler( 18 | name) 19 | handler.setLevel(logging.INFO) 20 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 21 | handler.setFormatter(formatter) 22 | logger.addHandler(handler) 23 | return logger 24 | 25 | 26 | def dump_json(obj, fdir, name): 27 | """ 28 | Dump python object in json 29 | """ 30 | if fdir and not os.path.exists(fdir): 31 | os.makedirs(fdir) 32 | with open(os.path.join(fdir, name), "w") as f: 33 | json.dump(obj, f, indent=4, sort_keys=False) 34 | 35 | 36 | def load_json(fdir, name): 37 | """ 38 | Load json as python object 39 | """ 40 | path = os.path.join(fdir, name) 41 | if not os.path.exists(path): 42 | raise FileNotFoundError("Could not find json file: {}".format(path)) 43 | with open(path, "r") as f: 44 | obj = json.load(f) 45 | return obj 46 | -------------------------------------------------------------------------------- /nnet/separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import argparse 7 | 8 | import torch as th 9 | import numpy as np 10 | 11 | from conv_tas_net_decode import ConvTasNet 12 | 13 | from libs.utils import load_json, get_logger 14 | from libs.audio import WaveReader, write_wav 15 | from libs.kaldi_io import read_vec_flt, read_mat 16 | logger = get_logger(__name__) 17 | 18 | #from pudb import set_trace 19 | #set_trace() 20 | 21 | class NnetComputer(object): 22 | def __init__(self, cpt_dir, gpuid): 23 | self.device = th.device( 24 | "cuda:{}".format(gpuid)) if gpuid >= 0 else th.device("cpu") 25 | nnet = self._load_nnet(cpt_dir) 26 | self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet 27 | # set eval model 28 | self.nnet.eval() 29 | 30 | def _load_nnet(self, cpt_dir): 31 | nnet_conf = load_json(cpt_dir, "mdl.json") 32 | nnet = ConvTasNet(**nnet_conf) 33 | cpt_fname = os.path.join(cpt_dir, "best.pt.tar") 34 | cpt = th.load(cpt_fname, map_location="cpu") 35 | nnet.load_state_dict(cpt["model_state_dict"]) 36 | logger.info("Load checkpoint from {}, epoch {:d}".format( 37 | cpt_fname, cpt["epoch"])) 38 | return nnet 39 | 40 | def compute(self, samps, aux_samps, aux_samps_len): 41 | with th.no_grad(): 42 | raw = th.tensor(samps, dtype=th.float32, device=self.device) 43 | aux = th.tensor(aux_samps, dtype=th.float32, device=self.device) 44 | aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device) 45 | aux = aux.unsqueeze(0) 46 | sps,sps2,sps3,spk_pred = self.nnet(raw, aux, aux_len) 47 | sp_samps = np.squeeze(sps.detach().cpu().numpy()) 48 | return [sp_samps] 49 | 50 | 51 | def run(): 52 | mix_input = WaveReader("data/wsj0_2mix/tt/mix.scp", sample_rate=8000) 53 | aux_input = WaveReader("data/wsj0_2mix/tt/aux.scp", sample_rate=8000) 54 | computer = NnetComputer("exp_epoch114/conv_tasnet/conv-net", 3) 55 | for key, mix_samps in mix_input: 56 | aux_samps = aux_input[key] 57 | logger.info("Compute on utterance {}...".format(key)) 58 | spks = computer.compute(mix_samps, aux_samps, len(aux_samps)) 59 | norm = np.linalg.norm(mix_samps, np.inf) 60 | for idx, samps in enumerate(spks): 61 | samps = samps[:mix_samps.size] 62 | # norm 63 | samps = samps * norm / np.max(np.abs(samps)) 64 | write_wav( 65 | os.path.join("rec/", "{}.wav".format(key)), 66 | samps, 67 | fs=8000) 68 | logger.info("Compute over {:d} utterances".format(len(mix_input))) 69 | 70 | 71 | if __name__ == "__main__": 72 | run() 73 | -------------------------------------------------------------------------------- /nnet/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import pprint 4 | import argparse 5 | import random 6 | 7 | from libs.trainer import SiSnrTrainer 8 | from libs.dataset import make_dataloader 9 | from libs.utils import dump_json, get_logger 10 | 11 | from conv_tas_net import ConvTasNet 12 | from conf import trainer_conf, nnet_conf, train_data, dev_data, chunk_size 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def run(args): 18 | gpuids = tuple(map(int, args.gpus.split(","))) 19 | 20 | nnet = ConvTasNet(**nnet_conf) 21 | trainer = SiSnrTrainer(nnet, 22 | gpuid=gpuids, 23 | checkpoint=args.checkpoint, 24 | resume=args.resume, 25 | **trainer_conf) 26 | 27 | data_conf = { 28 | "train": train_data, 29 | "dev": dev_data, 30 | "chunk_size": chunk_size 31 | } 32 | for conf, fname in zip([nnet_conf, trainer_conf, data_conf], 33 | ["mdl.json", "trainer.json", "data.json"]): 34 | dump_json(conf, args.checkpoint, fname) 35 | 36 | train_loader = make_dataloader(train=True, 37 | data_kwargs=train_data, 38 | batch_size=args.batch_size, 39 | chunk_size=chunk_size, 40 | num_workers=args.num_workers) 41 | dev_loader = make_dataloader(train=False, 42 | data_kwargs=dev_data, 43 | batch_size=args.batch_size, 44 | chunk_size=chunk_size, 45 | num_workers=args.num_workers) 46 | 47 | trainer.run(train_loader, dev_loader, num_epochs=args.epochs) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description= 53 | "Command to start ConvTasNet training, configured from conf.py", 54 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 55 | parser.add_argument("--gpus", 56 | type=str, 57 | default="0,1", 58 | help="Training on which GPUs " 59 | "(one or more, egs: 0, \"0,1\")") 60 | parser.add_argument("--epochs", 61 | type=int, 62 | default=50, 63 | help="Number of training epochs") 64 | parser.add_argument("--checkpoint", 65 | type=str, 66 | required=True, 67 | help="Directory to dump models") 68 | parser.add_argument("--resume", 69 | type=str, 70 | default="", 71 | help="Exist model to resume training from") 72 | parser.add_argument("--batch-size", 73 | type=int, 74 | default=16, 75 | help="Number of utterances in each batch") 76 | parser.add_argument("--num-workers", 77 | type=int, 78 | default=4, 79 | help="Number of workers used in data loader") 80 | args = parser.parse_args() 81 | logger.info("Arguments in command:\n{}".format(pprint.pformat(vars(args)))) 82 | 83 | run(args) 84 | -------------------------------------------------------------------------------- /pretrain_model/link.txt: -------------------------------------------------------------------------------- 1 | Link:https://pan.baidu.com/s/1Yts3r9_-yX5ydI8s1sFtlQ 2 | Password:q5e3 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | tqdm==4.26.0 3 | numpy==1.15.2 4 | scipy==1.1.0 -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | cpt_dir=exp/conv_tasnet 6 | epochs=100 7 | # constrainted by GPU number & memory 8 | batch_size=20 9 | cache_size=16 10 | 11 | #[ $# -ne 2 ] && echo "Script error: $0 " && exit 1 12 | 13 | ./nnet/train.py --gpu "0,1,2" --epochs $epochs --batch-size $batch_size --checkpoint $cpt_dir/conv-net 14 | --------------------------------------------------------------------------------