├── README.md └── projections.py /README.md: -------------------------------------------------------------------------------- 1 | ## sphereface2_speaker_verification 2 | 3 | official implementation of sphereface2 for speaker verification in [Exploring Binary Classification Loss for Speaker Verification](https://ieeexplore.ieee.org/abstract/document/10094954), this code is based on [wespeaker](https://github.com/wenet-e2e/wespeaker) 4 | 5 | ### advantages: 6 | 1. better performance, especially on hard trials 7 | 2. robust to noisy labels 8 | 3. natural parallelization for the classifier layer 9 | 10 | ## Running 11 | 12 | SphereFace2 has been supported in Wespeaker toolkit, see [#173](https://github.com/wenet-e2e/wespeaker/pull/173). Welcome to use wespeaker for developing and research. 13 | 14 | ## Results 15 | 16 | ResNet34-TSTP-emb256 17 | 18 | | Model | Margin | Params | LM | AS-Norm | vox1-O-clean | vox1-E-clean | vox1-H-clean | 19 | |:------:|:------:|:------:|:--:|:-------:|:------------:|:------------:|:------------:| 20 | | AAM | 0.2 |6.63M | × | × | 0.867 | 1.049 | 1.959 | 21 | | | 0.2 | | × | √ | 0.787 | 0.964 | 1.726 | 22 | | | 0.5 | | √ | × | 0.797 | 0.937 | 1.695 | 23 | | | 0.5 | | √ | √ | 0.723 | 0.867 | 1.532 | 24 | | C-Sphereface2 | 0.2 |6.63M | × | × | 0.904 | 0.973 | 1.737 | 25 | | | 0.2 | | × | √ | 0.835 | 0.931 | 1.652 | 26 | | | 0.3 | | √ | × | 0.830 | 0.862 | 1.510 | 27 | | | 0.3 | | √ | √ | 0.755 | 0.833 | 1.449 | 28 | | A-Sphereface2 | 0.15 | 6.63M | × | × | 0.835 | 0.975 | 1.742 | 29 | | | 0.15 | | × | √ | 0.761 | 0.938 | 1.630 | 30 | | | 0.25 | | √ | × | 0.766 | 0.899 | 1.590 | 31 | | | 0.25 | | √ | √ | 0.686 | 0.852 | 1.480 | 32 | 33 | ## Citations 34 | If you find its useful, please cite it as 35 | ```bibtex 36 | @inproceedings{han2023exploring, 37 | title={Exploring Binary Classification Loss for Speaker Verification}, 38 | author={Han, Bing and Chen, Zhengyang and Qian, Yanmin}, 39 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 40 | pages={1--5}, 41 | year={2023}, 42 | organization={IEEE} 43 | } 44 | 45 | @InProceedings{wen2021sphereface2, 46 | title = {SphereFace2: Binary Classification is All You Need for Deep Face Recognition}, 47 | author = {Wen, Yandong and Liu, Weiyang and Weller, Adrian and Raj, Bhiksha and Singh, Rita}, 48 | booktitle = {ICLR}, 49 | year = {2022} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /projections.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com) 2 | # 2021 Zhengyang Chen (chenzhengyang117@gmail.com) 3 | # 2022 Hongji Wang (jijijiang77@gmail.com) 4 | # 2023 Bing Han (hanjiameng0321@gmail.com) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import math 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | 24 | 25 | def get_projection(conf): 26 | if conf['project_type'] == 'add_margin': 27 | projection = AddMarginProduct(conf['embed_dim'], 28 | conf['num_class'], 29 | scale=conf['scale'], 30 | margin=0.0) 31 | elif conf['project_type'] == 'arc_margin': 32 | projection = ArcMarginProduct(conf['embed_dim'], 33 | conf['num_class'], 34 | scale=conf['scale'], 35 | margin=0.0, 36 | easy_margin=conf['easy_margin']) 37 | elif conf['project_type'] == 'arc_margin_intertopk_subcenter': 38 | projection = ArcMarginProduct_intertopk_subcenter( 39 | conf['embed_dim'], 40 | conf['num_class'], 41 | scale=conf['scale'], 42 | margin=0.0, 43 | easy_margin=conf['easy_margin'], 44 | K=conf.get('K', 3), 45 | mp=conf.get('mp', 0.06), 46 | k_top=conf.get('k_top', 5), 47 | do_lm=conf.get('do_lm', False)) 48 | elif conf['project_type'] == 'sphere': 49 | projection = SphereProduct(conf['embed_dim'], 50 | conf['num_class'], 51 | margin=4) 52 | elif conf['project_type'] == 'sphereface2': 53 | projection = SphereFace2(conf['embed_dim'], 54 | conf['num_class'], 55 | scale=conf['scale'], 56 | margin=0.0, 57 | t=conf.get('t', 3), 58 | lanbuda=conf.get('lanbuda', 0.7), 59 | margin_type=conf.get('margin_type', 'C')) 60 | else: 61 | projection = Linear(conf['embed_dim'], conf['num_class']) 62 | 63 | return projection 64 | 65 | def fun_g(z, t: int): 66 | gz = 2 * torch.pow((z + 1) / 2, t) - 1 67 | return gz 68 | 69 | class SphereFace2(nn.Module): 70 | r"""Implement of sphereface2 for speaker verification: 71 | Reference: 72 | [1] Exploring Binary Classification Loss for Speaker Verification 73 | https://ieeexplore.ieee.org/abstract/document/10094954 74 | [2] Sphereface2: Binary classification is all you need for deep face recognition 75 | https://arxiv.org/pdf/2108.01513 76 | Args: 77 | in_features: size of each input sample 78 | out_features: size of each output sample 79 | scale: norm of input feature 80 | margin: margin 81 | lanbuda: weight of positive and negative pairs 82 | t: parameter for adjust score distribution 83 | margin_type: A:cos(theta+margin) or C:cos(theta)-margin 84 | recommend margin 0.2 for C and 0.15 for A 85 | """ 86 | 87 | def __init__(self, 88 | in_features, 89 | out_features, 90 | scale=32.0, 91 | margin=0.2, 92 | lanbuda=0.7, 93 | t=3, 94 | margin_type='C'): 95 | super(SphereFace2, self).__init__() 96 | self.in_features = in_features 97 | self.out_features = out_features 98 | self.scale = scale 99 | self.weight = nn.Parameter(torch.FloatTensor(out_features, 100 | in_features)) 101 | nn.init.xavier_uniform_(self.weight) 102 | self.bias = nn.Parameter(torch.zeros(1, 1)) 103 | self.t = t 104 | self.lanbuda = lanbuda 105 | self.margin_type = margin_type 106 | 107 | ######## 108 | self.margin = margin 109 | self.cos_m = math.cos(margin) 110 | self.sin_m = math.sin(margin) 111 | self.th = math.cos(math.pi - margin) 112 | self.mm = math.sin(math.pi - margin) 113 | self.mmm = 1.0 + + math.cos(math.pi - margin) 114 | ######## 115 | 116 | def update(self, margin=0.2): 117 | self.margin = margin 118 | self.cos_m = math.cos(margin) 119 | self.sin_m = math.sin(margin) 120 | self.th = math.cos(math.pi - margin) 121 | self.mm = math.sin(math.pi - margin) 122 | self.mmm = 1.0 + + math.cos(math.pi - margin) 123 | 124 | def forward(self, input, label): 125 | # compute similarity 126 | cos = F.linear(F.normalize(input), F.normalize(self.weight)) 127 | 128 | if self.margin_type == 'A': # arcface type 129 | sin = torch.sqrt(1.0 - torch.pow(cos, 2)) 130 | cos_m_theta_p = self.scale * fun_g(torch.where(cos > self.th, cos * self.cos_m - sin * self.sin_m, cos-self.mmm), self.t) + self.bias[0][0] 131 | cos_m_theta_n = self.scale * fun_g(cos * self.cos_m + sin * self.sin_m, self.t) + self.bias[0][0] 132 | cos_p_theta = self.lanbuda * torch.log(1 + torch.exp(-1.0 * cos_m_theta_p)) 133 | cos_n_theta = (1-self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n)) 134 | else: # cosface type 135 | cos_m_theta_p = self.scale * (fun_g(cos, self.t) - self.margin) + self.bias[0][0] 136 | cos_m_theta_n = self.scale * (fun_g(cos, self.t) + self.margin) + self.bias[0][0] 137 | cos_p_theta = self.lanbuda * torch.log(1 + torch.exp(-1.0 * cos_m_theta_p)) 138 | cos_n_theta = (1-self.lanbuda) * torch.log(1 + torch.exp(cos_m_theta_n)) 139 | 140 | target_mask = input.new_zeros(cos.size()) 141 | target_mask.scatter_(1, label.view(-1, 1).long(), 1.0) 142 | nontarget_mask = 1 - target_mask 143 | cos1 = (cos - self.margin) * target_mask + cos * nontarget_mask 144 | output = self.scale * cos1 # for compute the accuracy 145 | loss = (target_mask * cos_p_theta + nontarget_mask * cos_n_theta).sum(1).mean() 146 | return output, loss 147 | 148 | def extra_repr(self): 149 | return '''in_features={}, out_features={}, scale={}, lanbuda={}, 150 | margin={}, t={}, margin_type={}'''.format(self.in_features, 151 | self.out_features, 152 | self.scale, self.lanbuda, self.margin, 153 | self.t, self.margin_type) 154 | 155 | class ArcMarginProduct(nn.Module): 156 | r"""Implement of large margin arc distance: : 157 | Args: 158 | in_features: size of each input sample 159 | out_features: size of each output sample 160 | scale: norm of input feature 161 | margin: margin 162 | cos(theta + margin) 163 | """ 164 | 165 | def __init__(self, 166 | in_features, 167 | out_features, 168 | scale=32.0, 169 | margin=0.2, 170 | easy_margin=False): 171 | super(ArcMarginProduct, self).__init__() 172 | self.in_features = in_features 173 | self.out_features = out_features 174 | self.scale = scale 175 | self.margin = margin 176 | self.weight = nn.Parameter(torch.FloatTensor(out_features, 177 | in_features)) 178 | nn.init.xavier_uniform_(self.weight) 179 | 180 | self.easy_margin = easy_margin 181 | self.cos_m = math.cos(margin) 182 | self.sin_m = math.sin(margin) 183 | self.th = math.cos(math.pi - margin) 184 | self.mm = math.sin(math.pi - margin) * margin 185 | self.mmm = 1.0 + math.cos( 186 | math.pi - margin) # this can make the output more continuous 187 | ######## 188 | self.m = self.margin 189 | ######## 190 | 191 | def update(self, margin=0.2): 192 | self.margin = margin 193 | self.cos_m = math.cos(margin) 194 | self.sin_m = math.sin(margin) 195 | self.th = math.cos(math.pi - margin) 196 | self.mm = math.sin(math.pi - margin) * margin 197 | self.m = self.margin 198 | self.mmm = 1.0 + math.cos(math.pi - margin) 199 | # self.weight = self.weight 200 | # self.scale = self.scale 201 | 202 | def forward(self, input, label): 203 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 204 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 205 | phi = cosine * self.cos_m - sine * self.sin_m 206 | if self.easy_margin: 207 | phi = torch.where(cosine > 0, phi, cosine) 208 | else: 209 | ######## 210 | # phi = torch.where(cosine > self.th, phi, cosine - self.mm) 211 | phi = torch.where(cosine > self.th, phi, cosine - self.mmm) 212 | ######## 213 | 214 | one_hot = input.new_zeros(cosine.size()) 215 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 216 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 217 | output *= self.scale 218 | 219 | return output 220 | 221 | def extra_repr(self): 222 | return '''in_features={}, out_features={}, scale={}, 223 | margin={}, easy_margin={}'''.format(self.in_features, 224 | self.out_features, 225 | self.scale, self.margin, 226 | self.easy_margin) 227 | 228 | 229 | class ArcMarginProduct_intertopk_subcenter(nn.Module): 230 | r"""Implement of large margin arc distance with intertopk and subcenter: 231 | Reference: 232 | MULTI-QUERY MULTI-HEAD ATTENTION POOLING AND INTER-TOPK PENALTY 233 | FOR SPEAKER VERIFICATION. 234 | https://arxiv.org/pdf/2110.05042.pdf 235 | Sub-center ArcFace: Boosting Face Recognition by 236 | Large-Scale Noisy Web Faces. 237 | https://ibug.doc.ic.ac.uk/media/uploads/documents/eccv_1445.pdf 238 | Args: 239 | in_features: size of each input sample 240 | out_features: size of each output sample 241 | scale: norm of input feature 242 | margin: margin 243 | cos(theta + margin) 244 | K: number of sub-centers 245 | k_top: number of hard samples 246 | mp: margin penalty of hard samples 247 | do_lm: whether do large margin finetune 248 | """ 249 | 250 | def __init__(self, 251 | in_features, 252 | out_features, 253 | scale=32.0, 254 | margin=0.2, 255 | easy_margin=False, 256 | K=3, 257 | mp=0.06, 258 | k_top=5, 259 | do_lm=False): 260 | super(ArcMarginProduct_intertopk_subcenter, self).__init__() 261 | self.in_features = in_features 262 | self.out_features = out_features 263 | self.scale = scale 264 | self.margin = margin 265 | self.do_lm = do_lm 266 | 267 | # intertopk + subcenter 268 | self.K = K 269 | if do_lm: # if do LMF, remove hard sample penalty 270 | self.mp = 0.0 271 | self.k_top = 0 272 | else: 273 | self.mp = mp 274 | self.k_top = k_top 275 | 276 | # initial classifier 277 | self.weight = nn.Parameter( 278 | torch.FloatTensor(self.K * out_features, in_features)) 279 | nn.init.xavier_uniform_(self.weight) 280 | 281 | self.easy_margin = easy_margin 282 | self.cos_m = math.cos(margin) 283 | self.sin_m = math.sin(margin) 284 | self.th = math.cos(math.pi - margin) 285 | self.mm = math.sin(math.pi - margin) * margin 286 | self.mmm = 1.0 + math.cos( 287 | math.pi - margin) # this can make the output more continuous 288 | ######## 289 | self.m = self.margin 290 | ######## 291 | self.cos_mp = math.cos(0.0) 292 | self.sin_mp = math.sin(0.0) 293 | 294 | def update(self, margin=0.2): 295 | self.margin = margin 296 | self.cos_m = math.cos(margin) 297 | self.sin_m = math.sin(margin) 298 | self.th = math.cos(math.pi - margin) 299 | self.mm = math.sin(math.pi - margin) * margin 300 | self.m = self.margin 301 | self.mmm = 1.0 + math.cos(math.pi - margin) 302 | 303 | # hard sample margin is increasing as margin 304 | if margin > 0.001: 305 | mp = self.mp * (margin / 0.2) 306 | else: 307 | mp = 0.0 308 | self.cos_mp = math.cos(mp) 309 | self.sin_mp = math.sin(mp) 310 | 311 | def forward(self, input, label): 312 | cosine = F.linear(F.normalize(input), 313 | F.normalize(self.weight)) # (batch, out_dim * k) 314 | cosine = torch.reshape( 315 | cosine, (-1, self.out_features, self.K)) # (batch, out_dim, k) 316 | cosine, _ = torch.max(cosine, 2) # (batch, out_dim) 317 | 318 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 319 | phi = cosine * self.cos_m - sine * self.sin_m 320 | phi_mp = cosine * self.cos_mp + sine * self.sin_mp 321 | 322 | if self.easy_margin: 323 | phi = torch.where(cosine > 0, phi, cosine) 324 | else: 325 | ######## 326 | # phi = torch.where(cosine > self.th, phi, cosine - self.mm) 327 | phi = torch.where(cosine > self.th, phi, cosine - self.mmm) 328 | ######## 329 | 330 | one_hot = input.new_zeros(cosine.size()) 331 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 332 | 333 | if self.k_top > 0: 334 | # topk (j != y_i) 335 | _, top_k_index = torch.topk(cosine - 2 * one_hot, 336 | self.k_top) # exclude j = y_i 337 | top_k_one_hot = input.new_zeros(cosine.size()).scatter_( 338 | 1, top_k_index, 1) 339 | 340 | # sum 341 | output = (one_hot * phi) + (top_k_one_hot * phi_mp) + ( 342 | (1.0 - one_hot - top_k_one_hot) * cosine) 343 | else: 344 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 345 | output *= self.scale 346 | return output 347 | 348 | def extra_repr(self): 349 | return 'in_features={}, out_features={}, scale={}, margin={}, easy_margin={},' \ 350 | 'K={}, mp={}, k_top={}, do_lm={}'.format( 351 | self.in_features, self.out_features, self.scale, self.margin, 352 | self.easy_margin, self.K, self.mp, self.k_top, self.do_lm) 353 | 354 | 355 | class AddMarginProduct(nn.Module): 356 | r"""Implement of large margin cosine distance: : 357 | Args: 358 | in_features: size of each input sample 359 | out_features: size of each output sample 360 | scale: norm of input feature 361 | margin: margin 362 | cos(theta) - margin 363 | """ 364 | 365 | def __init__(self, in_features, out_features, scale=32.0, margin=0.20): 366 | super(AddMarginProduct, self).__init__() 367 | self.in_features = in_features 368 | self.out_features = out_features 369 | self.scale = scale 370 | self.margin = margin 371 | self.weight = nn.Parameter(torch.FloatTensor(out_features, 372 | in_features)) 373 | nn.init.xavier_uniform_(self.weight) 374 | 375 | def update(self, margin): 376 | self.margin = margin 377 | 378 | def forward(self, input, label): 379 | # ---------------- cos(theta) & phi(theta) --------------- 380 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 381 | phi = cosine - self.margin 382 | # ---------------- convert label to one-hot --------------- 383 | one_hot = input.new_zeros(cosine.size()) 384 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 385 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 386 | output *= self.scale 387 | return output 388 | 389 | def __repr__(self): 390 | return self.__class__.__name__ + '(' \ 391 | + 'in_features=' + str(self.in_features) \ 392 | + ', out_features=' + str(self.out_features) \ 393 | + ', scale=' + str(self.scale) \ 394 | + ', margin=' + str(self.margin) + ')' 395 | 396 | 397 | class SphereProduct(nn.Module): 398 | r"""Implement of large margin cosine distance: : 399 | Args: 400 | in_features: size of each input sample 401 | out_features: size of each output sample 402 | margin: margin 403 | cos(margin * theta) 404 | """ 405 | 406 | def __init__(self, in_features, out_features, margin=2): 407 | super(SphereProduct, self).__init__() 408 | self.in_features = in_features 409 | self.out_features = out_features 410 | self.margin = margin 411 | self.base = 1000.0 412 | self.gamma = 0.12 413 | self.power = 1 414 | self.LambdaMin = 5.0 415 | self.iter = 0 416 | self.weight = nn.Parameter(torch.FloatTensor(out_features, 417 | in_features)) 418 | nn.init.xavier_uniform(self.weight) 419 | 420 | # duplication formula 421 | self.mlambda = [ 422 | lambda x: x**0, lambda x: x**1, lambda x: 2 * x**2 - 1, 423 | lambda x: 4 * x**3 - 3 * x, lambda x: 8 * x**4 - 8 * x**2 + 1, 424 | lambda x: 16 * x**5 - 20 * x**3 + 5 * x 425 | ] 426 | assert self.margin < 6 427 | 428 | def forward(self, input, label): 429 | # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power)) 430 | self.iter += 1 431 | self.lamb = max( 432 | self.LambdaMin, 433 | self.base * (1 + self.gamma * self.iter)**(-1 * self.power)) 434 | 435 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 436 | cos_theta = cos_theta.clamp(-1, 1) 437 | cos_m_theta = self.mlambda[self.margin](cos_theta) 438 | theta = cos_theta.data.acos() 439 | k = (self.margin * theta / 3.14159265).floor() 440 | phi_theta = ((-1.0)**k) * cos_m_theta - 2 * k 441 | NormOfFeature = torch.norm(input, 2, 1) 442 | one_hot = input.new_zeros(cos_theta.size()) 443 | one_hot.scatter_(1, label.view(-1, 1), 1) 444 | output = (one_hot * (phi_theta - cos_theta) / 445 | (1 + self.lamb)) + cos_theta 446 | output *= NormOfFeature.view(-1, 1) 447 | 448 | return output 449 | 450 | def __repr__(self): 451 | return self.__class__.__name__ + '(' \ 452 | + 'in_features=' + str(self.in_features) \ 453 | + ', out_features=' + str(self.out_features) \ 454 | + ', margin=' + str(self.margin) + ')' 455 | 456 | 457 | class Linear(nn.Module): 458 | """ 459 | The linear transform for simple softmax loss 460 | """ 461 | 462 | def __init__(self, emb_dim=512, class_num=1000): 463 | super(Linear, self).__init__() 464 | 465 | self.trans = nn.Sequential(nn.BatchNorm1d(emb_dim), 466 | nn.ReLU(inplace=True), 467 | nn.Linear(emb_dim, class_num)) 468 | 469 | def forward(self, input, label): 470 | out = self.trans(input) 471 | return out 472 | 473 | 474 | if __name__ == '__main__': 475 | # projection = ArcMarginProduct(100, 476 | # 200, 477 | # scale=32.0, 478 | # margin=0.2, 479 | # easy_margin=False) 480 | # 481 | # print(hasattr(projection, 'update_mar')) 482 | projection = ArcMarginProduct_intertopk_subcenter(100, 483 | 200, 484 | scale=32.0, 485 | margin=0.0, 486 | easy_margin=False, 487 | K=3, 488 | mp=0.06, 489 | k_top=5) 490 | print(hasattr(projection, 'update')) 491 | projection.update(0.2) 492 | print(projection) 493 | embed = torch.randn(16, 100) 494 | label = torch.randint(200, (16, )) 495 | out = projection(embed, label) 496 | print(out.size()) 497 | 498 | # for name, param in projection.named_parameters(): 499 | # print(name) 500 | # print(param.shape) 501 | --------------------------------------------------------------------------------