├── README.md └── ge2e.py /README.md: -------------------------------------------------------------------------------- 1 | # GE2E-Loss 2 | Pytorch implementation of Generalized End-to-End Loss for speaker verification, proposed in https://arxiv.org/pdf/1710.10467.pdf [1]. 3 | 4 | Includes an argument to define whether to use the 'softmax' or 'contrast' type loss (equations 6 and 7 respectively in [1]). Uses vector operations to speed up calculations of the cosine similarity scores for an utterance embedding against all the other speaker embedding centroids. 5 | 6 | Below is some example code for how to use this. The example values for certain parameters are taken from [1] 7 | 8 | ```python 9 | 10 | import torch 11 | from ge2e import GE2ELoss 12 | 13 | criterion = GE2ELoss(init_w=10.0, init_b=-5.0, loss_method='softmax') #for softmax loss 14 | criterion = GE2ELoss(init_w=10.0, init_b=-5.0, loss_method='contrast') #for contrast loss 15 | 16 | N = 64 #Number of speakers in a batch 17 | M = 10 #Number of utterances for each speaker 18 | D = 256 #Dimensions of the speaker embeddings, such as a d-vector or x-vector 19 | 20 | test_input = torch.rand(N, M, D) 21 | loss = criterion(test_input) #output is a scalar 22 | loss.backward() 23 | ``` 24 | 25 | [1] GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION, https://arxiv.org/pdf/1710.10467.pdf 26 | -------------------------------------------------------------------------------- /ge2e.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class GE2ELoss(nn.Module): 6 | 7 | def __init__(self, init_w=10.0, init_b=-5.0, loss_method='softmax'): 8 | ''' 9 | Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] 10 | 11 | Accepts an input of size (N, M, D) 12 | 13 | where N is the number of speakers in the batch, 14 | M is the number of utterances per speaker, 15 | and D is the dimensionality of the embedding vector (e.g. d-vector) 16 | 17 | Args: 18 | - init_w (float): defines the initial value of w in Equation (5) of [1] 19 | - init_b (float): definies the initial value of b in Equation (5) of [1] 20 | ''' 21 | super(GE2ELoss, self).__init__() 22 | self.w = nn.Parameter(torch.tensor(init_w)) 23 | self.b = nn.Parameter(torch.tensor(init_b)) 24 | self.loss_method = loss_method 25 | 26 | assert self.loss_method in ['softmax', 'contrast'] 27 | 28 | if self.loss_method == 'softmax': 29 | self.embed_loss = self.embed_loss_softmax 30 | if self.loss_method == 'contrast': 31 | self.embed_loss = self.embed_loss_contrast 32 | 33 | def calc_new_centroids(self, dvecs, centroids, spkr, utt): 34 | ''' 35 | Calculates the new centroids excluding the reference utterance 36 | ''' 37 | excl = torch.cat((dvecs[spkr,:utt], dvecs[spkr,utt+1:])) 38 | excl = torch.mean(excl, 0) 39 | new_centroids = [] 40 | for i, centroid in enumerate(centroids): 41 | if i == spkr: 42 | new_centroids.append(excl) 43 | else: 44 | new_centroids.append(centroid) 45 | return torch.stack(new_centroids) 46 | 47 | def calc_cosine_sim(self, dvecs, centroids): 48 | ''' 49 | Make the cosine similarity matrix with dims (N,M,N) 50 | ''' 51 | cos_sim_matrix = [] 52 | for spkr_idx, speaker in enumerate(dvecs): 53 | cs_row = [] 54 | for utt_idx, utterance in enumerate(speaker): 55 | new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) 56 | # vector based cosine similarity for speed 57 | cs_row.append(torch.clamp(torch.mm(utterance.unsqueeze(1).transpose(0,1), new_centroids.transpose(0,1)) / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 1e-6)) 58 | cs_row = torch.cat(cs_row, dim=0) 59 | cos_sim_matrix.append(cs_row) 60 | return torch.stack(cos_sim_matrix) 61 | 62 | def embed_loss_softmax(self, dvecs, cos_sim_matrix): 63 | ''' 64 | Calculates the loss on each embedding $L(e_{ji})$ by taking softmax 65 | ''' 66 | N, M, _ = dvecs.shape 67 | L = [] 68 | for j in range(N): 69 | L_row = [] 70 | for i in range(M): 71 | L_row.append(-F.log_softmax(cos_sim_matrix[j,i], 0)[j]) 72 | L_row = torch.stack(L_row) 73 | L.append(L_row) 74 | return torch.stack(L) 75 | 76 | def embed_loss_contrast(self, dvecs, cos_sim_matrix): 77 | ''' 78 | Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid 79 | ''' 80 | N, M, _ = dvecs.shape 81 | L = [] 82 | for j in range(N): 83 | L_row = [] 84 | for i in range(M): 85 | centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j,i]) 86 | excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j+1:])) 87 | L_row.append(1. - torch.sigmoid(cos_sim_matrix[j,i,j]) + torch.max(excl_centroids_sigmoids)) 88 | L_row = torch.stack(L_row) 89 | L.append(L_row) 90 | return torch.stack(L) 91 | 92 | def forward(self, dvecs): 93 | ''' 94 | Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) 95 | ''' 96 | #Calculate centroids 97 | centroids = torch.mean(dvecs, 1) 98 | 99 | #Calculate the cosine similarity matrix 100 | cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids) 101 | torch.clamp(self.w, 1e-6) 102 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 103 | L = self.embed_loss(dvecs, cos_sim_matrix) 104 | return L.sum() 105 | 106 | 107 | --------------------------------------------------------------------------------