├── .gitignore ├── .gitmodules ├── DatasetLoader.py ├── LICENSE.md ├── NOTICE.md ├── README.md ├── References.md ├── SpeakerNet.py ├── configs ├── RawNet3_AAM.yaml ├── ResNetSE34L_AM.yaml └── ResNetSE34L_AP.yaml ├── dataprep.py ├── lists ├── augment.txt ├── fileparts.txt └── files.txt ├── loss ├── aamsoftmax.py ├── amsoftmax.py ├── angleproto.py ├── ge2e.py ├── proto.py ├── softmax.py ├── softmaxproto.py └── triplet.py ├── models ├── RawNet3.py ├── RawNetBasicBlock.py ├── ResNetBlocks.py ├── ResNetSE34L.py ├── ResNetSE34V2.py └── VGGVox.py ├── optimizer ├── adam.py └── sgd.py ├── requirements.txt ├── scheduler └── steplr.py ├── trainSpeakerNet.py ├── tuneThreshold.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Other files 2 | *.model 3 | *.wav 4 | *.mp4 5 | *.pcm 6 | *.yaml 7 | data 8 | data/ 9 | exps/ 10 | core.* 11 | 12 | # NSML related 13 | .nsmlignore 14 | *.nsml.py 15 | setup.py 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "models/weights/RawNet3"] 2 | path = models/weights/RawNet3 3 | url = https://huggingface.co/jungjee/RawNet3 4 | -------------------------------------------------------------------------------- /DatasetLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import numpy 6 | import random 7 | import pdb 8 | import os 9 | import threading 10 | import time 11 | import math 12 | import glob 13 | import soundfile 14 | from scipy import signal 15 | from scipy.io import wavfile 16 | from torch.utils.data import Dataset, DataLoader 17 | import torch.distributed as dist 18 | 19 | def round_down(num, divisor): 20 | return num - (num%divisor) 21 | 22 | def worker_init_fn(worker_id): 23 | numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) 24 | 25 | 26 | def loadWAV(filename, max_frames, evalmode=True, num_eval=10): 27 | 28 | # Maximum audio length 29 | max_audio = max_frames * 160 + 240 30 | 31 | # Read wav file and convert to torch tensor 32 | audio, sample_rate = soundfile.read(filename) 33 | 34 | audiosize = audio.shape[0] 35 | 36 | if audiosize <= max_audio: 37 | shortage = max_audio - audiosize + 1 38 | audio = numpy.pad(audio, (0, shortage), 'wrap') 39 | audiosize = audio.shape[0] 40 | 41 | if evalmode: 42 | startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval) 43 | else: 44 | startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))]) 45 | 46 | feats = [] 47 | if evalmode and max_frames == 0: 48 | feats.append(audio) 49 | else: 50 | for asf in startframe: 51 | feats.append(audio[int(asf):int(asf)+max_audio]) 52 | 53 | feat = numpy.stack(feats,axis=0).astype(numpy.float) 54 | 55 | return feat; 56 | 57 | class AugmentWAV(object): 58 | 59 | def __init__(self, musan_path, rir_path, max_frames): 60 | 61 | self.max_frames = max_frames 62 | self.max_audio = max_audio = max_frames * 160 + 240 63 | 64 | self.noisetypes = ['noise','speech','music'] 65 | 66 | self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} 67 | self.numnoise = {'noise':[1,1], 'speech':[3,7], 'music':[1,1] } 68 | self.noiselist = {} 69 | 70 | augment_files = glob.glob(os.path.join(musan_path,'*/*/*/*.wav')); 71 | 72 | for file in augment_files: 73 | if not file.split('/')[-4] in self.noiselist: 74 | self.noiselist[file.split('/')[-4]] = [] 75 | self.noiselist[file.split('/')[-4]].append(file) 76 | 77 | self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav')); 78 | 79 | def additive_noise(self, noisecat, audio): 80 | 81 | clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) 82 | 83 | numnoise = self.numnoise[noisecat] 84 | noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1])) 85 | 86 | noises = [] 87 | 88 | for noise in noiselist: 89 | 90 | noiseaudio = loadWAV(noise, self.max_frames, evalmode=False) 91 | noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1]) 92 | noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) 93 | noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio) 94 | 95 | return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio 96 | 97 | def reverberate(self, audio): 98 | 99 | rir_file = random.choice(self.rir_files) 100 | 101 | rir, fs = soundfile.read(rir_file) 102 | rir = numpy.expand_dims(rir.astype(numpy.float),0) 103 | rir = rir / numpy.sqrt(numpy.sum(rir**2)) 104 | 105 | return signal.convolve(audio, rir, mode='full')[:,:self.max_audio] 106 | 107 | 108 | class train_dataset_loader(Dataset): 109 | def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs): 110 | 111 | self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames) 112 | 113 | self.train_list = train_list 114 | self.max_frames = max_frames; 115 | self.musan_path = musan_path 116 | self.rir_path = rir_path 117 | self.augment = augment 118 | 119 | # Read training files 120 | with open(train_list) as dataset_file: 121 | lines = dataset_file.readlines(); 122 | 123 | # Make a dictionary of ID names and ID indices 124 | dictkeys = list(set([x.split()[0] for x in lines])) 125 | dictkeys.sort() 126 | dictkeys = { key : ii for ii, key in enumerate(dictkeys) } 127 | 128 | # Parse the training list into file names and ID indices 129 | self.data_list = [] 130 | self.data_label = [] 131 | 132 | for lidx, line in enumerate(lines): 133 | data = line.strip().split(); 134 | 135 | speaker_label = dictkeys[data[0]]; 136 | filename = os.path.join(train_path,data[1]); 137 | 138 | self.data_label.append(speaker_label) 139 | self.data_list.append(filename) 140 | 141 | def __getitem__(self, indices): 142 | 143 | feat = [] 144 | 145 | for index in indices: 146 | 147 | audio = loadWAV(self.data_list[index], self.max_frames, evalmode=False) 148 | 149 | if self.augment: 150 | augtype = random.randint(0,4) 151 | if augtype == 1: 152 | audio = self.augment_wav.reverberate(audio) 153 | elif augtype == 2: 154 | audio = self.augment_wav.additive_noise('music',audio) 155 | elif augtype == 3: 156 | audio = self.augment_wav.additive_noise('speech',audio) 157 | elif augtype == 4: 158 | audio = self.augment_wav.additive_noise('noise',audio) 159 | 160 | feat.append(audio); 161 | 162 | feat = numpy.concatenate(feat, axis=0) 163 | 164 | return torch.FloatTensor(feat), self.data_label[index] 165 | 166 | def __len__(self): 167 | return len(self.data_list) 168 | 169 | 170 | 171 | class test_dataset_loader(Dataset): 172 | def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs): 173 | self.max_frames = eval_frames; 174 | self.num_eval = num_eval 175 | self.test_path = test_path 176 | self.test_list = test_list 177 | 178 | def __getitem__(self, index): 179 | audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval) 180 | return torch.FloatTensor(audio), self.test_list[index] 181 | 182 | def __len__(self): 183 | return len(self.test_list) 184 | 185 | 186 | class train_dataset_sampler(torch.utils.data.Sampler): 187 | def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs): 188 | 189 | self.data_label = data_source.data_label; 190 | self.nPerSpeaker = nPerSpeaker; 191 | self.max_seg_per_spk = max_seg_per_spk; 192 | self.batch_size = batch_size; 193 | self.epoch = 0; 194 | self.seed = seed; 195 | self.distributed = distributed; 196 | 197 | def __iter__(self): 198 | 199 | g = torch.Generator() 200 | g.manual_seed(self.seed + self.epoch) 201 | indices = torch.randperm(len(self.data_label), generator=g).tolist() 202 | 203 | data_dict = {} 204 | 205 | # Sort into dictionary of file indices for each ID 206 | for index in indices: 207 | speaker_label = self.data_label[index] 208 | if not (speaker_label in data_dict): 209 | data_dict[speaker_label] = []; 210 | data_dict[speaker_label].append(index); 211 | 212 | 213 | ## Group file indices for each class 214 | dictkeys = list(data_dict.keys()); 215 | dictkeys.sort() 216 | 217 | lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)] 218 | 219 | flattened_list = [] 220 | flattened_label = [] 221 | 222 | for findex, key in enumerate(dictkeys): 223 | data = data_dict[key] 224 | numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker) 225 | 226 | rp = lol(numpy.arange(numSeg),self.nPerSpeaker) 227 | flattened_label.extend([findex] * (len(rp))) 228 | for indices in rp: 229 | flattened_list.append([data[i] for i in indices]) 230 | 231 | ## Mix data in random order 232 | mixid = torch.randperm(len(flattened_label), generator=g).tolist() 233 | mixlabel = [] 234 | mixmap = [] 235 | 236 | ## Prevent two pairs of the same speaker in the same batch 237 | for ii in mixid: 238 | startbatch = round_down(len(mixlabel), self.batch_size) 239 | if flattened_label[ii] not in mixlabel[startbatch:]: 240 | mixlabel.append(flattened_label[ii]) 241 | mixmap.append(ii) 242 | 243 | mixed_list = [flattened_list[i] for i in mixmap] 244 | 245 | ## Divide data to each GPU 246 | if self.distributed: 247 | total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size()) 248 | start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size ) 249 | end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size ) 250 | self.num_samples = end_index - start_index 251 | return iter(mixed_list[start_index:end_index]) 252 | else: 253 | total_size = round_down(len(mixed_list), self.batch_size) 254 | self.num_samples = total_size 255 | return iter(mixed_list[:total_size]) 256 | 257 | 258 | def __len__(self) -> int: 259 | return self.num_samples 260 | 261 | def set_epoch(self, epoch: int) -> None: 262 | self.epoch = epoch 263 | 264 | 265 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | VoxCeleb trainer 2 | 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------------- 24 | 25 | This project contains subcomponents with separate copyright notices and license terms. 26 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 27 | 28 | ===== 29 | 30 | pytorch/vision 31 | https://github.com/pytorch/vision 32 | 33 | 34 | BSD 3-Clause License 35 | 36 | Copyright (c) Soumith Chintala 2016, 37 | All rights reserved. 38 | 39 | Redistribution and use in source and binary forms, with or without 40 | modification, are permitted provided that the following conditions are met: 41 | 42 | * Redistributions of source code must retain the above copyright notice, this 43 | list of conditions and the following disclaimer. 44 | 45 | * Redistributions in binary form must reproduce the above copyright notice, 46 | this list of conditions and the following disclaimer in the documentation 47 | and/or other materials provided with the distribution. 48 | 49 | * Neither the name of the copyright holder nor the names of its 50 | contributors may be used to endorse or promote products derived from 51 | this software without specific prior written permission. 52 | 53 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 54 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 55 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 56 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 57 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 58 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 59 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 60 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 61 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 62 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 63 | 64 | ===== 65 | 66 | CoinCheung/pytorch-loss 67 | https://github.com/CoinCheung/pytorch-loss 68 | 69 | 70 | MIT License 71 | 72 | Copyright (c) 2019 CoinCheung 73 | 74 | Permission is hereby granted, free of charge, to any person obtaining a copy 75 | of this software and associated documentation files (the "Software"), to deal 76 | in the Software without restriction, including without limitation the rights 77 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 78 | copies of the Software, and to permit persons to whom the Software is 79 | furnished to do so, subject to the following conditions: 80 | 81 | The above copyright notice and this permission notice shall be included in all 82 | copies or substantial portions of the Software. 83 | 84 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 85 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 86 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 87 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 88 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 89 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 90 | SOFTWARE. 91 | 92 | ===== 93 | 94 | wujiyang/Face_Pytorch 95 | https://github.com/wujiyang/Face_Pytorch 96 | 97 | 98 | @author: wujiyang 99 | @contact: wujiyang@hust.edu.cn 100 | 101 | Licensed under the Apache License, Version 2.0 (the "License"); 102 | you may not use this file except in compliance with the License. 103 | You may obtain a copy of the License at 104 | 105 | http://www.apache.org/licenses/LICENSE-2.0 106 | 107 | Unless required by applicable law or agreed to in writing, software 108 | distributed under the License is distributed on an "AS IS" BASIS, 109 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 110 | See the License for the specific language governing permissions and 111 | limitations under the License. 112 | 113 | ===== 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoxCeleb trainer 2 | 3 | This repository contains the framework for training speaker recognition models described in the paper '_In defence of metric learning for speaker recognition_' and '_Pushing the limits of raw waveform speaker recognition_'. 4 | 5 | ### Dependencies 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | ### Data preparation 11 | 12 | The following script can be used to download and prepare the VoxCeleb dataset for training. 13 | 14 | ``` 15 | python ./dataprep.py --save_path data --download --user USERNAME --password PASSWORD 16 | python ./dataprep.py --save_path data --extract 17 | python ./dataprep.py --save_path data --convert 18 | ``` 19 | In order to use data augmentation, also run: 20 | 21 | ``` 22 | python ./dataprep.py --save_path data --augment 23 | ``` 24 | 25 | In addition to the Python dependencies, `wget` and `ffmpeg` must be installed on the system. 26 | 27 | ### Training examples 28 | 29 | - ResNetSE34L with AM-Softmax: 30 | ``` 31 | python ./trainSpeakerNet.py --config ./configs/ResNetSE34L_AM.yaml 32 | ``` 33 | 34 | - RawNet3 with AAM-Softmax 35 | ``` 36 | python ./trainSpeakerNet.py --config ./configs/RawNet3_AAM.yaml 37 | ``` 38 | 39 | - ResNetSE34L with Angular prototypical: 40 | ``` 41 | python ./trainSpeakerNet.py --config ./configs/ResNetSE34L_AP.yaml 42 | ``` 43 | 44 | You can pass individual arguments that are defined in trainSpeakerNet.py by `--{ARG_NAME} {VALUE}`. 45 | Note that the configuration file overrides the arguments passed via command line. 46 | 47 | ### Pretrained models 48 | 49 | A pretrained model, described in [1], can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_lite_ap.model). 50 | 51 | You can check that the following script returns: `EER 2.1792`. You will be given an option to save the scores. 52 | 53 | ``` 54 | python ./trainSpeakerNet.py --eval --model ResNetSE34L --log_input True --trainfunc angleproto --save_path exps/test --eval_frames 400 --initial_model baseline_lite_ap.model 55 | ``` 56 | 57 | A larger model trained with online data augmentation, described in [2], can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_v2_smproto.model). 58 | 59 | The following script should return: `EER 1.0180`. 60 | 61 | ``` 62 | python ./trainSpeakerNet.py --eval --model ResNetSE34V2 --log_input True --encoder_type ASP --n_mels 64 --trainfunc softmaxproto --save_path exps/test --eval_frames 400 --initial_model baseline_v2_smproto.model 63 | ``` 64 | 65 | Pretrained RawNet3, described in [3], can be downloaded via `git submodule update --init --recursive`. 66 | 67 | The following script should return `EER 0.8932`. 68 | 69 | ``` 70 | python ./trainSpeakerNet.py --eval --config ./configs/RawNet3_AAM.yaml --initial_model models/weights/RawNet3/model.pt 71 | ``` 72 | 73 | 74 | 75 | ### Implemented loss functions 76 | ``` 77 | Softmax (softmax) 78 | AM-Softmax (amsoftmax) 79 | AAM-Softmax (aamsoftmax) 80 | GE2E (ge2e) 81 | Prototypical (proto) 82 | Triplet (triplet) 83 | Angular Prototypical (angleproto) 84 | ``` 85 | 86 | ### Implemented models and encoders 87 | ``` 88 | ResNetSE34L (SAP, ASP) 89 | ResNetSE34V2 (SAP, ASP) 90 | VGGVox40 (SAP, TAP, MAX) 91 | ``` 92 | 93 | ### Data augmentation 94 | 95 | `--augment True` enables online data augmentation, described in [2]. 96 | 97 | ### Adding new models and loss functions 98 | 99 | You can add new models and loss functions to `models` and `loss` directories respectively. See the existing definitions for examples. 100 | 101 | ### Accelerating training 102 | 103 | - Use `--mixedprec` flag to enable mixed precision training. This is recommended for Tesla V100, GeForce RTX 20 series or later models. 104 | 105 | - Use `--distributed` flag to enable distributed training. 106 | 107 | - GPU indices should be set before training using the command `export CUDA_VISIBLE_DEVICES=0,1,2,3`. 108 | 109 | - If you are running more than one distributed training session, you need to change the `--port` argument. 110 | 111 | ### Data 112 | 113 | The [VoxCeleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/) datasets are used for these experiments. 114 | 115 | The train list should contain the identity and the file path, one line per utterance, as follows: 116 | ``` 117 | id00000 id00000/youtube_key/12345.wav 118 | id00012 id00012/21Uxsk56VDQ/00001.wav 119 | ``` 120 | 121 | The train list for VoxCeleb2 can be download from [here](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/train_list.txt). The 122 | test lists for VoxCeleb1 can be downloaded from [here](https://mm.kaist.ac.kr/datasets/voxceleb/index.html#testlist). 123 | 124 | ### Replicating the results from the paper 125 | 126 | 1. Model definitions 127 | - `VGG-M-40` in [1] is `VGGVox` in the repository. 128 | - `Thin ResNet-34` in [1] is `ResNetSE34` in the repository. 129 | - `Fast ResNet-34` in [1] is `ResNetSE34L` in the repository. 130 | - `H / ASP` in [2] is `ResNetSE34V2` in the repository. 131 | 132 | 2. For metric learning objectives, the batch size in the paper is `nPerSpeaker` multiplied by `batch_size` in the code. For the batch size of 800 in the paper, use `--nPerSpeaker 2 --batch_size 400`, `--nPerSpeaker 3 --batch_size 266`, etc. 133 | 134 | 3. The models have been trained with `--max_frames 200` and evaluated with `--max_frames 400`. 135 | 136 | 4. You can get a good balance between speed and performance using the configuration below. 137 | 138 | ``` 139 | python ./trainSpeakerNet.py --model ResNetSE34L --trainfunc angleproto --batch_size 400 --nPerSpeaker 2 140 | ``` 141 | 142 | ### Citation 143 | 144 | Please cite [1] if you make use of the code. Please see [here](References.md) for the full list of methods used in this trainer. 145 | 146 | [1] _In defence of metric learning for speaker recognition_ 147 | ``` 148 | @inproceedings{chung2020in, 149 | title={In defence of metric learning for speaker recognition}, 150 | author={Chung, Joon Son and Huh, Jaesung and Mun, Seongkyu and Lee, Minjae and Heo, Hee Soo and Choe, Soyeon and Ham, Chiheon and Jung, Sunghwan and Lee, Bong-Jin and Han, Icksang}, 151 | booktitle={Proc. Interspeech}, 152 | year={2020} 153 | } 154 | ``` 155 | 156 | [2] _The ins and outs of speaker recognition: lessons from VoxSRC 2020_ 157 | ``` 158 | @inproceedings{kwon2021ins, 159 | title={The ins and outs of speaker recognition: lessons from {VoxSRC} 2020}, 160 | author={Kwon, Yoohwan and Heo, Hee Soo and Lee, Bong-Jin and Chung, Joon Son}, 161 | booktitle={Proc. ICASSP}, 162 | year={2021} 163 | } 164 | ``` 165 | 166 | [3] _Pushing the limits of raw waveform speaker recognition_ 167 | ``` 168 | @inproceedings{jung2022pushing, 169 | title={Pushing the limits of raw waveform speaker recognition}, 170 | author={Jung, Jee-weon and Kim, You Jin and Heo, Hee-Soo and Lee, Bong-Jin and Kwon, Youngki and Chung, Joon Son}, 171 | booktitle={Proc. Interspeech}, 172 | year={2022} 173 | } 174 | ``` 175 | 176 | ### License 177 | ``` 178 | Copyright (c) 2020-present NAVER Corp. 179 | 180 | Permission is hereby granted, free of charge, to any person obtaining a copy 181 | of this software and associated documentation files (the "Software"), to deal 182 | in the Software without restriction, including without limitation the rights 183 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 184 | copies of the Software, and to permit persons to whom the Software is 185 | furnished to do so, subject to the following conditions: 186 | 187 | The above copyright notice and this permission notice shall be included in 188 | all copies or substantial portions of the Software. 189 | 190 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 191 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 192 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 193 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 194 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 195 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 196 | THE SOFTWARE. 197 | ``` 198 | -------------------------------------------------------------------------------- /References.md: -------------------------------------------------------------------------------- 1 | ## References 2 | 3 | Please cite the following if you make use of the code. 4 | 5 | ``` 6 | @inproceedings{chung2020in, 7 | title={In defence of metric learning for speaker recognition}, 8 | author={Chung, Joon Son and Huh, Jaesung and Mun, Seongkyu and Lee, Minjae and Heo, Hee Soo and Choe, Soyeon and Ham, Chiheon and Jung, Sunghwan and Lee, Bong-Jin and Han, Icksang}, 9 | booktitle={Interspeech}, 10 | year={2020} 11 | } 12 | ``` 13 | 14 | This trainer uses many models and loss functions that have been proposed in previous works. The suggested citations are as follows: 15 | 16 | ### Models 17 | 18 | #### VGGVox 19 | ``` 20 | @inproceedings{nagrani2017voxceleb, 21 | title={VoxCeleb: A Large-Scale Speaker Identification Dataset}, 22 | author={Nagrani, Arsha and Chung, Joon Son and Zisserman, Andrew}, 23 | booktitle={Interspeech}, 24 | pages={2616--2620}, 25 | year={2017} 26 | } 27 | ``` 28 | 29 | #### ResNet 30 | ``` 31 | @inproceedings{he2016deep, 32 | title={Deep residual learning for image recognition}, 33 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 34 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 35 | pages={770--778}, 36 | year={2016} 37 | } 38 | ``` 39 | 40 | 41 | 42 | ### Aggregation 43 | 44 | #### SAP 45 | ``` 46 | @inproceedings{bhattacharya2017deep, 47 | title={Deep Speaker Embeddings for Short-Duration Speaker Verification}, 48 | author={Bhattacharya, Gautam and Alam, Md Jahangir and Kenny, Patrick}, 49 | booktitle={Interspeech}, 50 | pages={1517--1521}, 51 | year={2017} 52 | } 53 | ``` 54 | 55 | #### ASP 56 | ``` 57 | @inproceedings{okabe2018attentive, 58 | title={Attentive Statistics Pooling for Deep Speaker Embedding}, 59 | author={Okabe, Koji and Koshinaka, Takafumi and Shinoda, Koichi}, 60 | booktitle={Interspeech}, 61 | pages={2252--2256}, 62 | year={2018} 63 | } 64 | ``` 65 | 66 | 67 | ### Loss functions 68 | 69 | #### Prototypical Networks 70 | ``` 71 | @inproceedings{snell2017prototypical, 72 | title={Prototypical networks for few-shot learning}, 73 | author={Snell, Jake and Swersky, Kevin and Zemel, Richard}, 74 | booktitle={Advances in Neural Information Processing Systems}, 75 | pages={4077--4087}, 76 | year={2017} 77 | } 78 | ``` 79 | 80 | #### GE2E 81 | ``` 82 | @inproceedings{wan2018generalized, 83 | title={Generalized end-to-end loss for speaker verification}, 84 | author={Wan, Li and Wang, Quan and Papir, Alan and Moreno, Ignacio Lopez}, 85 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing}, 86 | pages={4879--4883}, 87 | year={2018} 88 | } 89 | ``` 90 | 91 | 92 | #### Triplet loss 93 | ``` 94 | @inproceedings{schroff2015facenet, 95 | title={Facenet: A unified embedding for face recognition and clustering}, 96 | author={Schroff, Florian and Kalenichenko, Dmitry and Philbin, James}, 97 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 98 | pages={815--823}, 99 | year={2015} 100 | } 101 | ``` 102 | 103 | #### AM-Softmax 104 | ``` 105 | @inproceedings{wang2018cosface, 106 | title={Cosface: Large margin cosine loss for deep face recognition}, 107 | author={Wang, Hao and Wang, Yitong and Zhou, Zheng and Ji, Xing and Gong, Dihong and Zhou, Jingchao and Li, Zhifeng and Liu, Wei}, 108 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 109 | pages={5265--5274}, 110 | year={2018} 111 | } 112 | ``` 113 | 114 | #### AAM-Softmax 115 | ``` 116 | @inproceedings{deng2019arcface, 117 | title={Arcface: Additive angular margin loss for deep face recognition}, 118 | author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos}, 119 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 120 | pages={4690--4699}, 121 | year={2019} 122 | } 123 | ``` 124 | 125 | ### Data augmentation 126 | 127 | #### MUSAN database 128 | ``` 129 | @article{snyder2015musan, 130 | title={Musan: A music, speech, and noise corpus}, 131 | author={Snyder, David and Chen, Guoguo and Povey, Daniel}, 132 | journal={arXiv preprint arXiv:1510.08484}, 133 | year={2015} 134 | } 135 | ``` 136 | 137 | #### Room Impulse Response database 138 | ``` 139 | @inproceedings{ko2017study, 140 | title={A study on data augmentation of reverberant speech for robust speech recognition}, 141 | author={Ko, Tom and Peddinti, Vijayaditya and Povey, Daniel and Seltzer, Michael L and Khudanpur, Sanjeev}, 142 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing}, 143 | pages={5220--5224}, 144 | year={2017} 145 | } 146 | ``` 147 | 148 | -------------------------------------------------------------------------------- /SpeakerNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy, sys, random 8 | import time, itertools, importlib 9 | 10 | from DatasetLoader import test_dataset_loader 11 | from torch.cuda.amp import autocast, GradScaler 12 | 13 | 14 | class WrappedModel(nn.Module): 15 | 16 | ## The purpose of this wrapper is to make the model structure consistent between single and multi-GPU 17 | 18 | def __init__(self, model): 19 | super(WrappedModel, self).__init__() 20 | self.module = model 21 | 22 | def forward(self, x, label=None): 23 | return self.module(x, label) 24 | 25 | 26 | class SpeakerNet(nn.Module): 27 | def __init__(self, model, optimizer, trainfunc, nPerSpeaker, **kwargs): 28 | super(SpeakerNet, self).__init__() 29 | 30 | SpeakerNetModel = importlib.import_module("models." + model).__getattribute__("MainModel") 31 | self.__S__ = SpeakerNetModel(**kwargs) 32 | 33 | LossFunction = importlib.import_module("loss." + trainfunc).__getattribute__("LossFunction") 34 | self.__L__ = LossFunction(**kwargs) 35 | 36 | self.nPerSpeaker = nPerSpeaker 37 | 38 | def forward(self, data, label=None): 39 | 40 | data = data.reshape(-1, data.size()[-1]).cuda() 41 | outp = self.__S__.forward(data) 42 | 43 | if label == None: 44 | return outp 45 | 46 | else: 47 | 48 | outp = outp.reshape(self.nPerSpeaker, -1, outp.size()[-1]).transpose(1, 0).squeeze(1) 49 | 50 | nloss, prec1 = self.__L__.forward(outp, label) 51 | 52 | return nloss, prec1 53 | 54 | 55 | class ModelTrainer(object): 56 | def __init__(self, speaker_model, optimizer, scheduler, gpu, mixedprec, **kwargs): 57 | 58 | self.__model__ = speaker_model 59 | 60 | Optimizer = importlib.import_module("optimizer." + optimizer).__getattribute__("Optimizer") 61 | self.__optimizer__ = Optimizer(self.__model__.parameters(), **kwargs) 62 | 63 | Scheduler = importlib.import_module("scheduler." + scheduler).__getattribute__("Scheduler") 64 | self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, **kwargs) 65 | 66 | self.scaler = GradScaler() 67 | 68 | self.gpu = gpu 69 | 70 | self.mixedprec = mixedprec 71 | 72 | assert self.lr_step in ["epoch", "iteration"] 73 | 74 | # ## ===== ===== ===== ===== ===== ===== ===== ===== 75 | # ## Train network 76 | # ## ===== ===== ===== ===== ===== ===== ===== ===== 77 | 78 | def train_network(self, loader, verbose): 79 | 80 | self.__model__.train() 81 | 82 | stepsize = loader.batch_size 83 | 84 | counter = 0 85 | index = 0 86 | loss = 0 87 | top1 = 0 88 | # EER or accuracy 89 | 90 | tstart = time.time() 91 | 92 | for data, data_label in loader: 93 | 94 | data = data.transpose(1, 0) 95 | 96 | self.__model__.zero_grad() 97 | 98 | label = torch.LongTensor(data_label).cuda() 99 | 100 | if self.mixedprec: 101 | with autocast(): 102 | nloss, prec1 = self.__model__(data, label) 103 | self.scaler.scale(nloss).backward() 104 | self.scaler.step(self.__optimizer__) 105 | self.scaler.update() 106 | else: 107 | nloss, prec1 = self.__model__(data, label) 108 | nloss.backward() 109 | self.__optimizer__.step() 110 | 111 | loss += nloss.detach().cpu().item() 112 | top1 += prec1.detach().cpu().item() 113 | counter += 1 114 | index += stepsize 115 | 116 | telapsed = time.time() - tstart 117 | tstart = time.time() 118 | 119 | if verbose: 120 | sys.stdout.write("\rProcessing {:d} of {:d}:".format(index, loader.__len__() * loader.batch_size)) 121 | sys.stdout.write("Loss {:f} TEER/TAcc {:2.3f}% - {:.2f} Hz ".format(loss / counter, top1 / counter, stepsize / telapsed)) 122 | sys.stdout.flush() 123 | 124 | if self.lr_step == "iteration": 125 | self.__scheduler__.step() 126 | 127 | if self.lr_step == "epoch": 128 | self.__scheduler__.step() 129 | 130 | return (loss / counter, top1 / counter) 131 | 132 | ## ===== ===== ===== ===== ===== ===== ===== ===== 133 | ## Evaluate from list 134 | ## ===== ===== ===== ===== ===== ===== ===== ===== 135 | 136 | def evaluateFromList(self, test_list, test_path, nDataLoaderThread, distributed, print_interval=100, num_eval=10, **kwargs): 137 | 138 | if distributed: 139 | rank = torch.distributed.get_rank() 140 | else: 141 | rank = 0 142 | 143 | self.__model__.eval() 144 | 145 | lines = [] 146 | files = [] 147 | feats = {} 148 | tstart = time.time() 149 | 150 | ## Read all lines 151 | with open(test_list) as f: 152 | lines = f.readlines() 153 | 154 | ## Get a list of unique file names 155 | files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines])) 156 | setfiles = list(set(files)) 157 | setfiles.sort() 158 | 159 | ## Define test data loader 160 | test_dataset = test_dataset_loader(setfiles, test_path, num_eval=num_eval, **kwargs) 161 | 162 | if distributed: 163 | sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False) 164 | else: 165 | sampler = None 166 | 167 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=nDataLoaderThread, drop_last=False, sampler=sampler) 168 | 169 | ## Extract features for every image 170 | for idx, data in enumerate(test_loader): 171 | inp1 = data[0][0].cuda() 172 | with torch.no_grad(): 173 | ref_feat = self.__model__(inp1).detach().cpu() 174 | feats[data[1][0]] = ref_feat 175 | telapsed = time.time() - tstart 176 | 177 | if idx % print_interval == 0 and rank == 0: 178 | sys.stdout.write( 179 | "\rReading {:d} of {:d}: {:.2f} Hz, embedding size {:d}".format(idx, test_loader.__len__(), idx / telapsed, ref_feat.size()[1]) 180 | ) 181 | 182 | all_scores = [] 183 | all_labels = [] 184 | all_trials = [] 185 | 186 | if distributed: 187 | ## Gather features from all GPUs 188 | feats_all = [None for _ in range(0, torch.distributed.get_world_size())] 189 | torch.distributed.all_gather_object(feats_all, feats) 190 | 191 | if rank == 0: 192 | 193 | tstart = time.time() 194 | print("") 195 | 196 | ## Combine gathered features 197 | if distributed: 198 | feats = feats_all[0] 199 | for feats_batch in feats_all[1:]: 200 | feats.update(feats_batch) 201 | 202 | ## Read files and compute all scores 203 | for idx, line in enumerate(lines): 204 | 205 | data = line.split() 206 | 207 | ## Append random label if missing 208 | if len(data) == 2: 209 | data = [random.randint(0, 1)] + data 210 | 211 | ref_feat = feats[data[1]].cuda() 212 | com_feat = feats[data[2]].cuda() 213 | 214 | if self.__model__.module.__L__.test_normalize: 215 | ref_feat = F.normalize(ref_feat, p=2, dim=1) 216 | com_feat = F.normalize(com_feat, p=2, dim=1) 217 | 218 | dist = torch.cdist(ref_feat.reshape(num_eval, -1), com_feat.reshape(num_eval, -1)).detach().cpu().numpy() 219 | 220 | score = -1 * numpy.mean(dist) 221 | 222 | all_scores.append(score) 223 | all_labels.append(int(data[0])) 224 | all_trials.append(data[1] + " " + data[2]) 225 | 226 | if idx % print_interval == 0: 227 | telapsed = time.time() - tstart 228 | sys.stdout.write("\rComputing {:d} of {:d}: {:.2f} Hz".format(idx, len(lines), idx / telapsed)) 229 | sys.stdout.flush() 230 | 231 | return (all_scores, all_labels, all_trials) 232 | 233 | ## ===== ===== ===== ===== ===== ===== ===== ===== 234 | ## Save parameters 235 | ## ===== ===== ===== ===== ===== ===== ===== ===== 236 | 237 | def saveParameters(self, path): 238 | 239 | torch.save(self.__model__.module.state_dict(), path) 240 | 241 | ## ===== ===== ===== ===== ===== ===== ===== ===== 242 | ## Load parameters 243 | ## ===== ===== ===== ===== ===== ===== ===== ===== 244 | 245 | def loadParameters(self, path): 246 | 247 | self_state = self.__model__.module.state_dict() 248 | loaded_state = torch.load(path, map_location="cuda:%d" % self.gpu) 249 | if len(loaded_state.keys()) == 1 and "model" in loaded_state: 250 | loaded_state = loaded_state["model"] 251 | newdict = {} 252 | delete_list = [] 253 | for name, param in loaded_state.items(): 254 | new_name = "__S__."+name 255 | newdict[new_name] = param 256 | delete_list.append(name) 257 | loaded_state.update(newdict) 258 | for name in delete_list: 259 | del loaded_state[name] 260 | for name, param in loaded_state.items(): 261 | origname = name 262 | if name not in self_state: 263 | name = name.replace("module.", "") 264 | 265 | if name not in self_state: 266 | print("{} is not in the model.".format(origname)) 267 | continue 268 | 269 | if self_state[name].size() != loaded_state[origname].size(): 270 | print("Wrong parameter length: {}, model: {}, loaded: {}".format(origname, self_state[name].size(), loaded_state[origname].size())) 271 | continue 272 | 273 | self_state[name].copy_(param) 274 | -------------------------------------------------------------------------------- /configs/RawNet3_AAM.yaml: -------------------------------------------------------------------------------- 1 | model: RawNet3 2 | encoder_type: ECA 3 | nOut: 256 4 | sinc_stride: 10 5 | max_frames: 300 6 | batch_size: 32 7 | eval_frames: 400 8 | trainfunc: aamsoftmax 9 | lr_decay: 0.75 10 | weight_decay: 5e-05 11 | max_seg_per_spk: 500 12 | augment: True 13 | save_path: exps/RawNet3_AAM -------------------------------------------------------------------------------- /configs/ResNetSE34L_AM.yaml: -------------------------------------------------------------------------------- 1 | model: ResNetSE34L 2 | log_input: True 3 | encoder_type: SAP 4 | trainfunc: amsoftmax 5 | save_path: exps/ResNetSE34L_AM 6 | nLcasses: 5994 7 | batch_size: 200 8 | scale: 30 9 | margin: 0.3 -------------------------------------------------------------------------------- /configs/ResNetSE34L_AP.yaml: -------------------------------------------------------------------------------- 1 | model: ResNetSE34L 2 | log_input: True 3 | encoder_type: SAP 4 | trainfunc: angleproto 5 | save_path: exps/ResNetSE34L_AP 6 | nPerSpeaker: 2 7 | batch_size: 200 -------------------------------------------------------------------------------- /dataprep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | # The script downloads the VoxCeleb datasets and converts all files to WAV. 4 | # Requirement: ffmpeg and wget running on a Linux system. 5 | 6 | import argparse 7 | import os 8 | import subprocess 9 | import pdb 10 | import hashlib 11 | import time 12 | import glob 13 | import tarfile 14 | from zipfile import ZipFile 15 | from tqdm import tqdm 16 | from scipy.io import wavfile 17 | 18 | ## ========== =========== 19 | ## Parse input arguments 20 | ## ========== =========== 21 | parser = argparse.ArgumentParser(description = "VoxCeleb downloader"); 22 | 23 | parser.add_argument('--save_path', type=str, default="data", help='Target directory'); 24 | parser.add_argument('--user', type=str, default="user", help='Username'); 25 | parser.add_argument('--password', type=str, default="pass", help='Password'); 26 | 27 | parser.add_argument('--download', dest='download', action='store_true', help='Enable download') 28 | parser.add_argument('--extract', dest='extract', action='store_true', help='Enable extract') 29 | parser.add_argument('--convert', dest='convert', action='store_true', help='Enable convert') 30 | parser.add_argument('--augment', dest='augment', action='store_true', help='Download and extract augmentation files') 31 | 32 | args = parser.parse_args(); 33 | 34 | ## ========== =========== 35 | ## MD5SUM 36 | ## ========== =========== 37 | def md5(fname): 38 | 39 | hash_md5 = hashlib.md5() 40 | with open(fname, "rb") as f: 41 | for chunk in iter(lambda: f.read(4096), b""): 42 | hash_md5.update(chunk) 43 | return hash_md5.hexdigest() 44 | 45 | ## ========== =========== 46 | ## Download with wget 47 | ## ========== =========== 48 | def download(args, lines): 49 | 50 | for line in lines: 51 | url = line.split()[0] 52 | md5gt = line.split()[1] 53 | outfile = url.split('/')[-1] 54 | 55 | ## Download files 56 | out = subprocess.call('wget %s --user %s --password %s -O %s/%s'%(url,args.user,args.password,args.save_path,outfile), shell=True) 57 | if out != 0: 58 | raise ValueError('Download failed %s. If download fails repeatedly, use alternate URL on the VoxCeleb website.'%url) 59 | 60 | ## Check MD5 61 | md5ck = md5('%s/%s'%(args.save_path,outfile)) 62 | if md5ck == md5gt: 63 | print('Checksum successful %s.'%outfile) 64 | else: 65 | raise Warning('Checksum failed %s.'%outfile) 66 | 67 | ## ========== =========== 68 | ## Concatenate file parts 69 | ## ========== =========== 70 | def concatenate(args,lines): 71 | 72 | for line in lines: 73 | infile = line.split()[0] 74 | outfile = line.split()[1] 75 | md5gt = line.split()[2] 76 | 77 | ## Concatenate files 78 | out = subprocess.call('cat %s/%s > %s/%s' %(args.save_path,infile,args.save_path,outfile), shell=True) 79 | 80 | ## Check MD5 81 | md5ck = md5('%s/%s'%(args.save_path,outfile)) 82 | if md5ck == md5gt: 83 | print('Checksum successful %s.'%outfile) 84 | else: 85 | raise Warning('Checksum failed %s.'%outfile) 86 | 87 | out = subprocess.call('rm %s/%s' %(args.save_path,infile), shell=True) 88 | 89 | ## ========== =========== 90 | ## Extract zip files 91 | ## ========== =========== 92 | def is_within_directory(directory, target): 93 | 94 | abs_directory = os.path.abspath(directory) 95 | abs_target = os.path.abspath(target) 96 | 97 | prefix = os.path.commonprefix([abs_directory, abs_target]) 98 | 99 | return prefix == abs_directory 100 | 101 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 102 | 103 | for member in tar.getmembers(): 104 | member_path = os.path.join(path, member.name) 105 | if not is_within_directory(path, member_path): 106 | raise Exception("Attempted Path Traversal in Tar File") 107 | tar.extractall(path, members, numeric_owner=numeric_owner) 108 | 109 | def full_extract(args, fname): 110 | 111 | print('Extracting %s'%fname) 112 | if fname.endswith(".tar.gz"): 113 | with tarfile.open(fname, "r:gz") as tar: 114 | safe_extract(tar, args.save_path) 115 | elif fname.endswith(".zip"): 116 | with ZipFile(fname, 'r') as zf: 117 | zf.extractall(args.save_path) 118 | 119 | 120 | ## ========== =========== 121 | ## Partially extract zip files 122 | ## ========== =========== 123 | def part_extract(args, fname, target): 124 | 125 | print('Extracting %s'%fname) 126 | with ZipFile(fname, 'r') as zf: 127 | for infile in zf.namelist(): 128 | if any([infile.startswith(x) for x in target]): 129 | zf.extract(infile,args.save_path) 130 | # pdb.set_trace() 131 | # zf.extractall(args.save_path) 132 | 133 | ## ========== =========== 134 | ## Convert 135 | ## ========== =========== 136 | def convert(args): 137 | 138 | files = glob.glob('%s/voxceleb2/*/*/*.m4a'%args.save_path) 139 | files.sort() 140 | 141 | print('Converting files from AAC to WAV') 142 | for fname in tqdm(files): 143 | outfile = fname.replace('.m4a','.wav') 144 | out = subprocess.call('ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s >/dev/null 2>/dev/null' %(fname,outfile), shell=True) 145 | if out != 0: 146 | raise ValueError('Conversion failed %s.'%fname) 147 | 148 | ## ========== =========== 149 | ## Split MUSAN for faster random access 150 | ## ========== =========== 151 | def split_musan(args): 152 | 153 | files = glob.glob('%s/musan/*/*/*.wav'%args.save_path) 154 | 155 | audlen = 16000*5 156 | audstr = 16000*3 157 | 158 | for idx,file in enumerate(files): 159 | fs,aud = wavfile.read(file) 160 | writedir = os.path.splitext(file.replace('/musan/','/musan_split/'))[0] 161 | os.makedirs(writedir) 162 | for st in range(0,len(aud)-audlen,audstr): 163 | wavfile.write(writedir+'/%05d.wav'%(st/fs),fs,aud[st:st+audlen]) 164 | 165 | print(idx,file) 166 | 167 | ## ========== =========== 168 | ## Main script 169 | ## ========== =========== 170 | if __name__ == "__main__": 171 | 172 | if not os.path.exists(args.save_path): 173 | raise ValueError('Target directory does not exist.') 174 | 175 | f = open('lists/fileparts.txt','r') 176 | fileparts = f.readlines() 177 | f.close() 178 | 179 | f = open('lists/files.txt','r') 180 | files = f.readlines() 181 | f.close() 182 | 183 | f = open('lists/augment.txt','r') 184 | augfiles = f.readlines() 185 | f.close() 186 | 187 | if args.augment: 188 | download(args,augfiles) 189 | part_extract(args,os.path.join(args.save_path,'rirs_noises.zip'),['RIRS_NOISES/simulated_rirs/mediumroom','RIRS_NOISES/simulated_rirs/smallroom']) 190 | full_extract(args,os.path.join(args.save_path,'musan.tar.gz')) 191 | split_musan(args) 192 | 193 | if args.download: 194 | download(args,fileparts) 195 | 196 | if args.extract: 197 | concatenate(args, files) 198 | for file in files: 199 | full_extract(args,os.path.join(args.save_path,file.split()[1])) 200 | out = subprocess.call('mv %s/dev/aac/* %s/aac/ && rm -r %s/dev' %(args.save_path,args.save_path,args.save_path), shell=True) 201 | out = subprocess.call('mv %s/wav %s/voxceleb1' %(args.save_path,args.save_path), shell=True) 202 | out = subprocess.call('mv %s/aac %s/voxceleb2' %(args.save_path,args.save_path), shell=True) 203 | 204 | if args.convert: 205 | convert(args) 206 | 207 | -------------------------------------------------------------------------------- /lists/augment.txt: -------------------------------------------------------------------------------- 1 | http://www.openslr.org/resources/28/rirs_noises.zip e6f48e257286e05de56413b4779d8ffb 2 | http://www.openslr.org/resources/17/musan.tar.gz 0c472d4fc0c5141eca47ad1ffeb2a7df -------------------------------------------------------------------------------- /lists/fileparts.txt: -------------------------------------------------------------------------------- 1 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partaa e395d020928bc15670b570a21695ed96 2 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partab bbfaaccefab65d82b21903e81a8a8020 3 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partac 017d579a2a96a077f40042ec33e51512 4 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_dev_wav_partad 7bb1e9f70fddc7a678fa998ea8b3ba19 5 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaa da070494c573e5c0564b1d11c3b20577 6 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partab 17fe6dab2b32b48abaf1676429cdd06f 7 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partac 1de58e086c5edf63625af1cb6d831528 8 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partad 5a043eb03e15c5a918ee6a52aad477f9 9 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partae cea401b624983e2d0b2a87fb5d59aa60 10 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partaf fc886d9ba90ab88e7880ee98effd6ae9 11 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partag d160ecc3f6ee3eed54d55349531cb42e 12 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_dev_aac_partah 6b84a81b9af72a9d9eecbb3b1f602e65 13 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox1_test_wav.zip 185fdc63c3c739954633d50379a3d102 14 | http://cnode01.mm.kaist.ac.kr/voxceleb/vox1a/vox2_test_aac.zip 0d2b3ea430a821c33263b5ea37ede312 15 | -------------------------------------------------------------------------------- /lists/files.txt: -------------------------------------------------------------------------------- 1 | vox1_dev_wav_parta* vox1_dev_wav.zip ae63e55b951748cc486645f532ba230b 2 | vox2_dev_aac_parta* vox2_dev_aac.zip bbc063c46078a602ca71605645c2a402 3 | -------------------------------------------------------------------------------- /loss/aamsoftmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | # Adapted from https://github.com/wujiyang/Face_Pytorch (Apache License) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import time, pdb, numpy, math 9 | from utils import accuracy 10 | 11 | class LossFunction(nn.Module): 12 | def __init__(self, nOut, nClasses, margin=0.3, scale=15, easy_margin=False, **kwargs): 13 | super(LossFunction, self).__init__() 14 | 15 | self.test_normalize = True 16 | 17 | self.m = margin 18 | self.s = scale 19 | self.in_feats = nOut 20 | self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True) 21 | self.ce = nn.CrossEntropyLoss() 22 | nn.init.xavier_normal_(self.weight, gain=1) 23 | 24 | self.easy_margin = easy_margin 25 | self.cos_m = math.cos(self.m) 26 | self.sin_m = math.sin(self.m) 27 | 28 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 29 | self.th = math.cos(math.pi - self.m) 30 | self.mm = math.sin(math.pi - self.m) * self.m 31 | 32 | print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s)) 33 | 34 | def forward(self, x, label=None): 35 | 36 | assert x.size()[0] == label.size()[0] 37 | assert x.size()[1] == self.in_feats 38 | 39 | # cos(theta) 40 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 41 | # cos(theta + m) 42 | sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1)) 43 | phi = cosine * self.cos_m - sine * self.sin_m 44 | 45 | if self.easy_margin: 46 | phi = torch.where(cosine > 0, phi, cosine) 47 | else: 48 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 49 | 50 | #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') 51 | one_hot = torch.zeros_like(cosine) 52 | one_hot.scatter_(1, label.view(-1, 1), 1) 53 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 54 | output = output * self.s 55 | 56 | loss = self.ce(output, label) 57 | prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] 58 | return loss, prec1 -------------------------------------------------------------------------------- /loss/amsoftmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | # Adapted from https://github.com/CoinCheung/pytorch-loss (MIT License) 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import time, pdb, numpy 9 | from utils import accuracy 10 | 11 | class LossFunction(nn.Module): 12 | def __init__(self, nOut, nClasses, margin=0.3, scale=15, **kwargs): 13 | super(LossFunction, self).__init__() 14 | 15 | self.test_normalize = True 16 | 17 | self.m = margin 18 | self.s = scale 19 | self.in_feats = nOut 20 | self.W = torch.nn.Parameter(torch.randn(nOut, nClasses), requires_grad=True) 21 | self.ce = nn.CrossEntropyLoss() 22 | nn.init.xavier_normal_(self.W, gain=1) 23 | 24 | print('Initialised AMSoftmax m=%.3f s=%.3f'%(self.m,self.s)) 25 | 26 | def forward(self, x, label=None): 27 | 28 | assert x.size()[0] == label.size()[0] 29 | assert x.size()[1] == self.in_feats 30 | 31 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 32 | x_norm = torch.div(x, x_norm) 33 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 34 | w_norm = torch.div(self.W, w_norm) 35 | costh = torch.mm(x_norm, w_norm) 36 | label_view = label.view(-1, 1) 37 | if label_view.is_cuda: label_view = label_view.cpu() 38 | delt_costh = torch.zeros(costh.size()).scatter_(1, label_view, self.m) 39 | if x.is_cuda: delt_costh = delt_costh.cuda() 40 | costh_m = costh - delt_costh 41 | costh_m_s = self.s * costh_m 42 | loss = self.ce(costh_m_s, label) 43 | prec1 = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0] 44 | return loss, prec1 45 | 46 | -------------------------------------------------------------------------------- /loss/angleproto.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import time, pdb, numpy 8 | from utils import accuracy 9 | 10 | class LossFunction(nn.Module): 11 | 12 | def __init__(self, init_w=10.0, init_b=-5.0, **kwargs): 13 | super(LossFunction, self).__init__() 14 | 15 | self.test_normalize = True 16 | 17 | self.w = nn.Parameter(torch.tensor(init_w)) 18 | self.b = nn.Parameter(torch.tensor(init_b)) 19 | self.criterion = torch.nn.CrossEntropyLoss() 20 | 21 | print('Initialised AngleProto') 22 | 23 | def forward(self, x, label=None): 24 | 25 | assert x.size()[1] >= 2 26 | 27 | out_anchor = torch.mean(x[:,1:,:],1) 28 | out_positive = x[:,0,:] 29 | stepsize = out_anchor.size()[0] 30 | 31 | cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2)) 32 | torch.clamp(self.w, 1e-6) 33 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 34 | 35 | label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() 36 | nloss = self.criterion(cos_sim_matrix, label) 37 | prec1 = accuracy(cos_sim_matrix.detach(), label.detach(), topk=(1,))[0] 38 | 39 | return nloss, prec1 -------------------------------------------------------------------------------- /loss/ge2e.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ## Fast re-implementation of the GE2E loss (https://arxiv.org/abs/1710.10467) 4 | ## Numerically checked against https://github.com/cvqluu/GE2E-Loss 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import time, pdb, numpy 10 | from utils import accuracy 11 | 12 | class LossFunction(nn.Module): 13 | 14 | def __init__(self, init_w=10.0, init_b=-5.0, **kwargs): 15 | super(LossFunction, self).__init__() 16 | 17 | self.test_normalize = True 18 | 19 | self.w = nn.Parameter(torch.tensor(init_w)) 20 | self.b = nn.Parameter(torch.tensor(init_b)) 21 | self.criterion = torch.nn.CrossEntropyLoss() 22 | 23 | print('Initialised GE2E') 24 | 25 | def forward(self, x, label=None): 26 | 27 | assert x.size()[1] >= 2 28 | 29 | gsize = x.size()[1] 30 | centroids = torch.mean(x, 1) 31 | stepsize = x.size()[0] 32 | 33 | cos_sim_matrix = [] 34 | 35 | for ii in range(0,gsize): 36 | idx = [*range(0,gsize)] 37 | idx.remove(ii) 38 | exc_centroids = torch.mean(x[:,idx,:], 1) 39 | cos_sim_diag = F.cosine_similarity(x[:,ii,:],exc_centroids) 40 | cos_sim = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2)) 41 | cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag 42 | cos_sim_matrix.append(torch.clamp(cos_sim,1e-6)) 43 | 44 | cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1) 45 | 46 | torch.clamp(self.w, 1e-6) 47 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 48 | 49 | label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() 50 | nloss = self.criterion(cos_sim_matrix.view(-1,stepsize), torch.repeat_interleave(label,repeats=gsize,dim=0).cuda()) 51 | prec1 = accuracy(cos_sim_matrix.view(-1,stepsize).detach(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach(), topk=(1,))[0] 52 | 53 | return nloss, prec1 -------------------------------------------------------------------------------- /loss/proto.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ## Re-implementation of prototypical networks (https://arxiv.org/abs/1703.05175). 4 | ## Numerically checked against https://github.com/cyvius96/prototypical-network-pytorch 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import time, pdb, numpy 10 | from utils import accuracy 11 | 12 | class LossFunction(nn.Module): 13 | 14 | def __init__(self, **kwargs): 15 | super(LossFunction, self).__init__() 16 | 17 | self.test_normalize = False 18 | 19 | self.criterion = torch.nn.CrossEntropyLoss() 20 | 21 | print('Initialised Prototypical Loss') 22 | 23 | def forward(self, x, label=None): 24 | 25 | assert x.size()[1] >= 2 26 | 27 | out_anchor = torch.mean(x[:,1:,:],1) 28 | out_positive = x[:,0,:] 29 | stepsize = out_anchor.size()[0] 30 | 31 | output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))**2) 32 | label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() 33 | nloss = self.criterion(output, label) 34 | prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] 35 | 36 | return nloss, prec1 -------------------------------------------------------------------------------- /loss/softmax.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import time, pdb, numpy 8 | from utils import accuracy 9 | 10 | class LossFunction(nn.Module): 11 | def __init__(self, nOut, nClasses, **kwargs): 12 | super(LossFunction, self).__init__() 13 | 14 | self.test_normalize = True 15 | 16 | self.criterion = torch.nn.CrossEntropyLoss() 17 | self.fc = nn.Linear(nOut,nClasses) 18 | 19 | print('Initialised Softmax Loss') 20 | 21 | def forward(self, x, label=None): 22 | 23 | x = self.fc(x) 24 | nloss = self.criterion(x, label) 25 | prec1 = accuracy(x.detach(), label.detach(), topk=(1,))[0] 26 | 27 | return nloss, prec1 -------------------------------------------------------------------------------- /loss/softmaxproto.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import loss.softmax as softmax 7 | import loss.angleproto as angleproto 8 | 9 | class LossFunction(nn.Module): 10 | 11 | def __init__(self, **kwargs): 12 | super(LossFunction, self).__init__() 13 | 14 | self.test_normalize = True 15 | 16 | self.softmax = softmax.LossFunction(**kwargs) 17 | self.angleproto = angleproto.LossFunction(**kwargs) 18 | 19 | print('Initialised SoftmaxPrototypical Loss') 20 | 21 | def forward(self, x, label=None): 22 | 23 | assert x.size()[1] == 2 24 | 25 | nlossS, prec1 = self.softmax(x.reshape(-1,x.size()[-1]), label.repeat_interleave(2)) 26 | 27 | nlossP, _ = self.angleproto(x,None) 28 | 29 | return nlossS+nlossP, prec1 -------------------------------------------------------------------------------- /loss/triplet.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import time, pdb, numpy 8 | from tuneThreshold import tuneThresholdfromScore 9 | import random 10 | 11 | class LossFunction(nn.Module): 12 | 13 | def __init__(self, hard_rank=0, hard_prob=0, margin=0, **kwargs): 14 | super(LossFunction, self).__init__() 15 | 16 | self.test_normalize = True 17 | 18 | self.hard_rank = hard_rank 19 | self.hard_prob = hard_prob 20 | self.margin = margin 21 | 22 | print('Initialised Triplet Loss') 23 | 24 | def forward(self, x, label=None): 25 | 26 | assert x.size()[1] == 2 27 | 28 | out_anchor = F.normalize(x[:,0,:], p=2, dim=1) 29 | out_positive = F.normalize(x[:,1,:], p=2, dim=1) 30 | stepsize = out_anchor.size()[0] 31 | 32 | output = -1 * (F.pairwise_distance(out_anchor.unsqueeze(-1),out_positive.unsqueeze(-1).transpose(0,2))**2) 33 | 34 | negidx = self.mineHardNegative(output.detach()) 35 | 36 | out_negative = out_positive[negidx,:] 37 | 38 | labelnp = numpy.array([1]*len(out_positive)+[0]*len(out_negative)) 39 | 40 | ## calculate distances 41 | pos_dist = F.pairwise_distance(out_anchor,out_positive) 42 | neg_dist = F.pairwise_distance(out_anchor,out_negative) 43 | 44 | ## loss function 45 | nloss = torch.mean(F.relu(torch.pow(pos_dist, 2) - torch.pow(neg_dist, 2) + self.margin)) 46 | 47 | scores = -1 * torch.cat([pos_dist,neg_dist],dim=0).detach().cpu().numpy() 48 | 49 | errors = tuneThresholdfromScore(scores, labelnp, []); 50 | 51 | return nloss, errors[1] 52 | 53 | ## ===== ===== ===== ===== ===== ===== ===== ===== 54 | ## Hard negative mining 55 | ## ===== ===== ===== ===== ===== ===== ===== ===== 56 | 57 | def mineHardNegative(self, output): 58 | 59 | negidx = [] 60 | 61 | for idx, similarity in enumerate(output): 62 | 63 | simval, simidx = torch.sort(similarity,descending=True) 64 | 65 | if self.hard_rank < 0: 66 | 67 | ## Semi hard negative mining 68 | 69 | semihardidx = simidx[(similarity[idx] - self.margin < simval) & (simval < similarity[idx])] 70 | 71 | if len(semihardidx) == 0: 72 | negidx.append(random.choice(simidx)) 73 | else: 74 | negidx.append(random.choice(semihardidx)) 75 | 76 | else: 77 | 78 | ## Rank based negative mining 79 | 80 | simidx = simidx[simidx!=idx] 81 | 82 | if random.random() < self.hard_prob: 83 | negidx.append(simidx[random.randint(0, self.hard_rank)]) 84 | else: 85 | negidx.append(random.choice(simidx)) 86 | 87 | return negidx -------------------------------------------------------------------------------- /models/RawNet3.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from asteroid_filterbanks import Encoder, ParamSincFB 6 | 7 | from models.RawNetBasicBlock import Bottle2neck, PreEmphasis 8 | 9 | 10 | class RawNet3(nn.Module): 11 | def __init__(self, block, model_scale, context, summed, C=1024, **kwargs): 12 | super().__init__() 13 | 14 | nOut = kwargs["nOut"] 15 | 16 | self.context = context 17 | self.encoder_type = kwargs["encoder_type"] 18 | self.log_sinc = kwargs["log_sinc"] 19 | self.norm_sinc = kwargs["norm_sinc"] 20 | self.out_bn = kwargs["out_bn"] 21 | self.summed = summed 22 | 23 | self.preprocess = nn.Sequential( 24 | PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True) 25 | ) 26 | self.conv1 = Encoder( 27 | ParamSincFB( 28 | C // 4, 29 | 251, 30 | stride=kwargs["sinc_stride"], 31 | ) 32 | ) 33 | self.relu = nn.ReLU() 34 | self.bn1 = nn.BatchNorm1d(C // 4) 35 | 36 | self.layer1 = block( 37 | C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5 38 | ) 39 | self.layer2 = block( 40 | C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3 41 | ) 42 | self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale) 43 | self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1) 44 | 45 | if self.context: 46 | attn_input = 1536 * 3 47 | else: 48 | attn_input = 1536 49 | print("self.encoder_type", self.encoder_type) 50 | if self.encoder_type == "ECA": 51 | attn_output = 1536 52 | elif self.encoder_type == "ASP": 53 | attn_output = 1 54 | else: 55 | raise ValueError("Undefined encoder") 56 | 57 | self.attention = nn.Sequential( 58 | nn.Conv1d(attn_input, 128, kernel_size=1), 59 | nn.ReLU(), 60 | nn.BatchNorm1d(128), 61 | nn.Conv1d(128, attn_output, kernel_size=1), 62 | nn.Softmax(dim=2), 63 | ) 64 | 65 | self.bn5 = nn.BatchNorm1d(3072) 66 | 67 | self.fc6 = nn.Linear(3072, nOut) 68 | self.bn6 = nn.BatchNorm1d(nOut) 69 | 70 | self.mp3 = nn.MaxPool1d(3) 71 | 72 | def forward(self, x): 73 | """ 74 | :param x: input mini-batch (bs, samp) 75 | """ 76 | 77 | with torch.cuda.amp.autocast(enabled=False): 78 | x = self.preprocess(x) 79 | x = torch.abs(self.conv1(x)) 80 | if self.log_sinc: 81 | x = torch.log(x + 1e-6) 82 | if self.norm_sinc == "mean": 83 | x = x - torch.mean(x, dim=-1, keepdim=True) 84 | elif self.norm_sinc == "mean_std": 85 | m = torch.mean(x, dim=-1, keepdim=True) 86 | s = torch.std(x, dim=-1, keepdim=True) 87 | s[s < 0.001] = 0.001 88 | x = (x - m) / s 89 | 90 | if self.summed: 91 | x1 = self.layer1(x) 92 | x2 = self.layer2(x1) 93 | x3 = self.layer3(self.mp3(x1) + x2) 94 | else: 95 | x1 = self.layer1(x) 96 | x2 = self.layer2(x1) 97 | x3 = self.layer3(x2) 98 | 99 | x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1)) 100 | x = self.relu(x) 101 | 102 | t = x.size()[-1] 103 | 104 | if self.context: 105 | global_x = torch.cat( 106 | ( 107 | x, 108 | torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t), 109 | torch.sqrt( 110 | torch.var(x, dim=2, keepdim=True).clamp( 111 | min=1e-4, max=1e4 112 | ) 113 | ).repeat(1, 1, t), 114 | ), 115 | dim=1, 116 | ) 117 | else: 118 | global_x = x 119 | 120 | w = self.attention(global_x) 121 | 122 | mu = torch.sum(x * w, dim=2) 123 | sg = torch.sqrt( 124 | (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4) 125 | ) 126 | 127 | x = torch.cat((mu, sg), 1) 128 | 129 | x = self.bn5(x) 130 | 131 | x = self.fc6(x) 132 | 133 | if self.out_bn: 134 | x = self.bn6(x) 135 | 136 | return x 137 | 138 | 139 | def MainModel(**kwargs): 140 | 141 | model = RawNet3( 142 | Bottle2neck, model_scale=8, context=True, summed=True, out_bn=False, log_sinc=True, norm_sinc="mean", grad_mult=1, **kwargs 143 | ) 144 | return model 145 | -------------------------------------------------------------------------------- /models/RawNetBasicBlock.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PreEmphasis(torch.nn.Module): 9 | def __init__(self, coef: float = 0.97) -> None: 10 | super().__init__() 11 | self.coef = coef 12 | # make kernel 13 | # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. 14 | self.register_buffer( 15 | "flipped_filter", 16 | torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0), 17 | ) 18 | 19 | def forward(self, input: torch.tensor) -> torch.tensor: 20 | assert ( 21 | len(input.size()) == 2 22 | ), "The number of dimensions of input tensor must be 2!" 23 | # reflect padding to match lengths of in/out 24 | input = input.unsqueeze(1) 25 | input = F.pad(input, (1, 0), "reflect") 26 | return F.conv1d(input, self.flipped_filter) 27 | 28 | 29 | class AFMS(nn.Module): 30 | """ 31 | Alpha-Feature map scaling, added to the output of each residual block[1,2]. 32 | 33 | Reference: 34 | [1] RawNet2 : https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf 35 | [2] AMFS : https://www.koreascience.or.kr/article/JAKO202029757857763.page 36 | """ 37 | 38 | def __init__(self, nb_dim: int) -> None: 39 | super().__init__() 40 | self.alpha = nn.Parameter(torch.ones((nb_dim, 1))) 41 | self.fc = nn.Linear(nb_dim, nb_dim) 42 | self.sig = nn.Sigmoid() 43 | 44 | def forward(self, x): 45 | y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1) 46 | y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1) 47 | 48 | x = x + self.alpha 49 | x = x * y 50 | return x 51 | 52 | 53 | class Bottle2neck(nn.Module): 54 | def __init__( 55 | self, 56 | inplanes, 57 | planes, 58 | kernel_size=None, 59 | dilation=None, 60 | scale=4, 61 | pool=False, 62 | ): 63 | 64 | super().__init__() 65 | 66 | width = int(math.floor(planes / scale)) 67 | 68 | self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1) 69 | self.bn1 = nn.BatchNorm1d(width * scale) 70 | 71 | self.nums = scale - 1 72 | 73 | convs = [] 74 | bns = [] 75 | 76 | num_pad = math.floor(kernel_size / 2) * dilation 77 | 78 | for i in range(self.nums): 79 | convs.append( 80 | nn.Conv1d( 81 | width, 82 | width, 83 | kernel_size=kernel_size, 84 | dilation=dilation, 85 | padding=num_pad, 86 | ) 87 | ) 88 | bns.append(nn.BatchNorm1d(width)) 89 | 90 | self.convs = nn.ModuleList(convs) 91 | self.bns = nn.ModuleList(bns) 92 | 93 | self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1) 94 | self.bn3 = nn.BatchNorm1d(planes) 95 | 96 | self.relu = nn.ReLU() 97 | 98 | self.width = width 99 | 100 | self.mp = nn.MaxPool1d(pool) if pool else False 101 | self.afms = AFMS(planes) 102 | 103 | if inplanes != planes: # if change in number of filters 104 | self.residual = nn.Sequential( 105 | nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False) 106 | ) 107 | else: 108 | self.residual = nn.Identity() 109 | 110 | def forward(self, x): 111 | residual = self.residual(x) 112 | 113 | out = self.conv1(x) 114 | out = self.relu(out) 115 | out = self.bn1(out) 116 | 117 | spx = torch.split(out, self.width, 1) 118 | for i in range(self.nums): 119 | if i == 0: 120 | sp = spx[i] 121 | else: 122 | sp = sp + spx[i] 123 | sp = self.convs[i](sp) 124 | sp = self.relu(sp) 125 | sp = self.bns[i](sp) 126 | if i == 0: 127 | out = sp 128 | else: 129 | out = torch.cat((out, sp), 1) 130 | 131 | out = torch.cat((out, spx[self.nums]), 1) 132 | 133 | out = self.conv3(out) 134 | out = self.relu(out) 135 | out = self.bn3(out) 136 | 137 | out += residual 138 | if self.mp: 139 | out = self.mp(out) 140 | out = self.afms(out) 141 | 142 | return out 143 | -------------------------------------------------------------------------------- /models/ResNetBlocks.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class SEBasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 11 | super(SEBasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.se = SELayer(planes, reduction) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.relu(out) 26 | out = self.bn1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.se(out) 31 | 32 | if self.downsample is not None: 33 | residual = self.downsample(x) 34 | 35 | out += residual 36 | out = self.relu(out) 37 | return out 38 | 39 | 40 | class SEBottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 44 | super(SEBottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.se = SELayer(planes * 4, reduction) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | residual = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv3(out) 69 | out = self.bn3(out) 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class SELayer(nn.Module): 82 | def __init__(self, channel, reduction=8): 83 | super(SELayer, self).__init__() 84 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 85 | self.fc = nn.Sequential( 86 | nn.Linear(channel, channel // reduction), 87 | nn.ReLU(inplace=True), 88 | nn.Linear(channel // reduction, channel), 89 | nn.Sigmoid() 90 | ) 91 | 92 | def forward(self, x): 93 | b, c, _, _ = x.size() 94 | y = self.avg_pool(x).view(b, c) 95 | y = self.fc(y).view(b, c, 1, 1) 96 | return x * y -------------------------------------------------------------------------------- /models/ResNetSE34L.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torchaudio 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter 9 | from models.ResNetBlocks import * 10 | 11 | class ResNetSE(nn.Module): 12 | def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', n_mels=40, log_input=True, **kwargs): 13 | super(ResNetSE, self).__init__() 14 | 15 | print('Embedding size is %d, encoder %s.'%(nOut, encoder_type)) 16 | 17 | self.inplanes = num_filters[0] 18 | self.encoder_type = encoder_type 19 | self.n_mels = n_mels 20 | self.log_input = log_input 21 | 22 | self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3, 23 | bias=False) 24 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | self.layer1 = self._make_layer(block, num_filters[0], layers[0]) 28 | self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2)) 29 | self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2)) 30 | self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1)) 31 | 32 | self.instancenorm = nn.InstanceNorm1d(n_mels) 33 | self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, window_fn=torch.hamming_window, n_mels=n_mels) 34 | 35 | if self.encoder_type == "SAP": 36 | self.sap_linear = nn.Linear(num_filters[3] * block.expansion, num_filters[3] * block.expansion) 37 | self.attention = self.new_parameter(num_filters[3] * block.expansion, 1) 38 | out_dim = num_filters[3] * block.expansion 39 | elif self.encoder_type == "ASP": 40 | self.sap_linear = nn.Linear(num_filters[3] * block.expansion, num_filters[3] * block.expansion) 41 | self.attention = self.new_parameter(num_filters[3] * block.expansion, 1) 42 | out_dim = num_filters[3] * block.expansion * 2 43 | else: 44 | raise ValueError('Undefined encoder') 45 | 46 | self.fc = nn.Linear(out_dim, nOut) 47 | 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 51 | elif isinstance(m, nn.BatchNorm2d): 52 | nn.init.constant_(m.weight, 1) 53 | nn.init.constant_(m.bias, 0) 54 | 55 | def _make_layer(self, block, planes, blocks, stride=1): 56 | downsample = None 57 | if stride != 1 or self.inplanes != planes * block.expansion: 58 | downsample = nn.Sequential( 59 | nn.Conv2d(self.inplanes, planes * block.expansion, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(planes * block.expansion), 62 | ) 63 | 64 | layers = [] 65 | layers.append(block(self.inplanes, planes, stride, downsample)) 66 | self.inplanes = planes * block.expansion 67 | for i in range(1, blocks): 68 | layers.append(block(self.inplanes, planes)) 69 | 70 | return nn.Sequential(*layers) 71 | 72 | def new_parameter(self, *size): 73 | out = nn.Parameter(torch.FloatTensor(*size)) 74 | nn.init.xavier_normal_(out) 75 | return out 76 | 77 | def forward(self, x): 78 | 79 | with torch.no_grad(): 80 | with torch.cuda.amp.autocast(enabled=False): 81 | x = self.torchfb(x)+1e-6 82 | if self.log_input: x = x.log() 83 | x = self.instancenorm(x).unsqueeze(1).detach() 84 | 85 | x = self.conv1(x) 86 | x = self.bn1(x) 87 | x = self.relu(x) 88 | 89 | x = self.layer1(x) 90 | x = self.layer2(x) 91 | x = self.layer3(x) 92 | x = self.layer4(x) 93 | 94 | x = torch.mean(x, dim=2, keepdim=True) 95 | 96 | if self.encoder_type == "SAP": 97 | x = x.permute(0,3,1,2).squeeze(-1) 98 | h = torch.tanh(self.sap_linear(x)) 99 | w = torch.matmul(h, self.attention).squeeze(dim=2) 100 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 101 | x = torch.sum(x * w, dim=1) 102 | elif self.encoder_type == "ASP": 103 | x = x.permute(0,3,1,2).squeeze(-1) 104 | h = torch.tanh(self.sap_linear(x)) 105 | w = torch.matmul(h, self.attention).squeeze(dim=2) 106 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 107 | mu = torch.sum(x * w, dim=1) 108 | rh = torch.sqrt( ( torch.sum((x**2) * w, dim=1) - mu**2 ).clamp(min=1e-5) ) 109 | x = torch.cat((mu,rh),1) 110 | 111 | x = x.view(x.size()[0], -1) 112 | x = self.fc(x) 113 | 114 | return x 115 | 116 | 117 | def MainModel(nOut=256, **kwargs): 118 | # Number of filters 119 | num_filters = [16, 32, 64, 128] 120 | model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, nOut, **kwargs) 121 | return model 122 | -------------------------------------------------------------------------------- /models/ResNetSE34V2.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torchaudio 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter 9 | from models.ResNetBlocks import * 10 | from utils import PreEmphasis 11 | 12 | class ResNetSE(nn.Module): 13 | def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', n_mels=40, log_input=True, **kwargs): 14 | super(ResNetSE, self).__init__() 15 | 16 | print('Embedding size is %d, encoder %s.'%(nOut, encoder_type)) 17 | 18 | self.inplanes = num_filters[0] 19 | self.encoder_type = encoder_type 20 | self.n_mels = n_mels 21 | self.log_input = log_input 22 | 23 | self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=1, padding=1) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 26 | 27 | 28 | self.layer1 = self._make_layer(block, num_filters[0], layers[0]) 29 | self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2)) 30 | self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2)) 31 | self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2)) 32 | 33 | self.instancenorm = nn.InstanceNorm1d(n_mels) 34 | self.torchfb = torch.nn.Sequential( 35 | PreEmphasis(), 36 | torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, window_fn=torch.hamming_window, n_mels=n_mels) 37 | ) 38 | 39 | outmap_size = int(self.n_mels/8) 40 | 41 | self.attention = nn.Sequential( 42 | nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), 43 | nn.ReLU(), 44 | nn.BatchNorm1d(128), 45 | nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), 46 | nn.Softmax(dim=2), 47 | ) 48 | 49 | if self.encoder_type == "SAP": 50 | out_dim = num_filters[3] * outmap_size 51 | elif self.encoder_type == "ASP": 52 | out_dim = num_filters[3] * outmap_size * 2 53 | else: 54 | raise ValueError('Undefined encoder') 55 | 56 | self.fc = nn.Linear(out_dim, nOut) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 61 | elif isinstance(m, nn.BatchNorm2d): 62 | nn.init.constant_(m.weight, 1) 63 | nn.init.constant_(m.bias, 0) 64 | 65 | def _make_layer(self, block, planes, blocks, stride=1): 66 | downsample = None 67 | if stride != 1 or self.inplanes != planes * block.expansion: 68 | downsample = nn.Sequential( 69 | nn.Conv2d(self.inplanes, planes * block.expansion, 70 | kernel_size=1, stride=stride, bias=False), 71 | nn.BatchNorm2d(planes * block.expansion), 72 | ) 73 | 74 | layers = [] 75 | layers.append(block(self.inplanes, planes, stride, downsample)) 76 | self.inplanes = planes * block.expansion 77 | for i in range(1, blocks): 78 | layers.append(block(self.inplanes, planes)) 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def new_parameter(self, *size): 83 | out = nn.Parameter(torch.FloatTensor(*size)) 84 | nn.init.xavier_normal_(out) 85 | return out 86 | 87 | def forward(self, x): 88 | 89 | with torch.no_grad(): 90 | with torch.cuda.amp.autocast(enabled=False): 91 | x = self.torchfb(x)+1e-6 92 | if self.log_input: x = x.log() 93 | x = self.instancenorm(x).unsqueeze(1) 94 | 95 | x = self.conv1(x) 96 | x = self.relu(x) 97 | x = self.bn1(x) 98 | 99 | x = self.layer1(x) 100 | x = self.layer2(x) 101 | x = self.layer3(x) 102 | x = self.layer4(x) 103 | 104 | x = x.reshape(x.size()[0],-1,x.size()[-1]) 105 | 106 | w = self.attention(x) 107 | 108 | if self.encoder_type == "SAP": 109 | x = torch.sum(x * w, dim=2) 110 | elif self.encoder_type == "ASP": 111 | mu = torch.sum(x * w, dim=2) 112 | sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-5) ) 113 | x = torch.cat((mu,sg),1) 114 | 115 | x = x.view(x.size()[0], -1) 116 | x = self.fc(x) 117 | 118 | return x 119 | 120 | 121 | def MainModel(nOut=256, **kwargs): 122 | # Number of filters 123 | num_filters = [32, 64, 128, 256] 124 | model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, nOut, **kwargs) 125 | return model 126 | 127 | -------------------------------------------------------------------------------- /models/VGGVox.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torchaudio 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter 9 | 10 | class MainModel(nn.Module): 11 | def __init__(self, nOut = 1024, encoder_type='SAP', log_input=True, **kwargs): 12 | super(MainModel, self).__init__(); 13 | 14 | print('Embedding size is %d, encoder %s.'%(nOut, encoder_type)) 15 | 16 | self.encoder_type = encoder_type 17 | self.log_input = log_input 18 | 19 | self.netcnn = nn.Sequential( 20 | nn.Conv2d(1, 96, kernel_size=(5,7), stride=(1,2), padding=(2,2)), 21 | nn.BatchNorm2d(96), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=(1,3), stride=(1,2)), 24 | 25 | nn.Conv2d(96, 256, kernel_size=(5,5), stride=(2,2), padding=(1,1)), 26 | nn.BatchNorm2d(256), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), 29 | 30 | nn.Conv2d(256, 384, kernel_size=(3,3), padding=(1,1)), 31 | nn.BatchNorm2d(384), 32 | nn.ReLU(inplace=True), 33 | 34 | nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), 35 | nn.BatchNorm2d(256), 36 | nn.ReLU(inplace=True), 37 | 38 | nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(inplace=True), 41 | nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), 42 | 43 | nn.Conv2d(256, 512, kernel_size=(4,1), padding=(0,0)), 44 | nn.BatchNorm2d(512), 45 | nn.ReLU(inplace=True), 46 | 47 | ); 48 | 49 | if self.encoder_type == "MAX": 50 | self.encoder = nn.AdaptiveMaxPool2d((1,1)) 51 | out_dim = 512 52 | elif self.encoder_type == "TAP": 53 | self.encoder = nn.AdaptiveAvgPool2d((1,1)) 54 | out_dim = 512 55 | elif self.encoder_type == "SAP": 56 | self.sap_linear = nn.Linear(512, 512) 57 | self.attention = self.new_parameter(512, 1) 58 | out_dim = 512 59 | else: 60 | raise ValueError('Undefined encoder') 61 | 62 | self.fc = nn.Linear(out_dim, nOut) 63 | 64 | self.instancenorm = nn.InstanceNorm1d(40) 65 | self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=40) 66 | 67 | def new_parameter(self, *size): 68 | out = nn.Parameter(torch.FloatTensor(*size)) 69 | nn.init.xavier_normal_(out) 70 | return out 71 | 72 | def forward(self, x): 73 | 74 | with torch.no_grad(): 75 | with torch.cuda.amp.autocast(enabled=False): 76 | x = self.torchfb(x)+1e-6 77 | if self.log_input: x = x.log() 78 | x = self.instancenorm(x).unsqueeze(1) 79 | 80 | x = self.netcnn(x); 81 | 82 | if self.encoder_type == "MAX" or self.encoder_type == "TAP": 83 | x = self.encoder(x) 84 | x = x.view((x.size()[0], -1)) 85 | 86 | elif self.encoder_type == "SAP": 87 | x = x.permute(0, 2, 1, 3) 88 | x = x.squeeze(dim=1).permute(0, 2, 1) # batch * L * D 89 | h = torch.tanh(self.sap_linear(x)) 90 | w = torch.matmul(h, self.attention).squeeze(dim=2) 91 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 92 | x = torch.sum(x * w, dim=1) 93 | 94 | x = self.fc(x); 95 | 96 | return x; 97 | 98 | -------------------------------------------------------------------------------- /optimizer/adam.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | def Optimizer(parameters, lr, weight_decay, **kwargs): 7 | 8 | print('Initialised Adam optimizer') 9 | 10 | return torch.optim.Adam(parameters, lr = lr, weight_decay = weight_decay); 11 | -------------------------------------------------------------------------------- /optimizer/sgd.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | def Optimizer(parameters, lr, weight_decay, **kwargs): 7 | 8 | print('Initialised SGD optimizer') 9 | 10 | return torch.optim.SGD(parameters, lr = lr, momentum = 0.9, weight_decay=weight_decay); 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchaudio>=0.7.0 3 | asteroid_filterbanks==0.4.0 4 | numpy 5 | scipy 6 | scikit-learn 7 | tqdm 8 | pyyaml 9 | soundfile 10 | -------------------------------------------------------------------------------- /scheduler/steplr.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | 6 | def Scheduler(optimizer, test_interval, max_epoch, lr_decay, **kwargs): 7 | 8 | sche_fn = torch.optim.lr_scheduler.StepLR(optimizer, step_size=test_interval, gamma=lr_decay) 9 | 10 | lr_step = 'epoch' 11 | 12 | print('Initialised step LR scheduler') 13 | 14 | return sche_fn, lr_step 15 | -------------------------------------------------------------------------------- /trainSpeakerNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import sys, time, os, argparse 5 | import yaml 6 | import numpy 7 | import torch 8 | import glob 9 | import zipfile 10 | import warnings 11 | import datetime 12 | from tuneThreshold import * 13 | from SpeakerNet import * 14 | from DatasetLoader import * 15 | import torch.distributed as dist 16 | import torch.multiprocessing as mp 17 | warnings.simplefilter("ignore") 18 | 19 | ## ===== ===== ===== ===== ===== ===== ===== ===== 20 | ## Parse arguments 21 | ## ===== ===== ===== ===== ===== ===== ===== ===== 22 | 23 | parser = argparse.ArgumentParser(description = "SpeakerNet") 24 | 25 | parser.add_argument('--config', type=str, default=None, help='Config YAML file') 26 | 27 | ## Data loader 28 | parser.add_argument('--max_frames', type=int, default=200, help='Input length to the network for training') 29 | parser.add_argument('--eval_frames', type=int, default=300, help='Input length to the network for testing 0 uses the whole files') 30 | parser.add_argument('--batch_size', type=int, default=200, help='Batch size, number of speakers per batch') 31 | parser.add_argument('--max_seg_per_spk', type=int, default=500, help='Maximum number of utterances per speaker per epoch') 32 | parser.add_argument('--nDataLoaderThread', type=int, default=5, help='Number of loader threads') 33 | parser.add_argument('--augment', type=bool, default=False, help='Augment input') 34 | parser.add_argument('--seed', type=int, default=10, help='Seed for the random number generator') 35 | 36 | ## Training details 37 | parser.add_argument('--test_interval', type=int, default=10, help='Test and save every [test_interval] epochs') 38 | parser.add_argument('--max_epoch', type=int, default=500, help='Maximum number of epochs') 39 | parser.add_argument('--trainfunc', type=str, default="", help='Loss function') 40 | 41 | ## Optimizer 42 | parser.add_argument('--optimizer', type=str, default="adam", help='sgd or adam') 43 | parser.add_argument('--scheduler', type=str, default="steplr", help='Learning rate scheduler') 44 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 45 | parser.add_argument("--lr_decay", type=float, default=0.95, help='Learning rate decay every [test_interval] epochs') 46 | parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay in the optimizer') 47 | 48 | ## Loss functions 49 | parser.add_argument("--hard_prob", type=float, default=0.5, help='Hard negative mining probability, otherwise random, only for some loss functions') 50 | parser.add_argument("--hard_rank", type=int, default=10, help='Hard negative mining rank in the batch, only for some loss functions') 51 | parser.add_argument('--margin', type=float, default=0.1, help='Loss margin, only for some loss functions') 52 | parser.add_argument('--scale', type=float, default=30, help='Loss scale, only for some loss functions') 53 | parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses') 54 | parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses') 55 | 56 | ## Evaluation parameters 57 | parser.add_argument('--dcf_p_target', type=float, default=0.05, help='A priori probability of the specified target speaker') 58 | parser.add_argument('--dcf_c_miss', type=float, default=1, help='Cost of a missed detection') 59 | parser.add_argument('--dcf_c_fa', type=float, default=1, help='Cost of a spurious detection') 60 | 61 | ## Load and save 62 | parser.add_argument('--initial_model', type=str, default="", help='Initial model weights') 63 | parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs') 64 | 65 | ## Training and test data 66 | parser.add_argument('--train_list', type=str, default="data/train_list.txt", help='Train list') 67 | parser.add_argument('--test_list', type=str, default="data/test_list.txt", help='Evaluation list') 68 | parser.add_argument('--train_path', type=str, default="data/voxceleb2", help='Absolute path to the train set') 69 | parser.add_argument('--test_path', type=str, default="data/voxceleb1", help='Absolute path to the test set') 70 | parser.add_argument('--musan_path', type=str, default="data/musan_split", help='Absolute path to the test set') 71 | parser.add_argument('--rir_path', type=str, default="data/RIRS_NOISES/simulated_rirs", help='Absolute path to the test set') 72 | 73 | ## Model definition 74 | parser.add_argument('--n_mels', type=int, default=40, help='Number of mel filterbanks') 75 | parser.add_argument('--log_input', type=bool, default=False, help='Log input features') 76 | parser.add_argument('--model', type=str, default="", help='Name of model definition') 77 | parser.add_argument('--encoder_type', type=str, default="SAP", help='Type of encoder') 78 | parser.add_argument('--nOut', type=int, default=512, help='Embedding size in the last FC layer') 79 | parser.add_argument('--sinc_stride', type=int, default=10, help='Stride size of the first analytic filterbank layer of RawNet3') 80 | 81 | ## For test only 82 | parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only') 83 | 84 | ## Distributed and mixed precision training 85 | parser.add_argument('--port', type=str, default="8888", help='Port for distributed training, input as text') 86 | parser.add_argument('--distributed', dest='distributed', action='store_true', help='Enable distributed training') 87 | parser.add_argument('--mixedprec', dest='mixedprec', action='store_true', help='Enable mixed precision training') 88 | 89 | args = parser.parse_args() 90 | 91 | ## Parse YAML 92 | def find_option_type(key, parser): 93 | for opt in parser._get_optional_actions(): 94 | if ('--' + key) in opt.option_strings: 95 | return opt.type 96 | raise ValueError 97 | 98 | if args.config is not None: 99 | with open(args.config, "r") as f: 100 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 101 | for k, v in yml_config.items(): 102 | if k in args.__dict__: 103 | typ = find_option_type(k, parser) 104 | args.__dict__[k] = typ(v) 105 | else: 106 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 107 | 108 | 109 | ## ===== ===== ===== ===== ===== ===== ===== ===== 110 | ## Trainer script 111 | ## ===== ===== ===== ===== ===== ===== ===== ===== 112 | 113 | def main_worker(gpu, ngpus_per_node, args): 114 | 115 | args.gpu = gpu 116 | 117 | ## Load models 118 | s = SpeakerNet(**vars(args)) 119 | 120 | if args.distributed: 121 | os.environ['MASTER_ADDR']='localhost' 122 | os.environ['MASTER_PORT']=args.port 123 | 124 | dist.init_process_group(backend='nccl', world_size=ngpus_per_node, rank=args.gpu) 125 | 126 | torch.cuda.set_device(args.gpu) 127 | s.cuda(args.gpu) 128 | 129 | s = torch.nn.parallel.DistributedDataParallel(s, device_ids=[args.gpu], find_unused_parameters=True) 130 | 131 | print('Loaded the model on GPU {:d}'.format(args.gpu)) 132 | 133 | else: 134 | s = WrappedModel(s).cuda(args.gpu) 135 | 136 | it = 1 137 | eers = [100] 138 | 139 | if args.gpu == 0: 140 | ## Write args to scorefile 141 | scorefile = open(args.result_save_path+"/scores.txt", "a+") 142 | 143 | ## Initialise trainer and data loader 144 | train_dataset = train_dataset_loader(**vars(args)) 145 | 146 | train_sampler = train_dataset_sampler(train_dataset, **vars(args)) 147 | 148 | train_loader = torch.utils.data.DataLoader( 149 | train_dataset, 150 | batch_size=args.batch_size, 151 | num_workers=args.nDataLoaderThread, 152 | sampler=train_sampler, 153 | pin_memory=False, 154 | worker_init_fn=worker_init_fn, 155 | drop_last=True, 156 | ) 157 | 158 | trainer = ModelTrainer(s, **vars(args)) 159 | 160 | ## Load model weights 161 | modelfiles = glob.glob('%s/model0*.model'%args.model_save_path) 162 | modelfiles.sort() 163 | 164 | if(args.initial_model != ""): 165 | trainer.loadParameters(args.initial_model) 166 | print("Model {} loaded!".format(args.initial_model)) 167 | elif len(modelfiles) >= 1: 168 | trainer.loadParameters(modelfiles[-1]) 169 | print("Model {} loaded from previous state!".format(modelfiles[-1])) 170 | it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1 171 | 172 | for ii in range(1,it): 173 | trainer.__scheduler__.step() 174 | 175 | ## Evaluation code - must run on single GPU 176 | if args.eval == True: 177 | 178 | pytorch_total_params = sum(p.numel() for p in s.module.__S__.parameters()) 179 | 180 | print('Total parameters: ',pytorch_total_params) 181 | print('Test list',args.test_list) 182 | 183 | sc, lab, _ = trainer.evaluateFromList(**vars(args)) 184 | 185 | if args.gpu == 0: 186 | 187 | result = tuneThresholdfromScore(sc, lab, [1, 0.1]) 188 | 189 | fnrs, fprs, thresholds = ComputeErrorRates(sc, lab) 190 | mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa) 191 | 192 | print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]), "MinDCF {:2.5f}".format(mindcf)) 193 | 194 | return 195 | 196 | ## Save training code and params 197 | if args.gpu == 0: 198 | pyfiles = glob.glob('./*.py') 199 | strtime = datetime.datetime.now().strftime("%Y%m%d%H%M%S") 200 | 201 | zipf = zipfile.ZipFile(args.result_save_path+ '/run%s.zip'%strtime, 'w', zipfile.ZIP_DEFLATED) 202 | for file in pyfiles: 203 | zipf.write(file) 204 | zipf.close() 205 | 206 | with open(args.result_save_path + '/run%s.cmd'%strtime, 'w') as f: 207 | f.write('%s'%args) 208 | 209 | ## Core training script 210 | for it in range(it,args.max_epoch+1): 211 | 212 | train_sampler.set_epoch(it) 213 | 214 | clr = [x['lr'] for x in trainer.__optimizer__.param_groups] 215 | 216 | loss, traineer = trainer.train_network(train_loader, verbose=(args.gpu == 0)) 217 | 218 | if args.gpu == 0: 219 | print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f}".format(it, traineer, loss, max(clr))) 220 | scorefile.write("Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f} \n".format(it, traineer, loss, max(clr))) 221 | 222 | if it % args.test_interval == 0: 223 | 224 | sc, lab, _ = trainer.evaluateFromList(**vars(args)) 225 | 226 | if args.gpu == 0: 227 | 228 | result = tuneThresholdfromScore(sc, lab, [1, 0.1]) 229 | 230 | fnrs, fprs, thresholds = ComputeErrorRates(sc, lab) 231 | mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa) 232 | 233 | eers.append(result[1]) 234 | 235 | print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, VEER {:2.4f}, MinDCF {:2.5f}".format(it, result[1], mindcf)) 236 | scorefile.write("Epoch {:d}, VEER {:2.4f}, MinDCF {:2.5f}\n".format(it, result[1], mindcf)) 237 | 238 | trainer.saveParameters(args.model_save_path+"/model%09d.model"%it) 239 | 240 | with open(args.model_save_path+"/model%09d.eer"%it, 'w') as eerfile: 241 | eerfile.write('{:2.4f}'.format(result[1])) 242 | 243 | scorefile.flush() 244 | 245 | if args.gpu == 0: 246 | scorefile.close() 247 | 248 | 249 | ## ===== ===== ===== ===== ===== ===== ===== ===== 250 | ## Main function 251 | ## ===== ===== ===== ===== ===== ===== ===== ===== 252 | 253 | 254 | def main(): 255 | args.model_save_path = args.save_path+"/model" 256 | args.result_save_path = args.save_path+"/result" 257 | args.feat_save_path = "" 258 | 259 | os.makedirs(args.model_save_path, exist_ok=True) 260 | os.makedirs(args.result_save_path, exist_ok=True) 261 | 262 | n_gpus = torch.cuda.device_count() 263 | 264 | print('Python Version:', sys.version) 265 | print('PyTorch Version:', torch.__version__) 266 | print('Number of GPUs:', torch.cuda.device_count()) 267 | print('Save path:',args.save_path) 268 | 269 | if args.distributed: 270 | mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, args)) 271 | else: 272 | main_worker(0, None, args) 273 | 274 | 275 | if __name__ == '__main__': 276 | main() -------------------------------------------------------------------------------- /tuneThreshold.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import os 5 | import glob 6 | import sys 7 | import time 8 | from sklearn import metrics 9 | import numpy 10 | import pdb 11 | from operator import itemgetter 12 | 13 | def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None): 14 | 15 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1) 16 | fnr = 1 - tpr 17 | 18 | tunedThreshold = []; 19 | if target_fr: 20 | for tfr in target_fr: 21 | idx = numpy.nanargmin(numpy.absolute((tfr - fnr))) 22 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); 23 | 24 | for tfa in target_fa: 25 | idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] 26 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); 27 | 28 | idxE = numpy.nanargmin(numpy.absolute((fnr - fpr))) 29 | eer = max(fpr[idxE],fnr[idxE])*100 30 | 31 | return (tunedThreshold, eer, fpr, fnr); 32 | 33 | # Creates a list of false-negative rates, a list of false-positive rates 34 | # and a list of decision thresholds that give those error-rates. 35 | def ComputeErrorRates(scores, labels): 36 | 37 | # Sort the scores from smallest to largest, and also get the corresponding 38 | # indexes of the sorted scores. We will treat the sorted scores as the 39 | # thresholds at which the the error-rates are evaluated. 40 | sorted_indexes, thresholds = zip(*sorted( 41 | [(index, threshold) for index, threshold in enumerate(scores)], 42 | key=itemgetter(1))) 43 | sorted_labels = [] 44 | labels = [labels[i] for i in sorted_indexes] 45 | fnrs = [] 46 | fprs = [] 47 | 48 | # At the end of this loop, fnrs[i] is the number of errors made by 49 | # incorrectly rejecting scores less than thresholds[i]. And, fprs[i] 50 | # is the total number of times that we have correctly accepted scores 51 | # greater than thresholds[i]. 52 | for i in range(0, len(labels)): 53 | if i == 0: 54 | fnrs.append(labels[i]) 55 | fprs.append(1 - labels[i]) 56 | else: 57 | fnrs.append(fnrs[i-1] + labels[i]) 58 | fprs.append(fprs[i-1] + 1 - labels[i]) 59 | fnrs_norm = sum(labels) 60 | fprs_norm = len(labels) - fnrs_norm 61 | 62 | # Now divide by the total number of false negative errors to 63 | # obtain the false positive rates across all thresholds 64 | fnrs = [x / float(fnrs_norm) for x in fnrs] 65 | 66 | # Divide by the total number of corret positives to get the 67 | # true positive rate. Subtract these quantities from 1 to 68 | # get the false positive rates. 69 | fprs = [1 - x / float(fprs_norm) for x in fprs] 70 | return fnrs, fprs, thresholds 71 | 72 | # Computes the minimum of the detection cost function. The comments refer to 73 | # equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan. 74 | def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa): 75 | min_c_det = float("inf") 76 | min_c_det_threshold = thresholds[0] 77 | for i in range(0, len(fnrs)): 78 | # See Equation (2). it is a weighted sum of false negative 79 | # and false positive errors. 80 | c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target) 81 | if c_det < min_c_det: 82 | min_c_det = c_det 83 | min_c_det_threshold = thresholds[i] 84 | # See Equations (3) and (4). Now we normalize the cost. 85 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 86 | min_dcf = min_c_det / c_def 87 | return min_dcf, min_c_det_threshold -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | """Computes the precision@k for the specified values of k""" 9 | maxk = max(topk) 10 | batch_size = target.size(0) 11 | 12 | _, pred = output.topk(maxk, 1, True, True) 13 | pred = pred.t() 14 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 15 | 16 | res = [] 17 | for k in topk: 18 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 19 | res.append(correct_k.mul_(100.0 / batch_size)) 20 | return res 21 | 22 | class PreEmphasis(torch.nn.Module): 23 | 24 | def __init__(self, coef: float = 0.97): 25 | super().__init__() 26 | self.coef = coef 27 | # make kernel 28 | # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped. 29 | self.register_buffer( 30 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 31 | ) 32 | 33 | def forward(self, input: torch.tensor) -> torch.tensor: 34 | assert len(input.size()) == 2, 'The number of dimensions of input tensor must be 2!' 35 | # reflect padding to match lengths of in/out 36 | input = input.unsqueeze(1) 37 | input = F.pad(input, (1, 0), 'reflect') 38 | return F.conv1d(input, self.flipped_filter).squeeze(1) --------------------------------------------------------------------------------