├── chime6_rttm ├── dev_rttm ├── dev_rttm.scoring ├── eval_rttm └── eval_rttm.scoring ├── conf └── tsvad_config.json ├── inference.py ├── model ├── __pycache__ │ ├── tsvad.cpython-36.pyc │ ├── tsvad_dprnn.cpython-36.pyc │ ├── tsvad_dprnn_v2.cpython-36.pyc │ ├── tsvad_dprnn_v3.cpython-36.pyc │ ├── tsvad_tf.cpython-36.pyc │ └── tsvad_v2.cpython-36.pyc ├── tsvad.py └── tsvad_dprnn.py ├── run.sh ├── test.sh ├── train.py ├── trainer ├── __pycache__ │ ├── basic.cpython-36.pyc │ ├── mix_wave.cpython-36.pyc │ └── radam.cpython-36.pyc ├── basic.py └── radam.py ├── uem_file.scoring └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── dataset_loader.cpython-36.pyc ├── dataset_loader.cpython-37.pyc ├── utils.cpython-36.pyc └── utils.cpython-37.pyc ├── convert_prob_to_rttm.py ├── dataset_loader.py ├── parse_options.sh └── utils.py /conf/tsvad_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_config": { 3 | "training_dir": "data/train", 4 | "output_directory": "checkpoints/tsvad", 5 | "trainer_type": "basic", 6 | "model_type": "tsvad_dprnn_v3", 7 | "max_iter": 500000, 8 | "optimize_param": { 9 | "optim_type": "RAdam", 10 | "learning_rate": 1e-4, 11 | "max_grad_norm": 10, 12 | "lr_scheduler":{ 13 | "step_size": 100000, 14 | "gamma": 0.5, 15 | "last_epoch": -1 16 | } 17 | }, 18 | "batch_size": 32, 19 | "nframes": 40, 20 | "iters_per_checkpoint": 10000, 21 | "iters_per_log": 1, 22 | "seed": 1234, 23 | "checkpoint_path": "" 24 | }, 25 | "infer_config": { 26 | "model_type": "tsvad_dprnn_v3", 27 | "model_path": "checkpoints_tsvad", 28 | "output_dir": "hyp" 29 | }, 30 | "model_config": { 31 | "out_channels": [ 64, 64, 128, 128], 32 | "nproj": 384, 33 | "cell": 896, 34 | "dprnn_layers": 6 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0,os.getcwd()) 4 | 5 | import numpy 6 | import torch 7 | import logging 8 | from tqdm import tqdm 9 | from kaldiio import WriteHelper 10 | from importlib import import_module 11 | 12 | from torch.utils.data import DataLoader 13 | from util.dataset_loader import EvalDataset 14 | 15 | def compute_tsvad_weights(writer, utt, preds): 16 | for i in range(4): 17 | pred = preds[:, i] 18 | uid = utt + '-' + str(i+1) 19 | writer(uid, pred) 20 | 21 | def inference(infer_config): 22 | # Initial 23 | model_type = infer_config.get('model_type', 'tsvad') 24 | model_path = infer_config.get('model_path', '') 25 | output_dir = infer_config.get('output_dir', '') 26 | feats_dir = infer_config.get('feats_dir', '') 27 | ivectors_dir = infer_config.get('ivectors_dir', '') 28 | 29 | # Load Model 30 | module = import_module('model.{}'.format(model_type)) 31 | MODEL = getattr(module, 'Model') 32 | model = MODEL() 33 | model.load_state_dict(torch.load(model_path)['model']) 34 | 35 | print (model) 36 | 37 | model = model.cuda() 38 | 39 | # Load evaluation data 40 | evalset = EvalDataset(feats_dir=feats_dir, ivectors_dir=ivectors_dir) 41 | eval_loader = DataLoader(evalset, num_workers=0, shuffle=False, batch_size=1) 42 | 43 | # Prepare logger 44 | logger = logging.getLogger("logger") 45 | handler1 = logging.StreamHandler() 46 | logger.setLevel(logging.INFO) 47 | 48 | formatter = logging.Formatter("%(asctime)s %(message)s", 49 | datefmt="%m-%d %H:%M:%S") 50 | handler1.setFormatter(formatter) 51 | logger.addHandler(handler1) 52 | 53 | logger.info("Evaluation utterances: {}".format(len(evalset))) 54 | 55 | # ================ MAIN EVALUATION LOOP! =================== 56 | 57 | logger.info("Start evaluation...") 58 | 59 | model.eval() 60 | with WriteHelper('ark,t:{}/weights.ark'.format(output_dir)) as writer: 61 | for i, batch in tqdm(enumerate(eval_loader)): 62 | utt, _, _ = batch 63 | with torch.no_grad(): 64 | preds = model.inference(batch).squeeze(0).cpu().numpy() 65 | compute_tsvad_weights(writer, utt[0], preds) 66 | 67 | if __name__ == "__main__": 68 | import argparse 69 | import json 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument('-c', '--config', type=str, default='conf/config_dcase.json', 73 | help='JSON file for configuration') 74 | parser.add_argument('-p', '--model_path', type=str, default=None, 75 | help='model path to load') 76 | parser.add_argument('-o', '--output_dir', type=str, default=None, 77 | help='output directory') 78 | parser.add_argument('-f', '--feats_dir', type=str, default=None, 79 | help='output directory') 80 | parser.add_argument('-i', '--ivectors_dir', type=str, default=None, 81 | help='output directory') 82 | parser.add_argument('-g', '--gpu', type=str, default='0', 83 | help='Using gpu #') 84 | args = parser.parse_args() 85 | 86 | # Parse configs. Globals nicer in this case 87 | with open(args.config) as f: 88 | data = f.read() 89 | config = json.loads(data) 90 | infer_config = config["infer_config"] 91 | global model_config 92 | model_config = config['model_config'] 93 | 94 | if args.model_path is not None: 95 | infer_config['model_path'] = args.model_path 96 | if args.output_dir is not None: 97 | infer_config['output_dir'] = args.output_dir 98 | if args.feats_dir is not None: 99 | infer_config['feats_dir'] = args.feats_dir 100 | if args.ivectors_dir is not None: 101 | infer_config['ivectors_dir'] = args.ivectors_dir 102 | 103 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 104 | 105 | torch.backends.cudnn.enabled = True 106 | torch.backends.cudnn.benchmark = False 107 | 108 | inference(infer_config) 109 | -------------------------------------------------------------------------------- /model/__pycache__/tsvad.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tsvad_dprnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad_dprnn.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tsvad_dprnn_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad_dprnn_v2.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tsvad_dprnn_v3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad_dprnn_v3.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tsvad_tf.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad_tf.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/tsvad_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/model/__pycache__/tsvad_v2.cpython-36.pyc -------------------------------------------------------------------------------- /model/tsvad.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class CNN_ReLU_BatchNorm(nn.Module): 8 | def __init__(self, in_channels=1, out_channels=64, kernel_size=3, stride=(1, 1), padding=1): 9 | super(CNN_ReLU_BatchNorm, self).__init__() 10 | self.cnn = nn.Sequential( 11 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 12 | nn.ReLU(), 13 | nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99), 14 | ) 15 | 16 | def forward(self, feature): 17 | feature = self.cnn(feature) 18 | return feature 19 | 20 | 21 | class BLSTMP(nn.Module): 22 | def __init__(self, n_in, n_hidden, nproj=160, dropout=0, num_layers=1): 23 | super(BLSTMP, self).__init__() 24 | 25 | self.num_layers = num_layers 26 | 27 | self.rnns = nn.ModuleList([nn.LSTM(n_in, n_hidden, bidirectional=True, dropout=dropout, batch_first=True)]) 28 | self.linears = nn.ModuleList([nn.Linear(2*n_hidden, 2*nproj)]) 29 | 30 | for i in range(num_layers-1): 31 | self.rnns.append(nn.LSTM(2*nproj, n_hidden, bidirectional=True, dropout=dropout, batch_first=True)) 32 | self.linears.append(nn.Linear(2*n_hidden, 2*nproj)) 33 | 34 | def forward(self, feature): 35 | recurrent, _ = self.rnns[0](feature) 36 | output = self.linears[0](recurrent) 37 | 38 | for i in range(self.num_layers-1): 39 | output, _ = self.rnns[i+1](output) 40 | output = self.linears[i+1](output) 41 | 42 | return output 43 | 44 | 45 | class Model(nn.Module): 46 | def __init__(self, out_channels=[64, 64, 128, 128], rproj=128, nproj=160, cell=896): 47 | super(Model, self).__init__() 48 | 49 | batchnorm = nn.BatchNorm2d(1, eps=0.001, momentum=0.99) 50 | 51 | cnn_relu_batchnorm1 = CNN_ReLU_BatchNorm(in_channels=1, out_channels=out_channels[0]) 52 | cnn_relu_batchnorm2 = CNN_ReLU_BatchNorm(in_channels=out_channels[0], out_channels=out_channels[1]) 53 | cnn_relu_batchnorm3 = CNN_ReLU_BatchNorm(in_channels=out_channels[1], out_channels=out_channels[2], stride=(1, 2)) 54 | cnn_relu_batchnorm4 = CNN_ReLU_BatchNorm(in_channels=out_channels[2], out_channels=out_channels[3]) 55 | 56 | self.cnn = nn.Sequential( 57 | batchnorm, 58 | cnn_relu_batchnorm1, 59 | cnn_relu_batchnorm2, 60 | cnn_relu_batchnorm3, 61 | cnn_relu_batchnorm4 62 | ) 63 | 64 | self.linear = nn.Linear(out_channels[-1]*20+100, 3*rproj) 65 | self.rnn_speaker_detection = BLSTMP(3*rproj, cell, num_layers=2) 66 | self.rnn_combine = BLSTMP(8*nproj, cell) 67 | 68 | self.output_layer = nn.Linear(nproj//2, 1) 69 | 70 | def forward(self, batch): 71 | feats, targets, ivectors = batch 72 | 73 | feats = self.cnn(feats) 74 | bs, chan, tframe, dim = feats.size() 75 | 76 | feats = feats.permute(0, 2, 1, 3) 77 | feats = feats.contiguous().view(bs, tframe, chan*dim) # B x 1 x T x 2560 78 | feats = feats.unsqueeze(1).repeat(1, 4, 1, 1) # B x 4 x T x 2560 79 | ivectors = ivectors.view(bs, 4, 100).unsqueeze(2) # B x 4 x 1 x 100 80 | ivectors = ivectors.repeat(1, 1, tframe, 1) # B x 4 x T x 100 81 | 82 | sd_in = torch.cat((feats, ivectors), dim=-1) # B x 4 x T x 2660 83 | sd_in = self.linear(sd_in).view(4*bs, tframe, -1) # 4B x T x 384 84 | sd_out = self.rnn_speaker_detection(sd_in) # 4B x T x 320 85 | sd_out = sd_out.contiguous().view(bs, 4, tframe, -1) # B x 4 x T x 320 86 | sd_out = sd_out.permute(0, 2, 1, 3) # B x T x 4 x 320 87 | sd_out = sd_out.contiguous().view(bs, tframe, -1) # B x T x 1280 88 | 89 | outputs = self.rnn_combine(sd_out) # B x T x 320 90 | outputs = outputs.contiguous().view(bs, tframe, 4, -1) # B x T x 4 x 80 91 | preds = self.output_layer(outputs).squeeze(-1) # B x T x 4 92 | preds = nn.Sigmoid()(preds) 93 | 94 | loss = nn.BCELoss(reduction='sum')(preds, targets) / tframe / bs 95 | loss_detail = {"diarization loss": loss.item()} 96 | 97 | return loss, loss_detail 98 | 99 | def inference(self, batch): 100 | _, feats, ivectors = batch 101 | 102 | feats = self.cnn(feats) 103 | bs, chan, tframe, dim = feats.size() 104 | 105 | feats = feats.permute(0, 2, 1, 3) 106 | feats = feats.contiguous().view(bs, tframe, chan*dim) # B x 1 x T x 2560 107 | feats = feats.unsqueeze(1).repeat(1, 4, 1, 1) # B x 4 x T x 2560 108 | ivectors = ivectors.view(bs, 4, 100).unsqueeze(2) # B x 4 x 1 x 100 109 | ivectors = ivectors.repeat(1, 1, tframe, 1) # B x 4 x T x 100 110 | 111 | sd_in = torch.cat((feats, ivectors), dim=-1) # B x 4 x T x 2660 112 | sd_in = self.linear(sd_in).view(4*bs, tframe, -1) # 4B x T x 384 113 | sd_out = self.rnn_speaker_detection(sd_in) # 4B x T x 320 114 | sd_out = sd_out.contiguous().view(bs, 4, tframe, -1) # B x 4 x T x 320 115 | sd_out = sd_out.permute(0, 2, 1, 3) # B x T x 4 x 320 116 | sd_out = sd_out.contiguous().view(bs, tframe, -1) # B x T x 1280 117 | 118 | outputs = self.rnn_combine(sd_out) # B x T x 320 119 | outputs = outputs.contiguous().view(bs, tframe, 4, -1) # B x T x 4 x 80 120 | preds = self.output_layer(outputs).squeeze(-1) # B x T x 4 121 | preds = nn.Sigmoid()(preds) 122 | 123 | return preds 124 | -------------------------------------------------------------------------------- /model/tsvad_dprnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | 7 | class CNN_ReLU_BatchNorm(nn.Module): 8 | def __init__(self, in_channels=1, out_channels=64, kernel_size=3, stride=(1, 1), padding=1): 9 | super(CNN_ReLU_BatchNorm, self).__init__() 10 | self.cnn = nn.Sequential( 11 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding), 12 | nn.ReLU(), 13 | nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.99), 14 | ) 15 | 16 | def forward(self, feature): 17 | feature = self.cnn(feature) 18 | return feature 19 | 20 | 21 | class BLSTMP(nn.Module): 22 | def __init__(self, n_in, n_hidden, nproj=384, dropout=0, num_layers=1): 23 | super(BLSTMP, self).__init__() 24 | 25 | self.num_layers = num_layers 26 | 27 | self.rnns = nn.ModuleList([nn.LSTM(n_in, n_hidden, dropout=dropout, bidirectional=True, batch_first=True)]) 28 | self.linears = nn.ModuleList([nn.Linear(n_hidden*2, nproj)]) 29 | 30 | for i in range(num_layers-1): 31 | self.rnns.append(nn.LSTM(nproj, n_hidden, dropout=dropout, bidirectional=True, batch_first=True)) 32 | self.linears.append(nn.Linear(n_hidden*2, nproj)) 33 | 34 | def forward(self, feature): 35 | recurrent, _ = self.rnns[0](feature) 36 | output = self.linears[0](recurrent) 37 | 38 | for i in range(self.num_layers-1): 39 | output, _ = self.rnns[i+1](output) 40 | output = self.linears[i+1](output) 41 | 42 | return output 43 | 44 | 45 | class DPRNN(nn.Module): 46 | def __init__(self, nproj=384, cell=896, num_layers=6): 47 | super(DPRNN, self).__init__() 48 | 49 | self.num_layers = num_layers 50 | 51 | self.dprnn_t = nn.ModuleList([BLSTMP(nproj, cell) for i in range(num_layers)]) 52 | self.t_norm = nn.ModuleList([nn.GroupNorm(1, nproj, eps=1e-8) for i in range(num_layers)]) 53 | self.dprnn_c = nn.ModuleList([BLSTMP(nproj, cell) for i in range(num_layers)]) 54 | self.c_norm = nn.ModuleList([nn.GroupNorm(1, nproj, eps=1e-8) for i in range(num_layers)]) 55 | 56 | def forward(self, output): 57 | bs, dim, tframe, spks = output.size() 58 | for i in range(self.num_layers): 59 | dprnn_t_in = output.permute(0, 3, 2, 1).contiguous().view(bs*4, tframe, -1) # B*4 x T x 384 60 | dprnn_t_out = self.dprnn_t[i](dprnn_t_in).contiguous().view(bs, 4, tframe, -1) # B x 4 x T x 384 61 | dprnn_t_out = dprnn_t_out.permute(0, 3, 2, 1) # B x 384 x T x 4 62 | 63 | dprnn_t_out = self.t_norm[i](dprnn_t_out) 64 | output = dprnn_t_out + output # B x 384 x T x 4 65 | 66 | dprnn_c_in = output.permute(0, 2, 3, 1).contiguous().view(bs*tframe, 4, -1) # B*T x 4 x 384 67 | dprnn_c_out = self.dprnn_c[i](dprnn_c_in).contiguous().view(bs, tframe, 4, -1) # B x T x 4 x 384 68 | dprnn_c_out = dprnn_c_out.permute(0, 3, 1, 2) # B x 384 x T x 4 69 | 70 | dprnn_c_out = self.c_norm[i](dprnn_c_out) 71 | output = dprnn_c_out + output # B x 384 x T x 4 72 | 73 | output = output.permute(0, 2, 3, 1) # B x T x 4 x 384 74 | 75 | return output 76 | 77 | 78 | class Model(nn.Module): 79 | def __init__(self, out_channels=[64, 64, 128, 128], nproj=384, cell=896, dprnn_layers=6): 80 | super(Model, self).__init__() 81 | 82 | batchnorm = nn.BatchNorm2d(1, eps=0.001, momentum=0.99) 83 | 84 | cnn_relu_batchnorm1 = CNN_ReLU_BatchNorm(in_channels=1, out_channels=out_channels[0]) 85 | cnn_relu_batchnorm2 = CNN_ReLU_BatchNorm(in_channels=out_channels[0], out_channels=out_channels[1]) 86 | cnn_relu_batchnorm3 = CNN_ReLU_BatchNorm(in_channels=out_channels[1], out_channels=out_channels[2], stride=(1, 2)) 87 | cnn_relu_batchnorm4 = CNN_ReLU_BatchNorm(in_channels=out_channels[2], out_channels=out_channels[3]) 88 | 89 | self.cnn = nn.Sequential( 90 | batchnorm, 91 | cnn_relu_batchnorm1, 92 | cnn_relu_batchnorm2, 93 | cnn_relu_batchnorm3, 94 | cnn_relu_batchnorm4 95 | ) 96 | 97 | self.linear = nn.Linear(out_channels[-1]*20+100, nproj) 98 | self.rnn_speaker_detection = BLSTMP(nproj, cell, num_layers=2) 99 | self.rnn_combine = DPRNN(num_layers=dprnn_layers) 100 | self.output_layer = nn.Linear(nproj, 1) 101 | 102 | def forward(self, batch): 103 | feats, targets, ivectors = batch 104 | 105 | feats = self.cnn(feats) 106 | bs, chan, tframe, dim = feats.size() 107 | 108 | feats = feats.permute(0, 2, 1, 3) 109 | feats = feats.contiguous().view(bs, tframe, chan*dim) # B x 1 x T x 2560 110 | feats = feats.unsqueeze(1).repeat(1, 4, 1, 1) # B x 4 x T x 2560 111 | ivectors = ivectors.view(bs, 4, 100).unsqueeze(2) # B x 4 x 1 x 100 112 | ivectors = ivectors.repeat(1, 1, tframe, 1) # B x 4 x T x 100 113 | 114 | sd_in = torch.cat((feats, ivectors), dim=-1) # B x 4 x T x 2660 115 | sd_in = self.linear(sd_in).view(4*bs, tframe, -1) # B*4 x T x 384 116 | sd_out = self.rnn_speaker_detection(sd_in) # B*4 x T x 384 117 | 118 | # DPRNN 119 | output = sd_out.contiguous().view(bs, 4, tframe, -1) # B x 4 x T x 384 120 | output = output.permute(0, 3, 2, 1) # B x 384 x T x 4 121 | output = self.rnn_combine(output) # B x T x 4 x 384 122 | 123 | preds = self.output_layer(output).squeeze(-1) # B x T x 4 124 | preds = nn.Sigmoid()(preds) 125 | 126 | loss = nn.BCELoss(reduction='sum')(preds, targets) / tframe / bs 127 | loss_detail = {"diarization loss": loss.item()} 128 | 129 | return loss, loss_detail 130 | 131 | def inference(self, batch): 132 | _, feats, ivectors = batch 133 | 134 | feats = self.cnn(feats) 135 | bs, chan, tframe, dim = feats.size() 136 | 137 | feats = feats.permute(0, 2, 1, 3) 138 | feats = feats.contiguous().view(bs, tframe, chan*dim) # B x 1 x T x 2560 139 | feats = feats.unsqueeze(1).repeat(1, 4, 1, 1) # B x 4 x T x 2560 140 | ivectors = ivectors.view(bs, 4, 100).unsqueeze(2) # B x 4 x 1 x 100 141 | ivectors = ivectors.repeat(1, 1, tframe, 1) # B x 4 x T x 100 142 | 143 | sd_in = torch.cat((feats, ivectors), dim=-1) # B x 4 x T x 2660 144 | sd_in = self.linear(sd_in).view(4*bs, tframe, -1) # B*4 x T x 384 145 | sd_out = self.rnn_speaker_detection(sd_in) # B*4 x T x 384 146 | 147 | # DPRNN 148 | output = sd_out.contiguous().view(bs, 4, tframe, -1) # B x 4 x T x 384 149 | output = output.permute(0, 3, 2, 1) # B x 384 x T x 4 150 | output = self.rnn_combine(output) # B x T x 4 x 384 151 | 152 | preds = self.output_layer(output).squeeze(-1) # B x T x 4 153 | preds = nn.Sigmoid()(preds) 154 | 155 | return preds 156 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | source /home/dodohow1011/miniconda3/bin/activate py36 2 | 3 | output_dir=tsvad_dprnn_v3_nframes40_b64 4 | gpu=2 5 | 6 | . ./util/parse_options.sh 7 | 8 | python train.py -c conf/tsvad_config.json \ 9 | -o checkpoints/$output_dir -g $gpu 10 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | source /home/dodohow1011/miniconda3/bin/activate py36 2 | 3 | dset=eval 4 | out=dprnn_v3_${dset}_460000_track1 5 | stage=0 6 | stop_stage=3 7 | 8 | . ./util/parse_options.sh 9 | 10 | if [ $stage -le 0 -a $stop_stage -ge 0 ]; then 11 | if [ ! -f prediction/$out/.done ]; then 12 | mkdir prediction/$out 13 | python inference.py -c conf/tsvad_config.json \ 14 | -p checkpoints/tsvad_dprnn_v3_nframes40_b64/01-08_16-34_460000 -g 3 \ 15 | -o prediction/$out -f data/$dset -i data/$dset 16 | 17 | touch prediction/$out/.done 18 | fi 19 | fi 20 | 21 | #TS-VAD probabilities post-processing and DER scoring 22 | 23 | scoring=prediction/$out/scoring 24 | hyp_rttm=$scoring/rttm 25 | ref_rttm=chime6_rttm/${dset}_rttm 26 | thr=0.4 27 | window=51 28 | min_silence=0.3 29 | min_speech=0.2 30 | 31 | if [ $stage -le 1 -a $stop_stage -ge 1 ]; then 32 | python util/convert_prob_to_rttm.py --threshold 0.4 --window 51 --min_silence 0.3 --min_speech 0.2 ark:"sort prediction/$out/weights.ark |" $hyp_rttm || exit 1; 33 | fi 34 | 35 | if [ $stage -le 2 -a $stop_stage -ge 2 ]; then 36 | echo "Diarization results for $test" 37 | sed 's/_U0[1-6]\.ENH//g' $ref_rttm > $ref_rttm.scoring 38 | sed 's/_U0[1-6]\.ENH//g' $hyp_rttm > $hyp_rttm.scoring 39 | ref_rttm_path=$(readlink -f ${ref_rttm}.scoring) 40 | hyp_rttm_path=$(readlink -f ${hyp_rttm}.scoring) 41 | cd dscore && python score.py -u ../uem_file.scoring -r $ref_rttm_path \ 42 | -s $hyp_rttm_path 2>&1 | tee -a ../$scoring/DER && cd .. || exit 1; 43 | fi 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0,os.getcwd()) 4 | 5 | import time 6 | import torch 7 | import logging 8 | import numpy as np 9 | from pathlib import Path 10 | from importlib import import_module 11 | 12 | from torch.utils.data import DataLoader 13 | from util.dataset_loader import Dataset 14 | 15 | 16 | def train(train_config): 17 | # Initial 18 | output_directory = train_config.get('output_directory', '') 19 | max_iter = train_config.get('max_iter', 100000) 20 | batch_size = train_config.get('batch_size', 128) 21 | nframes = train_config.get('nframes', 40) 22 | iters_per_checkpoint = train_config.get('iters_per_checkpoint', 10000) 23 | iters_per_log = train_config.get('iters_per_log', 1000) 24 | seed = train_config.get('seed', 1234) 25 | checkpoint_path = train_config.get('checkpoint_path', '') 26 | trainer_type = train_config.get('trainer_type', 'basic') 27 | 28 | # Setup 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | 33 | # Initial trainer 34 | module = import_module('trainer.{}'.format(trainer_type), package=None) 35 | TRAINER = getattr( module, 'Trainer') 36 | trainer = TRAINER( train_config, model_config) 37 | try: 38 | collate_fn = getattr( module, 'collate') 39 | except: 40 | collate_fn = None 41 | 42 | # Load checkpoint if the path is given 43 | iteration = 1 44 | if checkpoint_path != "": 45 | iteration = trainer.load_checkpoint( checkpoint_path) 46 | iteration += 1 # next iteration is iteration + 1 47 | 48 | # Load training data 49 | trainset = Dataset(train_config['training_dir'], nframes) 50 | train_loader = DataLoader(trainset, num_workers=32, shuffle=True, 51 | batch_size=batch_size, 52 | pin_memory=True, 53 | drop_last=True, 54 | collate_fn=collate_fn) 55 | 56 | # Get shared output_directory ready 57 | output_directory = Path(output_directory) 58 | output_directory.mkdir(parents=True, exist_ok=True) 59 | 60 | # Prepare logger 61 | logger = logging.getLogger("logger") 62 | handler1 = logging.StreamHandler() 63 | handler2 = logging.FileHandler(filename=str(output_directory/'Stat')) 64 | logger.setLevel(logging.INFO) 65 | 66 | formatter = logging.Formatter("%(asctime)s %(message)s", 67 | datefmt="%m-%d %H:%M:%S") 68 | handler1.setFormatter(formatter) 69 | handler2.setFormatter(formatter) 70 | logger.addHandler(handler1) 71 | logger.addHandler(handler2) 72 | 73 | logger.info("Output directory: {}".format(output_directory)) 74 | logger.info("Training utterances: {}".format(len(trainset))) 75 | logger.info("Batch size: {}".format(batch_size)) 76 | logger.info("# of frames per sample: {}".format(nframes)) 77 | 78 | # ================ MAIN TRAINNIG LOOP! =================== 79 | 80 | logger.info("Start traininig...") 81 | 82 | loss_log = dict() 83 | while iteration <= max_iter: 84 | for i, batch in enumerate(train_loader): 85 | 86 | iteration, loss_detail, lr = trainer.step(batch, iteration=iteration) 87 | 88 | # Keep Loss detail 89 | for key,val in loss_detail.items(): 90 | if key not in loss_log.keys(): 91 | loss_log[key] = list() 92 | loss_log[key].append(val) 93 | 94 | # Save model per N iterations 95 | if iteration % iters_per_checkpoint == 0: 96 | checkpoint_path = output_directory / "{}_{}".format(time.strftime("%m-%d_%H-%M", time.localtime()),iteration) 97 | trainer.save_checkpoint( checkpoint_path) 98 | 99 | # Show log per M iterations 100 | if iteration % iters_per_log == 0 and len(loss_log.keys()) > 0: 101 | mseg = 'Iter {}:'.format( iteration) 102 | for key,val in loss_log.items(): 103 | mseg += ' {}: {:.6f}'.format(key,np.mean(val)) 104 | mseg += ' lr: {:.6f}'.format(lr) 105 | logger.info(mseg) 106 | loss_log = dict() 107 | 108 | if iteration > max_iter: 109 | break 110 | 111 | print('Finished') 112 | 113 | 114 | if __name__ == "__main__": 115 | import argparse 116 | import json 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('-c', '--config', type=str, default='tsvad_config.json', 120 | help='JSON file for configuration') 121 | parser.add_argument('-o', '--output_directory', type=str, default=None, 122 | help='Directory for checkpoint output') 123 | parser.add_argument('-p', '--checkpoint_path', type=str, default=None, 124 | help='checkpoint path to keep training') 125 | parser.add_argument('-T', '--training_dir', type=str, default=None, 126 | help='Traininig dictionary path') 127 | 128 | parser.add_argument('-g', '--gpu', type=str, default='0', 129 | help='Using gpu #') 130 | args = parser.parse_args() 131 | 132 | # Parse configs. Globals nicer in this case 133 | with open(args.config) as f: 134 | data = f.read() 135 | config = json.loads(data) 136 | train_config = config["train_config"] 137 | global model_config 138 | model_config = config["model_config"] 139 | 140 | if args.output_directory is not None: 141 | train_config['output_directory'] = args.output_directory 142 | if args.checkpoint_path is not None: 143 | train_config['checkpoint_path'] = args.checkpoint_path 144 | 145 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 146 | 147 | torch.backends.cudnn.enabled = True 148 | torch.backends.cudnn.benchmark = False 149 | 150 | train(train_config) 151 | -------------------------------------------------------------------------------- /trainer/__pycache__/basic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/trainer/__pycache__/basic.cpython-36.pyc -------------------------------------------------------------------------------- /trainer/__pycache__/mix_wave.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/trainer/__pycache__/mix_wave.cpython-36.pyc -------------------------------------------------------------------------------- /trainer/__pycache__/radam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/trainer/__pycache__/radam.cpython-36.pyc -------------------------------------------------------------------------------- /trainer/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .radam import RAdam 4 | from importlib import import_module 5 | 6 | 7 | class Trainer(object): 8 | def __init__(self, train_config, model_config): 9 | learning_rate = train_config.get('learning_rate', 1e-4) 10 | model_type = train_config.get('model_type', 'tsvad') 11 | self.opt_param = train_config.get('optimize_param', { 12 | 'optim_type': 'RAdam', 13 | 'learning_rate': 1e-4, 14 | 'max_grad_norm': 10, 15 | 'lr_scheduler':{ 16 | 'step_size': 100000, 17 | 'gamma': 0.5, 18 | 'last_epoch': -1 19 | } 20 | }) 21 | 22 | module = import_module('model.{}'.format(model_type), package=None) 23 | MODEL = getattr(module, 'Model') 24 | model = MODEL().cuda() 25 | 26 | print(model) 27 | 28 | self.model = model.cuda() 29 | self.learning_rate = learning_rate 30 | 31 | if self.opt_param['optim_type'].upper() == 'RADAM': 32 | self.optimizer = RAdam( self.model.parameters(), 33 | lr=self.opt_param['learning_rate'], 34 | betas=(0.5,0.999), 35 | weight_decay=0.0) 36 | else: 37 | self.optimizer = torch.optim.Adam( self.model.parameters(), 38 | lr=self.opt_param['learning_rate'], 39 | betas=(0.5,0.999), 40 | weight_decay=0.0) 41 | 42 | if 'lr_scheduler' in self.opt_param.keys(): 43 | self.scheduler = torch.optim.lr_scheduler.StepLR( 44 | optimizer=self.optimizer, 45 | **self.opt_param['lr_scheduler'] 46 | ) 47 | else: 48 | self.scheduler = None 49 | 50 | 51 | self.iteration = 0 52 | self.model.train() 53 | 54 | def step(self, input, iteration=None): 55 | assert self.model.training 56 | self.model.zero_grad() 57 | 58 | input = [x.cuda() for x in input] 59 | loss, loss_detail = self.model(input) 60 | 61 | loss.backward() 62 | if self.opt_param['max_grad_norm'] > 0: 63 | torch.nn.utils.clip_grad_norm_( 64 | self.model.parameters(), 65 | self.opt_param['max_grad_norm']) 66 | self.optimizer.step() 67 | for param_group in self.optimizer.param_groups: 68 | learning_rate = param_group['lr'] 69 | 70 | if self.scheduler is not None: 71 | self.scheduler.step() 72 | 73 | if iteration is not None: 74 | self.iteration = iteration + 1 75 | else: 76 | self.iteration += 1 77 | 78 | return self.iteration, loss_detail, learning_rate 79 | 80 | 81 | def save_checkpoint(self, checkpoint_path): 82 | torch.save( { 83 | 'model': self.model.state_dict(), 84 | 'optimizer': self.optimizer.state_dict(), 85 | 'iteration': self.iteration, 86 | }, checkpoint_path) 87 | print("Saved state dict. to {}".format(checkpoint_path)) 88 | 89 | 90 | def load_checkpoint(self, checkpoint_path): 91 | checkpoint_data = torch.load(checkpoint_path, map_location='cpu') 92 | self.model.load_state_dict(checkpoint_data['model']) 93 | # self.optimizer.load_state_dict(checkpoint_data['optimizer']) 94 | return checkpoint_data['iteration'] 95 | -------------------------------------------------------------------------------- /trainer/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(RAdam, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(RAdam, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('RAdam does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 46 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 47 | 48 | state['step'] += 1 49 | buffered = self.buffer[int(state['step'] % 10)] 50 | if state['step'] == buffered[0]: 51 | N_sma, step_size = buffered[1], buffered[2] 52 | else: 53 | buffered[0] = state['step'] 54 | beta2_t = beta2 ** state['step'] 55 | N_sma_max = 2 / (1 - beta2) - 1 56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 57 | buffered[1] = N_sma 58 | 59 | # more conservative since it's an approximated value 60 | if N_sma >= 5: 61 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 62 | else: 63 | step_size = 1.0 / (1 - beta1 ** state['step']) 64 | buffered[2] = step_size 65 | 66 | if group['weight_decay'] != 0: 67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 68 | 69 | # more conservative since it's an approximated value 70 | if N_sma >= 5: 71 | denom = exp_avg_sq.sqrt().add_(group['eps']) 72 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 73 | else: 74 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 75 | 76 | p.data.copy_(p_data_fp32) 77 | 78 | return loss 79 | 80 | class PlainRAdam(Optimizer): 81 | 82 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 83 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 84 | 85 | super(PlainRAdam, self).__init__(params, defaults) 86 | 87 | def __setstate__(self, state): 88 | super(PlainRAdam, self).__setstate__(state) 89 | 90 | def step(self, closure=None): 91 | 92 | loss = None 93 | if closure is not None: 94 | loss = closure() 95 | 96 | for group in self.param_groups: 97 | 98 | for p in group['params']: 99 | if p.grad is None: 100 | continue 101 | grad = p.grad.data.float() 102 | if grad.is_sparse: 103 | raise RuntimeError('RAdam does not support sparse gradients') 104 | 105 | p_data_fp32 = p.data.float() 106 | 107 | state = self.state[p] 108 | 109 | if len(state) == 0: 110 | state['step'] = 0 111 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 112 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 113 | else: 114 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 115 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 116 | 117 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 118 | beta1, beta2 = group['betas'] 119 | 120 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 121 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 122 | 123 | state['step'] += 1 124 | beta2_t = beta2 ** state['step'] 125 | N_sma_max = 2 / (1 - beta2) - 1 126 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 127 | 128 | if group['weight_decay'] != 0: 129 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 130 | 131 | # more conservative since it's an approximated value 132 | if N_sma >= 5: 133 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 134 | denom = exp_avg_sq.sqrt().add_(group['eps']) 135 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 136 | else: 137 | step_size = group['lr'] / (1 - beta1 ** state['step']) 138 | p_data_fp32.add_(-step_size, exp_avg) 139 | 140 | p.data.copy_(p_data_fp32) 141 | 142 | return loss 143 | 144 | 145 | class AdamW(Optimizer): 146 | 147 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 148 | defaults = dict(lr=lr, betas=betas, eps=eps, 149 | weight_decay=weight_decay, warmup = warmup) 150 | super(AdamW, self).__init__(params, defaults) 151 | 152 | def __setstate__(self, state): 153 | super(AdamW, self).__setstate__(state) 154 | 155 | def step(self, closure=None): 156 | loss = None 157 | if closure is not None: 158 | loss = closure() 159 | 160 | for group in self.param_groups: 161 | 162 | for p in group['params']: 163 | if p.grad is None: 164 | continue 165 | grad = p.grad.data.float() 166 | if grad.is_sparse: 167 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 168 | 169 | p_data_fp32 = p.data.float() 170 | 171 | state = self.state[p] 172 | 173 | if len(state) == 0: 174 | state['step'] = 0 175 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 176 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 177 | else: 178 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 179 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 180 | 181 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 182 | beta1, beta2 = group['betas'] 183 | 184 | state['step'] += 1 185 | 186 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 187 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 188 | 189 | denom = exp_avg_sq.sqrt().add_(group['eps']) 190 | bias_correction1 = 1 - beta1 ** state['step'] 191 | bias_correction2 = 1 - beta2 ** state['step'] 192 | 193 | if group['warmup'] > state['step']: 194 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 195 | else: 196 | scheduled_lr = group['lr'] 197 | 198 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 199 | 200 | if group['weight_decay'] != 0: 201 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 202 | 203 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 204 | 205 | p.data.copy_(p_data_fp32) 206 | 207 | return loss 208 | -------------------------------------------------------------------------------- /uem_file.scoring: -------------------------------------------------------------------------------- 1 | S01 1 0 12000 2 | S02 1 75 12000 3 | S09 1 64 12000 4 | S21 1 59 12000 5 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__init__.py -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dodohow1011/TS-VAD/360afd9d57e504f2d95722aef9469292e9b4c113/util/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /util/convert_prob_to_rttm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2020 Yuri Khokhlov, Ivan Medennikov (STC-innovations Ltd) 4 | # Apache 2.0. 5 | 6 | """This script converts TS-VAD output probabilities to a NIST RTTM file. 7 | 8 | The segments file format is: 9 | 10 | The labels file format is: 11 | 12 | 13 | The output RTTM format is: 14 | \ 15 | 16 | where: 17 | = "SPEAKER" 18 | = 19 | = "0" 20 | = start time of segment 21 | = duration of segment 22 | = "" 23 | = "" 24 | = 25 | = "" 26 | = "" 27 | """ 28 | 29 | 30 | import os 31 | import argparse 32 | import regex as re 33 | import numpy as np 34 | from scipy import signal, ndimage 35 | from kaldiio import ReadHelper 36 | 37 | 38 | class Segment: 39 | def __init__(self, begin, end, label): 40 | self.begin = begin 41 | self.end = end 42 | self.label = label 43 | 44 | def length(self): 45 | return self.end - self.begin 46 | 47 | 48 | class VadProbSet: 49 | def __init__(self, vad_rspec, reg_exp): 50 | data = dict() 51 | prev = -1 52 | with ReadHelper(vad_rspec) as reader: 53 | for utid, prob in reader: 54 | result = reg_exp.match(utid) 55 | assert result is not None, 'Wrong utterance ID format: \"{}\"'.format(utid) 56 | sess_indx = result.group(1) 57 | spkr = result.group(2) 58 | 59 | result = reg_exp.match(sess_indx) 60 | assert result is not None, 'Wrong utterance ID format: \"{}\"'.format(sess_indx) 61 | sess = result.group(1) 62 | indx = int(result.group(2)) 63 | 64 | sess = sess + '-' + spkr 65 | 66 | if sess not in data.keys(): 67 | assert indx == 1 68 | prev = -1 69 | data[sess] = list() 70 | assert indx >= prev 71 | data[sess].append(prob) 72 | prev = indx 73 | reader.close() 74 | print(' loaded {} sessions'.format(len(data))) 75 | print(' combining fragments') 76 | self.data = dict() 77 | for sess, items in data.items(): 78 | self.data[sess] = np.hstack(items) 79 | 80 | common_thresh = dict() 81 | for sess, items in self.data.items(): 82 | s = sess.split('-')[0] 83 | if s not in common_thresh.keys(): 84 | common_thresh[s] = list() 85 | common_thresh[s].append(items) 86 | 87 | self.common_thresh = dict() 88 | for s, items in common_thresh.items(): 89 | thresh = np.max(np.vstack(items), axis=0) 90 | self.common_thresh[s] = thresh 91 | 92 | def vectorize(self, prob, common_thresh, threshold=0.4): 93 | p = np.zeros(prob.shape) 94 | for i in range(prob.shape[0]): 95 | p[i] = 1.0 if prob[i] > threshold and common_thresh[i]-prob[i] < 0.3 else 0.0 96 | return p 97 | 98 | def apply_filter(self, window, threshold, threshold_first): 99 | for sess in self.data.keys(): 100 | if threshold_first: 101 | self.data[sess] = np.vectorize(lambda value: 1.0 if value > threshold else 0.0)(self.data[sess]).astype(dtype=np.int32) 102 | if window > 1: 103 | self.data[sess] = signal.medfilt(self.data[sess], window).astype(dtype=np.int32) 104 | else: 105 | if window > 1: 106 | self.data[sess] = signal.medfilt(self.data[sess], window) 107 | self.data[sess] = np.vectorize(lambda value: 1.0 if value > threshold else 0.0)(self.data[sess]).astype(dtype=np.int32) 108 | # s = sess.split('-')[0] 109 | # self.data[sess] = self.vectorize(self.data[sess], self.common_thresh[s]) 110 | 111 | def convert(self, frame_shift, min_silence, min_speech, out_rttm): 112 | min_silence = int(round(min_silence / frame_shift)) 113 | min_speech = int(round(min_speech / frame_shift)) 114 | with open(out_rttm, 'wt', encoding='utf-8') as wstream: 115 | for sess, prob in self.data.items(): 116 | print(' session: {} num_frames: {} duration: {:.2f} hrs'.format(sess, len(prob), len(prob) * frame_shift / 60 / 60)) 117 | segments = list() 118 | for i, label in enumerate(prob): 119 | if (len(segments) == 0) or (segments[-1].label != label): 120 | segments.append(Segment(i, i + 1, label)) 121 | else: 122 | segments[-1].end += 1 123 | if (min_silence > 0) or (min_speech > 0): 124 | items = segments 125 | segments = list() 126 | for segm in items: 127 | if len(segments) == 0: 128 | segments.append(segm) 129 | elif segm.label == segments[-1].label: 130 | segments[-1].end = segm.end 131 | else: 132 | min_length = min_silence if segm.label == 0 else min_speech 133 | if segm.length() < min_length: 134 | segments[-1].end = segm.end 135 | else: 136 | segments.append(segm) 137 | for segm in segments: 138 | if segm.label == 1: 139 | begin = frame_shift * segm.begin 140 | length = frame_shift * segm.length() 141 | result = reg_exp.match(sess) 142 | assert result is not None, 'Wrong format: \"{}\"'.format(sess) 143 | utid = result.group(1) 144 | spk = result.group(2) 145 | wstream.write('SPEAKER {} 1 {:7.3f} {:7.3f} {} \n'.format(utid, begin, length, spk)) 146 | wstream.close() 147 | 148 | if __name__ == '__main__': 149 | parser = argparse.ArgumentParser(description='Usage: convert_prob_to_wa.py ') 150 | parser.add_argument("--frame_shift", "-s", type=float, default=0.010) 151 | parser.add_argument("--reg_exp", "-x", type=str, default=r'^(\S+)-(\d+)$') 152 | parser.add_argument("--window", "-w", type=int, default=1) 153 | parser.add_argument("--threshold", "-t", type=float, default=0.5) 154 | parser.add_argument("--threshold_first", "-r", action="store_true") 155 | parser.add_argument("--min_silence", "-k", type=float, default=0.0) 156 | parser.add_argument("--min_speech", "-m", type=float, default=0.0) 157 | parser.add_argument('vad_rspec', type=str) 158 | parser.add_argument('out_rttm', type=str) 159 | args = parser.parse_args() 160 | 161 | print('Options:') 162 | print(' Frame shift in sec: {}'.format(args.frame_shift)) 163 | print(' Utterance ID regexp: {}'.format(args.reg_exp)) 164 | print(' Med. filter window: {}'.format(args.window)) 165 | print(' Prob. threshold: {}'.format(args.threshold)) 166 | print(' Apply thresh. first: {}'.format(args.threshold_first)) 167 | print(' Min silence length: {}'.format(args.min_silence)) 168 | print(' Min speech length: {}'.format(args.min_speech)) 169 | print(' VAD rspec: {}'.format(args.vad_rspec)) 170 | print(' Output rttm file: {}'.format(args.out_rttm)) 171 | 172 | reg_exp = re.compile(args.reg_exp) 173 | 174 | parent = os.path.dirname(os.path.abspath(args.out_rttm)) 175 | if not os.path.exists(parent): 176 | os.makedirs(parent) 177 | 178 | print('Loading VAD probabilities') 179 | vad_prob = VadProbSet(args.vad_rspec, reg_exp) 180 | 181 | print('Applying filtering') 182 | vad_prob.apply_filter(args.window, args.threshold, args.threshold_first) 183 | 184 | print('Writing rttm') 185 | vad_prob.convert(args.frame_shift, args.min_silence, args.min_speech, args.out_rttm) 186 | -------------------------------------------------------------------------------- /util/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from kaldi.util.io import read_matrix 8 | 9 | from .utils import (files_to_list, files_to_dict, load_scp_to_torch) 10 | 11 | class Dataset(torch.utils.data.Dataset): 12 | """ 13 | This is the main class that calculates the spectrogram and returns the 14 | spectrogram, audio pair. 15 | """ 16 | def __init__(self, training_dir='./data/train', nframes=40): 17 | 18 | self.utt2feat = files_to_dict(os.path.join(training_dir,'feats.scp')) 19 | self.utt2nframe = files_to_dict(os.path.join(training_dir,'utt2num_frames')) 20 | self.utt2target = files_to_dict(os.path.join(training_dir,'dense_targets.scp')) 21 | self.utt2iv = files_to_dict(os.path.join(training_dir,'ivector_online.scp')) 22 | self.utt_list = [k for k in self.utt2target.keys() if int(self.utt2nframe[k]) >= nframes ] 23 | 24 | self.nframes = nframes 25 | 26 | def __getitem__(self, index): 27 | utt = self.utt_list[index] 28 | feat_length = int(self.utt2nframe[utt]) 29 | 30 | assert feat_length >= self.nframes 31 | 32 | feat = load_scp_to_torch(self.utt2feat[utt]).unsqueeze(0) 33 | target = load_scp_to_torch(self.utt2target[utt])[:, 1::2] 34 | ivectors = load_scp_to_torch(self.utt2iv[utt]).mean(dim=0) 35 | 36 | max_start = feat_length - self.nframes 37 | feat_start = random.randint(0, max_start) 38 | feat = feat[:, feat_start:(feat_start+self.nframes)] 39 | target = target[feat_start:(feat_start+self.nframes)] 40 | 41 | return feat, target, ivectors 42 | 43 | def __len__(self): 44 | return len(self.utt_list) 45 | 46 | 47 | class EvalDataset(torch.utils.data.Dataset): 48 | """ 49 | This is the main class that calculates the spectrogram and returns the 50 | spectrogram, audio pair. 51 | """ 52 | def __init__(self, feats_dir, ivectors_dir): 53 | 54 | self.utt2feat = files_to_dict(os.path.join(feats_dir,'feats.scp')) 55 | self.utt2iv = files_to_dict(os.path.join(ivectors_dir,'ivector_online.scp')) 56 | self.utt_list = [k for k in self.utt2feat.keys()] 57 | 58 | def __getitem__(self, index): 59 | utt = self.utt_list[index] 60 | 61 | feat = load_scp_to_torch(self.utt2feat[utt]).unsqueeze(0).cuda() 62 | ivectors = load_scp_to_torch(self.utt2iv[utt]).mean(dim=0).cuda() 63 | 64 | return utt, feat, ivectors 65 | 66 | def __len__(self): 67 | return len(self.utt_list) 68 | -------------------------------------------------------------------------------- /util/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | from kaldi.util.io import read_matrix 6 | 7 | def files_to_list(filename): 8 | """ 9 | Takes a text file of filenames and makes a list of filenames 10 | """ 11 | with open(filename, encoding='utf-8') as f: 12 | files = f.readlines() 13 | files = [f.rstrip().split() for f in files] 14 | return files 15 | 16 | 17 | def files_to_dict(filename): 18 | """ 19 | Takes a text file of filenames and makes a dict of filenames 20 | """ 21 | with open(filename, encoding='utf-8') as f: 22 | files = f.readlines() 23 | files = dict([f.rstrip().split() for f in files]) 24 | return files 25 | 26 | def load_scp_to_torch(scp_path): 27 | """ 28 | Loads data into torch array 29 | """ 30 | data = read_matrix(scp_path).numpy() 31 | return torch.from_numpy(data).float() 32 | 33 | def load_wav_to_torch(scp_path): 34 | """ 35 | Loads wavdata into torch array 36 | """ 37 | data = read_matrix(scp_path).numpy().reshape(-1) 38 | data = data / MAX_WAV_VALUE 39 | return torch.from_numpy(data).float() 40 | 41 | 42 | def load_spk_to_torch(spk_id, length=1): 43 | """ 44 | Loads spk_id into torch tensor 45 | """ 46 | return torch.ones(1, dtype=torch.long) * int(spk_id) 47 | --------------------------------------------------------------------------------