├── .gitignore ├── LICENSE.md ├── README.md ├── Stage1 ├── README.md ├── dataLoader.py ├── encoder.py ├── loss.py ├── main_train.py ├── model.py ├── pytorch_revgrad │ ├── README.md │ ├── __init__.py │ ├── functional.py │ ├── module.py │ └── version.py ├── rir.npy ├── run.sh └── tools.py ├── Stage2 ├── README.md ├── dataLoader.py ├── encoder.py ├── loss.py ├── main_train.py ├── model.py ├── run_LGL.sh ├── run_baseline.sh └── tools.py └── utils ├── LGL.png ├── requirements.txt ├── test_list.txt └── train_mini.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Other files 2 | *.model 3 | *.pth 4 | *.wav 5 | *.mp4 6 | *.txt 7 | *.pcm 8 | data/ 9 | tests/ 10 | Stage1/exp/ 11 | Stage2/exp/ 12 | demo/* 13 | utils/log.txt 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tao Ruijie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-supervised speaker recognition with LGL 2 | 3 | This repository contains the code for our ICASSP 2022 paper: ["Self-supervised speaker recognition with loss-gated learning"](https://arxiv.org/pdf/2110.03869.pdf). We propose to filter the unreliable pseudo label in Stage II, so that train with the reliable pseudo label only to boost the system. 4 | 5 | ![LGL.png](utils/LGL.png) 6 | 7 | ## Result (Train on VoxCeleb2 without labels, test on Vox1_O EER) 8 | 9 | | System | Stage 1 | Stage 2 | 10 | | :-: | :-: | :-: | 11 | | EER | 7.36 | 1.66 | 12 | 13 | ## Difference for our paper and this code 14 | 15 | * In our paper, we extend the channel size of speaker encoder to 1024 in the iteration 5, in this code we remove this setting to simply the code. You can do that in the last iteration to get the better result. 16 | 17 | * In our paper, we manually determinate the end of each iteration, that is not user-friendly. In this code, we end the iteration if EER can not improve in continuous N = 4 epochs. You can increase it to improve the performance. 18 | 19 | I do not have time to run the entire code again. I have checked Stage 1 and get the EER=7.36. While I believe a EER that smaller than 2.00 can easily be obtained in Stage 2 in this code. 20 | 21 | *** 22 | 23 | ## Dependencies 24 | 25 | Note: That is the setting based on my device, you can modify the torch and torchaudio version based on your device. 26 | 27 | ``` 28 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 29 | conda install -c pytorch faiss-gpu 30 | pip install -r utils/requirements.txt 31 | ``` 32 | 33 | ## Data preparation 34 | 35 | Please follow the official code to perpare your VoxCeleb2 dataset from the 'Data preparation' part in [this repository](https://github.com/clovaai/voxceleb_trainer). 36 | 37 | Dataset for training usage: 38 | 39 | 1) VoxCeleb2 training set 40 | 41 | 2) MUSAN dataset; 42 | 43 | 3) RIR dataset; 44 | 45 | Dataset for evaluation: 46 | 47 | 1) VoxCeleb1 test set for [Vox1_O](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt) 48 | 49 | 2) VoxCeleb1 train set for [Vox1_E](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/list_test_all2.txt) and [Vox1_H](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/list_test_hard2.txt) (Optional) 50 | 51 | I have added the test_list (Vox1_O) in `utils`. This train_list contains the length for each utterances. 52 | 53 | `train_mini.txt` is a subset of VoxCeleb2. It contains 100k utterances from 4082 speakers. 54 | 55 | Download `train_list.txt` from [here](https://drive.google.com/u/0/uc?id=1eraQWNKNHS_s6SnPjoZrQ_1HOeUREh9R&export=download) and put it in `utils`. 56 | 57 | ## Stage I: Contrastive Learning 58 | 59 | Firstly, you need to train a basic speaker encoder with contrastive learning format, change the path to folder `Stage1` and use: 60 | 61 | ``` 62 | bash run.sh 63 | ``` 64 | 65 | Every `test_step` epoches, system will be evaluated in Vox1_O set and print the EER. 66 | 67 | The result will be saved in `Stage1/exps/exp1/score.txt`. The model will saved in `Stage1/exps/exp1/model`. I also provide the [model](https://drive.google.com/u/0/uc?id=1GTKG04Hs0rr--SOUOYpu9ZTkBT2UZHQk&export=download) that EER=7.36. 68 | 69 | In my case, I trained 50 epoches in one 3090 GPU. Each epoch takes 40 mins, the total training time is about 35 hours. 70 | 71 | ## Stage II: Classification Training (Baseline) 72 | 73 | For the baseline approach in Stage II, change the path to folder `Stage2` and use: 74 | 75 | ``` 76 | bash run_baseline.sh 77 | ``` 78 | 79 | Please modifiy the path for the `init_model` in `run_baseline.sh`. `init_model` is the path for the best model in Stage I. 80 | 81 | This is the end-to-end code. System will: 82 | 83 | 1) Do clustering; 84 | 85 | 2) Train the speaker encoder for classification 86 | 87 | 3) Repeat 1) and 2), if the EER in 2) can not improve in continuous 4 epochs. 88 | 89 | Here we do 5 iterations. Each epoch takes 20 mins. Clustering takes 18 mins. 90 | 91 | ## Stage II: Classification Training with LGL (Ours) 92 | 93 | For our LGL approach in Stage II, change the path to folder `Stage2` and use: 94 | 95 | ``` 96 | bash run_LGL.sh 97 | ``` 98 | 99 | This is also end-to-end code. System will: 100 | 101 | 1) Do clustering; 102 | 103 | 2) Train the speaker encoder for classification 104 | 105 | 3) Train the speaker encoder for classification with LGL, if the EER in 2) can not improve in continuous 4 epochs. 106 | 107 | 4) Repeat 1) 2) and 3), if the EER in 3) can not improve in continuous 4 epochs. 108 | 109 | ## Notes: 110 | 111 | I have already added annotation to make the code as clear as possible, please read them carefully. If you have questions, please post them in `issue` part. 112 | 113 | ## Reference 114 | ``` 115 | @inproceedings{tao2022self, 116 | title={Self-supervised speaker recognition with loss-gated learning}, 117 | author={Tao, Ruijie and Lee, Kong Aik and Das, Rohan Kumar and Hautam{\"a}ki, Ville and Li, Haizhou}, 118 | booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 119 | pages={6142--6146}, 120 | year={2022}, 121 | organization={IEEE} 122 | } 123 | ``` 124 | 125 | ## Acknowledge 126 | 127 | We study many useful projects in our codeing process, which includes: 128 | 129 | [clovaai/voxceleb_trainer](https://github.com/clovaai/voxceleb_trainer). 130 | 131 | [joonson/voxceleb_unsupervised](https://github.com/joonson/voxceleb_unsupervised). 132 | 133 | [lawlict/ECAPA-TDNN](https://github.com/lawlict/ECAPA-TDNN/blob/master/ecapa_tdnn.py). 134 | 135 | Thanks for these authors to open source their code! 136 | 137 | ### Cooperation 138 | 139 | If you are interested to work on this topic and have some ideas to implement, I am glad to collaborate and contribute with my experiences & knowlegde in this topic. Please contact me with ruijie.tao@u.nus.edu. 140 | -------------------------------------------------------------------------------- /Stage1/README.md: -------------------------------------------------------------------------------- 1 | ## Stage I: Contrastive Learning 2 | 3 | 1) Here we use a contrastive learning framework to train the basic speaker encoder. The EER (Vox_O) in our paper is 7.36, we modify a bit recently and now the result is 7.36 in 50 epochs. I believe it can get better if train for more epochs. 4 | 5 | 2) Any other self-supervised learning framework can be used in Stage I. In my experience, the EER smaller than 7.5 on Vox_O in the first stage is robust to the clustering-training based learning. A bad EER will lead to a bad clustering result, so that Stage II can not work. 6 | 7 | 3) Notice that our framework contains the [Augmentation adversarial training for unsupervised speaker recognition](https://arxiv.org/pdf/2007.12085.pdf]). Here [AAT](https://github.com/joonson/voxceleb_unsupervised) makes the result better. Our code in this part is also modified based on their project. Thanks for their open-source code! 8 | -------------------------------------------------------------------------------- /Stage1/dataLoader.py: -------------------------------------------------------------------------------- 1 | import torch, numpy, random, os, math, glob, soundfile 2 | from torch.utils.data import Dataset, DataLoader 3 | from scipy import signal 4 | 5 | class train_loader(Dataset): 6 | def __init__(self, max_frames, train_list, train_path, musan_path, **kwargs): 7 | self.max_frames = max_frames 8 | self.data_list = [] 9 | self.noisetypes = ['noise','speech','music'] # Type of noise 10 | self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} # The range of SNR 11 | self.noiselist = {} 12 | augment_files = glob.glob(os.path.join(musan_path,'*/*/*/*.wav')) # All noise files in list 13 | for file in augment_files: 14 | if not file.split('/')[-4] in self.noiselist: 15 | self.noiselist[file.split('/')[-4]] = [] 16 | self.noiselist[file.split('/')[-4]].append(file) # All noise files in dic 17 | self.rir_files = numpy.load('rir.npy') # Load the rir file 18 | for line in open(train_list).read().splitlines(): 19 | filename = os.path.join(train_path, line.split()[1]) 20 | self.data_list.append(filename) # Load the training data list 21 | 22 | def __getitem__(self, index): 23 | audio = loadWAVSplit(self.data_list[index], self.max_frames).astype(numpy.float) # Load one utterance 24 | augment_profiles, audio_aug = [], [] 25 | for ii in range(0,2): # Two segments of one utterance 26 | rir_gains = numpy.random.uniform(-7,3,1) 27 | rir_filts = random.choice(self.rir_files) 28 | noisecat = random.choice(self.noisetypes) 29 | noisefile = random.choice(self.noiselist[noisecat].copy()) # Augmentation information for each segment 30 | snr = [random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])] 31 | p = random.random() 32 | if p < 0.25: # Add rir only 33 | augment_profiles.append({'rir_filt':rir_filts, 'rir_gain':rir_gains, 'add_noise': None, 'add_snr': None}) 34 | elif p < 0.50: # Add noise only 35 | augment_profiles.append({'rir_filt':None, 'rir_gain':None, 'add_noise': noisefile, 'add_snr': snr}) 36 | else: # Add both 37 | augment_profiles.append({'rir_filt':rir_filts, 'rir_gain':rir_gains, 'add_noise': noisefile, 'add_snr': snr}) 38 | audio_aug.append(self.augment_wav(audio[0],augment_profiles[0])) # Segment 0 with augmentation method 0 39 | audio_aug.append(self.augment_wav(audio[1],augment_profiles[0])) # Segment 1 with augmentation method 0, used for AAT 40 | audio_aug.append(self.augment_wav(audio[1],augment_profiles[1])) # Segment 1 with augmentation method 1 41 | audio_aug = numpy.concatenate(audio_aug,axis=0) # Concate and return 42 | return torch.FloatTensor(audio_aug) 43 | 44 | def __len__(self): 45 | return len(self.data_list) 46 | 47 | def augment_wav(self,audio,augment): 48 | if augment['rir_filt'] is not None: 49 | rir = numpy.multiply(augment['rir_filt'], pow(10, 0.1 * augment['rir_gain'])) 50 | audio = signal.convolve(audio, rir, mode='full')[:len(audio)] 51 | if augment['add_noise'] is not None: 52 | noiseaudio = loadWAV(augment['add_noise'], self.max_frames).astype(numpy.float) 53 | noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) 54 | clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) 55 | noise = numpy.sqrt(10 ** ((clean_db - noise_db - augment['add_snr']) / 10)) * noiseaudio 56 | audio = audio + noise 57 | else: 58 | audio = numpy.expand_dims(audio, 0) 59 | return audio 60 | 61 | def loadWAV(filename, max_frames): 62 | max_audio = max_frames * 160 + 240 # 240 is for padding, for 15ms since window is 25ms and step is 10ms. 63 | audio, _ = soundfile.read(filename) 64 | audiosize = audio.shape[0] 65 | if audiosize <= max_audio: # Padding if the length is not enough 66 | shortage = math.floor( ( max_audio - audiosize + 1 ) / 2 ) 67 | audio = numpy.pad(audio, (shortage, shortage), 'wrap') 68 | audiosize = audio.shape[0] 69 | startframe = numpy.int64(random.random()*(audiosize-max_audio)) # Randomly select a start frame to extract audio 70 | feat = numpy.stack([audio[int(startframe):int(startframe)+max_audio]],axis=0) 71 | return feat 72 | 73 | def loadWAVSplit(filename, max_frames): # Load two segments 74 | max_audio = max_frames * 160 + 240 75 | audio, _ = soundfile.read(filename) 76 | audiosize = audio.shape[0] 77 | if audiosize <= max_audio: 78 | shortage = math.floor( ( max_audio - audiosize) / 2 ) 79 | audio = numpy.pad(audio, (shortage, shortage), 'wrap') 80 | audiosize = audio.shape[0] 81 | randsize = audiosize - (max_audio*2) # Select two segments 82 | startframe = random.sample(range(0, randsize), 2) 83 | startframe.sort() 84 | startframe[1] += max_audio # Non-overlapped two segments 85 | startframe = numpy.array(startframe) 86 | numpy.random.shuffle(startframe) 87 | feats = [] 88 | for asf in startframe: # Startframe[0] means the 1st segment, Startframe[1] means the 2nd segment 89 | feats.append(audio[int(asf):int(asf)+max_audio]) 90 | feat = numpy.stack(feats,axis=0) 91 | return feat 92 | 93 | def worker_init_fn(worker_id): 94 | numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) 95 | 96 | def get_loader(args): # Define the data loader 97 | trainLoader = train_loader(**vars(args)) 98 | trainLoader = torch.utils.data.DataLoader( 99 | trainLoader, 100 | batch_size=args.batch_size, 101 | shuffle=True, 102 | num_workers=args.n_cpu, 103 | pin_memory=False, 104 | drop_last=True, 105 | worker_init_fn=worker_init_fn, 106 | prefetch_factor=5, 107 | ) 108 | return trainLoader 109 | -------------------------------------------------------------------------------- /Stage1/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchaudio 6 | 7 | class SEModule(nn.Module): 8 | def __init__(self, channels, bottleneck=128): 9 | super(SEModule, self).__init__() 10 | self.se = nn.Sequential( 11 | nn.AdaptiveAvgPool1d(1), 12 | nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), 13 | nn.ReLU(), 14 | nn.BatchNorm1d(bottleneck), 15 | nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), 16 | nn.Sigmoid(), 17 | ) 18 | 19 | def forward(self, input): 20 | x = self.se(input) 21 | return input * x 22 | 23 | class Bottle2neck(nn.Module): 24 | def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale = 8): 25 | super(Bottle2neck, self).__init__() 26 | 27 | width = int(math.floor(planes / scale)) 28 | self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1) 29 | self.bn1 = nn.BatchNorm1d(width*scale) 30 | self.nums = scale -1 31 | convs = [] 32 | bns = [] 33 | num_pad = math.floor(kernel_size/2)*dilation 34 | for i in range(self.nums): 35 | convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad)) 36 | bns.append(nn.BatchNorm1d(width)) 37 | self.convs = nn.ModuleList(convs) 38 | self.bns = nn.ModuleList(bns) 39 | self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1) 40 | self.bn3 = nn.BatchNorm1d(planes) 41 | self.relu = nn.ReLU() 42 | self.width = width 43 | self.se = SEModule(planes) 44 | 45 | def forward(self, x): 46 | residual = x 47 | out = self.conv1(x) 48 | out = self.relu(out) 49 | out = self.bn1(out) 50 | spx = torch.split(out, self.width, 1) 51 | for i in range(self.nums): 52 | if i==0: 53 | sp = spx[i] 54 | else: 55 | sp = sp + spx[i] 56 | sp = self.convs[i](sp) 57 | sp = self.relu(sp) 58 | sp = self.bns[i](sp) 59 | if i==0: 60 | out = sp 61 | else: 62 | out = torch.cat((out, sp), 1) 63 | out = torch.cat((out, spx[self.nums]),1) 64 | out = self.conv3(out) 65 | out = self.relu(out) 66 | out = self.bn3(out) 67 | out = self.se(out) 68 | out += residual 69 | 70 | return out 71 | 72 | class PreEmphasis(torch.nn.Module): 73 | 74 | def __init__(self, coef: float = 0.97): 75 | super().__init__() 76 | self.coef = coef 77 | self.register_buffer( 78 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 79 | ) 80 | 81 | def forward(self, input: torch.tensor) -> torch.tensor: 82 | input = input.unsqueeze(1) 83 | input = F.pad(input, (1, 0), 'reflect') 84 | return F.conv1d(input, self.flipped_filter).squeeze(1) 85 | 86 | class ECAPA_TDNN(nn.Module): # Here we use a small ECAPA-TDNN, C=512. In my experiences, C=1024 slightly improves the performance but need more training time. 87 | def __init__(self, C = 512, **kwargs): 88 | super(ECAPA_TDNN, self).__init__() 89 | self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2) 90 | self.relu = nn.ReLU() 91 | self.bn1 = nn.BatchNorm1d(C) 92 | self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8) 93 | self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8) 94 | self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8) 95 | self.layer4 = nn.Conv1d(3*C, 1536, kernel_size=1) 96 | self.torchfbank = torch.nn.Sequential( 97 | PreEmphasis(), 98 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min = 20, f_max = 7600, window_fn=torch.hamming_window, n_mels=80), 99 | ) 100 | self.attention = nn.Sequential( 101 | nn.Conv1d(4608, 256, kernel_size=1), 102 | nn.ReLU(), 103 | nn.BatchNorm1d(256), 104 | nn.Tanh(), 105 | nn.Conv1d(256, 1536, kernel_size=1), 106 | nn.Softmax(dim=2), 107 | ) 108 | self.bn5 = nn.BatchNorm1d(3072) 109 | self.fc6 = nn.Linear(3072, 192) 110 | self.bn6 = nn.BatchNorm1d(192) 111 | 112 | def forward(self, x): 113 | with torch.no_grad(): 114 | x = self.torchfbank(x)+1e-6 115 | x = x.log() 116 | x = x - torch.mean(x, dim=-1, keepdim=True) 117 | x = self.conv1(x) 118 | x = self.relu(x) 119 | x = self.bn1(x) 120 | x1 = self.layer1(x) 121 | x2 = self.layer2(x+x1) 122 | x3 = self.layer3(x+x1+x2) 123 | x = self.layer4(torch.cat((x1,x2,x3),dim=1)) 124 | x = self.relu(x) 125 | t = x.size()[-1] 126 | global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1) 127 | w = self.attention(global_x) 128 | mu = torch.sum(x * w, dim=2) 129 | sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) ) 130 | x = torch.cat((mu,sg),1) 131 | x = self.bn5(x) 132 | x = self.fc6(x) 133 | x = self.bn6(x) 134 | return x 135 | 136 | from pytorch_revgrad import RevGrad 137 | 138 | class AATNet(nn.Module): # AAT system 139 | def __init__(self, **kwargs): 140 | super(AATNet, self).__init__() 141 | layers = [] 142 | layers.append(torch.nn.Sequential( 143 | nn.BatchNorm1d(384), 144 | torch.nn.ReLU(inplace=True), 145 | torch.nn.Linear(384,512), 146 | )) 147 | layers.append(torch.nn.Sequential( 148 | nn.BatchNorm1d(512), 149 | torch.nn.ReLU(inplace=True), 150 | torch.nn.Linear(512,2), 151 | )) 152 | self.matcher = torch.nn.Sequential(*layers) 153 | 154 | def reset_parameters(self): 155 | self.matcher.reset_parameters() 156 | 157 | def forward(self, x): 158 | return self.matcher(x) 159 | 160 | class Reverse(nn.Module): 161 | def __init__(self, **kwargs): 162 | super(Reverse, self).__init__() 163 | layers = [RevGrad()] 164 | self.matcher = torch.nn.Sequential(*layers) 165 | 166 | def reset_parameters(self): 167 | self.matcher.reset_parameters() 168 | 169 | def forward(self, x): 170 | return self.matcher(x) -------------------------------------------------------------------------------- /Stage1/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tools import * 5 | import numpy 6 | 7 | # This function is modified from https://github.com/HobbitLong/SupContrast/blob/master/losses.py 8 | class LossFunction(nn.Module): 9 | def __init__(self, init_w=10.0, init_b=-5.0, **kwargs): # No temp param 10 | super(LossFunction, self).__init__() 11 | self.w = nn.Parameter(torch.tensor(init_w)) 12 | self.b = nn.Parameter(torch.tensor(init_b)) 13 | 14 | def forward(self, features): 15 | batch_size = features.shape[0] 16 | mask = torch.eye(batch_size, dtype=torch.float32).to(torch.device('cuda')) 17 | count = features.shape[1] 18 | feature = torch.cat(torch.unbind(features, dim=1), dim=0) 19 | dot_feature = F.cosine_similarity(feature.unsqueeze(-1),feature.unsqueeze(-1).transpose(0,2)) 20 | torch.clamp(self.w, 1e-6) 21 | dot_feature = dot_feature * self.w + self.b # We add this from angle protocol loss. 22 | logits_max, _ = torch.max(dot_feature, dim=1, keepdim=True) 23 | logits = dot_feature - logits_max.detach() 24 | mask = mask.repeat(count, count) 25 | logits_mask = torch.scatter( 26 | torch.ones_like(mask), 27 | 1, 28 | torch.arange(batch_size * count).view(-1, 1).to(torch.device('cuda')), 29 | 0 30 | ) 31 | mask = mask * logits_mask 32 | exp_logits = torch.exp(logits) * logits_mask 33 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 34 | loss = -(mask * log_prob).sum(1) / mask.sum(1) 35 | loss = loss.view(count, batch_size).mean() 36 | n = batch_size * 2 37 | label = torch.from_numpy(numpy.asarray(list(range(batch_size - 1,batch_size*2 - 1)) + list(range(0,batch_size)))).cuda() 38 | logits = logits.flatten()[1:].view(n-1, n+1)[:,:-1].reshape(n, n-1) 39 | prec1, _ = accuracy(logits.detach().cpu(), label.detach().cpu(), topk=(1, 2)) # Compute the training acc 40 | 41 | return loss, prec1 -------------------------------------------------------------------------------- /Stage1/main_train.py: -------------------------------------------------------------------------------- 1 | import sys, time, os, argparse, warnings, glob, torch 2 | from tools import * 3 | from model import * 4 | from dataLoader import * 5 | 6 | # Training settings 7 | parser = argparse.ArgumentParser(description = "Stage I, self-supervsied speaker recognition with contrastive learning.") 8 | parser.add_argument('--max_frames', type=int, default=180, help='Input length to the network, 1.8s') 9 | parser.add_argument('--batch_size', type=int, default=300, help='Batch size, bigger is better') 10 | parser.add_argument('--n_cpu', type=int, default=4, help='Number of loader threads') 11 | parser.add_argument('--test_interval', type=int, default=1, help='Test and save every [test_interval] epochs') 12 | parser.add_argument('--max_epoch', type=int, default=80, help='Maximum number of epochs') 13 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 14 | parser.add_argument("--lr_decay", type=float, default=0.95, help='Learning rate decay every [test_interval] epochs') 15 | parser.add_argument('--initial_model', type=str, default="", help='Initial model path') 16 | parser.add_argument('--save_path', type=str, default="", help='Path for model and scores.txt') 17 | parser.add_argument('--train_list', type=str, default="", help='Path for Vox2 list, https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/train_list.txt') 18 | parser.add_argument('--val_list', type=str, default="", help='Path for Vox_O list, https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt') 19 | parser.add_argument('--train_path', type=str, default="", help='Path to the Vox2 set') 20 | parser.add_argument('--val_path', type=str, default="", help='Path to the Vox_O set') 21 | parser.add_argument('--musan_path', type=str, default="", help='Path to the musan set') 22 | parser.add_argument('--eval', dest='eval', action='store_true', help='Do evaluation only') 23 | args = parser.parse_args() 24 | 25 | # Initialization 26 | model_save_path = args.save_path+"/model" 27 | result_save_path = args.save_path+"/result" 28 | torch.multiprocessing.set_sharing_strategy('file_system') 29 | warnings.filterwarnings("ignore") 30 | os.makedirs(model_save_path, exist_ok = True) 31 | scorefile = open(args.save_path+"/scores.txt", "a+") 32 | it = 1 33 | 34 | Trainer = model(**vars(args)) # Define the framework 35 | modelfiles = glob.glob('%s/model0*.model'%model_save_path) # Search the existed model files 36 | modelfiles.sort() 37 | 38 | if(args.initial_model != ""): # If initial_model is exist, system will train from the initial_model 39 | Trainer.load_network(args.initial_model) 40 | elif len(modelfiles) >= 1: # Otherwise, system will try to start from the saved model&epoch 41 | Trainer.load_network(modelfiles[-1]) 42 | it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1 43 | 44 | if args.eval == True: # Do evaluation only 45 | EER, minDCF = Trainer.evaluate_network(**vars(args)) 46 | print('EER %2.4f, minDCF %.3f\n'%(EER, minDCF)) 47 | quit() 48 | 49 | trainLoader = get_loader(args) # Define the dataloader 50 | 51 | while it < args.max_epoch: 52 | # Train for one epoch 53 | loss, traineer, lr = Trainer.train_network(loader=trainLoader, epoch = it) 54 | 55 | # Evaluation every [test_interval] epochs, record the training loss, training acc, evaluation EER/minDCF 56 | if it % args.test_interval == 0: 57 | Trainer.save_network(model_save_path+"/model%09d.model"%it) 58 | EER, minDCF = Trainer.evaluate_network(**vars(args)) 59 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "LR %f, Acc %2.2f, LOSS %f, EER %2.4f, minDCF %.3f"%( lr, traineer, loss, EER, minDCF)) 60 | scorefile.write("Epoch %d, LR %f, Acc %2.2f, LOSS %f, EER %2.4f, minDCF %.3f\n"%(it, lr, traineer, loss, EER, minDCF)) 61 | scorefile.flush() 62 | # Otherwise, recored the training loss and acc 63 | else: 64 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "LR %f, Acc %2.2f, LOSS %f"%( lr, traineer, loss)) 65 | scorefile.write("Epoch %d, LR %f, Acc %2.2f, LOSS %f\n"%(it, lr, traineer, loss)) 66 | scorefile.flush() 67 | 68 | it += 1 69 | print("") -------------------------------------------------------------------------------- /Stage1/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy, tqdm, sys, time, soundfile 6 | 7 | from loss import * 8 | from encoder import * 9 | from tools import * 10 | 11 | class model(nn.Module): 12 | def __init__(self, lr, lr_decay, **kwargs): 13 | super(model, self).__init__() 14 | self.Network = ECAPA_TDNN().cuda() # Speaker encoder 15 | self.Loss = LossFunction().cuda() # Contrastive loss 16 | self.AATNet = AATNet().cuda() # AAT, which is used to improve the performace 17 | self.Reverse = Reverse().cuda() # AAT 18 | self.OptimNet = torch.optim.Adam(list(self.Network.parameters()) + list(self.Loss.parameters()), lr = lr) 19 | self.OptimAAT = torch.optim.Adam(self.AATNet.parameters(), lr = lr) 20 | self.Scheduler = torch.optim.lr_scheduler.StepLR(self.OptimNet, step_size = 5, gamma=lr_decay) 21 | print("Model para number = %.2f"%(sum(param.numel() for param in self.Network.parameters()) / 1024 / 1024)) 22 | 23 | def train_network(self, loader, epoch): 24 | # Contrastive learning with AAT, for more details about AAT, please check here: https://github.com/joonson/voxceleb_unsupervised 25 | self.train() 26 | self.Scheduler.step(epoch - 1) # Update the learning rate 27 | loss, top1 = 0, 0 28 | lr = self.OptimNet.param_groups[0]['lr'] # Read the current learning rate 29 | criterion = torch.nn.CrossEntropyLoss() # Use for AAT 30 | AAT_labels = torch.LongTensor([1]*loader.batch_size+[0]*loader.batch_size).cuda() # AAT labels 31 | tstart = time.time() # Used to monitor the training speed 32 | for counter, data in enumerate(loader, start = 1): 33 | data = data.transpose(0,1) 34 | feat = [] 35 | for inp in data: 36 | feat.append(self.Network.forward(torch.FloatTensor(inp).cuda())) # Feed the segments to get the speaker embeddings 37 | feat = torch.stack(feat,dim=1).squeeze() 38 | self.zero_grad() 39 | # Train discriminator 40 | out_a, out_s, out_p = feat[:,0,:].detach(), feat[:,1,:].detach(), feat[:,2,:].detach() 41 | in_AAT = torch.cat((torch.cat((out_a,out_s),1),torch.cat((out_a,out_p),1)),0) 42 | out_AAT = self.AATNet(in_AAT) 43 | dloss = criterion(out_AAT, AAT_labels) 44 | dloss.backward() 45 | self.OptimAAT.step() 46 | # Train model 47 | self.zero_grad() 48 | in_AAT = torch.cat((torch.cat((feat[:,0,:],feat[:,1,:]),1),torch.cat((feat[:,0,:],feat[:,2,:]),1)),0) 49 | out_AAT = self.AATNet(self.Reverse(in_AAT)) 50 | closs = criterion(out_AAT, AAT_labels) # AAT loss 51 | sloss, prec1 = self.Loss.forward(feat[:,[0,2],:]) # speaker loss 52 | nloss = sloss + closs * 3 # Total loss 53 | loss += nloss.detach().cpu() 54 | top1 += prec1 # Training acc 55 | nloss.backward() 56 | self.OptimNet.step() 57 | time_used = time.time() - tstart # Time for this epoch 58 | sys.stdout.write("[%2d] Lr: %5f, %.2f%% (est %.1f mins) Loss %f EER/TAcc %2.3f%% \r"%(epoch, lr, 100 * (counter / loader.__len__()), time_used * loader.__len__() / counter / 60, loss/counter, top1/counter)) 59 | sys.stdout.flush() 60 | sys.stdout.write("\n") 61 | return loss/counter, top1/counter, lr 62 | 63 | def evaluate_network(self, val_list, val_path, **kwargs): 64 | self.eval() 65 | files, feats = [], {} 66 | for line in open(val_list).read().splitlines(): 67 | data = line.split() 68 | files.append(data[1]) 69 | files.append(data[2]) 70 | setfiles = list(set(files)) 71 | setfiles.sort() # Read the list of wav files 72 | for idx, file in tqdm.tqdm(enumerate(setfiles), total = len(setfiles)): 73 | audio, _ = soundfile.read(os.path.join(val_path, file)) 74 | feat = torch.FloatTensor(numpy.stack([audio], axis=0)).cuda() 75 | with torch.no_grad(): 76 | ref_feat = self.Network.forward(feat).detach().cpu() 77 | feats[file] = ref_feat # Extract features for each data, get the feature dict 78 | scores, labels = [], [] 79 | for line in open(val_list).read().splitlines(): 80 | data = line.split() 81 | ref_feat = F.normalize(feats[data[1]].cuda(), p=2, dim=1) # feature 1 82 | com_feat = F.normalize(feats[data[2]].cuda(), p=2, dim=1) # feature 2 83 | score = numpy.mean(torch.matmul(ref_feat, com_feat.T).detach().cpu().numpy()) # Get the score 84 | scores.append(score) 85 | labels.append(int(data[0])) 86 | EER = tuneThresholdfromScore(scores, labels, [1, 0.1])[1] 87 | fnrs, fprs, thresholds = ComputeErrorRates(scores, labels) 88 | minDCF, _ = ComputeMinDcf(fnrs, fprs, thresholds, 0.05, 1, 1) 89 | return EER, minDCF 90 | 91 | def save_network(self, path): # Save the model 92 | torch.save(self.state_dict(), path) 93 | 94 | def load_network(self, path): # Load the parameters of the pretrain model 95 | self_state = self.state_dict() 96 | loaded_state = torch.load(path) 97 | print("Model %s loaded!"%(path)) 98 | for name, param in loaded_state.items(): 99 | origname = name 100 | if name not in self_state: 101 | name = name.replace("module.", "") 102 | if name not in self_state: 103 | print("%s is not in the model."%origname) 104 | continue 105 | if self_state[name].size() != loaded_state[origname].size(): 106 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())) 107 | continue 108 | self_state[name].copy_(param) 109 | 110 | -------------------------------------------------------------------------------- /Stage1/pytorch_revgrad/README.md: -------------------------------------------------------------------------------- 1 | This module is learnt from [Augmentation adversarial training for unsupervised speaker recognition](https://github.com/joonson/voxceleb_unsupervised). In my experience, AAT can improve the EER from 8.5 to 7.5 in Vox\_O. Check their paper for more details. -------------------------------------------------------------------------------- /Stage1/pytorch_revgrad/__init__.py: -------------------------------------------------------------------------------- 1 | """A pytorch module (and function) to reverse gradients.""" 2 | from .module import RevGrad # noqa: F401 3 | from .version import __version__ # noqa: F401 4 | -------------------------------------------------------------------------------- /Stage1/pytorch_revgrad/functional.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | 3 | 4 | class RevGrad(Function): 5 | @staticmethod 6 | def forward(ctx, input_): 7 | ctx.save_for_backward(input_) 8 | output = input_ 9 | return output 10 | 11 | @staticmethod 12 | def backward(ctx, grad_output): # pragma: no cover 13 | grad_input = None 14 | if ctx.needs_input_grad[0]: 15 | grad_input = -grad_output 16 | return grad_input 17 | 18 | 19 | revgrad = RevGrad.apply 20 | -------------------------------------------------------------------------------- /Stage1/pytorch_revgrad/module.py: -------------------------------------------------------------------------------- 1 | from .functional import revgrad 2 | from torch.nn import Module 3 | 4 | 5 | class RevGrad(Module): 6 | def __init__(self, *args, **kwargs): 7 | """ 8 | A gradient reversal layer. 9 | 10 | This layer has no parameters, and simply reverses the gradient 11 | in the backward pass. 12 | """ 13 | 14 | super().__init__(*args, **kwargs) 15 | 16 | def forward(self, input_): 17 | return revgrad(input_) 18 | -------------------------------------------------------------------------------- /Stage1/pytorch_revgrad/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.1" 2 | version = __version__ 3 | -------------------------------------------------------------------------------- /Stage1/rir.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoRuijie/Loss-Gated-Learning/665a00ad9a62a94004bd9c89c48d0f1f5cecb79d/Stage1/rir.npy -------------------------------------------------------------------------------- /Stage1/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main_train.py \ 4 | --save_path exp/exp1 \ 5 | --batch_size 300 \ 6 | --lr 0.001 \ 7 | --lr_decay 0.90 \ 8 | --train_list ../utils/train_list.txt \ 9 | --val_list ../utils/test_list.txt \ 10 | --train_path /data08/VoxCeleb2/wav \ 11 | --val_path /data08/VoxCeleb1/wav \ 12 | --musan_path /data08/Others/musan_split \ 13 | --test_interval 1 \ -------------------------------------------------------------------------------- /Stage1/tools.py: -------------------------------------------------------------------------------- 1 | import torch, os, math, numpy 2 | from sklearn.metrics import accuracy_score 3 | from sklearn import metrics 4 | from operator import itemgetter 5 | 6 | def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None): 7 | 8 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1) 9 | fnr = 1 - tpr 10 | tunedThreshold = []; 11 | if target_fr: 12 | for tfr in target_fr: 13 | idx = numpy.nanargmin(numpy.absolute((tfr - fnr))) 14 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]) 15 | for tfa in target_fa: 16 | idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] 17 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]) 18 | idxE = numpy.nanargmin(numpy.absolute((fnr - fpr))) 19 | eer = max(fpr[idxE],fnr[idxE])*100 20 | 21 | return tunedThreshold, eer, fpr, fnr 22 | 23 | # Creates a list of false-negative rates, a list of false-positive rates 24 | # and a list of decision thresholds that give those error-rates. 25 | def ComputeErrorRates(scores, labels): 26 | 27 | # Sort the scores from smallest to largest, and also get the corresponding 28 | # indexes of the sorted scores. We will treat the sorted scores as the 29 | # thresholds at which the the error-rates are evaluated. 30 | sorted_indexes, thresholds = zip(*sorted( 31 | [(index, threshold) for index, threshold in enumerate(scores)], 32 | key=itemgetter(1))) 33 | sorted_labels = [] 34 | labels = [labels[i] for i in sorted_indexes] 35 | fnrs = [] 36 | fprs = [] 37 | 38 | # At the end of this loop, fnrs[i] is the number of errors made by 39 | # incorrectly rejecting scores less than thresholds[i]. And, fprs[i] 40 | # is the total number of times that we have correctly accepted scores 41 | # greater than thresholds[i]. 42 | for i in range(0, len(labels)): 43 | if i == 0: 44 | fnrs.append(labels[i]) 45 | fprs.append(1 - labels[i]) 46 | else: 47 | fnrs.append(fnrs[i-1] + labels[i]) 48 | fprs.append(fprs[i-1] + 1 - labels[i]) 49 | fnrs_norm = sum(labels) 50 | fprs_norm = len(labels) - fnrs_norm 51 | 52 | # Now divide by the total number of false negative errors to 53 | # obtain the false positive rates across all thresholds 54 | fnrs = [x / float(fnrs_norm) for x in fnrs] 55 | 56 | # Divide by the total number of corret positives to get the 57 | # true positive rate. Subtract these quantities from 1 to 58 | # get the false positive rates. 59 | fprs = [1 - x / float(fprs_norm) for x in fprs] 60 | return fnrs, fprs, thresholds 61 | 62 | # Computes the minimum of the detection cost function. The comments refer to 63 | # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan. 64 | def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa): 65 | min_c_det = float("inf") 66 | min_c_det_threshold = thresholds[0] 67 | for i in range(0, len(fnrs)): 68 | # See Equation (2). it is a weighted sum of false negative 69 | # and false positive errors. 70 | c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target) 71 | if c_det < min_c_det: 72 | min_c_det = c_det 73 | min_c_det_threshold = thresholds[i] 74 | # See Equations (3) and (4). Now we normalize the cost. 75 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 76 | min_dcf = min_c_det / c_def 77 | return min_dcf, min_c_det_threshold 78 | 79 | # Compute the training acc based on outputs and labels 80 | def accuracy(output, target, topk=(1,)): 81 | maxk = max(topk) 82 | batch_size = target.size(0) 83 | 84 | _, pred = output.topk(maxk, 1, True, True) 85 | 86 | pred = pred.t() 87 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 88 | 89 | res = [] 90 | for k in topk: 91 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 92 | res.append(correct_k.mul_(100.0 / batch_size)) 93 | return res -------------------------------------------------------------------------------- /Stage2/README.md: -------------------------------------------------------------------------------- 1 | In Stage 2, the baseline and our method shared the same code, I add the `LGL` in the args to simplify the code. 2 | 3 | As I mentioned, to make the code clear, this code and the code we used in our paper is a bit different: 4 | 5 | * In our paper, we extend the channel size of speaker encoder to `1024` in the iteration 5, in this code we remove this setting to simply the code. You can do that in the last iteration to get the better result. 6 | 7 | * In our paper, we manually determinate the end of each iteration, that is not user-friendly. In this code, we end the iteration if EER can not improve in continuous `N = 4` epochs. You can increase it to improve the performance. 8 | 9 | I do not have time&resource to run enough epochs for the Stage 2 (I only run 50 epochs get about `EER=2.2` already), so if you get the final performance, I appreciate if you can share your score.txt file with me. Thanks! 10 | -------------------------------------------------------------------------------- /Stage2/dataLoader.py: -------------------------------------------------------------------------------- 1 | import glob, numpy, os, random, soundfile, torch, wave 2 | from scipy import signal 3 | from tools import * 4 | 5 | def get_Loader(args, dic_label = None, cluster_only = False): 6 | # Get the loader for the cluster, batch_size is set as 1 to handlle the variable length input. Details see 1.2 part from here: https://github.com/TaoRuijie/TalkNet-ASD/blob/main/FAQ.md 7 | clusterLoader = cluster_loader(**vars(args)) 8 | clusterLoader = torch.utils.data.DataLoader(clusterLoader, batch_size = 1, shuffle = True, num_workers = args.n_cpu, drop_last = False) 9 | 10 | if cluster_only == True: # Only do clustering 11 | return clusterLoader 12 | # Get the loader for training 13 | trainLoader = train_loader(dic_label = dic_label, **vars(args)) 14 | trainLoader = torch.utils.data.DataLoader(trainLoader, batch_size = args.batch_size, shuffle = True, num_workers = args.n_cpu, drop_last = True) 15 | 16 | return trainLoader, clusterLoader 17 | 18 | class train_loader(object): 19 | def __init__(self, train_list, train_path, musan_path, rir_path, max_frames, dic_label, **kwargs): 20 | self.train_path = train_path 21 | self.max_frames = max_frames * 160 + 240 # Length of segment for training 22 | self.dic_label = dic_label # Pseudo labels dict 23 | self.noisetypes = ['noise','speech','music'] 24 | self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} 25 | self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1]} 26 | self.noiselist = {} 27 | augment_files = glob.glob(os.path.join(musan_path,'*/*/*/*.wav')) 28 | for file in augment_files: 29 | if file.split('/')[-4] not in self.noiselist: 30 | self.noiselist[file.split('/')[-4]] = [] 31 | self.noiselist[file.split('/')[-4]].append(file) 32 | self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav')) 33 | self.data_list = [] 34 | lines = open(train_list).read().splitlines() 35 | for index, line in enumerate(lines): 36 | file_name = line.split()[1] 37 | self.data_list.append(file_name) 38 | 39 | def __getitem__(self, index): 40 | file = self.data_list[index] # Get the filename 41 | label = self.dic_label[file] # Load the pseudo label 42 | segments = self.load_wav(file = file) # Load the augmented segment 43 | segments = torch.FloatTensor(numpy.array(segments)) 44 | return segments, label 45 | 46 | def load_wav(self, file): 47 | utterance, _ = soundfile.read(os.path.join(self.train_path, file)) # Read the wav file 48 | if utterance.shape[0] <= self.max_frames: # Padding if less than required length 49 | shortage = self.max_frames - utterance.shape[0] 50 | utterance = numpy.pad(utterance, (0, shortage), 'wrap') 51 | startframe = random.choice(range(0, utterance.shape[0] - (self.max_frames))) # Choose the startframe randomly 52 | segment = numpy.expand_dims(numpy.array(utterance[int(startframe):int(startframe)+self.max_frames]), axis = 0) 53 | 54 | if random.random() <= 0.5: 55 | segment = self.add_rev(segment, length = self.max_frames) # Rever 56 | if random.random() <= 0.5: 57 | segment = self.add_noise(segment, random.choice(['music', 'speech', 'noise']), length = self.max_frames) # Noise 58 | 59 | return segment[0] 60 | 61 | def __len__(self): 62 | return len(self.data_list) 63 | 64 | def add_rev(self, audio, length): 65 | rir_file = random.choice(self.rir_files) 66 | rir, sr = soundfile.read(rir_file) 67 | rir = numpy.expand_dims(rir.astype(numpy.float),0) 68 | rir = rir / numpy.sqrt(numpy.sum(rir**2)) 69 | return signal.convolve(audio, rir, mode='full')[:,:length] 70 | 71 | def add_noise(self, audio, noisecat, length): 72 | clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) 73 | numnoise = self.numnoise[noisecat] 74 | noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1])) 75 | noises = [] 76 | for noise in noiselist: 77 | noiselength = wave.open(noise, 'rb').getnframes() # Read the length of the noise file 78 | if noiselength <= length: 79 | noiseaudio, _ = soundfile.read(noise) 80 | noiseaudio = numpy.pad(noiseaudio, (0, length - noiselength), 'wrap') 81 | else: 82 | start_frame = numpy.int64(random.random()*(noiselength-length)) # If length is enough 83 | noiseaudio, _ = soundfile.read(noise, start = start_frame, stop = start_frame + length) # Only read some part to improve speed 84 | noiseaudio = numpy.stack([noiseaudio],axis=0) 85 | noise_db = 10 * numpy.log10(numpy.mean(noiseaudio ** 2)+1e-4) 86 | noisesnr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1]) 87 | noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noisesnr) / 10)) * noiseaudio) 88 | noise = numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) 89 | return noise + audio 90 | 91 | class cluster_loader(object): 92 | def __init__(self, train_list, train_path, **kwargs): 93 | self.data_list, self.data_length, self.data_label = [], [], [] 94 | self.train_path = train_path 95 | lines = open(train_list).read().splitlines() 96 | # Get the ground-truth labels, that is used to compute the NMI for post-analyze. 97 | dictkeys = list(set([x.split()[0] for x in lines])) 98 | dictkeys.sort() 99 | dictkeys = { key : ii for ii, key in enumerate(dictkeys) } 100 | 101 | for lidx, line in enumerate(lines): 102 | data = line.split() 103 | file_name = data[1] 104 | file_length = float(data[-1]) 105 | speaker_label = dictkeys[data[0]] 106 | self.data_list.append(file_name) # Filename 107 | self.data_length.append(file_length) # Filelength 108 | self.data_label.append(speaker_label) # GT Speaker label 109 | 110 | # sort the training set by the length of the audios, audio with similar length are saved togethor. 111 | inds = numpy.array(self.data_length).argsort() 112 | self.data_list, self.data_length, self.data_label = numpy.array(self.data_list)[inds], \ 113 | numpy.array(self.data_length)[inds], \ 114 | numpy.array(self.data_label)[inds] 115 | self.minibatch = [] 116 | start = 0 117 | while True: # Genearte each minibatch, audio with similar length are saved togethor. 118 | frame_length = self.data_length[start] 119 | minibatch_size = max(1, int(1600 // frame_length)) 120 | end = min(len(self.data_list), start + minibatch_size) 121 | self.minibatch.append([self.data_list[start:end], frame_length, self.data_label[start:end]]) 122 | if end == len(self.data_list): 123 | break 124 | start = end 125 | 126 | def __getitem__(self, index): 127 | data_lists, frame_length, data_labels = self.minibatch[index] # Get one minibatch 128 | filenames, labels, segments = [], [], [] 129 | for num in range(len(data_lists)): 130 | filename = data_lists[num] # Read filename 131 | label = data_labels[num] # Read GT label 132 | audio, sr = soundfile.read(os.path.join(self.train_path, filename)) 133 | if len(audio) < int(frame_length * sr): 134 | shortage = int(frame_length * sr) - len(audio) + 1 135 | audio = numpy.pad(audio, (0, shortage), 'wrap') 136 | audio = numpy.array(audio[:int(frame_length * sr)]) # Get clean utterance, better for clustering 137 | segments.append(audio) 138 | filenames.append(filename) 139 | labels.append(label) 140 | segments = torch.FloatTensor(numpy.array(segments)) 141 | return segments, filenames, labels 142 | 143 | def __len__(self): 144 | return len(self.minibatch) -------------------------------------------------------------------------------- /Stage2/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchaudio 6 | 7 | class SEModule(nn.Module): 8 | def __init__(self, channels, bottleneck=128): 9 | super(SEModule, self).__init__() 10 | self.se = nn.Sequential( 11 | nn.AdaptiveAvgPool1d(1), 12 | nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), 13 | nn.ReLU(), 14 | nn.BatchNorm1d(bottleneck), 15 | nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), 16 | nn.Sigmoid(), 17 | ) 18 | 19 | def forward(self, input): 20 | x = self.se(input) 21 | return input * x 22 | 23 | class Bottle2neck(nn.Module): 24 | def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale = 8): 25 | super(Bottle2neck, self).__init__() 26 | 27 | width = int(math.floor(planes / scale)) 28 | self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1) 29 | self.bn1 = nn.BatchNorm1d(width*scale) 30 | self.nums = scale -1 31 | convs = [] 32 | bns = [] 33 | num_pad = math.floor(kernel_size/2)*dilation 34 | for i in range(self.nums): 35 | convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad)) 36 | bns.append(nn.BatchNorm1d(width)) 37 | self.convs = nn.ModuleList(convs) 38 | self.bns = nn.ModuleList(bns) 39 | self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1) 40 | self.bn3 = nn.BatchNorm1d(planes) 41 | self.relu = nn.ReLU() 42 | self.width = width 43 | self.se = SEModule(planes) 44 | 45 | def forward(self, x): 46 | residual = x 47 | out = self.conv1(x) 48 | out = self.relu(out) 49 | out = self.bn1(out) 50 | spx = torch.split(out, self.width, 1) 51 | for i in range(self.nums): 52 | if i==0: 53 | sp = spx[i] 54 | else: 55 | sp = sp + spx[i] 56 | sp = self.convs[i](sp) 57 | sp = self.relu(sp) 58 | sp = self.bns[i](sp) 59 | if i==0: 60 | out = sp 61 | else: 62 | out = torch.cat((out, sp), 1) 63 | out = torch.cat((out, spx[self.nums]),1) 64 | out = self.conv3(out) 65 | out = self.relu(out) 66 | out = self.bn3(out) 67 | out = self.se(out) 68 | out += residual 69 | 70 | return out 71 | 72 | class PreEmphasis(torch.nn.Module): 73 | 74 | def __init__(self, coef: float = 0.97): 75 | super().__init__() 76 | self.coef = coef 77 | self.register_buffer( 78 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 79 | ) 80 | 81 | def forward(self, input: torch.tensor) -> torch.tensor: 82 | input = input.unsqueeze(1) 83 | input = F.pad(input, (1, 0), 'reflect') 84 | return F.conv1d(input, self.flipped_filter).squeeze(1) 85 | 86 | class ECAPA_TDNN(nn.Module): # Spec aug can also be added to improve the performance. 87 | def __init__(self, C = 512, **kwargs): 88 | super(ECAPA_TDNN, self).__init__() 89 | self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2) 90 | self.relu = nn.ReLU() 91 | self.bn1 = nn.BatchNorm1d(C) 92 | self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8) 93 | self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8) 94 | self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8) 95 | self.layer4 = nn.Conv1d(3*C, 1536, kernel_size=1) 96 | self.torchfbank = torch.nn.Sequential( 97 | PreEmphasis(), 98 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min = 20, f_max = 7600, window_fn=torch.hamming_window, n_mels=80), 99 | ) 100 | self.attention = nn.Sequential( 101 | nn.Conv1d(4608, 256, kernel_size=1), 102 | nn.ReLU(), 103 | nn.BatchNorm1d(256), 104 | nn.Tanh(), 105 | nn.Conv1d(256, 1536, kernel_size=1), 106 | nn.Softmax(dim=2), 107 | ) 108 | self.bn5 = nn.BatchNorm1d(3072) 109 | self.fc6 = nn.Linear(3072, 192) 110 | self.bn6 = nn.BatchNorm1d(192) 111 | 112 | def forward(self, x): 113 | with torch.no_grad(): 114 | x = self.torchfbank(x)+1e-6 115 | x = x.log() 116 | x = x - torch.mean(x, dim=-1, keepdim=True) 117 | x = self.conv1(x) 118 | x = self.relu(x) 119 | x = self.bn1(x) 120 | x1 = self.layer1(x) 121 | x2 = self.layer2(x+x1) 122 | x3 = self.layer3(x+x1+x2) 123 | x = self.layer4(torch.cat((x1,x2,x3),dim=1)) 124 | x = self.relu(x) 125 | t = x.size()[-1] 126 | global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1) 127 | w = self.attention(global_x) 128 | mu = torch.sum(x * w, dim=2) 129 | sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) ) 130 | x = torch.cat((mu,sg),1) 131 | x = self.bn5(x) 132 | x = self.fc6(x) 133 | x = self.bn6(x) 134 | return x -------------------------------------------------------------------------------- /Stage2/loss.py: -------------------------------------------------------------------------------- 1 | import torch, math, numpy 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from tools import * 5 | 6 | # AAM softmax loss same to that in voxceleb_trainer 7 | class LossFunction(nn.Module): 8 | def __init__(self, n_class, n_out = 192, m = 0.2, s = 30): 9 | 10 | super(LossFunction, self).__init__() 11 | self.m = m 12 | self.s = s 13 | self.weight = torch.nn.Parameter(torch.FloatTensor(n_class, n_out), requires_grad=True) 14 | self.ce = nn.CrossEntropyLoss(reduction = 'none') 15 | nn.init.xavier_normal_(self.weight, gain=1) 16 | self.cos_m = math.cos(self.m) 17 | self.sin_m = math.sin(self.m) 18 | self.th = math.cos(math.pi - self.m) 19 | self.mm = math.sin(math.pi - self.m) * self.m 20 | 21 | def forward(self, x, label, gate): 22 | 23 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 24 | sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1)) 25 | phi = cosine * self.cos_m - sine * self.sin_m 26 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 27 | one_hot = torch.zeros_like(cosine) 28 | one_hot.scatter_(1, label.view(-1, 1), 1) 29 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 30 | output = output * self.s 31 | 32 | ce = self.ce(output, label) 33 | # LGL 34 | mask = ce <= gate # Find the sample that loss smaller that gate 35 | nselect = sum(mask).detach() # Count the num 36 | loss = torch.sum(ce * mask, dim = -1) / nselect # Compute the loss for the selected data only 37 | prec1 = accuracy(output.detach(), label * mask.detach(), topk=(1,))[0] * x.size()[0] # Compute the training acc for these selected data only 38 | return loss, prec1, nselect -------------------------------------------------------------------------------- /Stage2/main_train.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pickle, glob 2 | from model import * 3 | from dataLoader import * 4 | from tools import * 5 | 6 | parser = argparse.ArgumentParser(description = "Loss Gated Learning") 7 | parser.add_argument('--n_cpu', type=int, default=8) 8 | parser.add_argument('--max_frames', type=int, default=300) 9 | parser.add_argument('--batch_size', type=int, default=512) 10 | parser.add_argument('--init_model', type=str, default="") 11 | parser.add_argument('--save_path', type=str, default="") 12 | parser.add_argument('--train_list', type=str, default="",help='Path for Vox2 list, https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/train_list.txt') 13 | parser.add_argument('--val_list', type=str, default="", help='Path for Vox_O list, https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt') 14 | parser.add_argument('--train_path', type=str, default="", help='Path to the Vox2 set') 15 | parser.add_argument('--val_path', type=str, default="", help='Path to the Vox_O set') 16 | parser.add_argument('--musan_path', type=str, default="", help='Path to the musan set') 17 | parser.add_argument('--rir_path', type=str, default="", help='Path to the rir set') 18 | parser.add_argument('--lr', type=float, default=0.001) 19 | parser.add_argument('--n_cluster', type=int, default=6000, help='Number of clusters') 20 | parser.add_argument('--test_interval',type=int, default=1) 21 | parser.add_argument('--max_epoch', type=int, default=100) 22 | parser.add_argument('--LGL', dest='LGL', action='store_true', help='Use LGL or baseline only') 23 | args = parser.parse_args() 24 | 25 | torch.multiprocessing.set_sharing_strategy('file_system') 26 | warnings.filterwarnings("ignore") 27 | inf_max = 10**3 28 | if args.LGL: 29 | gates = [1, 3, 3, 5, 6] # Set the gates in each iterations, which is different from our paper because we use stronger augmentation in dataloader 30 | else: 31 | gates = [inf_max, inf_max, inf_max, inf_max, inf_max] # Set the gate as a very large value = No gate (baseline) 32 | 33 | args.model_folder = os.path.join(args.save_path, 'model') # Path for the saved models 34 | args.dic_folder = os.path.join(args.save_path, 'dic') # Path for the saved pseudo label dic 35 | args.score_path = os.path.join(args.save_path, 'score.txt') # Path for the score file 36 | os.makedirs(args.model_folder, exist_ok = True) 37 | os.makedirs(args.dic_folder, exist_ok = True) 38 | score_file = open(args.score_path, "a+") 39 | 40 | stage, best_epoch, next_epoch, iteration = check_clustering(args.score_path, args.LGL) # Check the state of this epoch 41 | print(stage, best_epoch, next_epoch, iteration) 42 | 43 | Trainer = trainer(**vars(args)) # Define the framework 44 | modelfiles = glob.glob('%s/model0*.model'%args.model_folder) # Go for all saved model 45 | modelfiles.sort() 46 | 47 | if len(modelfiles) >= 1: # Load the previous model 48 | Trainer.load_parameters(modelfiles[-1]) 49 | args.epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1 50 | else: 51 | args.epoch = 1 # Start from the first epoch 52 | for items in vars(args): # Save the parameters in args 53 | score_file.write('%s %s\n'%(items, vars(args)[items])); 54 | score_file.flush() 55 | 56 | if args.epoch == 1: # Do clustering in the first epoch 57 | Trainer.load_parameters(args.init_model) # Load the init_model 58 | clusterLoader = get_Loader(args, cluster_only = True) # Data Loader 59 | dic_label, NMI = Trainer.cluster_network(loader = clusterLoader, n_cluster = args.n_cluster) # Do clustering 60 | pickle.dump(dic_label, open(args.dic_folder + "/label%04d.pkl"%args.epoch, "wb")) # Save the pseudo labels 61 | print_write(type = 'C', text = [args.epoch, NMI], score_file = score_file) 62 | 63 | labelfiles = glob.glob('%s/label0*.pkl'%args.dic_folder) # Read the last pseudo labels 64 | labelfiles.sort() 65 | dic_label = pickle.load(open(labelfiles[-1], "rb")) 66 | print("Dic %s loaded!"%labelfiles[-1]) 67 | trainLoader, clusterLoader = get_Loader(args, dic_label) # data loader with the pseduo labels 68 | 69 | while args.epoch <= args.max_epoch: 70 | stage, best_epoch, next_epoch, iteration = check_clustering(args.score_path, args.LGL) # Check the state of this epoch 71 | 72 | if stage == 'T': # Classification training 73 | loss, acc, nselects = Trainer.train_network(epoch = args.epoch, loader = trainLoader, gate = inf_max) 74 | print_write(type = 'T', text = [args.epoch, loss, acc, nselects], score_file = score_file) 75 | 76 | elif stage == 'L': # LGL training 77 | if best_epoch != None: # LGL start from the best model from 'T' stage 78 | Trainer.load_parameters('%s/model0%03d.model'%(args.model_folder, best_epoch)) # Load the best model 79 | loss, acc, nselects = Trainer.train_network(epoch = args.epoch, loader = trainLoader, gate = gates[iteration - 1]) 80 | print_write(type = 'L', text = [args.epoch, loss, acc, nselects, gates[iteration - 1]], score_file = score_file) 81 | 82 | elif stage == 'C': # Clustering 83 | iteration += 1 84 | if iteration > 5: # Maximun iteration is 5 85 | quit() 86 | Trainer.load_parameters('%s/model0%03d.model'%(args.model_folder, best_epoch)) # Load the best model 87 | clusterLoader = get_Loader(args, cluster_only = True) # Cluster loader 88 | dic_label, NMI = Trainer.cluster_network(loader = clusterLoader, n_cluster = args.n_cluster) # Clustering 89 | args.epoch = next_epoch 90 | print_write(type = 'C', text = [args.epoch, NMI], score_file = score_file) 91 | pickle.dump(dic_label, open(args.dic_folder + "/label%04d.pkl"%args.epoch, "wb")) # Save the pseudo label dic 92 | print("Dic %s loaded!"%(args.dic_folder + "/label%04d.pkl"%args.epoch)) 93 | Trainer = trainer(**vars(args)) # Define the framework 94 | Trainer.load_parameters(args.init_model) # Load the init_model 95 | trainLoader, clusterLoader = get_Loader(args, dic_label) # Get new dataloader with new label dic 96 | 97 | if args.epoch % args.test_interval == 0 and stage != 'C': # evaluation 98 | Trainer.save_parameters(args.model_folder + "/model%04d.model"%args.epoch) # Save the model 99 | EER, minDCF = Trainer.eval_network(**vars(args)) 100 | print_write(type = 'E', text = [args.epoch, EER, minDCF], score_file = score_file) 101 | 102 | args.epoch += 1 103 | -------------------------------------------------------------------------------- /Stage2/model.py: -------------------------------------------------------------------------------- 1 | import torch, sys, os, tqdm, numpy, time, faiss, gc, soundfile 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from tools import * 6 | from loss import * 7 | from encoder import * 8 | from collections import defaultdict 9 | from sklearn.metrics.cluster import normalized_mutual_info_score 10 | 11 | class trainer(nn.Module): 12 | def __init__(self, lr, n_cluster, **kwargs): 13 | super(trainer, self).__init__() 14 | self.Network = ECAPA_TDNN(C = 512).cuda() # Speaker encoder 15 | self.Loss = LossFunction(n_class = n_cluster).cuda() # Classification layer 16 | self.Optim = torch.optim.Adam(list(self.Network.parameters()) + list(self.Loss.parameters()), lr = lr) # Adam, learning rate is fixed 17 | 18 | def train_network(self, epoch, loader, gate): 19 | self.train() 20 | loss, index, nselects, top1 = 0, 0, 0, 0 21 | time_start = time.time() 22 | for num, (data, label) in enumerate(loader, start = 1): 23 | self.zero_grad() 24 | out = self.Network.forward(data.cuda()) # input segment and the output speaker embedding 25 | nloss, prec1, nselect = self.Loss.forward(out, label.cuda(), gate) # Get the loss, training acc and the number of selected data 26 | nloss.backward() 27 | self.Optim.step() 28 | loss += nloss.detach().cpu().numpy() 29 | index += len(label) 30 | nselects += nselect 31 | top1 += prec1 32 | time_used = time.time() - time_start 33 | sys.stderr.write(" [%2d] %.2f%% (est %.1f mins), Loss: %.3f, ACC: %.2f%%, select: %.2f%%, gate: %.1f\r" %\ 34 | (epoch, 100 * (num / loader.__len__()), time_used * loader.__len__() / num / 60, \ 35 | loss/num, top1/nselects, nselects/index*100, gate)) 36 | sys.stderr.flush() 37 | sys.stdout.write("\n") 38 | torch.cuda.empty_cache() 39 | return loss / num, top1/nselects, nselects/index*100 40 | 41 | def cluster_network(self, loader, n_cluster): 42 | self.eval() 43 | out_all, filenames_all, labels_all = [], [], [] 44 | for data, filenames, labels in tqdm.tqdm(loader): 45 | with torch.no_grad(): 46 | out = self.Network.forward(data[0].cuda()) # Get the embeddings 47 | out = F.normalize(out, p=2, dim=1) # Normalization 48 | for i in range(len(filenames)): # Save the filname, labels, and the embedding into the list [labels is used to compute NMI] 49 | filenames_all.append(filenames[i][0]) 50 | labels_all.append(labels[i].cpu().numpy()[0]) 51 | out_all.append(out[i].detach().cpu().numpy()) 52 | out_all = numpy.array(out_all) 53 | # Clustering using faiss https://github.com/facebookresearch/deepcluster 54 | clus = faiss.Clustering(out_all.shape[1], n_cluster) 55 | n, d = out_all.shape 56 | flat_config = faiss.GpuIndexFlatConfig() 57 | flat_config.useFloat16 = False 58 | flat_config.device = 0 59 | index = faiss.GpuIndexFlatIP(faiss.StandardGpuResources(), d, flat_config) 60 | clus.train(out_all, index) # Clustering 61 | preds = [int(i[0]) for i in index.search(out_all, 1)[1]] # Get the results 62 | del out_all 63 | gc.collect() 64 | dic_label = defaultdict(list) # Pseudo label dict 65 | 66 | for i in range(len(preds)): 67 | pred_label = preds[i] # pseudo label 68 | filename = filenames_all[i] # its filename 69 | dic_label[filename] = pred_label # save into the dic 70 | NMI = normalized_mutual_info_score(labels_all, preds) * 100 # Compute the NMI. 71 | torch.cuda.empty_cache() 72 | return dic_label, NMI 73 | 74 | def eval_network(self, val_list, val_path, **kwargs): 75 | self.eval() 76 | files, feats = [], {} 77 | for line in open(val_list).read().splitlines(): 78 | data = line.split() 79 | files.append(data[1]) 80 | files.append(data[2]) 81 | setfiles = list(set(files)) 82 | setfiles.sort() # Read the list of wav files 83 | for idx, file in tqdm.tqdm(enumerate(setfiles), total = len(setfiles)): 84 | audio, _ = soundfile.read(os.path.join(val_path, file)) 85 | feat = torch.FloatTensor(numpy.stack([audio], axis=0)).cuda() 86 | with torch.no_grad(): 87 | ref_feat = self.Network.forward(feat).detach().cpu() 88 | feats[file] = ref_feat # Extract features for each data, get the feature dict 89 | scores, labels = [], [] 90 | for line in open(val_list).read().splitlines(): 91 | data = line.split() 92 | ref_feat = F.normalize(feats[data[1]].cuda(), p=2, dim=1) # feature 1 93 | com_feat = F.normalize(feats[data[2]].cuda(), p=2, dim=1) # feature 2 94 | score = numpy.mean(torch.matmul(ref_feat, com_feat.T).detach().cpu().numpy()) # Get the score 95 | scores.append(score) 96 | labels.append(int(data[0])) 97 | EER = tuneThresholdfromScore(scores, labels, [1, 0.1])[1] 98 | fnrs, fprs, thresholds = ComputeErrorRates(scores, labels) 99 | minDCF, _ = ComputeMinDcf(fnrs, fprs, thresholds, 0.05, 1, 1) 100 | return [EER, minDCF] 101 | 102 | def save_parameters(self, path): 103 | torch.save(self.state_dict(), path) 104 | 105 | def load_parameters(self, path): 106 | self_state = self.state_dict() 107 | loaded_state = torch.load(path) 108 | print("Model %s loaded!"%(path)) 109 | for name, param in loaded_state.items(): 110 | origname = name 111 | if name not in self_state: 112 | name = name.replace("module.", "") 113 | if name not in self_state: 114 | # print("%s is not in the model."%origname) 115 | continue 116 | if self_state[name].size() != loaded_state[origname].size(): 117 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())) 118 | continue 119 | self_state[name].copy_(param) -------------------------------------------------------------------------------- /Stage2/run_LGL.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main_train.py \ 4 | --save_path exp/LGL \ 5 | --batch_size 400 \ 6 | --lr 0.001 \ 7 | --train_list ../utils/train_list.txt \ 8 | --val_list ../utils/test_list.txt \ 9 | --train_path /data08/VoxCeleb2/wav \ 10 | --val_path /data08/VoxCeleb1/wav \ 11 | --musan_path /data08/Others/musan_split \ 12 | --rir_path /data08/Others/RIRS_NOISES/simulated_rirs \ 13 | --init_model /home/ruijie/workspace/sslsr/Stage1/exp/exp1/model/model000000043.model \ 14 | --test_interval 1 \ 15 | --n_cluster 6000 \ 16 | --LGL -------------------------------------------------------------------------------- /Stage2/run_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main_train.py \ 4 | --save_path exp/baseline \ 5 | --batch_size 400 \ 6 | --lr 0.001 \ 7 | --train_list ../utils/train_list.txt \ 8 | --val_list ../utils/test_list.txt \ 9 | --train_path /data08/VoxCeleb2/wav \ 10 | --val_path /data08/VoxCeleb1/wav \ 11 | --musan_path /data08/Others/musan_split \ 12 | --rir_path /data08/Others/RIRS_NOISES/simulated_rirs \ 13 | --init_model /home/ruijie/workspace/Loss-Gated-Learning/Stage1/exp/exp1/model/model000000043.model \ 14 | --test_interval 1 \ 15 | --n_cluster 6000 16 | -------------------------------------------------------------------------------- /Stage2/tools.py: -------------------------------------------------------------------------------- 1 | import warnings, torch, os, math, numpy 2 | from sklearn.metrics import accuracy_score 3 | from sklearn import metrics 4 | from operator import itemgetter 5 | 6 | def print_write(type, text, score_file): # A helper function to print the text and write the log 7 | if type == 'T': # Classification training without LGL (Baseline) 8 | epoch, loss, acc, nselects = text 9 | print("%d epoch, LOSS %f, ACC %.2f%%, nselects %.2f%%\n"%(epoch, loss, acc, nselects)) 10 | score_file.write("[T], %d epoch, LOSS %f, ACC %.2f%%, nselects %.2f%%\n"%(epoch, loss, acc, nselects)) 11 | elif type == 'L': # Classification training with LGL (Propose) 12 | epoch, loss, acc, nselects, gate = text 13 | print("%d epoch, LOSS %f, ACC %.2f%%, nselects %.2f%%, Gate %.1f \n"%(epoch, loss, acc, nselects, gate)) 14 | score_file.write("[L], %d epoch, LOSS %f, ACC %.2f%%, nselects %.2f%%, Gate %.1f \n"%(epoch, loss, acc, nselects, gate)) 15 | elif type == 'C': # Clustering step 16 | epoch, NMI = text 17 | print("%d epoch, NMI %.2f\n"%(epoch, NMI)) 18 | score_file.write("[C], %d epoch, NMI %.2f\n"%(epoch, NMI)) 19 | elif type == 'E': # Evaluation step 20 | epoch, EER, minDCF = text 21 | print("EER %2.2f%%, minDCF %2.3f%%\n"%(EER, minDCF)) 22 | score_file.write("[E], %d epoch, EER %2.2f%%, minDCF %2.3f%%\n"%(epoch, EER, minDCF)) 23 | score_file.flush() 24 | 25 | def check_clustering(score_path, LGL): # Read the score.txt file, judge the next stage 26 | lines = open(score_path).read().splitlines() 27 | 28 | if LGL == True: # For LGL, the order is 29 | # Iteration 1: (C-T-T...-T-L-L...-L-) 30 | # Iteration 2: (C-T-T...-T-L-L...-L-) 31 | # ... 32 | EERs_T, epochs_T, EERs_L, epochs_L = [], [], [], [] 33 | iteration = 0 34 | train_type = 'T' 35 | for line in lines: 36 | if line.split(',')[0] == '[C]': # Clear all results after clustering 37 | EERs_T, EERs_L, epochs_T, epochs_L = [], [], [], [] 38 | train_type = 'T' 39 | iteration += 1 40 | elif line.split(',')[0] == '[E]': # Save the evaluation result in this iteration 41 | epoch = int(line.split(',')[1].split()[0]) 42 | EER = float(line.split(',')[-2].split()[-1][:-1]) 43 | if train_type == 'T': 44 | epochs_T.append(epoch) 45 | EERs_T.append(EER) # Result in [T] 46 | elif train_type == 'L': 47 | epochs_L.append(epoch) 48 | EERs_L.append(EER) # Result in [L] 49 | elif line.split(',')[0] == '[T]': # If the stage is [T], record it 50 | train_type = 'T' 51 | elif line.split(',')[0] == '[L]': # If the stage is [L], record it 52 | train_type = 'L' 53 | 54 | if train_type == 'T': # The stage is [T], so need to judge the next step is keeping [T]? Or do LGL for [L] ? 55 | if len(EERs_T) < 4: # Too short training epoch, keep training 56 | return 'T', None, None, iteration 57 | else: 58 | if EERs_T[-1] > min(EERs_T) and EERs_T[-2] > min(EERs_T) and EERs_T[-3] > min(EERs_T): # Get the best training result already, go LGL 59 | best_epoch = epochs_T[EERs_T.index(min(EERs_T))] 60 | next_epoch = epochs_T[-1] 61 | return 'L', best_epoch, next_epoch, iteration 62 | else: 63 | return 'T', None, None, iteration # EER can still drop, keep training 64 | 65 | elif train_type == 'L': 66 | if len(EERs_L) < 4: # Too short training epoch, keep LGL training 67 | return 'L', None, None, iteration 68 | else: 69 | if EERs_L[-1] > min(EERs_L) and EERs_L[-2] > min(EERs_L) and EERs_L[-3] > min(EERs_L): # Get the best LGL result already, go clustering 70 | best_epoch = epochs_L[EERs_L.index(min(EERs_L))] 71 | next_epoch = epochs_L[-1] 72 | return 'C', best_epoch, next_epoch, iteration # Clustering based on the best epoch is more robust 73 | else: 74 | return 'L', None, None, iteration # EER can still drop, keep training 75 | 76 | else: # Baseline approach without LGL 77 | EERs_T, epochs_T = [], [] 78 | iteration = 0 79 | for line in lines: 80 | if line.split(',')[0] == '[C]': # Clear all results after clustering 81 | EERs_T, epochs_T = [], [] 82 | iteration += 1 83 | elif line.split(',')[0] == '[E]': # Save the evaluation result 84 | epoch = int(line.split(',')[1].split()[0]) 85 | EER = float(line.split(',')[-2].split()[-1][:-1]) 86 | epochs_T.append(epoch) 87 | EERs_T.append(EER) 88 | 89 | if len(EERs_T) < 4: # Too short training epoch, keep training 90 | return 'T', None, None, iteration 91 | else: 92 | if EERs_T[-1] > min(EERs_T) and EERs_T[-2] > min(EERs_T) and EERs_T[-3] > min(EERs_T): # Get the best training result, go clustering 93 | best_epoch = epochs_T[EERs_T.index(min(EERs_T))] 94 | next_epoch = epochs_T[-1] 95 | return 'C', best_epoch, next_epoch, iteration 96 | else: 97 | return 'T', None, None, iteration # EER can still drop, keep training 98 | 99 | def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None): 100 | 101 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1) 102 | fnr = 1 - tpr 103 | tunedThreshold = []; 104 | if target_fr: 105 | for tfr in target_fr: 106 | idx = numpy.nanargmin(numpy.absolute((tfr - fnr))) 107 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]) 108 | for tfa in target_fa: 109 | idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] 110 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]) 111 | idxE = numpy.nanargmin(numpy.absolute((fnr - fpr))) 112 | eer = max(fpr[idxE],fnr[idxE])*100 113 | 114 | return tunedThreshold, eer, fpr, fnr 115 | 116 | # Creates a list of false-negative rates, a list of false-positive rates 117 | # and a list of decision thresholds that give those error-rates. 118 | def ComputeErrorRates(scores, labels): 119 | 120 | # Sort the scores from smallest to largest, and also get the corresponding 121 | # indexes of the sorted scores. We will treat the sorted scores as the 122 | # thresholds at which the the error-rates are evaluated. 123 | sorted_indexes, thresholds = zip(*sorted( 124 | [(index, threshold) for index, threshold in enumerate(scores)], 125 | key=itemgetter(1))) 126 | sorted_labels = [] 127 | labels = [labels[i] for i in sorted_indexes] 128 | fnrs = [] 129 | fprs = [] 130 | 131 | # At the end of this loop, fnrs[i] is the number of errors made by 132 | # incorrectly rejecting scores less than thresholds[i]. And, fprs[i] 133 | # is the total number of times that we have correctly accepted scores 134 | # greater than thresholds[i]. 135 | for i in range(0, len(labels)): 136 | if i == 0: 137 | fnrs.append(labels[i]) 138 | fprs.append(1 - labels[i]) 139 | else: 140 | fnrs.append(fnrs[i-1] + labels[i]) 141 | fprs.append(fprs[i-1] + 1 - labels[i]) 142 | fnrs_norm = sum(labels) 143 | fprs_norm = len(labels) - fnrs_norm 144 | 145 | # Now divide by the total number of false negative errors to 146 | # obtain the false positive rates across all thresholds 147 | fnrs = [x / float(fnrs_norm) for x in fnrs] 148 | 149 | # Divide by the total number of corret positives to get the 150 | # true positive rate. Subtract these quantities from 1 to 151 | # get the false positive rates. 152 | fprs = [1 - x / float(fprs_norm) for x in fprs] 153 | return fnrs, fprs, thresholds 154 | 155 | # Computes the minimum of the detection cost function. The comments refer to 156 | # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan. 157 | def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa): 158 | min_c_det = float("inf") 159 | min_c_det_threshold = thresholds[0] 160 | for i in range(0, len(fnrs)): 161 | # See Equation (2). it is a weighted sum of false negative 162 | # and false positive errors. 163 | c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target) 164 | if c_det < min_c_det: 165 | min_c_det = c_det 166 | min_c_det_threshold = thresholds[i] 167 | # See Equations (3) and (4). Now we normalize the cost. 168 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 169 | min_dcf = min_c_det / c_def 170 | return min_dcf, min_c_det_threshold 171 | 172 | def accuracy(output, target, topk=(1,)): 173 | 174 | maxk = max(topk) 175 | batch_size = target.size(0) 176 | _, pred = output.topk(maxk, 1, True, True) 177 | pred = pred.t() 178 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 179 | res = [] 180 | for k in topk: 181 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 182 | res.append(correct_k.mul_(100.0 / batch_size)) 183 | 184 | return res -------------------------------------------------------------------------------- /utils/LGL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TaoRuijie/Loss-Gated-Learning/665a00ad9a62a94004bd9c89c48d0f1f5cecb79d/utils/LGL.png -------------------------------------------------------------------------------- /utils/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | scikit-learn 4 | tqdm 5 | torchvision 6 | soundfile 7 | faiss --------------------------------------------------------------------------------