├── requirements.txt ├── .gitattributes ├── models └── exampleModel │ ├── model.chkpt │ └── model_config.pkl ├── scripts ├── getEmbeddingExample.py ├── featureExtractor.py ├── utils.py ├── data.py ├── loss.py ├── model.py ├── CNNs.py ├── poolings.py └── train.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.5 2 | torch==1.5 3 | librosa==0.7.2 4 | soundfile==0.10.3.post1 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.chkpt filter=lfs diff=lfs merge=lfs -text 2 | *.pkl filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /models/exampleModel/model.chkpt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5cd912d0ec6357420e49b794627db64073ec91753f1038f7ff6de788f0aabeae 3 | size 258295171 4 | -------------------------------------------------------------------------------- /models/exampleModel/model_config.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:be45ae3257e30d75f36ada35bfde553dffff3fce3674afa542dc5a6de12faa34 3 | size 904 4 | -------------------------------------------------------------------------------- /scripts/getEmbeddingExample.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import argparse 4 | from model import * 5 | from featureExtractor import * 6 | 7 | def prepareInput(features, device): 8 | 9 | inputs = torch.FloatTensor(features) 10 | inputs = inputs.to(device) 11 | inputs = inputs.unsqueeze(0) 12 | return inputs 13 | 14 | 15 | def getAudioEmbedding(audioPath, net, device): 16 | 17 | features = extractFeatures(audioPath) 18 | with torch.no_grad(): 19 | networkInputs = prepareInput(features, device) 20 | return net.getEmbedding(networkInputs) 21 | 22 | 23 | def main(opt,params): 24 | 25 | print('Loading Model') 26 | device = torch.device(params.device) 27 | net_dict = torch.load(params.modelCheckpoint, map_location=device) 28 | opt = net_dict['settings'] 29 | 30 | if torch.cuda.is_available(): 31 | print(torch.cuda.get_device_name(0)) 32 | 33 | net = SpeakerClassifier(opt, device) 34 | net.load_state_dict(net_dict['model']) 35 | net.to(device) 36 | net.eval() 37 | 38 | embedding = getAudioEmbedding(params.audioPath, net, device) 39 | print(embedding) 40 | 41 | if __name__ == "__main__": 42 | 43 | parser = argparse.ArgumentParser(description='score a trained model') 44 | parser.add_argument('--audioPath', type=str, required=True) 45 | parser.add_argument('--modelConfig', type=str, required=True) 46 | parser.add_argument('--modelCheckpoint', type=str, required=True) 47 | parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'cuda']) 48 | 49 | params = parser.parse_args() 50 | 51 | with open(params.modelConfig, 'rb') as handle: 52 | opt = pickle.load(handle) 53 | 54 | main(opt,params) 55 | 56 | -------------------------------------------------------------------------------- /scripts/featureExtractor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import soundfile as sf 3 | import sys 4 | import pickle 5 | import librosa 6 | import argparse 7 | 8 | def mfsc(y, sfr, window_size=0.025, window_stride=0.010, window='hamming', n_mels=80, preemCoef=0.97): 9 | win_length = int(sfr * window_size) 10 | hop_length = int(sfr * window_stride) 11 | n_fft = 512 12 | lowfreq = 0 13 | highfreq = sfr/2 14 | 15 | # melspectrogram 16 | y *= 32768 17 | y[1:] = y[1:] - preemCoef*y[:-1] 18 | y[0] *= (1 - preemCoef) 19 | S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, center=False) 20 | D = np.abs(S) 21 | param = librosa.feature.melspectrogram(S=D, sr=sfr, n_mels=n_mels, fmin=lowfreq, fmax=highfreq, norm=None) 22 | mf = np.log(np.maximum(1, param)) 23 | return mf 24 | 25 | def normalize(features): 26 | return features-np.mean(features, axis=0) 27 | 28 | 29 | def extractFeatures(audioPath): 30 | 31 | y, sfreq = sf.read(audioPath) 32 | features = mfsc(y, sfreq) 33 | return normalize(np.transpose(features)) 34 | 35 | def main(params): 36 | 37 | with open(params.audioFilesList,'r') as filesFile: 38 | for featureFile in filesFile: 39 | print(featureFile[:-1]) 40 | y, sfreq = sf.read('{}'.format(featureFile[:-1])) 41 | mf = mfsc(y, sfreq) 42 | with open('{}.pickle'.format(featureFile[:-5]), 'wb') as handle: 43 | pickle.dump(mf,handle) 44 | 45 | if __name__=='__main__': 46 | 47 | 48 | parser = argparse.ArgumentParser(description='Extract Features. Looks for .wav files and extract Features') 49 | parser.add_argument('--audioFilesList', '-i', type=str, required=True, default='', help='Wav Files List.') 50 | params=parser.parse_args() 51 | main(params) 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def Score(SC, th, rate): 6 | score_count = 0.0 7 | for sc in SC: 8 | if rate=='FAR': 9 | if float(sc)>=float(th): 10 | score_count+=1 11 | elif rate=='FRR': 12 | if float(sc) 1: 26 | checkpoint = { 27 | 'model': model.module.state_dict(), 28 | 'optimizer': optimizer.state_dict(), 29 | 'settings': opt, 30 | 'epoch': epoch, 31 | 'step':step} 32 | else: 33 | checkpoint = { 34 | 'model': model.state_dict(), 35 | 'optimizer': optimizer.state_dict(), 36 | 'settings': opt, 37 | 'epoch': epoch, 38 | 'step':step} 39 | 40 | torch.save(checkpoint,'{}/{}_{}.chkpt'.format(opt.out_dir, opt.model_name,step)) 41 | 42 | def Accuracy(pred, labels): 43 | 44 | acc = 0.0 45 | num_pred = pred.size()[0] 46 | pred = torch.max(pred, 1)[1] 47 | for idx in range(num_pred): 48 | if pred[idx].item() == labels[idx].item(): 49 | acc += 1 50 | 51 | return acc/num_pred 52 | 53 | def getNumberOfSpeakers(labelsFilePath): 54 | 55 | speakersDict = dict() 56 | with open(labelsFilePath,'r') as labelsFile: 57 | for line in labelsFile.readlines(): 58 | speakersDict[line.split()[1]] = 0 59 | return len(speakersDict) 60 | 61 | def getModelName(params): 62 | 63 | model_name = params.model_name 64 | 65 | model_name = model_name + '_{}'.format(params.front_end) + '_{}'.format(params.window_size) + '_{}batchSize'.format(params.batch_size*params.gradientAccumulation) + '_{}lr'.format(params.learning_rate) + '_{}weightDecay'.format(params.weight_decay) + '_{}kernel'.format(params.kernel_size) +'_{}embSize'.format(params.embedding_size) + '_{}s'.format(params.scalingFactor) + '_{}m'.format(params.marginFactor) 66 | 67 | model_name += '_{}'.format(params.pooling_method) + '_{}'.format(params.heads_number) 68 | 69 | return model_name 70 | -------------------------------------------------------------------------------- /scripts/data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from random import randint, randrange 4 | from torch.utils import data 5 | import soundfile as sf 6 | 7 | def featureReader(featurePath, VAD=None): 8 | 9 | with open(featurePath,'rb') as pickleFile: 10 | features = pickle.load(pickleFile) 11 | if VAD is not None: 12 | filtered_features = VAD.filter(features) 13 | else: 14 | filtered_features = features 15 | 16 | if filtered_features.shape[1]>0.: 17 | return np.transpose(filtered_features) 18 | else: 19 | return np.transpose(features) 20 | 21 | def normalizeFeatures(features, normalization='cmn'): 22 | 23 | mean = np.mean(features, axis=0) 24 | features -= mean 25 | if normalization=='cmn': 26 | return features 27 | if normalization=='cmvn': 28 | std = np.std(features, axis=0) 29 | std = np.where(std>0.01,std,1.0) 30 | return features/std 31 | 32 | class Dataset(data.Dataset): 33 | 34 | def __init__(self, utterances, parameters): 35 | 'Initialization' 36 | self.utterances = utterances 37 | self.parameters = parameters 38 | self.num_samples = len(utterances) 39 | 40 | def __normalize(self, features): 41 | mean = np.mean(features, axis=0) 42 | features -= mean 43 | if self.parameters.normalization=='cmn': 44 | return features 45 | if self.parameters.normalization=='cmvn': 46 | std = np.std(features, axis=0) 47 | std = np.where(std>0.01,std,1.0) 48 | return features/std 49 | 50 | def __sampleSpectogramWindow(self, features): 51 | file_size = features.shape[0] 52 | windowSizeInFrames = self.parameters.window_size*100 53 | index = randint(0, max(0,file_size-windowSizeInFrames-1)) 54 | a = np.array(range(min(file_size, int(windowSizeInFrames))))+index 55 | return features[a,:] 56 | 57 | def __getFeatureVector(self, utteranceName): 58 | 59 | with open(utteranceName + '.pickle','rb') as pickleFile: 60 | features = pickle.load(pickleFile) 61 | windowedFeatures = self.__sampleSpectogramWindow(self.__normalize(np.transpose(features))) 62 | return windowedFeatures 63 | 64 | def __len__(self): 65 | return self.num_samples 66 | 67 | def __getitem__(self, index): 68 | 'Generates one sample of data' 69 | utteranceTuple = self.utterances[index].strip().split() 70 | utteranceName = self.parameters.train_data_dir + '/' + utteranceTuple[0] 71 | utteranceLabel = int(utteranceTuple[1]) 72 | 73 | return self.__getFeatureVector(utteranceName), np.array(utteranceLabel) 74 | 75 | -------------------------------------------------------------------------------- /scripts/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from torch import nn 4 | 5 | class AMSoftmax(nn.Module): 6 | 7 | ''' 8 | Additve Margin Softmax as proposed in: 9 | https://arxiv.org/pdf/1801.05599.pdf 10 | Implementation Extracted From 11 | https://github.com/clovaai/voxceleb_trainer/blob/master/loss/cosface.py 12 | ''' 13 | 14 | def __init__(self, in_feats, n_classes, m=0.3, s=15, annealing=False): 15 | super(AMSoftmax, self).__init__() 16 | self.in_feats = in_feats 17 | self.m = m 18 | self.s = s 19 | self.annealing = annealing 20 | self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), requires_grad=True) 21 | nn.init.xavier_normal_(self.W, gain=1) 22 | self.annealing=annealing 23 | 24 | def getAnnealedFactor(self,step): 25 | alpha = self.__getAlpha(step) if self.annealing else 0. 26 | return 1/(1+alpha) 27 | 28 | def __getAlpha(self,step): 29 | return max(0, 1000./(pow(1.+0.0001*float(step),2.))) 30 | 31 | def __getCombinedCosth(self, costh, costh_m, step): 32 | 33 | alpha = self.__getAlpha(step) if self.annealing else 0. 34 | costh_combined = costh_m + alpha*costh 35 | return costh_combined/(1+alpha) 36 | 37 | def forward(self, x, label=None, step=0): 38 | assert x.size()[0] == label.size()[0] 39 | assert x.size()[1] == self.in_feats 40 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 41 | x_norm = torch.div(x, x_norm) 42 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 43 | w_norm = torch.div(self.W, w_norm) 44 | costh = torch.mm(x_norm, w_norm) 45 | label_view = label.view(-1, 1) 46 | if label_view.is_cuda: label_view = label_view.cpu() 47 | delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m) 48 | if x.is_cuda: delt_costh = delt_costh.cuda() 49 | costh_m = costh - delt_costh 50 | costh_combined = self.__getCombinedCosth(costh, costh_m, step) 51 | costh_m_s = self.s * costh_combined 52 | return costh, costh_m_s 53 | 54 | class FocalSoftmax(nn.Module): 55 | ''' 56 | Focal softmax as proposed in: 57 | "Focal Loss for Dense Object Detection" 58 | by T-Y. Lin et al. 59 | https://github.com/foamliu/InsightFace-v2/blob/master/focal_loss.py 60 | ''' 61 | def __init__(self, gamma=2): 62 | super(FocalSoftmax, self).__init__() 63 | self.gamma = gamma 64 | self.ce = nn.CrossEntropyLoss() 65 | 66 | def forward(self, input, target): 67 | logp = self.ce(input, target) 68 | p = torch.exp(-logp) 69 | loss = (1 - p) ** self.gamma * logp 70 | return loss.mean() 71 | 72 | -------------------------------------------------------------------------------- /scripts/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from poolings import * 5 | from CNNs import * 6 | from loss import * 7 | 8 | class SpeakerClassifier(nn.Module): 9 | 10 | def __init__(self, parameters, device): 11 | super().__init__() 12 | 13 | parameters.feature_size = 80 14 | self.device = device 15 | self.__initFrontEnd(parameters) 16 | self.__initPoolingLayers(parameters) 17 | self.__initFullyConnectedBlock(parameters) 18 | self.predictionLayer = AMSoftmax(parameters.embedding_size, parameters.num_spkrs, s=parameters.scalingFactor, m=parameters.marginFactor, annealing = parameters.annealing) 19 | 20 | 21 | def __initFrontEnd(self, parameters): 22 | 23 | if parameters.front_end=='VGG3L': 24 | self.vector_size = getVGG3LOutputDimension(parameters.feature_size, outputChannel=parameters.kernel_size) 25 | self.front_end = VGG3L(parameters.kernel_size) 26 | 27 | if parameters.front_end=='VGG4L': 28 | self.vector_size = getVGG4LOutputDimension(parameters.feature_size, outputChannel=parameters.kernel_size) 29 | self.front_end = VGG4L(parameters.kernel_size) 30 | 31 | def __initPoolingLayers(self,parameters): 32 | 33 | self.pooling_method = parameters.pooling_method 34 | 35 | if self.pooling_method == 'Attention': 36 | self.poolingLayer = Attention(self.vector_size) 37 | elif self.pooling_method == 'MHA': 38 | self.poolingLayer = MultiHeadAttention(self.vector_size, parameters.heads_number) 39 | elif self.pooling_method == 'DoubleMHA': 40 | self.poolingLayer = DoubleMHA(self.vector_size, parameters.heads_number, mask_prob = parameters.mask_prob) 41 | self.vector_size = self.vector_size//parameters.heads_number 42 | 43 | def __initFullyConnectedBlock(self, parameters): 44 | 45 | self.fc1 = nn.Linear(self.vector_size, parameters.embedding_size) 46 | self.b1 = nn.BatchNorm1d(parameters.embedding_size) 47 | self.fc2 = nn.Linear(parameters.embedding_size, parameters.embedding_size) 48 | self.b2 = nn.BatchNorm1d(parameters.embedding_size) 49 | self.preLayer = nn.Linear(parameters.embedding_size, parameters.embedding_size) 50 | self.b3 = nn.BatchNorm1d(parameters.embedding_size) 51 | 52 | def getEmbedding(self,x): 53 | 54 | encoder_output = self.front_end(x) 55 | embedding0, alignment = self.poolingLayer(encoder_output) 56 | embedding1 = F.relu(self.fc1(embedding0)) 57 | embedding2 = self.b2(F.relu(self.fc2(embedding1))) 58 | 59 | return embedding2 60 | 61 | def forward(self, x, label=None, step=0): 62 | 63 | encoder_output = self.front_end(x) 64 | 65 | embedding0, alignment = self.poolingLayer(encoder_output) 66 | embedding1 = F.relu(self.fc1(embedding0)) 67 | embedding2 = self.b2(F.relu(self.fc2(embedding1))) 68 | embedding3 = self.preLayer(embedding2) 69 | prediction, ouputTensor = self.predictionLayer(embedding3, label, step) 70 | 71 | return prediction, ouputTensor 72 | 73 | -------------------------------------------------------------------------------- /scripts/CNNs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | def getVGG3LOutputDimension(inputDimension, outputChannel=128): 8 | 9 | outputDimension = np.ceil(np.array(inputDimension, dtype=np.float32)/2) 10 | outputDimension = np.ceil(np.array(outputDimension, dtype=np.float32)/2) 11 | outputDimension = np.ceil(np.array(outputDimension, dtype=np.float32)/2) 12 | return int(outputDimension) * outputChannel 13 | 14 | def getVGG4LOutputDimension(inputDimension, outputChannel=128): 15 | 16 | outputDimension = np.ceil(np.array(inputDimension, dtype=np.float32)/2) 17 | outputDimension = np.ceil(np.array(outputDimension, dtype=np.float32)/2) 18 | outputDimension = np.ceil(np.array(outputDimension, dtype=np.float32)/2) 19 | outputDimension = np.ceil(np.array(outputDimension, dtype=np.float32)/2) 20 | return int(outputDimension) * outputChannel 21 | 22 | class VGG3L(torch.nn.Module): 23 | 24 | def __init__(self, kernel_size): 25 | super(VGG3L, self).__init__() 26 | 27 | self.conv11 = torch.nn.Conv2d(1, int(kernel_size/4), 3, stride=1, padding=1) 28 | self.conv12 = torch.nn.Conv2d(int(kernel_size/4), int(kernel_size/4), 3, stride=1, padding=1) 29 | self.conv21 = torch.nn.Conv2d(int(kernel_size/4), int(kernel_size/2), 3, stride=1, padding=1) 30 | self.conv22 = torch.nn.Conv2d(int(kernel_size/2), int(kernel_size/2), 3, stride=1, padding=1) 31 | self.conv31 = torch.nn.Conv2d(int(kernel_size/2), int(kernel_size), 3, stride=1, padding=1) 32 | self.conv32 = torch.nn.Conv2d(int(kernel_size), int(kernel_size), 3, stride=1, padding=1) 33 | 34 | def forward(self, paddedInputTensor): 35 | 36 | paddedInputTensor = paddedInputTensor.view( paddedInputTensor.size(0), paddedInputTensor.size(1), 1, paddedInputTensor.size(2)).transpose(1, 2) 37 | 38 | encodedTensorLayer1 = F.relu(self.conv11(paddedInputTensor)) 39 | encodedTensorLayer1 = F.relu(self.conv12(encodedTensorLayer1)) 40 | encodedTensorLayer1 = F.max_pool2d(encodedTensorLayer1, 2, stride=2, ceil_mode=True) 41 | 42 | encodedTensorLayer2 = F.relu(self.conv21(encodedTensorLayer1)) 43 | encodedTensorLayer2 = F.relu(self.conv22(encodedTensorLayer2)) 44 | encodedTensorLayer2 = F.max_pool2d(encodedTensorLayer2, 2, stride=2, ceil_mode=True) 45 | 46 | encodedTensorLayer3 = F.relu(self.conv31(encodedTensorLayer2)) 47 | encodedTensorLayer3 = F.relu(self.conv32(encodedTensorLayer3)) 48 | encodedTensorLayer3 = F.max_pool2d(encodedTensorLayer3, 2, stride=2, ceil_mode=True) 49 | outputTensor = encodedTensorLayer3.transpose(1, 2) 50 | outputTensor = outputTensor.contiguous().view(outputTensor.size(0), outputTensor.size(1), outputTensor.size(2) * outputTensor.size(3)) 51 | 52 | return outputTensor 53 | 54 | class VGG4L(torch.nn.Module): 55 | 56 | def __init__(self, kernel_size): 57 | super(VGG4L, self).__init__() 58 | 59 | self.conv11 = torch.nn.Conv2d(1, int(kernel_size/8), 3, stride=1, padding=1) 60 | self.conv12 = torch.nn.Conv2d(int(kernel_size/8), int(kernel_size/8), 3, stride=1, padding=1) 61 | self.conv21 = torch.nn.Conv2d(int(kernel_size/8), int(kernel_size/4), 3, stride=1, padding=1) 62 | self.conv22 = torch.nn.Conv2d(int(kernel_size/4), int(kernel_size/4), 3, stride=1, padding=1) 63 | self.conv31 = torch.nn.Conv2d(int(kernel_size/4), int(kernel_size/2), 3, stride=1, padding=1) 64 | self.conv32 = torch.nn.Conv2d(int(kernel_size/2), int(kernel_size/2), 3, stride=1, padding=1) 65 | self.conv41 = torch.nn.Conv2d(int(kernel_size/2), int(kernel_size), 3, stride=1, padding=1) 66 | self.conv42 = torch.nn.Conv2d(int(kernel_size), int(kernel_size), 3, stride=1, padding=1) 67 | 68 | def forward(self, paddedInputTensor): 69 | 70 | paddedInputTensor = paddedInputTensor.view( paddedInputTensor.size(0), paddedInputTensor.size(1), 1, paddedInputTensor.size(2)).transpose(1, 2) 71 | 72 | encodedTensorLayer1 = F.relu(self.conv11(paddedInputTensor)) 73 | encodedTensorLayer1 = F.relu(self.conv12(encodedTensorLayer1)) 74 | encodedTensorLayer1 = F.max_pool2d(encodedTensorLayer1, 2, stride=2, ceil_mode=True) 75 | 76 | encodedTensorLayer2 = F.relu(self.conv21(encodedTensorLayer1)) 77 | encodedTensorLayer2 = F.relu(self.conv22(encodedTensorLayer2)) 78 | encodedTensorLayer2 = F.max_pool2d(encodedTensorLayer2, 2, stride=2, ceil_mode=True) 79 | 80 | encodedTensorLayer3 = F.relu(self.conv31(encodedTensorLayer2)) 81 | encodedTensorLayer3 = F.relu(self.conv32(encodedTensorLayer3)) 82 | encodedTensorLayer3 = F.max_pool2d(encodedTensorLayer3, 2, stride=2, ceil_mode=True) 83 | 84 | encodedTensorLayer4 = F.relu(self.conv41(encodedTensorLayer3)) 85 | encodedTensorLayer4 = F.relu(self.conv42(encodedTensorLayer4)) 86 | encodedTensorLayer4 = F.max_pool2d(encodedTensorLayer4, 2, stride=2, ceil_mode=True) 87 | 88 | outputTensor = encodedTensorLayer4.transpose(1, 2) 89 | outputTensor = outputTensor.contiguous().view(outputTensor.size(0), outputTensor.size(1), outputTensor.size(2) * outputTensor.size(3)) 90 | 91 | return outputTensor 92 | 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DoubleAttentionSpeakerVerification 2 | 3 | Pytorch implemenation of the model proposed in the paper: 4 | 5 | [Double Multi-Head Attention for Speaker Verification](https://arxiv.org/abs/2007.13199) 6 | 7 | ## Installation 8 | 9 | This repository has been created using python3.6. You can find the python3 10 | dependencies on requirements.txt. Hence you can install it by: 11 | 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Note that soundfile library also needs the C libsndfile library. You can find 17 | more details about its installation in [Soundfile](https://pysoundfile.readthedocs.io/en/latest/). 18 | 19 | ## Usage 20 | 21 | This repository shoud allow you to train a speaker embedding extractor according to the setup described in the paper. This speaker embedding extractor is based on a VGG-based classifier which identifies speaker identities given variable length audio utterances. The network used for this work uses log mel-spectogram features as input. Hence, we have added here the instructions to reproduce the feature extraction, the network training and the speaker embedding extraction step. Feel free to ask any doubt via git-hub issues, [twitter](https://twitter.com/mikiindia) or mail(miquel.angel.india@upc.edu). 22 | 23 | ### Feature Extraction 24 | 25 | You can find in `scripts/featureExtractor.py` several functions which extract and normalize the log mel-spectogram descriptors. If you want to run the whole feature extraction over a set of audios you can run the following command: 26 | 27 | ```bash 28 | python scripts/featureExtractor -i files.lst 29 | ``` 30 | 31 | where `files.lst` contains the audio paths aimed to parameterize. Each row of the file must contain an audio path without the file format extension (we assume you will be using .wav). Example: 32 | 33 |
34 | audiosPath/audio1
35 | audiosPath/audio2
36 | ...
37 | audiosPath/audioN
38 | 39 | This script will extract a feature for each audio file and it will store it in a pickle in the same audio path. 40 | 41 | ### Network Training 42 | 43 | Once you have extracted the features from all the audios wanted to be used, It is needed to prepare some path files for the training step. The proposed models are trained as speaker classifiers, hence a classification-based loss and an accuracy metric will be used to monitorize the training progress. However in the validation step, an EER estimation is used to validate the network progress. The motivation behind this is that best accuracy models do not always have the best inter/intra speaker variability. Therefore we prefer to use directly a task based metric to validate the model instead of using a classification one. Two different kind of path files will then be needed for the training/validation procedures: 44 | 45 | Train Labels File (`train_labels_path`): 46 | 47 | This file must have three columns separated by a blank space. The first column must contain the audio utterance paths, the second column must contain the speaker labels and the third one must be filled with -1. It is assumed that the labels correspond to the output network labels. Hence if you are working with a N speakers database, the speaker labels values should be in the 0 to N-1 range. 48 | 49 | File Example: 50 | 51 |
52 | audiosPath/speaker1/audio1 0 -1
53 | audiosPath/speaker1/audio2 0 -1
54 | ...
55 | audiosPath/speakerN/audio4 N-1 -1
56 | 57 | We have also added a `--train_data_dir` path argument. The dataloader will then look for the features in `--train_data_dir` + `audiosPath/speakeri/audioj` paths. 58 | 59 | Valid Labels File: 60 | 61 | For the validation step, it will be needed a tuple of client/impostors trial files. Client trials (`valid_clients`) file must contain pairs of audio utterances from same speakers and the impostors trials (`valid_impostors`) file must also contain audio utterance pairs but from different speakers. Each pair path must be separated with a blank space: 62 | 63 | File Example (Clients): 64 | 65 |
66 | audiosPath/speaker1/audio1 audiosPath/speaker1/audio2
67 | audiosPath/speaker1/audio1 audiosPath/speaker1/audio3
68 | 
69 |   
70 | audiosPath/speakerN/audio4 audiosPath/speakerN/audio3
71 | 72 | Similar to the train file, we have also added a `--valid_data_dir` argument. 73 | 74 | Once you have all these data files ready, you can launch a model training with the following command: 75 | 76 | 77 | ```bash 78 | python scripts/train.py 79 | ``` 80 | 81 | With this script you will launch the model training with the default setup defined in `scripts/train.py`. The model will be trained following the methods and procedures described in the paper. The best models found will be saved in the `--out_dir` directory. You will find there a `.pkl` file with the training/model configuration and several checkpoint `.pt` files which store model weghts, optimizer state values, etc. The best saved models correspond to the last saved checkpoints. 82 | 83 | 84 | ### Speaker Embedding Extraction 85 | 86 | Given a trained model, this one can be used to extract a speaker embedding from a variable-length audio. We have added a script example to show how to use the models to extract speaker embeddings. This can be then used to extract similiary scores between audios computing the cosine distance between their embeddings. Run the following command: 87 | 88 | ```bash 89 | python scripts/getEmbeddingExample.py --audioPath --modelConfig --modelCheckpoint 90 | ``` 91 | This script will load the model and will extract/print the embedding given the input audio. 92 | 93 | -------------------------------------------------------------------------------- /scripts/poolings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import math 7 | import copy 8 | 9 | def new_parameter(*size): 10 | out = torch.nn.Parameter(torch.FloatTensor(*size)) 11 | torch.nn.init.xavier_normal_(out) 12 | return out 13 | 14 | class Attention(nn.Module): 15 | 16 | def __init__(self, embedding_size): 17 | 18 | super(Attention, self).__init__() 19 | self.embedding_size = embedding_size 20 | self.att=new_parameter(self.embedding_size,1) 21 | 22 | def forward(self,ht): 23 | attention_score = torch.matmul(ht, self.att).squeeze() 24 | attention_score = F.softmax(attention_score, dim=-1).view(ht.size(0), ht.size(1),1) 25 | ct = torch.sum(ht * attention_score,dim=1) 26 | 27 | return ct, attention_score 28 | 29 | class HeadAttention(nn.Module): 30 | 31 | def __init__(self, encoder_size, heads_number, mask_prob = 0.25, attentionSmoothing=False): 32 | 33 | super(HeadAttention, self).__init__() 34 | self.embedding_size = encoder_size//heads_number 35 | self.att=new_parameter(self.embedding_size,1) 36 | self.mask_prob = int(1/mask_prob) 37 | self.attentionSmoothing = attentionSmoothing 38 | 39 | def __maskAttention(self, attention_score, mask_value = -float('inf')): 40 | 41 | mask = torch.cuda.FloatTensor(attention_score.size()).random_(self.mask_prob)>0 42 | attention_score[~mask] = mask_value 43 | return attention_score 44 | 45 | def __narrowAttention(self, new_ht): 46 | 47 | attention_score = torch.matmul(new_ht, self.att).squeeze() 48 | if self.training: 49 | attention_score = self.__maskAttention(attention_score) 50 | attention_score = F.softmax(attention_score, dim=-1).view(new_ht.size(0), new_ht.size(1),1) 51 | return attention_score 52 | 53 | def __wideAttention(self): 54 | 55 | attention_score = torch.matmul(new_ht, self.att).squeeze() 56 | if self.training: 57 | attention_score = self.__maskAttention(attention_score, mask_value = -1) 58 | attention_score /= torch.sum(attention_score, dim=1).unsqueeze(1) 59 | return attention_score.view(new_ht.size(0), new_ht.size(1),1) 60 | 61 | def forward(self,ht): 62 | 63 | if self.attentionSmoothing: 64 | attention_score = self.__wideAttention(ht) 65 | else: 66 | attention_score = self.__narrowAttention(ht) 67 | 68 | weighted_ht = ht * attention_score 69 | ct = torch.sum(weighted_ht,dim=1) 70 | 71 | return ct, attention_score 72 | 73 | def innerKeyValueAttention(query, key, value): 74 | 75 | d_k = query.size(-1) 76 | scores = torch.diagonal(torch.matmul(key, query) / math.sqrt(d_k), dim1=-2, dim2=-1).view(value.size(0),value.size(1), value.size(2)) 77 | p_attn = F.softmax(scores, dim = -2) 78 | weighted_vector = value * p_attn.unsqueeze(-1) 79 | ct = torch.sum(weighted_vector, dim=1) 80 | return ct, p_attn 81 | 82 | 83 | class MultiHeadAttention(nn.Module): 84 | def __init__(self, encoder_size, heads_number): 85 | super(MultiHeadAttention, self).__init__() 86 | self.encoder_size = encoder_size 87 | assert self.encoder_size % heads_number == 0 # d_model 88 | self.head_size = self.encoder_size // heads_number 89 | self.heads_number = heads_number 90 | self.query = new_parameter(self.head_size, self.heads_number) 91 | self.aligmment = None 92 | 93 | def getAlignments(self,ht): 94 | batch_size = ht.size(0) 95 | key = ht.view(batch_size*ht.size(1), self.heads_number, self.head_size) 96 | value = ht.view(batch_size,-1,self.heads_number, self.head_size) 97 | headsContextVectors, self.alignment = innerKeyValueAttention(self.query, key, value) 98 | return self.alignment 99 | 100 | def getHeadsContextVectors(self,ht): 101 | batch_size = ht.size(0) 102 | key = ht.view(batch_size*ht.size(1), self.heads_number, self.head_size) 103 | value = ht.view(batch_size,-1,self.heads_number, self.head_size) 104 | headsContextVectors, self.alignment = innerKeyValueAttention(self.query, key, value) 105 | return headsContextVectors 106 | 107 | def forward(self, ht): 108 | headsContextVectors = self.getHeadsContextVectors(ht) 109 | return headsContextVectors.view(headsContextVectors.size(0),-1), copy.copy(self.alignment) 110 | 111 | 112 | class DoubleMHA(nn.Module): 113 | def __init__(self, encoder_size, heads_number, mask_prob=0.2): 114 | super(DoubleMHA, self).__init__() 115 | self.heads_number = heads_number 116 | self.utteranceAttention = MultiHeadAttention(encoder_size, heads_number) 117 | self.heads_size = encoder_size // heads_number 118 | self.headsAttention = HeadAttention(encoder_size, heads_number, mask_prob=mask_prob, attentionSmoothing=False) 119 | 120 | def getAlignments(self, x): 121 | 122 | utteranceRepresentation, alignment = self.utteranceAttention(x) 123 | headAlignments = self.headsAttention(utteranceRepresentation.view(utteranceRepresentation.size(0), self.heads_number, self.heads_size))[1] 124 | return alignment, headAlignments 125 | 126 | def forward(self, x): 127 | utteranceRepresentation, alignment = self.utteranceAttention(x) 128 | compressedRepresentation = self.headsAttention(utteranceRepresentation.view(utteranceRepresentation.size(0), self.heads_number, self.heads_size))[0] 129 | return compressedRepresentation, alignment 130 | 131 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | import random 6 | import pickle 7 | import time 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch import optim 12 | from torch.utils.data import DataLoader 13 | sys.path.append('./scripts/') 14 | from data import * 15 | from model import SpeakerClassifier 16 | from loss import * 17 | from utils import * 18 | 19 | class Trainer: 20 | 21 | def __init__(self, params, device): 22 | 23 | self.params = params 24 | self.device = device 25 | self.__load_network() 26 | self.__load_data() 27 | self.__load_optimizer() 28 | self.__load_criterion() 29 | self.__initialize_training_variables() 30 | 31 | def __load_previous_states(self): 32 | 33 | list_files = os.listdir(self.params.out_dir) 34 | list_files = [self.params.out_dir + '/' + f for f in list_files if '.chkpt' in f] 35 | if list_files: 36 | file2load = max(list_files, key=os.path.getctime) 37 | checkpoint = torch.load(file2load, map_location=self.device) 38 | try: 39 | self.net.load_state_dict(checkpoint['model']) 40 | except RuntimeError: 41 | self.net.module.load_state_dict(checkpoint['model']) 42 | self.optimizer.load_state_dict(checkpoint['optimizer']) 43 | self.params = checkpoint['settings'] 44 | self.starting_epoch = checkpoint['epoch']+1 45 | self.step = checkpoint['step'] 46 | print('Model "%s" is Loaded for requeue process' % file2load) 47 | else: 48 | self.step = 0 49 | self.starting_epoch = 1 50 | 51 | def __initialize_training_variables(self): 52 | 53 | if self.params.requeue: 54 | self.__load_previous_states() 55 | else: 56 | self.step = 0 57 | self.starting_epoch = 0 58 | 59 | self.best_EER = 50.0 60 | self.stopping = 0.0 61 | 62 | 63 | def __load_network(self): 64 | 65 | self.net = SpeakerClassifier(self.params, self.device) 66 | self.net.to(self.device) 67 | 68 | if torch.cuda.device_count() > 1: 69 | print("Let's use", torch.cuda.device_count(), "GPUs!") 70 | self.net = nn.DataParallel(self.net) 71 | 72 | 73 | def __load_data(self): 74 | print('Loading Data and Labels') 75 | with open(self.params.train_labels_path, 'r') as data_labels_file: 76 | train_labels=data_labels_file.readlines() 77 | 78 | data_loader_parameters = {'batch_size': self.params.batch_size, 'shuffle': True, 'num_workers': self.params.num_workers} 79 | self.training_generator = DataLoader(Dataset(train_labels, self.params), **data_loader_parameters) 80 | 81 | 82 | def __load_optimizer(self): 83 | if self.params.optimizer == 'Adam': 84 | self.optimizer = optim.Adam(self.net.parameters(), lr=self.params.learning_rate, weight_decay=self.params.weight_decay) 85 | if self.params.optimizer == 'SGD': 86 | self.optimizer = optim.SGD(self.net.parameters(), lr=self.params.learning_rate, weight_decay=self.params.weight_decay) 87 | if self.params.optimizer == 'RMSprop': 88 | self.optimizer = optim.RMSprop(self.net.parameters(), lr=self.params.learning_rate, weight_decay=self.params.weight_decay) 89 | 90 | def __update_optimizer(self): 91 | 92 | if self.params.optimizer == 'SGD' or self.params.optimizer == 'Adam': 93 | for paramGroup in self.optimizer.param_groups: 94 | paramGroup['lr'] *= 0.5 95 | print('New Learning Rate: {}'.format(paramGroup['lr'])) 96 | 97 | def __load_criterion(self): 98 | self.criterion = nn.CrossEntropyLoss() 99 | 100 | def __initialize_batch_variables(self): 101 | 102 | self.print_time = time.time() 103 | self.train_loss = 0.0 104 | self.train_accuracy = 0.0 105 | self.train_batch = 0 106 | 107 | def __extractInputFromFeature(self, sline): 108 | 109 | features1 = normalizeFeatures(featureReader(self.params.valid_data_dir + '/' + sline[0] + '.pickle'), normalization=self.params.normalization) 110 | features2 = normalizeFeatures(featureReader(self.params.valid_data_dir + '/' + sline[1] + '.pickle'), normalization=self.params.normalization) 111 | 112 | input1 = torch.FloatTensor(features1).to(self.device) 113 | input2 = torch.FloatTensor(features2).to(self.device) 114 | 115 | return input1.unsqueeze(0), input2.unsqueeze(0) 116 | 117 | def __extract_scores(self, trials): 118 | 119 | scores = [] 120 | for line in trials: 121 | sline = line[:-1].split() 122 | 123 | input1, input2 = self.__extractInputFromFeature(sline) 124 | 125 | if torch.cuda.device_count() > 1: 126 | emb1, emb2 = self.net.module.getEmbedding(input1), self.net.module.getEmbedding(input2) 127 | else: 128 | emb1, emb2 = self.net.getEmbedding(input1), self.net.getEmbedding(input2) 129 | 130 | dist = scoreCosineDistance(emb1, emb2) 131 | scores.append(dist.item()) 132 | 133 | return scores 134 | 135 | def __calculate_EER(self, CL, IM): 136 | 137 | thresholds = np.arange(-1,1,0.01) 138 | FRR, FAR = np.zeros(len(thresholds)), np.zeros(len(thresholds)) 139 | for idx,th in enumerate(thresholds): 140 | FRR[idx] = Score(CL, th,'FRR') 141 | FAR[idx] = Score(IM, th,'FAR') 142 | 143 | EER_Idx = np.argwhere(np.diff(np.sign(FAR - FRR)) != 0).reshape(-1) 144 | if len(EER_Idx)>0: 145 | if len(EER_Idx)>1: 146 | EER_Idx = EER_Idx[0] 147 | EER = round((FAR[int(EER_Idx)] + FRR[int(EER_Idx)])/2,4) 148 | else: 149 | EER = 50.00 150 | return EER 151 | 152 | def __getAnnealedFactor(self): 153 | if torch.cuda.device_count() > 1: 154 | return self.net.module.predictionLayer.getAnnealedFactor(self.step) 155 | else: 156 | return self.net.predictionLayer.getAnnealedFactor(self.step) 157 | 158 | def __validate(self): 159 | 160 | with torch.no_grad(): 161 | valid_time = time.time() 162 | self.net.eval() 163 | # EER Validation 164 | with open(params.valid_clients,'r') as clients_in, open(params.valid_impostors,'r') as impostors_in: 165 | # score clients 166 | CL = self.__extract_scores(clients_in) 167 | IM = self.__extract_scores(impostors_in) 168 | # Compute EER 169 | EER = self.__calculate_EER(CL, IM) 170 | 171 | annealedFactor = self.__getAnnealedFactor() 172 | print('Annealed Factor is {}.'.format(annealedFactor)) 173 | print('--Validation Epoch:{epoch: d}, Updates:{Num_Batch: d}, EER:{eer: 3.3f}, elapse:{elapse: 3.3f} min'.format(epoch=self.epoch, Num_Batch=self.step, eer=EER, elapse=(time.time()-valid_time)/60)) 174 | # early stopping and save the best model 175 | if EER < self.best_EER: 176 | self.best_EER = EER 177 | self.stopping = 0 178 | print('We found a better model!') 179 | chkptsave(params, self.net, self.optimizer, self.epoch, self.step) 180 | else: 181 | self.stopping += 1 182 | print('Better Accuracy is: {}. {} epochs of no improvement'.format(self.best_EER, self.stopping)) 183 | self.print_time = time.time() 184 | self.net.train() 185 | 186 | def __update(self): 187 | 188 | self.optimizer.step() 189 | self.optimizer.zero_grad() 190 | self.step += 1 191 | 192 | if self.step % int(self.params.print_every) == 0: 193 | print('Training Epoch:{epoch: d}, Updates:{Num_Batch: d} -----> xent:{xnet: .3f}, Accuracy:{acc: .2f}, elapse:{elapse: 3.3f} min'.format(epoch=self.epoch, Num_Batch=self.step, xnet=self.train_loss / self.train_batch, acc=self.train_accuracy *100/ self.train_batch, elapse=(time.time()-self.print_time)/60)) 194 | self.__initialize_batch_variables() 195 | 196 | # validation 197 | if self.step % self.params.validate_every == 0: 198 | self.__validate() 199 | 200 | def __updateTrainningVariables(self): 201 | 202 | if (self.stopping+1)% 15 ==0: 203 | self.__update_optimizer() 204 | 205 | def __randomSlice(self, inputTensor): 206 | index = random.randrange(200,self.params.window_size*100) 207 | return inputTensor[:,:index,:] 208 | 209 | def train(self): 210 | 211 | print('Start Training') 212 | for self.epoch in range(self.starting_epoch, self.params.max_epochs): # loop over the dataset multiple times 213 | self.net.train() 214 | self.__initialize_batch_variables() 215 | for input, label in self.training_generator: 216 | input, label = input.float().to(self.device), label.long().to(self.device) 217 | input = self.__randomSlice(input) if self.params.randomSlicing else input 218 | prediction, AMPrediction = self.net(input, label=label, step=self.step) 219 | loss = self.criterion(AMPrediction, label) 220 | loss.backward() 221 | self.train_accuracy += Accuracy(prediction, label) 222 | self.train_loss += loss.item() 223 | 224 | self.train_batch += 1 225 | if self.train_batch % self.params.gradientAccumulation == 0: 226 | self.__update() 227 | 228 | if self.stopping > self.params.early_stopping: 229 | print('--Best Model EER%%: %.2f' %(self.best_EER)) 230 | break 231 | 232 | self.__updateTrainningVariables() 233 | 234 | 235 | print('Finished Training') 236 | 237 | def main(opt): 238 | 239 | torch.manual_seed(1234) 240 | np.random.seed(1234) 241 | 242 | print('Defining Device') 243 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 244 | print(device) 245 | print(torch.cuda.get_device_name(0)) 246 | 247 | print('Loading Trainer') 248 | trainer = Trainer(opt, device) 249 | trainer.train() 250 | 251 | if __name__=="__main__": 252 | 253 | parser = argparse.ArgumentParser(description='Train a VGG based Speaker Embedding Extractor') 254 | 255 | parser.add_argument('--train_data_dir', type=str, default='/scratch/speaker_databases/', help='data directory.') 256 | parser.add_argument('--valid_data_dir', type=str, default='/scratch/speaker_databases/VoxCeleb-1/wav', help='data directory.') 257 | parser.add_argument('--train_labels_path', type = str, default = 'labels/Vox2.ndx') 258 | parser.add_argument('--data_mode', type = str, default = 'normal', choices=['normal','window']) 259 | parser.add_argument('--valid_clients', type = str, default='labels/clients.ndx') 260 | parser.add_argument('--valid_impostors', type = str, default='labels/impostors.ndx') 261 | parser.add_argument('--out_dir', type=str, default='./models/model1', help='directory where data is saved') 262 | parser.add_argument('--model_name', type=str, default='CNN', help='Model associated to the model builded') 263 | parser.add_argument('--front_end', type=str, default='VGG4L', choices = ['VGG3L','VGG4L'], help='Kind of Front-end Used') 264 | 265 | # Network Parameteres 266 | parser.add_argument('--window_size', type=float, default=3.5, help='number of seconds per window') 267 | parser.add_argument('--randomSlicing',action='store_true') 268 | parser.add_argument('--normalization', type=str, default='cmn', choices=['cmn', 'cmvn']) 269 | parser.add_argument('--kernel_size', type=int, default=1024) 270 | parser.add_argument('--embedding_size', type=int, default=400) 271 | parser.add_argument('--heads_number', type=int, default=32) 272 | parser.add_argument('--pooling_method', type=str, default='DoubleMHA', choices=['Attention', 'MHA', 'DoubleMHA'], help='Type of pooling methods') 273 | parser.add_argument('--mask_prob', type=float, default=0.3, help='Masking Drop Probability. Only Used for Only Double MHA') 274 | 275 | # AMSoftmax Config 276 | parser.add_argument('--scalingFactor', type=float, default=30.0, help='') 277 | parser.add_argument('--marginFactor', type=float, default=0.4, help='') 278 | parser.add_argument('--annealing', action='store_true') 279 | 280 | # Optimization 281 | parser.add_argument('--optimizer', type=str, choices=['Adam', 'SGD', 'RMSprop'], default='Adam') 282 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='') 283 | parser.add_argument('--weight_decay', type=float, default=0.001, help='') 284 | parser.add_argument('--batch_size', type=int, default=64, help='number of sequences to train on in parallel') 285 | parser.add_argument('--gradientAccumulation', type=int, default=2) 286 | parser.add_argument('--max_epochs', type=int, default=1000000, help='number of full passes through the trainning data') 287 | parser.add_argument('--early_stopping', type=int, default=25, help='-1 if not early stopping') 288 | parser.add_argument('--print_every', type = int, default = 1000) 289 | parser.add_argument('--requeue',action='store_true', help='restart from the last model for requeue on slurm') 290 | parser.add_argument('--validate_every', type = int, default = 10000) 291 | parser.add_argument('--num_workers', type = int, default = 2) 292 | 293 | # parse input params 294 | params=parser.parse_args() 295 | params.model_name = getModelName(params) 296 | params.num_spkrs = getNumberOfSpeakers(params.train_labels_path) 297 | print('{} Speaker Labels'.format(params.num_spkrs)) 298 | 299 | if not os.path.exists(params.out_dir): 300 | os.makedirs(params.out_dir) 301 | 302 | with open(params.out_dir + '/' + params.model_name + '_config.pkl', 'wb') as handle: 303 | pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL) 304 | 305 | main(params) 306 | --------------------------------------------------------------------------------