├── .gitignore ├── LICENSE ├── README.md ├── configs └── example.cfg ├── data_io.py ├── diarize.py ├── figures └── multitask.png ├── inference.py ├── models ├── classifiers.py ├── criteria.py ├── extractors.py └── sim_predictors.py ├── scotus_data_prep ├── filter_scp.pl ├── local_info.py ├── step1_downloadmp3.py ├── step2_scrape_dob.py ├── step3_prepdata.py ├── step4_extract_feats.sh └── step5_trim_split_data.py ├── train.py ├── utils.py └── voxceleb_data_prep ├── data ├── nationality_to_country.tsv └── us_states.csv └── scrape_nationalities.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | 141 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chau 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi Task Learning Speaker Embeddings 2 | Code for the paper: ["Leveraging speaker attribute information using multi task learning for speaker verification and diarization"](https://arxiv.org/abs/2010.14269) 3 | 4 | The overall concept of this paper is that training speaker embedding extractors on auxiliary attributes (such as age or nationality) alongside speaker classification can lead to increased performance for verification and diarization. Training the embeddings in this multi-task fashion can improve the descriptiveness of the embedding space. 5 | 6 | This is implemented by having multiple task-specific "heads" acting on the embeddings, such as x-vectors. Alongside speaker classification, one might employ age classification/regression, or nationality classification. A general diagram for this can be seen below: 7 | 8 |

9 | 10 |

11 | 12 | Along with experimental code, this repository covers other data preparation tasks used in the paper, such as webscraping age information for lawyers in SCOTUS, and nationality for Wikipedia celebrities. 13 | 14 | # Contents 15 | 16 | - Requirements 17 | - SCOTUS Data Preparation 18 | - VoxCeleb Data Preparation 19 | - Experiments 20 | 21 | # Requirements 22 | 23 | This was tested on python 3.6.8 24 | 25 | - General requirements: 26 | - numpy 27 | - tqdm 28 | - sklearn 29 | - scipy 30 | - Web-scraping: 31 | - beautifulsoup4 32 | - wikipedia 33 | - wptools 34 | - Levenshtein 35 | - python-dateutil 36 | - Experiments: 37 | - Kaldi 38 | - torch 39 | - kaldi_io 40 | 41 | You will also need access to the avvo.com API, along with a custom google search API key, with instructions of how to obtain/set these up in the SCOTUS Data Preparation section. 42 | 43 | # SCOTUS Data Preparation 44 | 45 | ## Obtaining case information 46 | 47 | This work utilizes the information available via the [Oyez project](https://www.oyez.org/about-audio), which provides audio files and transcriptions for US Supreme Court oral arguments. 48 | 49 | The Oyez project has an API, an example of which can be seen here: https://api.oyez.org/cases?per_page=100 50 | 51 | Thankfully, another GitHub repo exists which has already scraped the information provided by the API for each case and auto-updates every week: https://github.com/walkerdb/supreme_court_transcripts. This ~3.5GB repo will be used to extract the cases that we want audio and transcripts for. 52 | 53 | The following command will clone this repo (specifically my fork, which doesn't update with new cases, to ensure consistency with the number of cases that I have tested my code with): 54 | 55 | ```sh 56 | git clone https://github.com/cvqluu/supreme_court_transcripts 57 | ``` 58 | 59 | The location of this repository which contains the case information will be referred to `$SCOTUS_CASE_REPO` throughout the rest of the SCOTUS section. Of course, if you want to obtain the most up to date recordings, then you should clone the original repo by walkerdb. 60 | 61 | ## Downloading Audio and Filtering Cases 62 | 63 | Now that `$SCOTUS_CASE_REPO` exists and has been cloned from GitHub, we can now make us of the MTL-Speaker-Embeddings repo: 64 | 65 | ```sh 66 | git clone https://github.com/cvqluu/MTL-Speaker-Embeddings 67 | cd MTL-Speaker-Embeddings/scotus_data_prep 68 | ``` 69 | 70 | which puts us inside the `scotus_data_prep` folder, which is where the data prep scripts for SCOTUS are all present. 71 | 72 | Next, we want to run the first step, which downloads mp3 files and filters the cases which we can use. This takes in as input the location of the case JSONs in `$SCOTUS_CASE_REPO` and then produces a data folder (of which the location is up to you) which we will call `$BASE_OUTFOLDER`, which will store all the SCOTUS data to be used later on. 73 | 74 | ```sh 75 | python step1_downloadmp3.py --case-folder $SCOTUS_CASE_REPO/oyez/cases --base-outfolder $BASE_OUTFOLDER 76 | ``` 77 | 78 | This parses through the cases in `$SCOTUS_CASE_REPO` and eliminates cases argued before October 2005, as this is when the Supreme Court started recording digitally, instead of using a reel-to-reel taping system. The taped recordings were excluded as there were a number of problems and defects with these recordings as detailed here: https://www.oyez.org/about-audio. This script also eliminates cases where audio can't be found, and also ones where the transcription has a lot of invalid speaker turns (0 or less duration speaker turns). 79 | 80 | The outcome should be a file structure in `$BASE_OUTFOLDER` something like so: 81 | 82 | ``` 83 | $BASE_OUTFOLDER 84 | ├── speaker_ids.json 85 | ├── audio 86 | | ├── 2007.07-330.mp3 87 | | └── ... 88 | └── transcripts 89 | ├── 2007.07-330.json 90 | └── ... 91 | ``` 92 | 93 | Sometimes, downloading audio files may fail, but succeed on subsequent attempts. This script is safe to run multiple times, and will try and re-download files that are missing (while skipping ones that already have been processed). 94 | 95 | If you have used my fork for `$SCOTUS_CASE_REPO` then you should end up with 2035 mp3s/jsons in the audio and transcripts folders respectively, although this may not be consistent, depending on the availability of every mp3. 96 | 97 | ## Web scraping 98 | 99 | Next, we want to scrape the approximate DoB of each speaker found in the recordings downloaded, which is stored in `$BASE_OUTFOLDER/speaker_ids.json`. 100 | 101 | First of all, we need to set up and obtain a few things in order to fill out `local_info.py` inside of `scotus_data_prep`. 102 | 103 | ### Avvo API 104 | 105 | Instructions on how to obtain access to the Avvo API are here: http://avvo.github.io/api-doc/ 106 | 107 | Once this is set up, you can fill out `local_info.py` with your own `AVVO_API_ACCESS_TOKEN`. (Note: `AVVO_CSE_ID` is covered in the next section) 108 | 109 | ### Custom Google Search 110 | 111 | A custom google search will be needed, which can be set up here: 112 | 113 | https://cse.google.com/cse/all 114 | 115 | You will want to click `Add` to create a new search engine, and under `Sites to search`, you will want to enter `avvo.com`, with the Language specified as English. 116 | 117 | After clicking Create, this should succesfully create this custom search. The control panel will take you to where you can find the `Search Engine ID`, which you can fill into `AVVO_CSE_ID` in `local_info.py`. 118 | 119 | You will also need a google API key, which can be set up here: https://developers.google.com/maps/documentation/javascript/get-api-key 120 | 121 | Once this is obtained, fill out `GOOGLE_API_KEY` in `local_info.py`. 122 | 123 | ### Running the webscraper 124 | 125 | ```sh 126 | python step2_scrape_dob.py --base_outfolder $BASE_OUTFOLDER 127 | ``` 128 | 129 | This will go through the names in `speaker_ids.json` and try and scrape their dates of birth (DoB) from the following sources: 130 | 131 | - [Wikipedia](https://en.wikipedia.org/) 132 | - [Supreme court clerkship graduation dates](https://en.wikipedia.org/wiki/List_of_law_clerks_of_the_Supreme_Court_of_the_United_States_(Chief_Justice)) 133 | - JUSTIA.com 134 | - Avvo.com 135 | 136 | For the final three sources, the year discovered for graduation/admission to practice law will be subtracted by 25 to obtain an approximate date of birth. The dates of birth, and additional information about the source of each DoB are stored in pickled dictionaries in `$BASE_OUTFOLDER`: 137 | 138 | ``` 139 | $BASE_OUTFOLDER 140 | ├── speaker_ids.json 141 | ├── dobs.p 142 | ├── dobs_info.p 143 | ├── audio 144 | | └── ... 145 | └── transcripts 146 | └── ... 147 | ``` 148 | 149 | Like `step1_downloadmp3.py`, this script is also safe to re-run, and will try to re-scrape names for which no DoB has been found. If you wish to skip the names which have been attemped, and are present in the `dobs.p` dictionary, the step2 script can be run with a `--skip-attempted` flag. 150 | 151 | ## Prepping data for feature extraction 152 | 153 | ```sh 154 | python step3_prepdata.py --base_outfolder $BASE_OUTFOLDER 155 | ``` 156 | 157 | This prepares verification and diarization data folders that are ready for Kaldi to extract features from. As a result, the recordings are changed into the consistent naming format `YEAR-XYZ`, and this mapping from orignal names to new names is stored in a JSON file. 158 | 159 | Verification/training data is made by splitting up long utterances into non-overlapping 10s segments, with minimum length 4s. Utterances are named in a consistent fashion, and the age (in days) of each speaker at the time of each utterance is calculated, and placed into an `utt2age` file. 160 | 161 | Diarization data is made by splitting into 1.5s segments with 0.75s shift. A `ref.rttm` file is also created. 162 | 163 | This should result in the following file structure (in addition to what was previously shown): 164 | 165 | ``` 166 | $BASE_OUTFOLDER 167 | ├── ... 168 | ├── orec_recid_mapping.json 169 | ├── recid_orec_mapping.json 170 | ├── veri_data 171 | | ├── utt2spk 172 | | ├── spk2utt 173 | | ├── utt2age 174 | | ├── wav.scp 175 | | ├── segments 176 | | ├── real_utt2spk 177 | | └── real_spk2utt 178 | └── diar_data 179 | ├── utt2spk 180 | ├── spk2utt 181 | ├── ref.rttm 182 | ├── wav.scp 183 | ├── segments 184 | ├── real_utt2spk 185 | └── real_spk2utt 186 | 187 | ``` 188 | 189 | The `real_{utt2spk|spk2utt}` are there as Kaldi feature extraction insists on speaker ids being the prefix to an utterance name - these will be sorted out later on. 190 | 191 | ## Feature extraction 192 | 193 | We will use Kaldi to extract the features. This kaldi feature extraction script is found in `step4_extract_feats.sh`. 194 | 195 | You will need to edit at the top of this script and fill in your own $BASE_OUTFOLDER 196 | 197 | ```sh 198 | #!/bin/bash 199 | 200 | step=0 201 | nj=20 202 | base_outfolder=/PATH/TO/BASE_OUTFOLDER <--- Edit here 203 | 204 | ... 205 | ``` 206 | 207 | We used the `egs/voxceleb/v2` recipe folder to carry out this script (although the only thing specific to this recipe is the `conf/mfcc.conf`). 208 | 209 | To carry this out, we will assume you have Kaldi installed at `$KALDI_ROOT` 210 | 211 | ```sh 212 | cp step4_extract_feats.sh $KALDI_ROOT/egs/voxceleb/v2/ 213 | cd $KALDI_ROOT/egs/voxceleb/v2 214 | source path.sh 215 | bash step4_extract_feats.sh 216 | ``` 217 | 218 | 219 | ## Making train/test splits 220 | 221 | After changing back your working directory to `scotus_data_prep`, we can run the final stage script. 222 | 223 | ```sh 224 | python step5_trim_split_data.py --base_outfolder $BASE_OUTFOLDER --train-proportion 0.8 --pos-per-spk 12 225 | ``` 226 | 227 | This trims down the data folders to match what features have been extracted, and splits the recordings into train and test according to `--train-proportion`. It also makes a verification task for test set utterances, excluding speakers seen in the training set, generating trials of pairs of utterances to compare. 228 | 229 | The `--pos-per-speaker` option determines how many positive trials there are per test set speaker. If too high a value is selected, this may error out as it cannot select enough utterances, so if that occurs, try lowering this value. 230 | 231 | This should yield the following (once again in addition to what was produced before): 232 | 233 | ``` 234 | $BASE_OUTFOLDER 235 | ├── ... 236 | ├── veri_data_nosil 237 | | ├── train 238 | | | ├── utt2spk 239 | | | ├── spk2utt 240 | | | ├── utt2age 241 | | | ├── feats.scp 242 | | | └── utts 243 | | ├── test 244 | | | ├── veri_pairs 245 | | | └── ... 246 | | └── ... 247 | └── diar_data_nosil 248 | ├── train 249 | | ├── ref.rttm 250 | | └── ... 251 | ├── test 252 | | ├── ref.rttm 253 | | └── ... 254 | └── ... 255 | ``` 256 | 257 | # VoxCeleb Data Preparation 258 | 259 | TODO 260 | 261 | 262 | # Experiments 263 | 264 | Training an embedding extractor is done via `train.py`: 265 | 266 | ```sh 267 | python train.py --cfg configs/example.cfg 268 | ``` 269 | 270 | This runs according to the config file `configs/example.cfg` which we will detail and explain below. The following config file trains an xvector architecture embedding extractor with an age classification head with 10 classes and 0.1 weighting of the age loss function. 271 | 272 | 273 | ## Config File 274 | 275 | ```ini 276 | [Datasets] 277 | # Path to the datasets 278 | train = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/train 279 | test = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/test 280 | 281 | [Model] 282 | # Allowed model_type : ['XTDNN', 'ETDNN'] 283 | model_type = XTDNN 284 | # Allowed classifier_heads types: 285 | # ['speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec'] 286 | classifier_heads = speaker,age 287 | 288 | [Optim] 289 | # Allowed classifier_types: 290 | # ['xvec', 'adm', 'adacos', 'l2softmax', 'xvec_regression', 'arcface', 'sphereface'] 291 | classifier_types = xvec,xvec 292 | classifier_lr_mults = [1.0, 1.0] 293 | classifier_loss_weights = [1.0, 0.1] 294 | # Allowed smooth_types: 295 | # ['twoneighbour', 'uniform'] 296 | classifier_smooth_types = none,none 297 | 298 | [Hyperparams] 299 | input_dim = 30 300 | lr = 0.2 301 | batch_size = 500 302 | max_seq_len = 350 303 | no_cuda = False 304 | seed = 1234 305 | num_iterations = 50000 306 | momentum = 0.5 307 | scheduler_steps = [40000] 308 | scheduler_lambda = 0.5 309 | multi_gpu = False 310 | classifier_lr_mult = 1. 311 | embedding_dim = 256 312 | 313 | [Outputs] 314 | model_dir = exp/example_exp 315 | # At each checkpoint, the model will be evaluated on the test data, and the model will be saved 316 | checkpoint_interval = 500 317 | 318 | [Misc] 319 | num_age_bins = 10 320 | ``` 321 | 322 | Most of the parameters in this configuration file are fairly self explanatory. 323 | 324 | The most important setup is the `classifier_heads` field, which determines what tasks are being applied to the embedding. This also determines the number of parameters in each field in `[Optim]`, which have to match the number of classifier heads. 325 | 326 | 327 | ## Supported Auxiliary Tasks 328 | 329 | The currently supported tasks are `'speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec'`, and each has a requirement for the data that must be present in both train and test folders order for this to be evaluated: 330 | 331 | - `age` and `age_regression` 332 | - Age classification or Age regression 333 | - Requires: `utt2age` file 334 | - 2 column file of format: 335 | - ` ` 336 | - `gender` 337 | - Gender classification 338 | - Requires: `spk2gender` file 339 | - 2 column file of format: 340 | - ` ` 341 | - `nationality` 342 | - Nationality classification 343 | - Requires `spk2nat` file 344 | - 2 column file of format: 345 | - ` ` 346 | - `rec` 347 | - Recording ID classification 348 | - This is typically accompanied by a negative loss weight. When a negative loss weight is given, this automatically applies a Gradient Reversal Layer (GRL) to that classification head. 349 | - Requires `utt2rec` file (but not in test folder): 350 | - 2 column file of format: 351 | - ` ` 352 | 353 | 354 | ## Resuming training 355 | 356 | The following command will resume an experiment from the checkpoint at 25000 iterations: 357 | 358 | ```sh 359 | python train.py --cfg configs/example.cfg --resume-checkpoint 25000 360 | ``` 361 | -------------------------------------------------------------------------------- /configs/example.cfg: -------------------------------------------------------------------------------- 1 | [Datasets] 2 | train = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/train 3 | test = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/test 4 | 5 | [Model] 6 | #allowed model_type : ['XTDNN', 'ETDNN'] 7 | model_type = XTDNN 8 | classifier_heads = speaker,age 9 | 10 | [Optim] 11 | classifier_types = xvec,xvec 12 | classifier_lr_mults = [1.0, 1.0] 13 | classifier_loss_weights = [1.0, 0.1] 14 | classifier_smooth_types = none,none 15 | 16 | [Hyperparams] 17 | input_dim = 30 18 | lr = 0.2 19 | batch_size = 500 20 | max_seq_len = 350 21 | no_cuda = False 22 | seed = 1234 23 | num_iterations = 50000 24 | momentum = 0.5 25 | scheduler_steps = [40000] 26 | scheduler_lambda = 0.5 27 | multi_gpu = False 28 | classifier_lr_mult = 1. 29 | embedding_dim = 256 30 | 31 | [Outputs] 32 | model_dir = exp/example_exp 33 | checkpoint_interval = 500 34 | 35 | [Misc] 36 | num_age_bins = 10 -------------------------------------------------------------------------------- /diarize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import concurrent.futures 3 | import configparser 4 | import glob 5 | import json 6 | import os 7 | import pickle 8 | import random 9 | import re 10 | import shutil 11 | import subprocess 12 | import time 13 | from collections import OrderedDict 14 | from pprint import pprint 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import uvloop 21 | from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering 22 | from sklearn.metrics import pairwise_distances 23 | from sklearn.preprocessing import normalize 24 | from torch.utils.data import DataLoader 25 | from tqdm import tqdm 26 | 27 | from data_io import DiarizationDataset 28 | from models.extractors import ETDNN, FTDNN, XTDNN 29 | from train import parse_config 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description='Diarize per recording') 34 | parser.add_argument('--cfg', type=str, 35 | default='./configs/example_speaker.cfg') 36 | parser.add_argument('--checkpoint', type=int, 37 | default=0, help='Choose a specific iteration to evaluate on instead of the best eer') 38 | parser.add_argument('--diar-data', type=str, 39 | default='/disk/scratch2/s1786813/repos/supreme_court_transcripts/oyez/scotus_diarization_nosil/test') 40 | args = parser.parse_args() 41 | assert os.path.isfile(args.cfg) 42 | args._start_time = time.ctime() 43 | return args 44 | 45 | 46 | def lines_to_file(lines, filename, wmode="w+"): 47 | with open(filename, wmode) as fp: 48 | for line in lines: 49 | fp.write(line) 50 | 51 | 52 | def get_eer_metrics(folder): 53 | rpkl_path = os.path.join(folder, 'results.p') 54 | rpkl = pickle.load(open(rpkl_path, 'rb')) 55 | iterations = list(rpkl.keys()) 56 | eers = [rpkl[k]['test_eer'] for k in rpkl] 57 | return iterations, eers, np.min(eers) 58 | 59 | 60 | def setup(): 61 | if args.model_type == 'XTDNN': 62 | generator = XTDNN(features_per_frame=args.input_dim, 63 | embed_features=args.embedding_dim) 64 | if args.model_type == 'ETDNN': 65 | generator = ETDNN(features_per_frame=args.input_dim, 66 | embed_features=args.embedding_dim) 67 | if args.model_type == 'FTDNN': 68 | generator = FTDNN(in_dim=args.input_dim, 69 | embedding_dim=args.embedding_dim) 70 | 71 | generator.eval() 72 | generator = generator 73 | return generator 74 | 75 | 76 | def agg_clustering_oracle(S, num_clusters): 77 | ahc = AgglomerativeClustering( 78 | n_clusters=num_clusters, affinity='precomputed', linkage='average', compute_full_tree=True) 79 | return ahc.fit_predict(S) 80 | 81 | 82 | def score_der(hyp=None, ref=None, outfile=None, collar=0.25): 83 | ''' 84 | Takes in hypothesis rttm and reference rttm and returns the diarization error rate 85 | Calls md-eval.pl -> writes output to file -> greps for DER value 86 | ''' 87 | assert os.path.isfile(hyp) 88 | assert os.path.isfile(ref) 89 | assert outfile 90 | cmd = 'perl md-eval.pl -1 -c {} -s {} -r {} > {}'.format( 91 | collar, hyp, ref, outfile) 92 | subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, 93 | stderr=subprocess.DEVNULL) 94 | assert os.path.isfile(outfile) 95 | with open(outfile, 'r') as file: 96 | data = file.read().replace('\n', '') 97 | der_str = re.search( 98 | r'DIARIZATION\ ERROR\ =\ [0-9]+([.][0-9]+)?', data).group() 99 | der = float(der_str.split()[-1]) 100 | return der 101 | 102 | 103 | def score_der_uem(hyp=None, ref=None, outfile=None, uem=None, collar=0.25): 104 | ''' 105 | takes in hypothesis rttm and reference rttm and returns the diarization error rate 106 | calls md-eval.pl -> writes output to file -> greps for der value 107 | ''' 108 | assert os.path.isfile(hyp) 109 | assert os.path.isfile(ref) 110 | assert os.path.isfile(uem) 111 | assert outfile 112 | cmd = 'perl md-eval.pl -1 -c {} -s {} -r {} -u {} > {}'.format( 113 | collar, hyp, ref, uem, outfile) 114 | subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL, 115 | stderr=subprocess.DEVNULL) 116 | assert os.path.isfile(outfile) 117 | with open(outfile, 'r') as file: 118 | data = file.read().replace('\n', '') 119 | der_str = re.search( 120 | r'DIARIZATION\ ERROR\ =\ [0-9]+([.][0-9]+)?', data).group() 121 | der = float(der_str.split()[-1]) 122 | return der 123 | 124 | def make_rttm_lines(segcols, cluster_labels): 125 | # Make the rtttm from segments and cluster labels, resolve overlaps etc 126 | assert len(segcols[0]) == len(cluster_labels) 127 | assert len(set(segcols[1])) == 1, 'Must be from a single recording' 128 | rec_id = list(set(segcols[1]))[0] 129 | 130 | starts = segcols[2].astype(float) 131 | ends = segcols[3].astype(float) 132 | 133 | events = [{'start': starts[0], 'end': ends[0], 'label': cluster_labels[0]}] 134 | 135 | for t0, t1, lab in zip(starts, ends, cluster_labels): 136 | # TODO: Warning this only considers overlap with a single adjacent neighbour 137 | if t0 <= events[-1]['end']: 138 | if lab == events[-1]['label']: 139 | events[-1]['end'] = t1 140 | continue 141 | else: 142 | overlap = events[-1]['end'] - t0 143 | events[-1]['end'] -= overlap/2 144 | newevent = {'start': t0+overlap/2, 'end': t1, 'label': lab} 145 | events.append(newevent) 146 | else: 147 | newevent = {'start': t0, 'end': t1, 'label': lab} 148 | events.append(newevent) 149 | 150 | line_str = 'SPEAKER {} 0 {:.3f} {:.3f} {} \n' 151 | lines = [] 152 | for ev in events: 153 | offset = ev['end'] - ev['start'] 154 | if offset < 0.0: 155 | continue 156 | lines.append(line_str.format(rec_id, ev['start'], offset, ev['label'])) 157 | return lines 158 | 159 | 160 | def extract_and_diarize(generator, test_data_dir, device): 161 | all_hyp_rttms = [] 162 | all_ref_rttms = [] 163 | 164 | # if args.checkpoint == 0: 165 | # results_pkl = os.path.join(args.model_dir, 'diar_results.p') 166 | # else: 167 | results_pkl = os.path.join(args.model_dir, 'diar_results_{}.p'.format(args.checkpoint)) 168 | rttm_dir = os.path.join(args.model_dir, 'hyp_rttms') 169 | os.makedirs(rttm_dir, exist_ok=True) 170 | 171 | if os.path.isfile(results_pkl): 172 | rpkl = pickle.load(open(results_pkl, 'rb')) 173 | if rpkl['test_data'] != TEST_DATA_PATH: 174 | moved_rpkl = os.path.join( 175 | args.model_dir, rpkl['test_data'].replace(os.sep, '-') + '.p') 176 | shutil.copy(results_pkl, moved_rpkl) 177 | rpkl = OrderedDict({'test_data': TEST_DATA_PATH}) 178 | else: 179 | rpkl = OrderedDict({'test_data': TEST_DATA_PATH}) 180 | 181 | if 'full_der' in rpkl: 182 | print('({}) Full test DER: {}'.format(args.cfg, rpkl['full_der'])) 183 | if 'full_der_uem' not in rpkl: 184 | uem_file = os.path.join(TEST_DATA_PATH, 'uem') 185 | der_uem = score_der_uem(hyp=os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint)), 186 | ref=os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint)), 187 | outfile=os.path.join(args.model_dir, 'final_hyp_uem.derlog'), 188 | uem=uem_file, 189 | collar=0.25) 190 | rpkl['full_der_uem'] = der_uem 191 | pickle.dump(rpkl, open(results_pkl, 'wb')) 192 | print('({}) Full test DER (uem): {}'.format(args.cfg, rpkl['full_der_uem'])) 193 | 194 | else: 195 | ds_test = DiarizationDataset(test_data_dir) 196 | recs = ds_test.recs 197 | generator = setup() 198 | generator.eval().to(device) 199 | 200 | if args.checkpoint == 0: 201 | generator_its, eers, _ = get_eer_metrics(args.model_dir) 202 | g_path = os.path.join(args.model_dir, 'generator_{}.pt'.format( 203 | generator_its[np.argmin(eers)])) 204 | else: 205 | g_path = os.path.join(args.model_dir, 'generator_{}.pt'.format(args.checkpoint)) 206 | assert os.path.isfile(g_path), "Couldn't find {}".format(g_path) 207 | 208 | generator.load_state_dict(torch.load(g_path)) 209 | 210 | with torch.no_grad(): 211 | for i, r in tqdm(enumerate(recs), total=len(recs)): 212 | ref_rec_rttm = os.path.join(rttm_dir, '{}_{}_ref.rttm'.format(r, args.checkpoint)) 213 | hyp_rec_rttm = os.path.join(rttm_dir, '{}_{}_hyp.rttm'.format(r, args.checkpoint)) 214 | if r in rpkl and (os.path.isfile(ref_rec_rttm) and os.path.isfile(hyp_rec_rttm)): 215 | all_ref_rttms.append(ref_rec_rttm) 216 | all_hyp_rttms.append(hyp_rec_rttm) 217 | continue 218 | 219 | feats, spkrs, ref_rttm_lines, segcols, rec = ds_test.__getitem__(i) 220 | num_spkrs = len(set(spkrs)) 221 | assert r == rec 222 | 223 | # Extract embeds 224 | embeds = [] 225 | for feat in feats: 226 | if len(feat) <= 15: 227 | embeds.append(embed.cpu().numpy()) 228 | else: 229 | feat = feat.unsqueeze(0).to(device) 230 | embed = generator(feat) 231 | embeds.append(embed.cpu().numpy()) 232 | embeds = np.vstack(embeds) 233 | embeds = normalize(embeds, axis=1) 234 | 235 | # Compute similarity matrix 236 | sim_matrix = pairwise_distances(embeds, metric='cosine') 237 | cluster_labels = agg_clustering_oracle(sim_matrix, num_spkrs) 238 | # TODO: Consider overlapped dataset, prototype for now 239 | 240 | # Write to rttm 241 | hyp_rttm_lines = make_rttm_lines(segcols, cluster_labels) 242 | 243 | 244 | lines_to_file(ref_rttm_lines, ref_rec_rttm) 245 | lines_to_file(hyp_rttm_lines, hyp_rec_rttm) 246 | 247 | # Eval based on recording level rttm 248 | der = score_der(hyp=hyp_rec_rttm, ref=ref_rec_rttm, 249 | outfile='/tmp/{}.derlog'.format(rec), collar=0.25) 250 | print('({}) DER for {}: {}'.format(args.cfg, rec, der)) 251 | 252 | rpkl[rec] = der 253 | pickle.dump(rpkl, open(results_pkl, 'wb')) 254 | 255 | all_ref_rttms.append(ref_rec_rttm) 256 | all_hyp_rttms.append(hyp_rec_rttm) 257 | 258 | final_hyp_rttm = os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint)) 259 | final_ref_rttm = os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint)) 260 | 261 | os.system('cat {} > {}'.format(' '.join(all_ref_rttms), final_ref_rttm)) 262 | os.system('cat {} > {}'.format(' '.join(all_hyp_rttms), final_hyp_rttm)) 263 | 264 | time.sleep(4) 265 | 266 | full_der = score_der(hyp=final_hyp_rttm, ref=final_ref_rttm, 267 | outfile=os.path.join(args.model_dir, 'final_{}_hyp.derlog'.format(args.checkpoint)), collar=0.25) 268 | print('({}) Full test DER: {}'.format(args.cfg, full_der)) 269 | 270 | rpkl['full_der'] = full_der 271 | pickle.dump(rpkl, open(results_pkl, 'wb')) 272 | 273 | uem_file = os.path.join(TEST_DATA_PATH, 'uem') 274 | der_uem = score_der_uem(hyp=os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint)), 275 | ref=os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint)), 276 | outfile=os.path.join(args.model_dir, 'final_hyp_uem.derlog'), 277 | uem=uem_file, 278 | collar=0.25) 279 | rpkl['full_der_uem'] = der_uem 280 | print('({}) Full test DER (uem): {}'.format(args.cfg, rpkl['full_der_uem'])) 281 | pickle.dump(rpkl, open(results_pkl, 'wb')) 282 | 283 | 284 | if __name__ == "__main__": 285 | args = parse_args() 286 | args = parse_config(args) 287 | uvloop.install() 288 | rpkl_path = os.path.join(args.model_dir, 'results.p') 289 | if not os.path.isfile(rpkl_path): 290 | print('No results.p found') 291 | else: 292 | device = torch.device('cuda') 293 | 294 | TEST_DATA_PATH = args.diar_data 295 | 296 | generator = setup() 297 | extract_and_diarize(generator, TEST_DATA_PATH, device) 298 | -------------------------------------------------------------------------------- /figures/multitask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvqluu/MTL-Speaker-Embeddings/a1b892f788eecc3bb458bba547c62ec2dac3c661/figures/multitask.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import os 4 | import pickle 5 | import sys 6 | import h5py 7 | import json 8 | 9 | from collections import OrderedDict 10 | from glob import glob 11 | from math import floor, log10 12 | 13 | import numpy as np 14 | 15 | import kaldi_io 16 | import torch 17 | import uvloop 18 | from data_io import SpeakerTestDataset, odict_from_2_col, SpeakerDataset 19 | from kaldi_io import read_vec_flt 20 | from kaldiio import ReadHelper 21 | from models.extractors import ETDNN, FTDNN, XTDNN 22 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 23 | from sklearn.linear_model import LogisticRegression 24 | from sklearn.metrics import roc_curve 25 | from sklearn.metrics.pairwise import cosine_distances, cosine_similarity 26 | from sklearn.preprocessing import normalize 27 | from tqdm import tqdm 28 | from utils import SpeakerRecognitionMetrics 29 | 30 | 31 | 32 | def mtd(stuff, device): 33 | if isinstance(stuff, torch.Tensor): 34 | return stuff.to(device) 35 | else: 36 | return [mtd(s, device) for s in stuff] 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description='Test SV Model') 40 | parser.add_argument('--cfg', type=str, default='./configs/example_speaker.cfg') 41 | parser.add_argument('--best', action='store_true', default=False, help='Use best model') 42 | parser.add_argument('--checkpoint', type=int, default=-1, # which model to use, overidden by 'best' 43 | help='Use model checkpoint, default -1 uses final model') 44 | args = parser.parse_args() 45 | assert os.path.isfile(args.cfg) 46 | return args 47 | 48 | def parse_config(args): 49 | assert os.path.isfile(args.cfg) 50 | config = configparser.ConfigParser() 51 | config.read(args.cfg) 52 | 53 | args.train_data = config['Datasets'].get('train') 54 | assert args.train_data 55 | args.test_data = config['Datasets'].get('test') 56 | 57 | args.model_type = config['Model'].get('model_type', fallback='XTDNN') 58 | assert args.model_type in ['XTDNN', 'ETDNN', 'FTDNN'] 59 | 60 | args.classifier_heads = config['Model'].get('classifier_heads').split(',') 61 | assert len(args.classifier_heads) <= 3, 'Three options available' 62 | assert len(args.classifier_heads) == len(set(args.classifier_heads)) 63 | for clf in args.classifier_heads: 64 | assert clf in ['speaker', 'nationality', 'gender', 'age', 'age_regression'] 65 | 66 | args.classifier_types = config['Optim'].get('classifier_types').split(',') 67 | assert len(args.classifier_heads) == len(args.classifier_types) 68 | 69 | args.classifier_loss_weighting_type = config['Optim'].get('classifier_loss_weighting_type', fallback='none') 70 | assert args.classifier_loss_weighting_type in ['none', 'uncertainty_kendall', 'uncertainty_liebel', 'dwa'] 71 | 72 | args.dwa_temperature = config['Optim'].getfloat('dwa_temperature', fallback=2.) 73 | 74 | args.classifier_loss_weights = np.array(json.loads(config['Optim'].get('classifier_loss_weights'))).astype(float) 75 | assert len(args.classifier_heads) == len(args.classifier_loss_weights) 76 | 77 | args.classifier_lr_mults = np.array(json.loads(config['Optim'].get('classifier_lr_mults'))).astype(float) 78 | assert len(args.classifier_heads) == len(args.classifier_lr_mults) 79 | 80 | # assert clf_type in ['l2softmax', 'adm', 'adacos', 'xvec', 'arcface', 'sphereface', 'softmax'] 81 | 82 | args.classifier_smooth_types = config['Optim'].get('classifier_smooth_types').split(',') 83 | assert len(args.classifier_smooth_types) == len(args.classifier_heads) 84 | args.classifier_smooth_types = [s.strip() for s in args.classifier_smooth_types] 85 | 86 | args.label_smooth_type = config['Optim'].get('label_smooth_type', fallback='None') 87 | assert args.label_smooth_type in ['None', 'disturb', 'uniform'] 88 | args.label_smooth_prob = config['Optim'].getfloat('label_smooth_prob', fallback=0.1) 89 | 90 | args.input_dim = config['Hyperparams'].getint('input_dim', fallback=30) 91 | args.embedding_dim = config['Hyperparams'].getint('embedding_dim', fallback=512) 92 | args.num_iterations = config['Hyperparams'].getint('num_iterations', fallback=50000) 93 | 94 | 95 | args.model_dir = config['Outputs']['model_dir'] 96 | if not hasattr(args, 'basefolder'): 97 | args.basefolder = config['Outputs'].get('basefolder', fallback=None) 98 | args.log_file = os.path.join(args.model_dir, 'train.log') 99 | args.results_pkl = os.path.join(args.model_dir, 'results.p') 100 | 101 | args.num_age_bins = config['Misc'].getint('num_age_bins', fallback=10) 102 | args.age_label_smoothing = config['Misc'].getboolean('age_label_smoothing', fallback=False) 103 | return args 104 | 105 | 106 | def test(generator, ds_test, device, mindcf=False): 107 | generator.eval() 108 | all_embeds = [] 109 | all_utts = [] 110 | num_examples = len(ds_test.veri_utts) 111 | 112 | with torch.no_grad(): 113 | for i in range(num_examples): 114 | feats, utt = ds_test.__getitem__(i) 115 | feats = feats.unsqueeze(0).to(device) 116 | embeds = generator(feats) 117 | all_embeds.append(embeds.cpu().numpy()) 118 | all_utts.append(utt) 119 | 120 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 121 | all_embeds = np.vstack(all_embeds) 122 | all_embeds = normalize(all_embeds, axis=1) 123 | all_utts = np.array(all_utts) 124 | 125 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 126 | 127 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 128 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 129 | 130 | scores = metric.scores_from_pairs(emb0, emb1) 131 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs) 132 | generator.train() 133 | if mindcf: 134 | return eer, mindcf1, None 135 | else: 136 | return eer 137 | 138 | def test_fromdata(generator, all_feats, all_utts, ds_test, device, mindcf=False): 139 | generator.eval() 140 | all_embeds = [] 141 | 142 | with torch.no_grad(): 143 | for feats in all_feats: 144 | feats = feats.unsqueeze(0).to(device) 145 | embeds = generator(feats) 146 | all_embeds.append(embeds.cpu().numpy()) 147 | 148 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 149 | all_embeds = np.vstack(all_embeds) 150 | all_embeds = normalize(all_embeds, axis=1) 151 | all_utts = np.array(all_utts) 152 | 153 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 154 | 155 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 156 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 157 | 158 | scores = metric.scores_from_pairs(emb0, emb1) 159 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs) 160 | generator.train() 161 | if mindcf: 162 | return eer, mindcf1, None 163 | else: 164 | return eer 165 | 166 | 167 | def test_all_factors_multids(model_dict, ds_test, dsl, label_types, device): 168 | assert ds_test.test_mode 169 | 170 | for m in model_dict: 171 | model_dict[m]['model'].eval() 172 | 173 | label_types = [l for l in label_types if l not in ['speaker', 'rec']] 174 | 175 | with torch.no_grad(): 176 | feats, label_dict, all_utts = ds_test.get_test_items() 177 | all_embeds = [] 178 | pred_dict = {m: [] for m in label_types} 179 | for feat in tqdm(feats): 180 | feat = feat.unsqueeze(0).to(device) 181 | embed = model_dict['generator']['model'](model_dict['{}_ilayer'.format(dsl)]['model'](feat)) 182 | for m in label_types: 183 | dictkey = '{}_{}'.format(dsl, m) 184 | pred = torch.argmax(model_dict[dictkey]['model'](embed, label=None), dim=1) 185 | pred_dict[m].append(pred.cpu().numpy()[0]) 186 | all_embeds.append(embed.cpu().numpy()) 187 | 188 | accuracy_dict = {m: np.equal(label_dict[m], pred_dict[m]).sum() / len(all_utts) for m in label_types} 189 | 190 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 191 | all_embeds = np.vstack(all_embeds) 192 | all_embeds = normalize(all_embeds, axis=1) 193 | all_utts = np.array(all_utts) 194 | 195 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 196 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 197 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 198 | 199 | scores = metric.scores_from_pairs(emb0, emb1) 200 | print('Min score: {}, max score {}'.format(min(scores), max(scores))) 201 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs) 202 | 203 | for m in model_dict: 204 | model_dict[m]['model'].train() 205 | 206 | return eer, mindcf1, accuracy_dict 207 | 208 | 209 | def test_all_factors(model_dict, ds_test, device): 210 | assert ds_test.test_mode 211 | 212 | for m in model_dict: 213 | model_dict[m]['model'].eval() 214 | 215 | label_types = [l for l in ds_test.label_types if l in model_dict] 216 | 217 | with torch.no_grad(): 218 | feats, label_dict, all_utts = ds_test.get_test_items() 219 | all_embeds = [] 220 | pred_dict = {m: [] for m in label_types} 221 | for feat in tqdm(feats): 222 | feat = feat.unsqueeze(0).to(device) 223 | embed = model_dict['generator']['model'](feat) 224 | for m in label_types: 225 | if m.endswith('regression'): 226 | pred = model_dict[m]['model'](embed, label=None) 227 | else: 228 | pred = torch.argmax(model_dict[m]['model'](embed, label=None), dim=1) 229 | pred_dict[m].append(pred.cpu().numpy()[0]) 230 | all_embeds.append(embed.cpu().numpy()) 231 | 232 | accuracy_dict = {m: np.equal(label_dict[m], pred_dict[m]).sum() / len(all_utts) for m in label_types if not m.endswith('regression')} 233 | 234 | for m in label_types: 235 | if m.endswith('regression'): 236 | accuracy_dict[m] = np.mean((label_dict[m] - pred_dict[m])**2) 237 | 238 | if ds_test.veripairs: 239 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 240 | all_embeds = np.vstack(all_embeds) 241 | all_embeds = normalize(all_embeds, axis=1) 242 | all_utts = np.array(all_utts) 243 | 244 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 245 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 246 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 247 | 248 | scores = metric.scores_from_pairs(emb0, emb1) 249 | print('Min score: {}, max score {}'.format(min(scores), max(scores))) 250 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs) 251 | else: 252 | eer, mindcf1 = 0.0, 0.0 253 | 254 | for m in model_dict: 255 | model_dict[m]['model'].train() 256 | 257 | return eer, mindcf1, accuracy_dict 258 | 259 | 260 | def test_all_factors_ensemble(model_dict, ds_test, device, feats, all_utts, exclude=[], combine='sum'): 261 | assert ds_test.test_mode 262 | 263 | for m in model_dict: 264 | model_dict[m]['model'].eval() 265 | 266 | label_types = [l for l in ds_test.label_types if l in model_dict] 267 | set_veri_utts = set(list(ds_test.veri_0) + list(ds_test.veri_1)) 268 | aux_embeds_dict = {m: [] for m in label_types} 269 | 270 | with torch.no_grad(): 271 | veri_embeds = [] 272 | veri_utts = [] 273 | for feat, utt in tqdm(zip(feats, all_utts)): 274 | if utt in set_veri_utts: 275 | feat = feat.unsqueeze(0).to(device) 276 | embed = model_dict['generator']['model'](feat) 277 | veri_embeds.append(embed) 278 | veri_utts.append(utt) 279 | 280 | veri_embeds = torch.cat(veri_embeds) 281 | for m in label_types: 282 | task_embeds = model_dict[m]['model'](veri_embeds, label=None, transform=True) 283 | aux_embeds_dict[m] = normalize(task_embeds.cpu().numpy(), axis=1) 284 | 285 | veri_embeds = normalize(veri_embeds.cpu().numpy(), axis=1) 286 | 287 | aux_embeds_dict['base'] = veri_embeds 288 | 289 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 290 | 291 | total_scores = [] 292 | 293 | for key in aux_embeds_dict: 294 | if key not in exclude: 295 | utt_embed = OrderedDict({k: v for k, v in zip(veri_utts, aux_embeds_dict[key])}) 296 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 297 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 298 | scores = metric.scores_from_pairs(emb0, emb1) 299 | total_scores.append(scores) 300 | 301 | if combine == 'sum': 302 | total_scores = np.sum(np.array(total_scores), axis=0) 303 | eer, mindcf1 = metric.compute_min_cost(total_scores, 1. - ds_test.veri_labs) 304 | else: 305 | total_scores = np.array(total_scores).T 306 | lr_clf = LogisticRegression(solver='lbfgs') 307 | lr_clf.fit(total_scores, 1. - ds_test.veri_labs) 308 | weighted_scores = lr_clf.predict(total_scores) 309 | eer, mindcf1 = metric.compute_min_cost(weighted_scores, 1. - ds_test.veri_labs) 310 | 311 | for m in model_dict: 312 | model_dict[m]['model'].train() 313 | 314 | return eer, mindcf1 315 | 316 | 317 | def test_enrollment_models(generator, ds_test, device, return_scores=False, reduce_method='mean_embed'): 318 | assert reduce_method in ['mean_embed', 'mean_dist', 'max', 'min'] 319 | generator.eval() 320 | all_embeds = [] 321 | all_utts = [] 322 | num_examples = len(ds_test) 323 | 324 | with torch.no_grad(): 325 | for i in tqdm(range(num_examples)): 326 | feats, utt = ds_test.__getitem__(i) 327 | feats = feats.unsqueeze(0).to(device) 328 | embeds = generator(feats) 329 | all_embeds.append(embeds.cpu().numpy()) 330 | all_utts.append(utt) 331 | 332 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 333 | all_embeds = np.vstack(all_embeds) 334 | all_embeds = normalize(all_embeds, axis=1) 335 | all_utts = np.array(all_utts) 336 | 337 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 338 | model_embeds = OrderedDict({}) 339 | for i, model in enumerate(ds_test.models): 340 | model_embeds[model] = np.array([utt_embed[utt] for utt in ds_test.m_utts[i]]) 341 | 342 | emb0s = np.array([model_embeds[u] for u in ds_test.models_eval]) 343 | emb1 = np.array([utt_embed[u] for u in ds_test.eval_utts]) 344 | 345 | mscore_means = [] 346 | mscore_stds = [] 347 | scores = [] 348 | for model_utts, test_utt in zip(emb0s, emb1): 349 | if reduce_method == 'mean_embed': 350 | model_std = model_utts.std(0) 351 | model_mean = model_utts.mean(0) 352 | scores.append(np.linalg.norm(test_utt - model_mean)) 353 | elif reduce_method == 'mean_dist': 354 | dist_means = np.mean(np.array([np.linalg.norm(test_utt - e) for e in model_utts])) 355 | scores.append(dist_means) 356 | elif reduce_method == 'min': 357 | scores.append(np.min(np.array([np.linalg.norm(test_utt - e) for e in model_utts]))) 358 | elif reduce_method == 'max': 359 | scores.append(np.max(np.array([np.linalg.norm(test_utt - e) for e in model_utts]))) 360 | else: 361 | print('do nothing') 362 | 363 | scores = np.array(scores) 364 | if return_scores: 365 | return scores 366 | 367 | eer, mindcf1 = metric.compute_min_cost(scores, 368 | 1 - ds_test.veri_labs) 369 | generator.train() 370 | return eer, mindcf1, scores 371 | 372 | 373 | def test_nosil(generator, ds_test, device, mindcf=False): 374 | generator.eval() 375 | all_embeds = [] 376 | all_utts = [] 377 | num_examples = len(ds_test.veri_utts) 378 | 379 | with torch.no_grad(): 380 | with ReadHelper( 381 | 'ark:apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 scp:{0}/feats_trimmed.scp ' 382 | 'ark:- | select-voiced-frames ark:- scp:{0}/vad_trimmed.scp ark:- |'.format( 383 | ds_test.data_base_path)) as reader: 384 | for key, feat in tqdm(reader, total=num_examples): 385 | if key in ds_test.veri_utts: 386 | all_utts.append(key) 387 | feats = torch.FloatTensor(feat).unsqueeze(0).to(device) 388 | embeds = generator(feats) 389 | all_embeds.append(embeds.cpu().numpy()) 390 | 391 | metric = SpeakerRecognitionMetrics(distance_measure='cosine') 392 | all_embeds = np.vstack(all_embeds) 393 | all_embeds = normalize(all_embeds, axis=1) 394 | all_utts = np.array(all_utts) 395 | 396 | print(all_embeds.shape, len(ds_test.veri_utts)) 397 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)}) 398 | 399 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0]) 400 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1]) 401 | 402 | scores = metric.scores_from_pairs(emb0, emb1) 403 | fpr, tpr, thresholds = roc_curve(1 - ds_test.veri_labs, scores, pos_label=1, drop_intermediate=False) 404 | eer = metric.eer_from_ers(fpr, tpr) 405 | generator.train() 406 | if mindcf: 407 | mindcf1 = metric.compute_min_dcf(fpr, tpr, thresholds, p_target=0.01) 408 | mindcf2 = metric.compute_min_dcf(fpr, tpr, thresholds, p_target=0.001) 409 | return eer, mindcf1, mindcf2 410 | else: 411 | return eer 412 | 413 | 414 | def evaluate_deepmine(generator, ds_eval, device, outfile_path='./exp'): 415 | os.makedirs(outfile_path, exist_ok=True) 416 | generator.eval() 417 | 418 | answer_col0 = [] 419 | answer_col1 = [] 420 | answer_col2 = [] 421 | 422 | with torch.no_grad(): 423 | for i in tqdm(range(len(ds_eval))): 424 | model, enrol_utts, enrol_feats, eval_utts, eval_feats = ds_eval.__getitem__(i) 425 | answer_col0.append([model for _ in range(len(eval_utts))]) 426 | answer_col1.append(eval_utts) 427 | 428 | enrol_feats = mtd(enrol_feats, device) 429 | model_embed = torch.cat([generator(x.unsqueeze(0)) for x in enrol_feats]).cpu().numpy() 430 | model_embed = np.mean(normalize(model_embed, axis=1), axis=0).reshape(1, -1) 431 | 432 | del enrol_feats 433 | eval_feats = mtd(eval_feats, device) 434 | eval_embeds = torch.cat([generator(x.unsqueeze(0)) for x in eval_feats]).cpu().numpy() 435 | eval_embeds = normalize(eval_embeds, axis=1) 436 | 437 | scores = cosine_similarity(model_embed, eval_embeds).squeeze(0) 438 | assert len(scores) == len(eval_utts) 439 | answer_col2.append(scores) 440 | del eval_feats 441 | 442 | answer_col0 = np.concatenate(answer_col0) 443 | answer_col1 = np.concatenate(answer_col1) 444 | answer_col2 = np.concatenate(answer_col2) 445 | 446 | with open(os.path.join(outfile_path, 'answer_full.txt'), 'w+') as fp: 447 | for m, ev, s in zip(answer_col0, answer_col1, answer_col2): 448 | line = '{} {} {}\n'.format(m, ev, s) 449 | fp.write(line) 450 | 451 | with open(os.path.join(outfile_path, 'answer.txt'), 'w+') as fp: 452 | for s in answer_col2: 453 | line = '{}\n'.format(s) 454 | fp.write(line) 455 | 456 | if (answer_col0 == np.array(ds_eval.models_eval)).all(): 457 | print('model ordering matched') 458 | else: 459 | print('model ordering was not correct, need to fix before submission') 460 | 461 | if (answer_col1 == np.array(ds_eval.eval_utts)).all(): 462 | print('eval utt ordering matched') 463 | else: 464 | print('eval utt ordering was not correct, need to fix before submission') 465 | 466 | 467 | def dvec_compute(generator, ds_eval, device, num_jobs=20, outfolder='./exp/example_dvecs'): 468 | # naively compute the embeddings for each window 469 | # ds_len = len(ds_feats) 470 | all_utts = ds_eval.all_utts 471 | ds_len = len(all_utts) 472 | indices = np.arange(ds_len) 473 | job_split = np.array_split(indices, num_jobs) 474 | generator.eval().to(device) 475 | for job_num, job in enumerate(tqdm(job_split)): 476 | print('Starting job {}'.format(job_num)) 477 | ark_scp_output = 'ark:| copy-vector ark:- ark,scp:{0}/xvector.{1}.ark,{0}/xvector.{1}.scp'.format(outfolder, 478 | job_num + 1) 479 | job_utts = all_utts[job] 480 | job_feats = ds_eval.get_batches(job_utts) 481 | job_feats = mtd(job_feats, device) 482 | with torch.no_grad(): 483 | job_embeds = torch.cat([generator(x.unsqueeze(0)) for x in tqdm(job_feats)]).cpu().numpy() 484 | with kaldi_io.open_or_fd(ark_scp_output, 'wb') as f: 485 | for xvec, key in zip(job_embeds, job_utts): 486 | kaldi_io.write_vec_flt(f, xvec, key=key) 487 | 488 | 489 | def evaluate_deepmine_from_xvecs(ds_eval, outfolder='./exp/example_xvecs'): 490 | if not os.path.isfile(os.path.join(outfolder, 'xvector.scp')): 491 | xvec_scps = glob(os.path.join(outfolder, '*.scp')) 492 | assert len(xvec_scps) != 0, 'No xvector scps found' 493 | with open(os.path.join(outfolder, 'xvector.scp'), 'w+') as outfile: 494 | for fname in xvec_scps: 495 | with open(fname) as infile: 496 | for line in infile: 497 | outfile.write(line) 498 | 499 | xvec_dict = odict_from_2_col(os.path.join(outfolder, 'xvector.scp')) 500 | answer_col0 = [] 501 | answer_col1 = [] 502 | answer_col2 = [] 503 | 504 | for i in tqdm(range(len(ds_eval))): 505 | model, enrol_utts, eval_utts, = ds_eval.get_item_utts(i) 506 | answer_col0.append([model for _ in range(len(eval_utts))]) 507 | answer_col1.append(eval_utts) 508 | 509 | model_embeds = np.array([read_vec_flt(xvec_dict[u]) for u in enrol_utts]) 510 | model_embed = np.mean(normalize(model_embeds, axis=1), axis=0).reshape(1, -1) 511 | 512 | eval_embeds = np.array([read_vec_flt(xvec_dict[u]) for u in eval_utts]) 513 | eval_embeds = normalize(eval_embeds, axis=1) 514 | 515 | scores = cosine_similarity(model_embed, eval_embeds).squeeze(0) 516 | assert len(scores) == len(eval_utts) 517 | answer_col2.append(scores) 518 | 519 | answer_col0 = np.concatenate(answer_col0) 520 | answer_col1 = np.concatenate(answer_col1) 521 | answer_col2 = np.concatenate(answer_col2) 522 | 523 | print('Writing results to file...') 524 | with open(os.path.join(outfolder, 'answer_full.txt'), 'w+') as fp: 525 | for m, ev, s in tqdm(zip(answer_col0, answer_col1, answer_col2)): 526 | line = '{} {} {}\n'.format(m, ev, s) 527 | fp.write(line) 528 | 529 | with open(os.path.join(outfolder, 'answer.txt'), 'w+') as fp: 530 | for s in tqdm(answer_col2): 531 | line = '{}\n'.format(s) 532 | fp.write(line) 533 | 534 | if (answer_col0 == np.array(ds_eval.models_eval)).all(): 535 | print('model ordering matched') 536 | else: 537 | print('model ordering was not correct, need to fix before submission') 538 | 539 | if (answer_col1 == np.array(ds_eval.eval_utts)).all(): 540 | print('eval utt ordering matched') 541 | else: 542 | print('eval utt ordering was not correct, need to fix before submission') 543 | 544 | 545 | def round_sig(x, sig=2): 546 | return round(x, sig - int(floor(log10(abs(x)))) - 1) 547 | 548 | 549 | def get_eer_metrics(folder): 550 | rpkl_path = os.path.join(folder, 'wrec_results.p') 551 | rpkl = pickle.load(open(rpkl_path, 'rb')) 552 | iterations = list(rpkl.keys()) 553 | eers = [rpkl[k]['test_eer'] for k in rpkl] 554 | return iterations, eers, np.min(eers) 555 | 556 | 557 | def extract_test_embeds(generator, ds_test, device): 558 | assert ds_test.test_mode 559 | 560 | with torch.no_grad(): 561 | feats, label_dict, all_utts = ds_test.get_test_items() 562 | all_embeds = [] 563 | for feat in tqdm(feats): 564 | feat = feat.unsqueeze(0).to(device) 565 | embed = generator(feat) 566 | all_embeds.append(embed.cpu().numpy()) 567 | 568 | all_embeds = np.vstack(all_embeds) 569 | all_embeds = normalize(all_embeds, axis=1) 570 | all_utts = np.array(all_utts) 571 | 572 | return all_utts, all_embeds 573 | 574 | if __name__ == "__main__": 575 | args = parse_args() 576 | args = parse_config(args) 577 | uvloop.install() 578 | 579 | hf = h5py.File(os.path.join(args.model_dir, 'xvectors.h5'), 'w') 580 | 581 | use_cuda = torch.cuda.is_available() 582 | print('=' * 30) 583 | print('USE_CUDA SET TO: {}'.format(use_cuda)) 584 | print('CUDA AVAILABLE?: {}'.format(torch.cuda.is_available())) 585 | print('=' * 30) 586 | device = torch.device("cuda" if use_cuda else "cpu") 587 | 588 | if args.checkpoint == -1: 589 | g_path = os.path.join(args.model_dir, "final_generator_{}.pt".format(args.num_iterations)) 590 | else: 591 | g_path = os.path.join(args.model_dir, "generator_{}.pt".format(args.checkpoint)) 592 | 593 | if args.model_type == 'XTDNN': 594 | generator = XTDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim) 595 | if args.model_type == 'ETDNN': 596 | generator = ETDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim) 597 | if args.model_type == 'FTDNN': 598 | generator = FTDNN(in_dim=args.input_dim, embedding_dim=args.embedding_dim) 599 | 600 | if args.best: 601 | iterations, eers, _ = get_eer_metrics(args.model_dir) 602 | best_it = iterations[np.argmin(eers)] 603 | g_path = os.path.join(args.model_dir, "generator_{}.pt".format(best_it)) 604 | 605 | assert os.path.isfile(g_path), "Couldn't find {}".format(g_path) 606 | 607 | ds_train = SpeakerDataset(args.train_data, num_age_bins=args.num_age_bins) 608 | class_enc_dict = ds_train.get_class_encs() 609 | if args.test_data: 610 | ds_test = SpeakerDataset(args.test_data, test_mode=True, 611 | class_enc_dict=class_enc_dict, num_age_bins=args.num_age_bins) 612 | 613 | generator.load_state_dict(torch.load(g_path)) 614 | model = generator 615 | model.eval().to(device) 616 | 617 | utts, embeds = extract_test_embeds(generator, ds_test, device) 618 | 619 | hf.create_dataset('embeds', data=embeds) 620 | dt = h5py.string_dtype() 621 | hf.create_dataset('utts', data=np.string_(np.array(utts)), dtype=dt) 622 | hf.close() 623 | 624 | -------------------------------------------------------------------------------- /models/classifiers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Function 10 | 11 | ''' 12 | AdaCos and Ad margin loss taken from https://github.com/4uiiurz1/pytorch-adacos 13 | ''' 14 | 15 | class DropAffine(nn.Module): 16 | 17 | def __init__(self, num_features, num_classes): 18 | super(DropAffine, self).__init__() 19 | self.fc = nn.Linear(num_features, num_classes) 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | self.fc.reset_parameters() 24 | 25 | def forward(self, input, label=None): 26 | W = self.fc.weight 27 | b = self.fc.bias 28 | logits = F.linear(input, W, b) 29 | return logits 30 | 31 | 32 | class GradientReversalFunction(Function): 33 | """ 34 | Gradient Reversal Layer from: 35 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 36 | Forward pass is the identity function. In the backward pass, 37 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 38 | """ 39 | 40 | @staticmethod 41 | def forward(ctx, x, lambda_): 42 | ctx.lambda_ = lambda_ 43 | return x.clone() 44 | 45 | @staticmethod 46 | def backward(ctx, grads): 47 | lambda_ = ctx.lambda_ 48 | lambda_ = grads.new_tensor(lambda_) 49 | dx = -lambda_ * grads 50 | return dx, None 51 | 52 | 53 | class GradientReversal(nn.Module): 54 | def __init__(self, lambda_=1): 55 | super(GradientReversal, self).__init__() 56 | self.lambda_ = lambda_ 57 | 58 | def forward(self, x, **kwargs): 59 | return GradientReversalFunction.apply(x, self.lambda_) 60 | 61 | 62 | 63 | class L2SoftMax(nn.Module): 64 | 65 | def __init__(self, num_features, num_classes): 66 | super(L2SoftMax, self).__init__() 67 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 68 | self.reset_parameters() 69 | 70 | def reset_parameters(self): 71 | nn.init.xavier_uniform_(self.W) 72 | 73 | def forward(self, input, label=None): 74 | x = F.normalize(input) 75 | W = F.normalize(self.W) 76 | logits = F.linear(x, W) 77 | return logits 78 | 79 | class SoftMax(nn.Module): 80 | 81 | def __init__(self, num_features, num_classes): 82 | super(SoftMax, self).__init__() 83 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 84 | self.reset_parameters() 85 | 86 | def reset_parameters(self): 87 | nn.init.xavier_uniform_(self.W) 88 | 89 | def forward(self, input, label=None): 90 | x = input 91 | W = self.W 92 | logits = F.linear(x, W) 93 | return logits 94 | 95 | 96 | class LinearUncertain(nn.Module): 97 | 98 | def __init__(self, in_features, out_features, bias=True): 99 | super(LinearUncertain, self).__init__() 100 | self.in_features = in_features 101 | self.out_features = out_features 102 | self.bias = bias 103 | 104 | self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features)) 105 | self.weight_beta = nn.Parameter(torch.Tensor(out_features, in_features)) 106 | 107 | if bias: 108 | self.bias = nn.Parameter(torch.Tensor(out_features)) 109 | else: 110 | self.register_parameter('bias', None) 111 | self.reset_parameters() 112 | 113 | def reset_parameters(self): 114 | nn.init.xavier_uniform_(self.weight_mu, gain=1.0) 115 | init_beta = np.log(np.exp(0.5 * np.sqrt(6)/(self.in_features+self.out_features)) - 1) 116 | nn.init.constant_(self.weight_beta, init_beta) 117 | if self.bias is not None: 118 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu) 119 | bound = 1 / np.sqrt(fan_in) 120 | nn.init.uniform_(self.bias, -bound, bound) 121 | 122 | 123 | def forward(self, x): 124 | if self.training: 125 | eps = torch.randn(self.out_features, self.in_features).to(self.weight_mu.device) 126 | weights = self.weight_mu + torch.log(torch.exp(self.weight_beta) + 1) * eps 127 | else: 128 | weights = self.weight_mu 129 | return F.linear(x, weights, self.bias) 130 | 131 | 132 | class XVecHead(nn.Module): 133 | 134 | def __init__(self, num_features, num_classes, hidden_features=None): 135 | super(XVecHead, self).__init__() 136 | hidden_features = num_features if not hidden_features else hidden_features 137 | self.fc_hidden = nn.Linear(num_features, hidden_features) 138 | self.nl = nn.LeakyReLU() 139 | self.bn = nn.BatchNorm1d(hidden_features) 140 | self.fc = nn.Linear(hidden_features, num_classes) 141 | self.reset_parameters() 142 | 143 | def reset_parameters(self): 144 | self.fc.reset_parameters() 145 | 146 | def forward(self, input, label=None, transform=False): 147 | input = self.fc_hidden(input) 148 | input = self.nl(input) 149 | input = self.bn(input) 150 | if transform: 151 | return input 152 | W = self.fc.weight 153 | b = self.fc.bias 154 | logits = F.linear(input, W, b) 155 | if logits.shape[-1] == 1: 156 | logits = torch.squeeze(logits, dim=-1) 157 | return logits 158 | 159 | class XVecHeadUncertain(nn.Module): 160 | 161 | def __init__(self, num_features, num_classes, hidden_features=None): 162 | super().__init__() 163 | hidden_features = num_features if not hidden_features else hidden_features 164 | self.fc_hidden = LinearUncertain(num_features, hidden_features) 165 | self.nl = nn.LeakyReLU() 166 | self.bn = nn.BatchNorm1d(hidden_features) 167 | self.fc = LinearUncertain(hidden_features, num_classes) 168 | self.reset_parameters() 169 | 170 | def reset_parameters(self): 171 | self.fc.reset_parameters() 172 | 173 | def forward(self, input, label=None, transform=False): 174 | input = self.fc_hidden(input) 175 | input = self.nl(input) 176 | input = self.bn(input) 177 | if transform: 178 | return input 179 | logits = self.fc(input) 180 | return logits 181 | 182 | 183 | class AMSMLoss(nn.Module): 184 | 185 | def __init__(self, num_features, num_classes, s=30.0, m=0.4): 186 | super(AMSMLoss, self).__init__() 187 | self.num_features = num_features 188 | self.n_classes = num_classes 189 | self.s = s 190 | self.m = m 191 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 192 | self.reset_parameters() 193 | 194 | def reset_parameters(self): 195 | nn.init.xavier_uniform_(self.W) 196 | 197 | def forward(self, input, label=None): 198 | # normalize features 199 | x = F.normalize(input) 200 | # normalize weights 201 | W = F.normalize(self.W) 202 | # dot product 203 | logits = F.linear(x, W) 204 | if label is None: 205 | return logits 206 | # add margin 207 | target_logits = logits - self.m 208 | one_hot = torch.zeros_like(logits) 209 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 210 | output = logits * (1 - one_hot) + target_logits * one_hot 211 | # feature re-scale 212 | output *= self.s 213 | 214 | return output 215 | 216 | 217 | 218 | class SphereFace(nn.Module): 219 | 220 | def __init__(self, num_features, num_classes, s=30.0, m=1.35): 221 | super(SphereFace, self).__init__() 222 | self.num_features = num_features 223 | self.n_classes = num_classes 224 | self.s = s 225 | self.m = m 226 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 227 | self.reset_parameters() 228 | 229 | def reset_parameters(self): 230 | nn.init.xavier_uniform_(self.W) 231 | 232 | def forward(self, input, label=None): 233 | # normalize features 234 | x = F.normalize(input) 235 | # normalize weights 236 | W = F.normalize(self.W) 237 | # dot product 238 | logits = F.linear(x, W) 239 | if label is None: 240 | return logits 241 | # add margin 242 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) 243 | target_logits = torch.cos(self.m * theta) 244 | one_hot = torch.zeros_like(logits) 245 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 246 | output = logits * (1 - one_hot) + target_logits * one_hot 247 | # feature re-scale 248 | output *= self.s 249 | 250 | return output 251 | 252 | 253 | class ArcFace(nn.Module): 254 | 255 | def __init__(self, num_features, num_classes, s=30.0, m=0.50): 256 | super(ArcFace, self).__init__() 257 | self.num_features = num_features 258 | self.n_classes = num_classes 259 | self.s = s 260 | self.m = m 261 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 262 | self.reset_parameters() 263 | 264 | def reset_parameters(self): 265 | nn.init.xavier_uniform_(self.W) 266 | 267 | def forward(self, input, label=None): 268 | # normalize features 269 | x = F.normalize(input) 270 | # normalize weights 271 | W = F.normalize(self.W) 272 | # dot product 273 | logits = F.linear(x, W) 274 | if label is None: 275 | return logits 276 | # add margin 277 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) 278 | target_logits = torch.cos(theta + self.m) 279 | one_hot = torch.zeros_like(logits) 280 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 281 | output = logits * (1 - one_hot) + target_logits * one_hot 282 | # feature re-scale 283 | output *= self.s 284 | 285 | return output 286 | 287 | class AdaCos(nn.Module): 288 | 289 | def __init__(self, num_features, num_classes, m=0.50): 290 | super(AdaCos, self).__init__() 291 | self.num_features = num_features 292 | self.n_classes = num_classes 293 | self.s = math.sqrt(2) * math.log(num_classes - 1) 294 | self.m = m 295 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features)) 296 | self.reset_parameters() 297 | 298 | def reset_parameters(self): 299 | nn.init.xavier_uniform_(self.W) 300 | 301 | def forward(self, input, label=None): 302 | # normalize features 303 | x = F.normalize(input) 304 | # normalize weights 305 | W = F.normalize(self.W) 306 | # dot product 307 | logits = F.linear(x, W) 308 | if label is None: 309 | return logits 310 | # feature re-scale 311 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) 312 | one_hot = torch.zeros_like(logits) 313 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 314 | with torch.no_grad(): 315 | B_avg = torch.where(one_hot < 1, torch.exp(self.s * logits), torch.zeros_like(logits)) 316 | B_avg = torch.sum(B_avg) / input.size(0) 317 | theta_med = torch.median(theta[one_hot == 1]) 318 | self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med)) 319 | output = self.s * logits 320 | 321 | return output 322 | 323 | -------------------------------------------------------------------------------- /models/criteria.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MultiTaskUncertaintyLossKendall(nn.Module): 5 | 6 | def __init__(self, num_tasks): 7 | """ 8 | Multi task loss with uncertainty weighting (Kendall 2016) 9 | \eta = 2*log(\sigma) 10 | """ 11 | super().__init__() 12 | self.num_tasks = num_tasks 13 | self.eta = nn.Parameter(torch.zeros(num_tasks)) 14 | 15 | def forward(self, losses): 16 | """ 17 | Input: 1-d tensor of scalar losses, shape (num_tasks,) 18 | Output: Total weighted loss 19 | """ 20 | assert len(losses) == self.num_tasks, 'Expected {} losses to weight, got {}'.format(self.num_tasks, len(losses)) 21 | # factor = torch.pow(2*torch.exp(self.eta) - 2, -1) 22 | factor = torch.exp(-self.eta)/2. 23 | total_loss = (losses * factor + self.eta).sum() 24 | return total_loss/self.num_tasks 25 | 26 | 27 | class MultiTaskUncertaintyLossLiebel(nn.Module): 28 | 29 | def __init__(self, num_tasks): 30 | """ 31 | Multi task loss with uncertainty weighting 32 | Liebel (2018) version 33 | Regularisation term ln(1 + sigma^2) 34 | """ 35 | super().__init__() 36 | self.num_tasks = num_tasks 37 | self.sigma2 = nn.Parameter(0.25*torch.ones(num_tasks)) 38 | 39 | def forward(self, losses): 40 | """ 41 | Input: 1-d tensor of scalar losses, shape (num_tasks,) 42 | Output: Total weighted loss 43 | """ 44 | assert len(losses) == self.num_tasks, 'Expected {} losses to weight, got {}'.format(self.num_tasks, len(losses)) 45 | factor = 1./(2*self.sigma2) 46 | reg = torch.log(1. + self.sigma2) #regularisation term 47 | total_loss = (losses * factor + reg).sum() 48 | return total_loss/self.num_tasks 49 | 50 | class DisturbLabelLoss(nn.Module): 51 | 52 | def __init__(self, device, disturb_prob=0.1): 53 | super(DisturbLabelLoss, self).__init__() 54 | self.disturb_prob = disturb_prob 55 | self.ce = nn.CrossEntropyLoss() 56 | self.device = device 57 | 58 | def forward(self, pred, target): 59 | with torch.no_grad(): 60 | disturb_indexes = torch.rand(len(pred)) < self.disturb_prob 61 | target[disturb_indexes] = torch.randint(pred.shape[-1], (int(disturb_indexes.sum()),)).to(self.device) 62 | return self.ce(pred, target) 63 | 64 | 65 | class LabelSmoothingLoss(nn.Module): 66 | 67 | def __init__(self, smoothing=0.1, dim=-1): 68 | super(LabelSmoothingLoss, self).__init__() 69 | self.confidence = 1.0 - smoothing 70 | self.smoothing = smoothing 71 | self.dim = dim 72 | 73 | def forward(self, pred, target): 74 | pred = pred.log_softmax(dim=self.dim) 75 | with torch.no_grad(): 76 | true_dist = torch.zeros_like(pred) 77 | true_dist.fill_(self.smoothing / (pred.shape[-1] - 1)) 78 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 79 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 80 | 81 | 82 | class TwoNeighbourSmoothingLoss(nn.Module): 83 | 84 | def __init__(self, smoothing=0.1, dim=-1): 85 | super().__init__() 86 | self.dim = dim 87 | self.smoothing = smoothing 88 | self.confidence = 1.0 - smoothing 89 | 90 | 91 | def forward(self, pred, target): 92 | pred = pred.log_softmax(dim=self.dim) 93 | num_classes = pred.shape[self.dim] 94 | with torch.no_grad(): 95 | targets = target.data.unsqueeze(1) 96 | true_dist = torch.zeros_like(pred) 97 | 98 | up_labels = targets.add(1) 99 | up_labels[up_labels >= num_classes] = num_classes - 2 100 | down_labels = targets.add(-1) 101 | down_labels[down_labels < 0] = 1 102 | 103 | smooth_values = torch.zeros_like(targets).float().add_(self.smoothing/2) 104 | true_dist.scatter_add_(1, up_labels, smooth_values) 105 | true_dist.scatter_add_(1, down_labels, smooth_values) 106 | 107 | true_dist.scatter_(1, targets, self.confidence) 108 | 109 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) -------------------------------------------------------------------------------- /models/extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from collections import OrderedDict 6 | import numpy as np 7 | 8 | 9 | class TDNN(nn.Module): 10 | 11 | def __init__( 12 | self, 13 | input_dim=23, 14 | output_dim=512, 15 | context_size=5, 16 | stride=1, 17 | dilation=1, 18 | batch_norm=True, 19 | dropout_p=0.0, 20 | padding=0 21 | ): 22 | super(TDNN, self).__init__() 23 | self.context_size = context_size 24 | self.stride = stride 25 | self.input_dim = input_dim 26 | self.output_dim = output_dim 27 | self.dilation = dilation 28 | self.dropout_p = dropout_p 29 | self.padding = padding 30 | 31 | self.kernel = nn.Conv1d(self.input_dim, 32 | self.output_dim, 33 | self.context_size, 34 | stride=self.stride, 35 | padding=self.padding, 36 | dilation=self.dilation) 37 | 38 | self.nonlinearity = nn.LeakyReLU() 39 | self.batch_norm = batch_norm 40 | if batch_norm: 41 | self.bn = nn.BatchNorm1d(output_dim) 42 | self.drop = nn.Dropout(p=self.dropout_p) 43 | 44 | def forward(self, x): 45 | ''' 46 | input: size (batch, seq_len, input_features) 47 | outpu: size (batch, new_seq_len, output_features) 48 | ''' 49 | 50 | _, _, d = x.shape 51 | assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d) 52 | 53 | x = self.kernel(x.transpose(1,2)) 54 | x = self.nonlinearity(x) 55 | x = self.drop(x) 56 | 57 | if self.batch_norm: 58 | x = self.bn(x) 59 | return x.transpose(1,2) 60 | 61 | 62 | class StatsPool(nn.Module): 63 | 64 | def __init__(self, floor=1e-10, bessel=False): 65 | super(StatsPool, self).__init__() 66 | self.floor = floor 67 | self.bessel = bessel 68 | 69 | def forward(self, x): 70 | means = torch.mean(x, dim=1) 71 | _, t, _ = x.shape 72 | if self.bessel: 73 | t = t - 1 74 | residuals = x - means.unsqueeze(1) 75 | numerator = torch.sum(residuals**2, dim=1) 76 | stds = torch.sqrt(torch.clamp(numerator, min=self.floor)/t) 77 | x = torch.cat([means, stds], dim=1) 78 | return x 79 | 80 | 81 | class ETDNN(nn.Module): 82 | 83 | def __init__( 84 | self, 85 | features_per_frame=80, 86 | hidden_features=1024, 87 | embed_features=256, 88 | dropout_p=0.0, 89 | batch_norm=True 90 | ): 91 | super(ETDNN, self).__init__() 92 | self.features_per_frame = features_per_frame 93 | self.hidden_features = hidden_features 94 | self.embed_features = embed_features 95 | 96 | self.dropout_p = dropout_p 97 | self.batch_norm = batch_norm 98 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm} 99 | self.nl = nn.LeakyReLU() 100 | 101 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=self.hidden_features, context_size=5, dilation=1, **tdnn_kwargs) 102 | self.frame2 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs) 103 | self.frame3 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=2, **tdnn_kwargs) 104 | self.frame4 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs) 105 | self.frame5 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=3, **tdnn_kwargs) 106 | self.frame6 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs) 107 | self.frame7 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=4, **tdnn_kwargs) 108 | self.frame8 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs) 109 | self.frame9 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features*3, context_size=1, dilation=1, **tdnn_kwargs) 110 | 111 | self.tdnn_list = nn.Sequential(self.frame1, self.frame2, self.frame3, self.frame4, self.frame5, self.frame6, self.frame7, self.frame8, self.frame9) 112 | self.statspool = StatsPool() 113 | 114 | self.fc_embed = nn.Linear(self.hidden_features*6, self.embed_features) 115 | 116 | def forward(self, x): 117 | x = self.tdnn_list(x) 118 | x = self.statspool(x) 119 | x = self.fc_embed(x) 120 | return x 121 | 122 | 123 | class XTDNN(nn.Module): 124 | 125 | def __init__( 126 | self, 127 | features_per_frame=30, 128 | final_features=1500, 129 | embed_features=512, 130 | dropout_p=0.0, 131 | batch_norm=True 132 | ): 133 | super(XTDNN, self).__init__() 134 | self.features_per_frame = features_per_frame 135 | self.final_features = final_features 136 | self.embed_features = embed_features 137 | self.dropout_p = dropout_p 138 | self.batch_norm = batch_norm 139 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm} 140 | 141 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=512, context_size=5, dilation=1, **tdnn_kwargs) 142 | self.frame2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=2, **tdnn_kwargs) 143 | self.frame3 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=3, **tdnn_kwargs) 144 | self.frame4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1, **tdnn_kwargs) 145 | self.frame5 = TDNN(input_dim=512, output_dim=self.final_features, context_size=1, dilation=1, **tdnn_kwargs) 146 | 147 | self.tdnn_list = nn.Sequential(self.frame1, self.frame2, self.frame3, self.frame4, self.frame5) 148 | self.statspool = StatsPool() 149 | 150 | self.fc_embed = nn.Linear(self.final_features*2, self.embed_features) 151 | self.nl_embed = nn.LeakyReLU() 152 | self.bn_embed = nn.BatchNorm1d(self.embed_features) 153 | self.drop_embed = nn.Dropout(p=self.dropout_p) 154 | 155 | def forward(self, x): 156 | x = self.tdnn_list(x) 157 | x = self.statspool(x) 158 | x = self.fc_embed(x) 159 | x = self.nl_embed(x) 160 | x = self.bn_embed(x) 161 | x = self.drop_embed(x) 162 | return x 163 | 164 | 165 | class XTDNN_ILayer(nn.Module): 166 | 167 | def __init__( 168 | self, 169 | features_per_frame=30, 170 | dropout_p=0.0, 171 | batch_norm=True 172 | ): 173 | super().__init__() 174 | self.features_per_frame = features_per_frame 175 | self.dropout_p = dropout_p 176 | self.batch_norm = batch_norm 177 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm} 178 | 179 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=512, context_size=5, dilation=1, **tdnn_kwargs) 180 | 181 | def forward(self, x): 182 | x = self.frame1(x) 183 | return x 184 | 185 | 186 | class XTDNN_OLayer(nn.Module): 187 | 188 | def __init__( 189 | self, 190 | final_features=1500, 191 | embed_features=512, 192 | dropout_p=0.0, 193 | batch_norm=True 194 | ): 195 | super().__init__() 196 | self.final_features = final_features 197 | self.embed_features = embed_features 198 | self.dropout_p = dropout_p 199 | self.batch_norm = batch_norm 200 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm} 201 | 202 | self.frame2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=2, **tdnn_kwargs) 203 | self.frame3 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=3, **tdnn_kwargs) 204 | self.frame4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1, **tdnn_kwargs) 205 | self.frame5 = TDNN(input_dim=512, output_dim=self.final_features, context_size=1, dilation=1, **tdnn_kwargs) 206 | 207 | self.statspool = StatsPool() 208 | 209 | self.fc_embed = nn.Linear(self.final_features*2, self.embed_features) 210 | self.nl_embed = nn.LeakyReLU() 211 | self.bn_embed = nn.BatchNorm1d(self.embed_features) 212 | self.drop_embed = nn.Dropout(p=self.dropout_p) 213 | 214 | def forward(self, x): 215 | x = self.frame2(x) 216 | x = self.frame3(x) 217 | x = self.frame4(x) 218 | x = self.frame5(x) 219 | x = self.statspool(x) 220 | x = self.fc_embed(x) 221 | x = self.nl_embed(x) 222 | x = self.bn_embed(x) 223 | x = self.drop_embed(x) 224 | return x 225 | 226 | 227 | 228 | class DenseReLU(nn.Module): 229 | 230 | def __init__(self, in_dim, out_dim): 231 | super(DenseReLU, self).__init__() 232 | self.fc = nn.Linear(in_dim, out_dim) 233 | self.bn = nn.BatchNorm1d(out_dim) 234 | self.nl = nn.LeakyReLU() 235 | 236 | def forward(self, x): 237 | x = self.fc(x) 238 | x = self.nl(x) 239 | if len(x.shape) > 2: 240 | x = self.bn(x.transpose(1,2)).transpose(1,2) 241 | else: 242 | x = self.bn(x) 243 | return x 244 | 245 | 246 | class FTDNNLayer(nn.Module): 247 | 248 | def __init__(self, in_dim, out_dim, bottleneck_dim, context_size=2, dilations=None, paddings=None, alpha=0.0): 249 | ''' 250 | 3 stage factorised TDNN http://danielpovey.com/files/2018_interspeech_tdnnf.pdf 251 | ''' 252 | super(FTDNNLayer, self).__init__() 253 | paddings = [1,1,1] if not paddings else paddings 254 | dilations = [2,2,2] if not dilations else dilations 255 | kwargs = {'bias':False} 256 | self.factor1 = nn.Conv1d(in_dim, bottleneck_dim, context_size, padding=paddings[0], dilation=dilations[0], **kwargs) 257 | self.factor2 = nn.Conv1d(bottleneck_dim, bottleneck_dim, context_size, padding=paddings[1], dilation=dilations[1], **kwargs) 258 | self.factor3 = nn.Conv1d(bottleneck_dim, out_dim, context_size, padding=paddings[2], dilation=dilations[2], **kwargs) 259 | self.reset_parameters() 260 | self.nl = nn.ReLU() 261 | self.bn = nn.BatchNorm1d(out_dim) 262 | self.dropout = SharedDimScaleDropout(alpha=alpha, dim=1) 263 | 264 | def forward(self, x): 265 | ''' input (batch_size, seq_len, in_dim) ''' 266 | assert (x.shape[-1] == self.factor1.weight.shape[1]) 267 | x = self.factor1(x.transpose(1,2)) 268 | x = self.factor2(x) 269 | x = self.factor3(x) 270 | x = self.nl(x) 271 | x = self.bn(x).transpose(1,2) 272 | x = self.dropout(x) 273 | return x 274 | 275 | def step_semi_orth(self): 276 | with torch.no_grad(): 277 | factor1_M = self.get_semi_orth_weight(self.factor1) 278 | factor2_M = self.get_semi_orth_weight(self.factor2) 279 | self.factor1.weight.copy_(factor1_M) 280 | self.factor2.weight.copy_(factor2_M) 281 | 282 | def reset_parameters(self): 283 | # Standard dev of M init values is inverse of sqrt of num cols 284 | nn.init._no_grad_normal_(self.factor1.weight, 0., self.get_M_shape(self.factor1.weight)[1]**-0.5) 285 | nn.init._no_grad_normal_(self.factor2.weight, 0., self.get_M_shape(self.factor2.weight)[1]**-0.5) 286 | 287 | def orth_error(self): 288 | factor1_err = self.get_semi_orth_error(self.factor1).item() 289 | factor2_err = self.get_semi_orth_error(self.factor2).item() 290 | return factor1_err + factor2_err 291 | 292 | @staticmethod 293 | def get_semi_orth_weight(conv1dlayer): 294 | # updates conv1 weight M using update rule to make it more semi orthogonal 295 | # based off ConstrainOrthonormalInternal in nnet-utils.cc in Kaldi src/nnet3 296 | # includes the tweaks related to slowing the update speed 297 | # only an implementation of the 'floating scale' case 298 | with torch.no_grad(): 299 | update_speed = 0.125 300 | orig_shape = conv1dlayer.weight.shape 301 | # a conv weight differs slightly from TDNN formulation: 302 | # Conv weight: (out_filters, in_filters, kernel_width) 303 | # TDNN weight M is of shape: (in_dim, out_dim) or [rows, cols] 304 | # the in_dim of the TDNN weight is equivalent to in_filters * kernel_width of the Conv 305 | M = conv1dlayer.weight.reshape(orig_shape[0], orig_shape[1]*orig_shape[2]).T 306 | # M now has shape (in_dim[rows], out_dim[cols]) 307 | mshape = M.shape 308 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols 309 | M = M.T 310 | P = torch.mm(M, M.T) 311 | PP = torch.mm(P, P.T) 312 | trace_P = torch.trace(P) 313 | trace_PP = torch.trace(PP) 314 | ratio = trace_PP * P.shape[0] / (trace_P * trace_P) 315 | 316 | # the following is the tweak to avoid divergence (more info in Kaldi) 317 | assert ratio > 0.999 318 | if ratio > 1.02: 319 | update_speed *= 0.5 320 | if ratio > 1.1: 321 | update_speed *= 0.5 322 | 323 | scale2 = trace_PP/trace_P 324 | update = P - (torch.matrix_power(P, 0) * scale2) 325 | alpha = update_speed / scale2 326 | update = (-4.0 * alpha) * torch.mm(update, M) 327 | updated = M + update 328 | # updated has shape (cols, rows) if rows > cols, else has shape (rows, cols) 329 | # Transpose (or not) to shape (cols, rows) (IMPORTANT, s.t. correct dimensions are reshaped) 330 | # Then reshape to (cols, in_filters, kernel_width) 331 | return updated.reshape(*orig_shape) if mshape[0] > mshape[1] else updated.T.reshape(*orig_shape) 332 | 333 | @staticmethod 334 | def get_M_shape(conv_weight): 335 | orig_shape = conv_weight.shape 336 | return (orig_shape[1]*orig_shape[2], orig_shape[0]) 337 | 338 | @staticmethod 339 | def get_semi_orth_error(conv1dlayer): 340 | with torch.no_grad(): 341 | orig_shape = conv1dlayer.weight.shape 342 | M = conv1dlayer.weight.reshape(orig_shape[0], orig_shape[1]*orig_shape[2]) 343 | mshape = M.shape 344 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols 345 | M = M.T 346 | P = torch.mm(M, M.T) 347 | PP = torch.mm(P, P.T) 348 | trace_P = torch.trace(P) 349 | trace_PP = torch.trace(PP) 350 | ratio = trace_PP * P.shape[0] / (trace_P * trace_P) 351 | scale2 = torch.sqrt(trace_PP/trace_P) ** 2 352 | update = P - (torch.matrix_power(P, 0) * scale2) 353 | return torch.norm(update, p='fro') 354 | 355 | 356 | class FTDNN(nn.Module): 357 | 358 | def __init__(self, in_dim=30, embedding_dim=512): 359 | ''' 360 | The FTDNN architecture from 361 | "State-of-the-art speaker recognition with neural network embeddings in 362 | NIST SRE18 and Speakers in the Wild evaluations" 363 | https://www.sciencedirect.com/science/article/pii/S0885230819302700 364 | ''' 365 | super(FTDNN, self).__init__() 366 | 367 | self.layer01 = TDNN(input_dim=in_dim, output_dim=512, context_size=5, padding=2) 368 | self.layer02 = FTDNNLayer(512, 1024, 256, context_size=2, dilations=[2,2,2], paddings=[1,1,1]) 369 | self.layer03 = FTDNNLayer(1024, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0]) 370 | self.layer04 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1]) 371 | self.layer05 = FTDNNLayer(2048, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0]) 372 | self.layer06 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1]) 373 | self.layer07 = FTDNNLayer(3072, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1]) 374 | self.layer08 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1]) 375 | self.layer09 = FTDNNLayer(3072, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0]) 376 | self.layer10 = DenseReLU(1024, 2048) 377 | 378 | self.layer11 = StatsPool() 379 | 380 | self.layer12 = DenseReLU(4096, embedding_dim) 381 | 382 | 383 | def forward(self, x): 384 | x = self.layer01(x) 385 | x_2 = self.layer02(x) 386 | x_3 = self.layer03(x_2) 387 | x_4 = self.layer04(x_3) 388 | skip_5 = torch.cat([x_4, x_3], dim=-1) 389 | x = self.layer05(skip_5) 390 | x_6 = self.layer06(x) 391 | skip_7 = torch.cat([x_6, x_4, x_2], dim=-1) 392 | x = self.layer07(skip_7) 393 | x_8 = self.layer08(x) 394 | skip_9 = torch.cat([x_8, x_6, x_4], dim=-1) 395 | x = self.layer09(skip_9) 396 | x = self.layer10(x) 397 | x = self.layer11(x) 398 | x = self.layer12(x) 399 | return x 400 | 401 | def step_ftdnn_layers(self): 402 | for layer in self.children(): 403 | if isinstance(layer, FTDNNLayer): 404 | layer.step_semi_orth() 405 | 406 | def set_dropout_alpha(self, alpha): 407 | for layer in self.children(): 408 | if isinstance(layer, FTDNNLayer): 409 | layer.dropout.alpha = alpha 410 | 411 | def get_orth_errors(self): 412 | errors = 0. 413 | with torch.no_grad(): 414 | for layer in self.children(): 415 | if isinstance(layer, FTDNNLayer): 416 | errors += layer.orth_error() 417 | return errors 418 | 419 | 420 | class SharedDimScaleDropout(nn.Module): 421 | def __init__(self, alpha: float = 0.5, dim=1): 422 | ''' 423 | Continuous scaled dropout that is const over chosen dim (usually across time) 424 | Multiplies inputs by random mask taken from Uniform([1 - 2\alpha, 1 + 2\alpha]) 425 | ''' 426 | super(SharedDimScaleDropout, self).__init__() 427 | if alpha > 0.5 or alpha < 0: 428 | raise ValueError("alpha must be between 0 and 0.5") 429 | self.alpha = alpha 430 | self.dim = dim 431 | self.register_buffer('mask', torch.tensor(0.)) 432 | 433 | def forward(self, X): 434 | if self.training: 435 | if self.alpha != 0.: 436 | # sample mask from uniform dist with dim of length 1 in self.dim and then repeat to match size 437 | tied_mask_shape = list(X.shape) 438 | tied_mask_shape[self.dim] = 1 439 | repeats = [1 if i != self.dim else X.shape[self.dim] for i in range(len(X.shape))] 440 | return X * self.mask.repeat(tied_mask_shape).uniform_(1 - 2*self.alpha, 1 + 2*self.alpha).repeat(repeats) 441 | # expected value of dropout mask is 1 so no need to scale outputs like vanilla dropout 442 | return X 443 | -------------------------------------------------------------------------------- /models/sim_predictors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class LSTMSimilarity(nn.Module): 7 | 8 | def __init__(self, input_size=256, hidden_size=256, num_layers=2): 9 | super(LSTMSimilarity, self).__init__() 10 | self.lstm = nn.LSTM(input_size, 11 | hidden_size, 12 | num_layers=num_layers, 13 | bidirectional=True, 14 | batch_first=True) 15 | self.fc1 = nn.Linear(hidden_size*2, 64) 16 | self.nl = nn.ReLU(inplace=True) 17 | self.fc2 = nn.Linear(64, 1) 18 | 19 | def forward(self, x): 20 | x, _ = self.lstm(x) 21 | x = self.fc1(x) 22 | x = self.nl(x) 23 | x = self.fc2(x).squeeze(2) 24 | return x 25 | 26 | class LSTMSimilarityCosRes(nn.Module): 27 | 28 | def __init__(self, input_size=256, hidden_size=256, num_layers=2): 29 | ''' 30 | Like the LSTM Model but the LSTM only has to learn a modification to the cosine sim: 31 | y = LSTM(x) + pwise_cos_sim(x) 32 | ''' 33 | super(LSTMSimilarityCosRes, self).__init__() 34 | self.lstm = nn.LSTM(input_size, 35 | hidden_size, 36 | num_layers=num_layers, 37 | bidirectional=True, 38 | batch_first=True) 39 | self.fc1 = nn.Linear(hidden_size*2, 64) 40 | self.nl = nn.ReLU(inplace=True) 41 | self.fc2 = nn.Linear(64, 1) 42 | 43 | self.pdistlayer = pCosineSiamese() 44 | 45 | def forward(self, x): 46 | cs = self.pdistlayer(x) 47 | x, _ = self.lstm(x) 48 | x = self.fc1(x) 49 | x = self.nl(x) 50 | x = self.fc2(x).squeeze(2) 51 | x += cs 52 | return x 53 | 54 | class LSTMSimilarityCosWS(nn.Module): 55 | 56 | def __init__(self, input_size=256, hidden_size=256, num_layers=2): 57 | ''' 58 | Like the LSTM Model but the LSTM only has to learn a weighted sum of it and the cosine sim: 59 | y = A*LSTM(x) + B*pwise_cos_sim(x) 60 | ''' 61 | super(LSTMSimilarityCosWS, self).__init__() 62 | self.lstm = nn.LSTM(input_size, 63 | hidden_size, 64 | num_layers=num_layers, 65 | bidirectional=True, 66 | batch_first=True) 67 | self.fc1 = nn.Linear(hidden_size*2, 64) 68 | self.nl = nn.ReLU() 69 | self.fc2 = nn.Linear(64, 1) 70 | self.weightsum = nn.Linear(2,1) 71 | self.pdistlayer = pCosineSiamese() 72 | 73 | def forward(self, x): 74 | cs = self.pdistlayer(x) 75 | x, _ = self.lstm(x) 76 | x = self.fc1(x) 77 | x = self.nl(x) 78 | x = torch.sigmoid(self.fc2(x).squeeze(2)) 79 | x = torch.stack([x, cs], dim=-1) 80 | return self.weightsum(x).squeeze(-1) 81 | 82 | 83 | class pCosineSim(nn.Module): 84 | 85 | def __init__(self): 86 | super(pCosineSim, self).__init__() 87 | 88 | def forward(self, x): 89 | cs = [] 90 | for j in range(x.shape[0]): 91 | cs_row = torch.clamp(torch.mm(x[j].unsqueeze(1).transpose(0,1), x.transpose(0,1)) / (torch.norm(x[j]) * torch.norm(x, dim=1)), 1e-6) 92 | cs.append(cs_row) 93 | return torch.cat(cs) 94 | 95 | class pbCosineSim(nn.Module): 96 | 97 | def __init__(self): 98 | super(pbCosineSim, self).__init__() 99 | 100 | def forward(self, x): 101 | ''' 102 | Batch pairwise cosine similarity: 103 | 104 | input (batch_size, seq_len, d) 105 | output (batch_size, seq_len, seq_len) 106 | ''' 107 | cs = [] 108 | for j in range(x.shape[1]): 109 | cs_row = torch.clamp(torch.bmm(x[:, j, :].unsqueeze(1), x.transpose(1,2)) / (torch.norm(x[:, j, :], dim=-1).unsqueeze(1) * torch.norm(x, dim=-1)).unsqueeze(1), 1e-6) 110 | cs.append(cs_row) 111 | return torch.cat(cs, dim=1) 112 | 113 | class pCosineSiamese(nn.Module): 114 | 115 | def __init__(self): 116 | super(pCosineSiamese, self).__init__() 117 | 118 | def forward(self, x): 119 | ''' 120 | split every element in last dimension/2 and take cosine distance 121 | ''' 122 | x1, x2 = torch.split(x, x.shape[-1]//2, dim=-1) 123 | return F.cosine_similarity(x1, x2, dim=-1) 124 | 125 | def subsequent_mask(size): 126 | "Mask out subsequent positions." 127 | attn_shape = (1, size, size) 128 | mask = np.triu(np.ones(attn_shape), k=1) 129 | mask[mask==1.0] = float('-inf') 130 | return torch.FloatTensor(mask).squeeze(0) 131 | 132 | class TransformerSim(nn.Module): 133 | 134 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=2, dim_feedforward=1024): 135 | super(TransformerSim, self).__init__() 136 | 137 | self.tf = nn.Transformer(d_model=d_model, 138 | nhead=nhead, 139 | num_encoder_layers=num_encoder_layers, 140 | num_decoder_layers=num_encoder_layers, 141 | dim_feedforward=dim_feedforward) 142 | self.out_embed = nn.Embedding(3, d_model) 143 | self.generator = nn.Linear(d_model, 3) 144 | 145 | def forward(self, src, tgt, src_mask=None, tgt_mask=None): 146 | x = self.tf(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 147 | x = self.generator(x) 148 | return x 149 | 150 | def encode(self, src): 151 | x = self.tf.encoder(src) 152 | return x 153 | 154 | class XTransformerSim(nn.Module): 155 | 156 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=4, dim_feedforward=512): 157 | super(XTransformerSim, self).__init__() 158 | 159 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 160 | # self.pdistlayer = pCosineSiamese() 161 | self.fc1 = nn.Linear(d_model, 1) 162 | # self.weightsum = nn.Linear(2, 1) 163 | 164 | def forward(self, src): 165 | # cs = self.pdistlayer(src) 166 | x = self.tf(src) 167 | x = self.fc1(x).squeeze(-1) 168 | # x = torch.stack([x, cs], dim=-1) 169 | # x = self.weightsum(x).squeeze(-1) 170 | return x 171 | 172 | 173 | class XTransformerLSTMSim(nn.Module): 174 | 175 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=2, dim_feedforward=1024): 176 | super(XTransformerLSTMSim, self).__init__() 177 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 178 | self.lstm = nn.LSTM(d_model, 179 | d_model, 180 | num_layers=2, 181 | bidirectional=True) 182 | self.fc1 = nn.Linear(d_model*2, 64) 183 | self.fc2 = nn.Linear(64, 1) 184 | 185 | def forward(self, src): 186 | out = self.tf(src) 187 | out, _ = self.lstm(out) 188 | out = F.relu(self.fc1(out)) 189 | out = self.fc2(out) 190 | return out 191 | 192 | class AttnDecoderRNN(nn.Module): 193 | def __init__(self, hidden_size=256, output_size=1, dropout_p=0.1, max_length=250): 194 | super(AttnDecoderRNN, self).__init__() 195 | self.hidden_size = hidden_size 196 | self.output_size = output_size 197 | self.dropout_p = dropout_p 198 | self.max_length = max_length 199 | 200 | # self.embedding = nn.Embedding(self.output_size, self.hidden_size) 201 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length) 202 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size) 203 | self.dropout = nn.Dropout(self.dropout_p) 204 | self.gru = nn.GRU(self.hidden_size, self.hidden_size) 205 | self.out = nn.Linear(self.hidden_size, self.output_size) 206 | 207 | def forward(self, input, hidden, encoder_outputs): 208 | embedded = self.embedding(input).view(1, 1, -1) 209 | embedded = self.dropout(embedded) 210 | 211 | attn_weights = F.softmax( 212 | self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) 213 | attn_applied = torch.bmm(attn_weights.unsqueeze(0), 214 | encoder_outputs.unsqueeze(0)) 215 | 216 | output = torch.cat((embedded[0], attn_applied[0]), 1) 217 | output = self.attn_combine(output).unsqueeze(0) 218 | 219 | output = F.relu(output) 220 | output, hidden = self.gru(output, hidden) 221 | output = self.out(output[0]) 222 | return output, hidden, attn_weights 223 | 224 | def initHidden(self, device): 225 | return torch.zeros(1, 1, self.hidden_size, device=device) 226 | 227 | 228 | class XTransformer(nn.Module): 229 | 230 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024): 231 | super(XTransformer, self).__init__() 232 | 233 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 234 | # self.fc1 = nn.Linear(d_model, d_model) 235 | # self.nl = nn.ReLU(inplace=True) 236 | # self.fc2 = nn.Linear(d_model, d_model) 237 | 238 | self.pdist = pCosineSim() 239 | 240 | def forward(self, src): 241 | x = self.tf(src) 242 | x = x.squeeze(1) 243 | # x = self.fc1(x) 244 | # x = self.nl(x) 245 | # x = self.fc2(x) 246 | x = F.normalize(x, p=2, dim=-1) 247 | sim = self.pdist(x) 248 | return sim 249 | 250 | 251 | class XTransformerRes(nn.Module): 252 | 253 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024): 254 | super(XTransformerRes, self).__init__() 255 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 256 | self.pdist = pCosineSim() 257 | 258 | def forward(self, src): 259 | cs = self.pdist(src.squeeze(1)) 260 | x = self.tf(src) 261 | x = x.squeeze(1) 262 | x = F.normalize(x, p=2, dim=-1) 263 | sim = self.pdist(x) 264 | sim += cs 265 | return torch.clamp(sim/2, 1e-16, 1.-1e-16) 266 | 267 | 268 | class XTransformerMask(nn.Module): 269 | 270 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024): 271 | 272 | super(XTransformerMask, self).__init__() 273 | 274 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 275 | 276 | self.pdist = pCosineSim() 277 | 278 | def forward(self, src): 279 | mask = self.tf(src) 280 | mask = mask.squeeze(1) 281 | mask = torch.sigmoid(mask) 282 | out = F.normalize(mask * src.squeeze(1), p=2, dim=-1) 283 | sim = self.pdist(out) 284 | return sim 285 | 286 | 287 | class XTransformerMaskRes(nn.Module): 288 | 289 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024): 290 | 291 | super(XTransformerMaskRes, self).__init__() 292 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers) 293 | self.pdist = pbCosineSim() 294 | 295 | def forward(self, src, src_mask=None): 296 | cs = self.pdist(src.transpose(0, 1)) 297 | mask = self.tf(src, src_key_padding_mask=src_mask) 298 | mask = torch.sigmoid(mask) 299 | out = F.normalize(mask * src, p=2, dim=-1) 300 | sim = self.pdist(out.transpose(0, 1)) 301 | sim += cs 302 | return torch.clamp(sim/2, 1e-1, 1.-1e-1) 303 | 304 | 305 | 306 | class LSTMTransform(nn.Module): 307 | 308 | def __init__(self, input_size=128, hidden_size=256, num_layers=2): 309 | super(LSTMTransform, self).__init__() 310 | self.lstm = nn.LSTM(input_size, 311 | hidden_size, 312 | num_layers=num_layers, 313 | bidirectional=True, 314 | batch_first=True) 315 | 316 | self.fc1 = nn.Linear(512, 256) 317 | self.nl = nn.ReLU(inplace=True) 318 | self.fc2 = nn.Linear(256, 256) 319 | 320 | self.pdist = pCosineSim() 321 | 322 | def forward(self, x): 323 | x, _ = self.lstm(x) 324 | x = x.squeeze(0) 325 | x = self.fc1(x) 326 | x = self.nl(x) 327 | x = self.fc2(x) 328 | sim = self.pdist(x) 329 | return 1. - sim 330 | 331 | 332 | class ConvSim(nn.Module): 333 | 334 | def __init__(self, input_dim=256): 335 | super(ConvSim, self).__init__() 336 | self.input_dim = input_dim 337 | self.layer1 = nn.Sequential( 338 | nn.Conv1d(input_dim, 32, kernel_size=3, stride=1, padding=1), 339 | nn.ReLU(), 340 | nn.BatchNorm1d(32)) 341 | self.layer2 = nn.Sequential( 342 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=1), 343 | nn.ReLU(), 344 | nn.BatchNorm1d(32)) 345 | self.layer3 = nn.Sequential( 346 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=2), 347 | nn.ReLU(), 348 | nn.BatchNorm1d(32)) 349 | self.layer4 = nn.Sequential( 350 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3), 351 | nn.ReLU(), 352 | nn.BatchNorm1d(32)) 353 | self.layer5 = nn.Sequential( 354 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3), 355 | nn.ReLU(), 356 | nn.BatchNorm1d(32)) 357 | self.layer6 = nn.Sequential( 358 | nn.Conv1d(32, 1, kernel_size=1, stride=1, padding=1), 359 | nn.ReLU(), 360 | nn.BatchNorm1d(1)) 361 | 362 | def forward(self, x): 363 | if x.shape[-1] == self.input_dim: 364 | x = x.permute(0,2,1) 365 | x = self.layer1(x) 366 | x = self.layer2(x) 367 | x = self.layer3(x) 368 | x = self.layer4(x) 369 | x = self.layer5(x) 370 | x = self.layer6(x) 371 | return x.squeeze(1) 372 | 373 | class ConvCosResSim(nn.Module): 374 | 375 | def __init__(self, input_dim=256): 376 | super(ConvCosResSim, self).__init__() 377 | self.pdistlayer = pCosineSiamese() 378 | self.input_dim = input_dim 379 | # self.drop1 = nn.Dropout() 380 | # self.drop2 = nn.Dropout() 381 | # self.drop3 = nn.Dropout() 382 | # self.drop4 = nn.Dropout() 383 | # self.drop5 = nn.Dropout() 384 | self.layer1 = nn.Sequential( 385 | nn.Conv1d(input_dim, 32, kernel_size=3, stride=1, padding=1), 386 | nn.ReLU(), 387 | nn.BatchNorm1d(32)) 388 | self.layer2 = nn.Sequential( 389 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=1), 390 | nn.ReLU(), 391 | nn.BatchNorm1d(32)) 392 | self.layer3 = nn.Sequential( 393 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=2), 394 | nn.ReLU(), 395 | nn.BatchNorm1d(32)) 396 | self.layer4 = nn.Sequential( 397 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3), 398 | nn.ReLU(), 399 | nn.BatchNorm1d(32)) 400 | self.layer5 = nn.Sequential( 401 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3), 402 | nn.ReLU(), 403 | nn.BatchNorm1d(32)) 404 | self.layer6 = nn.Sequential( 405 | nn.Conv1d(32, 1, kernel_size=1, stride=1, padding=1), 406 | nn.ReLU(), 407 | nn.BatchNorm1d(1)) 408 | 409 | def forward(self, x): 410 | cs = self.pdistlayer(x) 411 | if x.shape[-1] == self.input_dim: 412 | x = x.permute(0,2,1) 413 | x = self.layer1(x) 414 | # x = self.drop1(x) 415 | x = self.layer2(x) 416 | # x = self.drop2(x) 417 | x = self.layer3(x) 418 | # x = self.drop3(x) 419 | x = self.layer4(x) 420 | # x = self.drop4(x) 421 | x = self.layer5(x) 422 | # x = self.drop5(x) 423 | x = self.layer6(x).squeeze(1) 424 | x += cs 425 | return x 426 | 427 | def set_dropout_p(self, p): 428 | for layer in self.children(): 429 | if isinstance(layer, nn.Dropout): 430 | layer.p = p -------------------------------------------------------------------------------- /scotus_data_prep/filter_scp.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # Copyright 2010-2012 Microsoft Corporation 3 | # Johns Hopkins University (author: Daniel Povey) 4 | 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 14 | # MERCHANTABLITY OR NON-INFRINGEMENT. 15 | # See the Apache 2 License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | # This script takes a list of utterance-ids or any file whose first field 20 | # of each line is an utterance-id, and filters an scp 21 | # file (or any file whose "n-th" field is an utterance id), printing 22 | # out only those lines whose "n-th" field is in id_list. The index of 23 | # the "n-th" field is 1, by default, but can be changed by using 24 | # the -f switch 25 | 26 | $exclude = 0; 27 | $field = 1; 28 | $shifted = 0; 29 | 30 | do { 31 | $shifted=0; 32 | if ($ARGV[0] eq "--exclude") { 33 | $exclude = 1; 34 | shift @ARGV; 35 | $shifted=1; 36 | } 37 | if ($ARGV[0] eq "-f") { 38 | $field = $ARGV[1]; 39 | shift @ARGV; shift @ARGV; 40 | $shifted=1 41 | } 42 | } while ($shifted); 43 | 44 | if(@ARGV < 1 || @ARGV > 2) { 45 | die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . 46 | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . 47 | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . 48 | "only the lines that were *not* in id_list.\n" . 49 | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . 50 | "If your older scripts (written before Oct 2014) stopped working and you used the\n" . 51 | "-f option, add 1 to the argument.\n" . 52 | "See also: utils/filter_scp.pl .\n"; 53 | } 54 | 55 | 56 | $idlist = shift @ARGV; 57 | open(F, "<$idlist") || die "Could not open id-list file $idlist"; 58 | while() { 59 | @A = split; 60 | @A>=1 || die "Invalid id-list file line $_"; 61 | $seen{$A[0]} = 1; 62 | } 63 | 64 | if ($field == 1) { # Treat this as special case, since it is common. 65 | while(<>) { 66 | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; 67 | # $1 is what we filter on. 68 | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { 69 | print $_; 70 | } 71 | } 72 | } else { 73 | while(<>) { 74 | @A = split; 75 | @A > 0 || die "Invalid scp file line $_"; 76 | @A >= $field || die "Invalid scp file line $_"; 77 | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { 78 | print $_; 79 | } 80 | } 81 | } 82 | 83 | # tests: 84 | # the following should print "foo 1" 85 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) 86 | # the following should print "bar 2". 87 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) 88 | -------------------------------------------------------------------------------- /scotus_data_prep/local_info.py: -------------------------------------------------------------------------------- 1 | AVVO_API_ACCESS_TOKEN='' 2 | AVVO_CSE_ID='' 3 | GOOGLE_API_KEY='' 4 | -------------------------------------------------------------------------------- /scotus_data_prep/step1_downloadmp3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import ssl 5 | import sys 6 | import urllib 7 | from glob import glob 8 | 9 | import numpy as np 10 | import requests 11 | from tqdm import tqdm 12 | 13 | ssl._create_default_https_context = ssl._create_unverified_context 14 | from collections import OrderedDict 15 | from multiprocessing import Pool 16 | 17 | import dateutil.parser as dparser 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Download mp3s and process the case jsons into something usable') 22 | parser.add_argument('--case-folder', type=str, help='Location of the case json files') 23 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder') 24 | args = parser.parse_args() 25 | return args 26 | 27 | def get_mp3_and_transcript(c): 28 | ''' 29 | Get the mp3 from case jsons 30 | ''' 31 | js = json.load(open(c, encoding='utf-8'), object_pairs_hook=OrderedDict) 32 | 33 | rec_name = os.path.splitext(os.path.basename(c))[0] 34 | case_year = int(rec_name[:4]) 35 | 36 | # Only want digital recordings, after October 2005 37 | if case_year <= 2004: 38 | return rec_name, False, False, None 39 | 40 | cutoff_date = dparser.parse('1 October 2005', fuzzy=True) 41 | 42 | if 'oral_argument_audio' in js: 43 | if js['oral_argument_audio']: 44 | if 'href' in js['oral_argument_audio']: 45 | t_url = js['oral_argument_audio']['href'] 46 | resp = requests.get(t_url, timeout=20) 47 | if resp.ok: 48 | js = json.loads(resp.content, object_pairs_hook=OrderedDict) 49 | else: 50 | return rec_name, False, False, None 51 | 52 | elif len(js['oral_argument_audio']) == 1: 53 | if 'href' in js['oral_argument_audio'][0]: 54 | t_url = js['oral_argument_audio'][0]['href'] 55 | resp = requests.get(t_url, timeout=20) 56 | if resp.ok: 57 | js = json.loads(resp.content, object_pairs_hook=OrderedDict) 58 | else: 59 | return rec_name, False, False, None 60 | 61 | if 'title' in js: 62 | dashind = js['title'].rfind('-') 63 | reduced_title = js['title'][dashind + 1:] 64 | start = reduced_title.find('(') 65 | end = reduced_title.find(')') 66 | if start != -1 and end != -1: 67 | reduced_title = reduced_title[:start - 1] 68 | try: 69 | case_date = dparser.parse(reduced_title, fuzzy=True) 70 | if case_date < cutoff_date: 71 | # print('Recording was too early') 72 | return rec_name, False, False, None 73 | except ValueError: 74 | # print('Couldnt figure out date {}'.format(reduced_title)) 75 | return rec_name, False, js, None 76 | 77 | if 'media_file' in js: 78 | if 'transcript' in js: 79 | if js['media_file'] and js['transcript']: 80 | for obj in js['media_file']: 81 | if obj: 82 | if 'href' in obj: 83 | url = obj['href'] 84 | if url.endswith('mp3'): 85 | return rec_name, url, js, case_date 86 | 87 | return rec_name, False, js, None 88 | 89 | 90 | def download_mp3(rec_name, url, outfile): 91 | filename = outfile 92 | if os.path.isfile(filename): 93 | return True 94 | try: 95 | request = urllib.request.urlopen(url, timeout=30) 96 | with open(filename+'_tmp', 'wb') as f: 97 | f.write(request.read()) 98 | os.rename(filename+'_tmp', filename) 99 | return True 100 | except: 101 | os.remove(filename+'_tmp') 102 | return False 103 | 104 | 105 | def filter_and_process_case(js, 106 | valid_to_invalid_ratio=4, 107 | min_num_utterances=5): 108 | ''' 109 | This goes through and removes invalid utterances from a case json 110 | If the invalid utterances outnumber the valid ones by a ratio greater than 111 | valid_to_invalid_ratio, then False is returned 112 | else, the transcription is returned 113 | ''' 114 | utts = [] 115 | utts_spkr = [] 116 | rec_speakers = {} 117 | 118 | invalid_utts = 0 119 | # Iterate through each speaker turn 120 | for sec in js['transcript']['sections']: 121 | for turn in sec['turns']: 122 | if not turn['speaker']: 123 | # ignore turns where no speaker is labelled 124 | invalid_utts += 1 125 | continue 126 | 127 | speaker = turn['speaker'] 128 | speaker_id = speaker['identifier'] 129 | 130 | if not speaker_id.strip(): 131 | # ignore turns where no speaker is labelled 132 | invalid_utts += 1 133 | continue 134 | 135 | if speaker_id not in rec_speakers: 136 | rec_speakers[speaker_id] = speaker 137 | 138 | for utt in turn['text_blocks']: 139 | utt['speaker_id'] = speaker_id 140 | utt['start'] = float(utt['start']) 141 | utt['stop'] = float(utt['stop']) 142 | 143 | if utt['start'] >= utt['stop']: 144 | invalid_utts += 1 145 | continue 146 | else: 147 | utts.append(utt) 148 | 149 | transcription = {'utts': utts, 'rec_speakers': rec_speakers} 150 | 151 | if len(utts) >= min_num_utterances: 152 | if invalid_utts > 0: 153 | if len(utts)/invalid_utts >= valid_to_invalid_ratio: 154 | return transcription 155 | else: 156 | return False 157 | else: 158 | return transcription 159 | else: 160 | return False 161 | 162 | def process_case(c): 163 | ''' 164 | Processes from the case json 165 | ''' 166 | rec_name, url, js, case_date = get_mp3_and_transcript(c) 167 | if url: # If mp3 was found 168 | transcription = filter_and_process_case(js) 169 | if transcription: # If transcription is valid 170 | mp3_outfile = os.path.join(base_audio_folder, '{}.mp3'.format(rec_name)) 171 | transcript_outfile = os.path.join(base_transcript_folder, '{}.json'.format(rec_name)) 172 | 173 | transcription['case_date'] = case_date.strftime('%Y/%m/%d') 174 | download_success = download_mp3(rec_name, url, mp3_outfile) 175 | 176 | if download_success: 177 | with open(transcript_outfile, 'w', encoding='utf-8') as outfile: 178 | json.dump(transcription, outfile) 179 | 180 | 181 | def collate_speakers(): 182 | ''' 183 | Goes inside base_transcript_folder and collates all the speakers 184 | Writes to a dictionary of all speakers: 185 | 186 | speaker_id_dict = {speaker_id_key:{DICT}, etc.} 187 | 188 | This will be useful later on for parsing their full names, instead of the speakeaker_id_key 189 | ''' 190 | transcripts = glob(os.path.join(base_transcript_folder, '*.json')) 191 | assert len(transcripts) > 1, 'Could only find 1 or less transcription json files' 192 | all_speaker_dict = {} 193 | 194 | for t in transcripts: 195 | js = json.load(open(t, encoding='utf-8'), object_pairs_hook=OrderedDict) 196 | for s in js['rec_speakers']: 197 | if s not in all_speaker_dict: 198 | all_speaker_dict[s] = js['rec_speakers'][s] 199 | 200 | outfile = os.path.join(base_outfolder, 'speaker_ids.json') 201 | with open(outfile, 'w', encoding='utf-8') as outfile: 202 | json.dump(all_speaker_dict, outfile) 203 | 204 | 205 | def mp3_and_transcript_exist(c): 206 | mp3_exist = os.path.isfile(os.path.join(base_audio_folder, '{}.mp3'.format(os.path.splitext(os.path.basename(c))[0]))) 207 | transcript_exist = os.path.isfile(os.path.join(base_transcript_folder, '{}.json'.format(os.path.splitext(os.path.basename(c))[0]))) 208 | if mp3_exist and transcript_exist: 209 | return True 210 | else: 211 | False 212 | 213 | 214 | if __name__ == '__main__': 215 | args = parse_args() 216 | case_folder = args.case_folder 217 | base_outfolder = args.base_outfolder 218 | os.makedirs(base_outfolder, exist_ok=True) 219 | 220 | base_audio_folder = os.path.join(base_outfolder, 'audio') 221 | os.makedirs(base_audio_folder, exist_ok=True) 222 | 223 | base_transcript_folder = os.path.join(base_outfolder, 'transcripts') 224 | os.makedirs(base_transcript_folder, exist_ok=True) 225 | 226 | cases = glob(os.path.join(case_folder, '20*.json')) 227 | print('{} cases found'.format(len(cases))) 228 | 229 | assert len(cases) > 1, "Could only find 1 or less case json files" 230 | 231 | trimmed_cases = [c for c in cases if not mp3_and_transcript_exist(c)] 232 | 233 | print('Processing {} cases...'.format(len(trimmed_cases))) 234 | 235 | with Pool(20) as p: 236 | for _ in tqdm(p.imap_unordered(process_case, trimmed_cases), total=len(trimmed_cases)): 237 | pass 238 | 239 | collate_speakers() 240 | 241 | -------------------------------------------------------------------------------- /scotus_data_prep/step2_scrape_dob.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import pickle 6 | import re 7 | import ssl 8 | import string 9 | import sys 10 | import time 11 | import warnings 12 | from collections import OrderedDict 13 | 14 | import numpy as np 15 | import requests 16 | import wikipedia 17 | import wptools 18 | from bs4 import BeautifulSoup 19 | from dateutil.relativedelta import relativedelta 20 | from Levenshtein import distance as lev_dist 21 | from local_info import (AVVO_API_ACCESS_TOKEN, AVVO_CSE_ID, GOOGLE_API_KEY) 22 | from tqdm import tqdm 23 | 24 | ssl._create_default_https_context = ssl._create_unverified_context 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Reads through speaker_ids.json and tries to webscrape the DOB of each lawyer') 29 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder') 30 | parser.add_argument('--overwrite', action='store_true', default=False, help='Overwrite dob pickle (default: False)') 31 | parser.add_argument('--skip-attempted', action='store_true', default=False, help='Skip names which have been attempted before') 32 | args = parser.parse_args() 33 | return args 34 | 35 | class LawyerDOBParser: 36 | 37 | def __init__(self, graduation_dob_offset=25, distance_threshold=4, minimum_age=18): 38 | self.graduation_dob_offset = graduation_dob_offset 39 | self.distance_threshold = distance_threshold 40 | self.scotus_clerks_populated = False 41 | self.minimum_age = minimum_age 42 | self.minimum_dob = datetime.datetime(2005, 10, 1) - relativedelta(years=minimum_age) 43 | 44 | self.google_api_key = GOOGLE_API_KEY 45 | self.avvo_cse_id = AVVO_CSE_ID 46 | 47 | def parse_name(self, name): 48 | ''' 49 | Parse name looking at wikipedia, then SCOTUS clerks, then the JUSTIA website 50 | 51 | Input: name 52 | Output: datetime object for D.O.B 53 | ''' 54 | if not self.scotus_clerks_populated: 55 | self.get_scotus_clerks() 56 | 57 | print('Searching for DOB of {}....'.format(name)) 58 | # Search wikipedia for person 59 | wiki_dob, wiki_info = self.parse_wiki(name) 60 | if wiki_dob: 61 | if wiki_dob <= self.minimum_dob: 62 | return wiki_dob, wiki_info 63 | 64 | # Search through supreme court clerks for person 65 | scotus_dob, scotus_info = self.search_scotus_clerks(name) 66 | if scotus_dob: 67 | scotus_dob = datetime.datetime(scotus_dob, 7, 2) 68 | if scotus_dob <= self.minimum_dob: 69 | return scotus_dob, scotus_info 70 | 71 | # Search through JUSTIA website 72 | justia_dob, justia_info = self.parse_justia(name) 73 | if justia_dob: 74 | justia_dob = datetime.datetime(justia_dob, 7, 2) 75 | if justia_dob <= self.minimum_dob: 76 | return justia_dob, justia_info 77 | 78 | # Search through AVVO website 79 | avvo_dob, avvo_info = self.parse_avvo(name) 80 | if avvo_dob: 81 | time.sleep(1) 82 | avvo_dob = datetime.datetime(avvo_dob, 7, 2) 83 | if avvo_dob <= self.minimum_dob: 84 | return avvo_dob, avvo_info 85 | 86 | print("Couldn't find any age for {}".format(name)) 87 | info_list = [wiki_info, scotus_info, justia_info, avvo_info] 88 | collated_info = {'info': {'type': None, 'error': 'no info found', 'collated_info': info_list}} 89 | return None, collated_info 90 | 91 | def parse_wiki(self, name): 92 | search = wikipedia.search(name) 93 | if search: 94 | if self.name_distance(name, search[0]) <= self.distance_threshold: 95 | name = search[0] 96 | 97 | wpage = wptools.page(name, silent=True) 98 | info = {'info': {'type': 'wiki', 'error': None, 'name': name}} 99 | try: 100 | page = wpage.get_parse() 101 | except: 102 | info['info']['error'] = 'page not found' 103 | return None, info 104 | try: 105 | if page.data: 106 | if 'infobox' in page.data: 107 | if 'birth_date' in page.data['infobox']: 108 | dob = page.data['infobox']['birth_date'].strip('{}').split('|') 109 | dinfo = [] 110 | for d in dob: 111 | try: 112 | dinfo.append(int(d)) 113 | except: 114 | continue 115 | if dinfo: 116 | if len(dinfo) > 3: 117 | dinfo = dinfo[-3:] 118 | if dinfo[0] > 1900: # simple check if 4-digit year recognised 119 | prelim_date = [1, 1, 1] 120 | for i, d in enumerate(dinfo): 121 | prelim_date[i] = d 122 | dob = datetime.datetime(*prelim_date) 123 | info['info']['links'] = page.data['iwlinks'] 124 | return dob, info 125 | info['info']['error'] = 'page couldnt be parsed' 126 | return None, info 127 | except: 128 | info['info']['error'] = 'page couldnt be parsed' 129 | return None, info 130 | 131 | def parse_justia(self, name): 132 | searched_name, distance, justia_url = self.search_justia(name) 133 | info = {'info': {'type': 'justia', 'searched_name': searched_name, 'justia_url': justia_url, 'error': None}} 134 | if distance <= self.distance_threshold: 135 | grad_year = self.parse_justia_lawyer(justia_url) 136 | if grad_year: 137 | return grad_year - self.graduation_dob_offset, info 138 | else: 139 | info['info']['error'] = 'no year found' 140 | return None, info 141 | else: 142 | info['info']['error'] = 'distance threshold not met' 143 | return None, info 144 | 145 | def search_justia(self, name): 146 | """ 147 | Input: Name to search, i.e. Anthony A. Yang (str,) 148 | Output: Matched name, Levenshtein distance to input, JUSTIA url 149 | """ 150 | base_search_url = 'https://lawyers.justia.com/search?profile-id-field=&practice-id-field=&query={}&location=' 151 | base_name = name.translate(str.maketrans('', '', string.punctuation)).lower() 152 | name_query = '+'.join(base_name.split()) 153 | search_url = base_search_url.format(name_query) 154 | 155 | search_url = base_search_url.format(name_query) 156 | search_request = requests.get(search_url) 157 | soup = BeautifulSoup(search_request.content, 'lxml') 158 | lawyer_avatars = soup.findAll('a', attrs={'class': 'lawyer-avatar'}) 159 | 160 | if lawyer_avatars: 161 | search_names = [] 162 | search_urls = [] 163 | 164 | for a in lawyer_avatars: 165 | search_names.append(a['title']) 166 | search_urls.append(a['href']) 167 | 168 | search_names = np.array(search_names) 169 | search_names_base = [n.translate(str.maketrans('', '', string.punctuation)).lower() for n in search_names] 170 | 171 | distances = np.array([self.name_distance(name, n) for n in search_names]) 172 | search_urls = np.array(search_urls) 173 | 174 | dist_order = np.argsort(distances) 175 | distances = distances[dist_order] 176 | search_urls = search_urls[dist_order] 177 | search_names = search_names[dist_order] 178 | 179 | return search_names[0], distances[0], search_urls[0] 180 | else: 181 | return 'None', 100000, 'None' 182 | 183 | @staticmethod 184 | def parse_justia_lawyer(lawyer_url): 185 | """ 186 | Input: Justia lawyer page url 187 | Output: Graduation year 188 | """ 189 | r = requests.get(lawyer_url) 190 | soup = BeautifulSoup(r.content, 'lxml') 191 | 192 | jurisdictions = soup.find('div', attrs={'id': 'jurisdictions-block'}) 193 | if jurisdictions: 194 | jd_admitted_year = [] 195 | for j in jurisdictions: 196 | try: 197 | jd_admitted_year.append(int(j.find('time')['datetime'])) 198 | except: 199 | continue 200 | if jd_admitted_year: 201 | return min(jd_admitted_year) 202 | else: 203 | # look for professional associations if jurisdictions is emtpy 204 | prof_assoc = None 205 | education = None 206 | blocks = soup.findAll('div', attrs={'class': 'block'}) 207 | for block in blocks: 208 | subdivs = block.findAll('div') 209 | for subdiv in subdivs: 210 | if subdiv.text == 'Professional Associations': 211 | prof_assoc = block 212 | break 213 | if subdiv.text == 'Education': 214 | education = block 215 | break 216 | 217 | if prof_assoc: 218 | prof_assoc_year = [] 219 | professional_associations = prof_assoc.findAll('time') 220 | for p in professional_associations: 221 | try: 222 | prof_assoc_year.append(int(p['datetime'])) 223 | except: 224 | continue 225 | if prof_assoc_year: 226 | return min(prof_assoc_year) 227 | 228 | if education: 229 | education_years = [] 230 | education_history = education.findAll('dl') 231 | for e in education_history: 232 | degree_type = e.find('dd').text 233 | if degree_type.strip().translate(str.maketrans('', '', string.punctuation)).lower() == 'jd': 234 | try: 235 | return int(e.find('time')['datetime']) 236 | except: 237 | continue 238 | 239 | def search_scotus_clerks(self, query_name): 240 | assert self.clerk_dob_dict, 'get_scotus_clerks must be called before this function' 241 | # query_name = query_name.translate(str.maketrans('', '', string.punctuation)).lower() 242 | distances = np.array([self.name_distance(query_name, k) for k in self.scotus_clerks]) 243 | closest_match = np.argmin(distances) 244 | info = {'info': {'type': 'clerk', 'closest_match': closest_match, 'error': None}} 245 | if distances[closest_match] <= self.distance_threshold: 246 | return self.clerk_dob_dict[self.scotus_clerks[closest_match]], info 247 | else: 248 | info['info']['error'] = 'distance threshold not met' 249 | return None, info 250 | 251 | def get_scotus_clerks(self): 252 | """ 253 | Populates self.clerk_dob_dict with dates of birth for SCOTUS clerks 254 | """ 255 | base_url = 'https://en.wikipedia.org/wiki/List_of_law_clerks_of_the_Supreme_Court_of_the_United_States_({})' 256 | seats = ['Chief_Justice', 'Seat_1', 'Seat_2', 257 | 'Seat_3', 'Seat_4', 'Seat_6', 'Seat_8', 258 | 'Seat_9', 'Seat_10'] 259 | urls = [base_url.format(s) for s in seats] 260 | 261 | self.all_cdicts = [] 262 | self.clerk_dob_dict = OrderedDict({}) 263 | 264 | for url in urls: 265 | mini_clerk_dict = self.parse_clerk_wiki(url) 266 | self.all_cdicts.append(mini_clerk_dict) 267 | 268 | for cdict in self.all_cdicts: 269 | self.clerk_dob_dict = {**self.clerk_dob_dict, **cdict} 270 | 271 | self.scotus_clerks = np.array(list(self.clerk_dob_dict.keys())) 272 | self.scotus_clerks_populated = True 273 | 274 | def parse_clerk_wiki(self, url): 275 | r = requests.get(url) 276 | soup = BeautifulSoup(r.content, 'lxml') 277 | tables = soup.findAll('table', attrs={'class': 'wikitable'}) 278 | clerk_dict = {} 279 | for table in tables: 280 | for tr in table.findAll('tr'): 281 | row_entries = tr.findAll('td') 282 | if len(row_entries) != 5: 283 | continue 284 | else: 285 | name = row_entries[0].text 286 | u = row_entries[3].text 287 | year_candidates = re.findall(r'\d{4}', u) 288 | 289 | if year_candidates: 290 | year = int(year_candidates[0]) 291 | else: 292 | continue 293 | 294 | cleaned_name = re.sub(r'\([^)]*\)', '', name) 295 | cleaned_name = re.sub(r'\[[^)]*\]', '', cleaned_name).strip() 296 | clerk_dict[cleaned_name] = year - self.graduation_dob_offset 297 | 298 | return clerk_dict 299 | 300 | def parse_avvo(self, name): 301 | avvo_ids, response = self.get_avvo_ids_google(name) 302 | info = {'info': {'type': 'avvo', 'google_response': response, 'error': None}} 303 | if avvo_ids: 304 | for aid in avvo_ids: 305 | avvo_resp = self.get_from_avvo_api(aid) 306 | dob_estimate = self.parse_avvo_api_response(avvo_resp, name) 307 | if dob_estimate: 308 | info['info']['avvo_id'] = aid 309 | return dob_estimate, info 310 | else: 311 | continue 312 | info['info']['error'] = 'no avvo ids yielded a dob estimate' 313 | info['info']['avvo_ids_attempted'] = avvo_ids 314 | return None, info 315 | else: 316 | info['info']['error'] = 'no links found.. check response' 317 | return None, info 318 | 319 | def parse_avvo_api_response(self, r, name): 320 | # r: response as dict 321 | if r['lawyers']: 322 | lawyer = r['lawyers'][0] 323 | licensed_since = lawyer['licensed_since'] 324 | if licensed_since: 325 | lawyer_name = '{} {} {}'.format(lawyer['firstname'], lawyer['middlename'], lawyer['lastname']) 326 | if self.name_distance(name, lawyer_name) <= self.distance_threshold: 327 | return licensed_since - self.graduation_dob_offset 328 | 329 | def search_google_avvo(self, query): 330 | r = requests.get('https://www.googleapis.com/customsearch/v1/siterestrict?key={}&cx={}&num=3&q="{}"'.format( 331 | self.google_api_key, self.avvo_cse_id, query)) 332 | r = json.loads(r.content) 333 | if 'items' in r: 334 | links = [l['link'] for l in r['items']] 335 | return links, r 336 | else: 337 | return None, r 338 | 339 | def get_avvo_ids_google(self, name): 340 | links, response = self.search_google_avvo(name) 341 | if links: 342 | avvo_ids = [self.get_avvo_id_from_link(l) for l in links] 343 | avvo_ids = [i for i in avvo_ids if i] 344 | return avvo_ids, response 345 | else: 346 | return None, response 347 | 348 | @staticmethod 349 | def get_avvo_id_from_link(link): 350 | if link.startswith('https://www.avvo.com/attorneys/'): 351 | page_path = os.path.splitext(link.split('/')[-1])[0] 352 | avvo_id = page_path.split('-')[-1] 353 | if avvo_id.isnumeric(): 354 | return avvo_id 355 | 356 | @staticmethod 357 | def get_from_avvo_api(avvo_id): 358 | headers = {'Authorization': 'Bearer {}'.format(AVVO_API_ACCESS_TOKEN)} 359 | url = 'https://api.avvo.com/api/4/lawyers/{}.json'.format(avvo_id) 360 | r = requests.get(url, headers=headers) 361 | return json.loads(r.content) 362 | 363 | @classmethod 364 | def name_distance(cls, string1, string2, wrong_initial_penalty=5): 365 | ''' 366 | levenshtein distance accommodating for initials 367 | First and last initials must match 368 | TODO: allow for hyphenated names with last name partial match 369 | ''' 370 | name1 = string1.lower().translate(str.maketrans('', '', string.punctuation)).lower() 371 | name2 = string2.lower().translate(str.maketrans('', '', string.punctuation)).lower() 372 | base_dist = lev_dist(name1, name2) 373 | 374 | if base_dist == 0: 375 | return 0 376 | 377 | if '-' in string1.split()[-1] or '-' in string2.split()[-1]: 378 | s1_perms = cls.hyphenation_perm(string1) 379 | s2_perms = cls.hyphenation_perm(string2) 380 | dists = [] 381 | for s1 in s1_perms: 382 | for s2 in s2_perms: 383 | dists.append(cls.name_distance(s1, s2)) 384 | return min(dists) 385 | 386 | name1_split = name1.split() 387 | name2_split = name2.split() 388 | 389 | if name1_split[0] == name2_split[0] and name1_split[-1] == name2_split[-1]: 390 | if len(name1_split) == 2 or len(name2_split) == 2: 391 | return lev_dist(' '.join([name1_split[0], name1_split[-1]]), 392 | ' '.join([name2_split[0], name2_split[-1]])) 393 | 394 | newname1 = ' '.join([n[0] if (1 <= i < len(name1_split) - 1) else n for i, n in enumerate(name1_split)]) 395 | newname2 = ' '.join([n[0] if (1 <= i < len(name2_split) - 1) else n for i, n in enumerate(name2_split)]) 396 | return lev_dist(newname1, newname2) 397 | else: 398 | return base_dist + wrong_initial_penalty 399 | 400 | @staticmethod 401 | def hyphenation_perm(name): 402 | splitup = name.split() 403 | lastname = splitup[-1] 404 | if '-' in lastname: 405 | lname_candidates = [' '.join(splitup[:-1] + [l]) for l in lastname.split('-')] 406 | return lname_candidates 407 | else: 408 | return [name] 409 | 410 | 411 | if __name__ == "__main__": 412 | assert AVVO_API_ACCESS_TOKEN, "AVVO_API_ACCESS_TOKEN in local_info.py needs to be filled out" 413 | assert AVVO_CSE_ID, "AVVO_CSE_ID in local_info.py needs to be filled out" 414 | assert GOOGLE_API_KEY, "GOOGLE_API_KEY in local_info.py needs to be filled out" 415 | 416 | args = parse_args() 417 | base_outfolder = args.base_outfolder 418 | assert os.path.isdir(base_outfolder) 419 | 420 | pickle_path = os.path.join(base_outfolder, 'dob.p') 421 | info_pickle_path = os.path.join(base_outfolder, 'dob_info.p') 422 | 423 | speaker_id_path = os.path.join(base_outfolder, 'speaker_ids.json') 424 | assert os.path.isfile(speaker_id_path), "Can't find speaker_ids.json" 425 | 426 | speaker_ids = json.load(open(speaker_id_path, encoding='utf-8'), 427 | object_pairs_hook=OrderedDict) 428 | 429 | parser = LawyerDOBParser() 430 | parser.get_scotus_clerks() 431 | 432 | if args.overwrite or not os.path.isfile(pickle_path): 433 | dobs = OrderedDict({}) 434 | infos = OrderedDict({}) 435 | speakers_to_scrape = sorted(speaker_ids.keys()) 436 | else: 437 | infos = pickle.load(open(info_pickle_path, 'rb')) 438 | dobs = pickle.load(open(pickle_path, 'rb')) 439 | if args.skip_attempted: 440 | speakers_to_scrape = set(speaker_ids.keys()) - set(dobs.keys()) 441 | else: 442 | speakers_to_scrape = set(speaker_ids.keys()) - set([s for s in dobs if dobs[s]]) 443 | 444 | if speakers_to_scrape: 445 | speakers_to_scrape = sorted(list(speakers_to_scrape)) 446 | 447 | for s in tqdm(speakers_to_scrape): 448 | query_name = speaker_ids[s]['name'] 449 | parsed_dob, info = parser.parse_name(query_name) 450 | dobs[s] = parsed_dob 451 | infos[s] = info 452 | pickle.dump(dobs, open(pickle_path, 'wb')) 453 | pickle.dump(infos, open(info_pickle_path, 'wb')) 454 | 455 | num_dob_speakers = sum([1 for s in dobs if dobs[s]]) 456 | 457 | print('Found DoB for {} out of {} speakers'.format(num_dob_speakers, len(speaker_ids))) 458 | print('Done!') 459 | 460 | -------------------------------------------------------------------------------- /scotus_data_prep/step3_prepdata.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import pickle 6 | import sys 7 | import shutil 8 | from collections import OrderedDict 9 | from copy import copy 10 | from glob import glob 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Prep the data for verification, diarization and feature extraction') 18 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder') 19 | parser.add_argument('--subsegment-length', type=float, default=1.5, help='Length of diarization subsegments (default: 1.5s)') 20 | parser.add_argument('--subsegment-shift', type=float, default=0.75, help='Subsegment shift duration (default: 0.75s)') 21 | args = parser.parse_args() 22 | return args 23 | 24 | def write_json(outfile, outdict): 25 | with open(outfile, 'w', encoding='utf-8') as wp: 26 | json.dump(outdict, wp) 27 | 28 | def assign_newrecnames(caselist): 29 | """ 30 | Converts recording names to a Kaldi friendly standard format of YEAR-XYZ 31 | 32 | input: recording json list 33 | output: (original to newrecname mapping, newrecname to original mapping) 34 | """ 35 | recid_orec_mapping = OrderedDict({}) 36 | 37 | for r in caselist: 38 | rec_name = os.path.splitext(os.path.basename(r))[0] 39 | year = rec_name[:4] 40 | assert len(year) == 4, 'Year length was not 4, something is wrong' 41 | index = 0 42 | new_recid = "{0}-{1:0=3d}".format(str(year), index) 43 | while True: 44 | if new_recid in recid_orec_mapping: 45 | index += 1 46 | new_recid = "{0}-{1:0=3d}".format(str(year), index) 47 | else: 48 | recid_orec_mapping[new_recid] = rec_name 49 | break 50 | 51 | orec_recid_mapping = OrderedDict({recid_orec_mapping[k]: k for k in recid_orec_mapping}) 52 | return orec_recid_mapping, recid_orec_mapping 53 | 54 | 55 | def make_wavscp(base_folder, caselist, orec_recid_mapping): 56 | """ 57 | Make wav.scp with new recnames 58 | 59 | inputs: (outfile, list of case jsons, original to new recording mapping) 60 | outputs: None [file written to outfile] 61 | """ 62 | wavlines = [] 63 | wavscp = os.path.join(base_folder, 'wav.scp') 64 | for r in caselist: 65 | orec_name = os.path.splitext(os.path.basename(r))[0] 66 | newrec_id = orec_recid_mapping[orec_name] 67 | mp3file = os.path.join(os.path.abspath(base_folder), 'audio/{}.mp3'.format(orec_name)) 68 | assert os.path.isfile(mp3file), "Couldn't find {}".format(mp3file) 69 | wavline = '{} ffmpeg -v 8 -i {} -f wav -ar 16000 -acodec pcm_s16le -|\n'.format(newrec_id, mp3file) 70 | wavlines.append(wavline) 71 | 72 | with open(wavscp, 'w+') as wp: 73 | for line in wavlines: 74 | wp.write(line) 75 | 76 | 77 | def process_utts_dob(utts, dobs): 78 | ''' 79 | Removes utterances from speakers with unknown dob 80 | ''' 81 | new_utts = [] 82 | for u in utts: 83 | if u['speaker_id'] in dobs: 84 | if dobs[u['speaker_id']]: 85 | new_utts.append(u) 86 | assert new_utts, "No more utts remain: this should not occur, check DoB pickle" 87 | return new_utts 88 | 89 | 90 | def join_up_utterances(utt_list, 91 | cutoff_dur=10000, 92 | soft_min_length=0.0): 93 | new_utt_list = [utt_list[0].copy()] 94 | for i, utt in enumerate(utt_list): 95 | if i == 0: 96 | continue 97 | if utt['start'] == new_utt_list[-1]['stop'] and utt['speaker_id'] == new_utt_list[-1]['speaker_id']: 98 | if new_utt_list[-1]['stop'] - new_utt_list[-1]['start'] < cutoff_dur or utt_list[i]['stop'] - utt_list[i][ 99 | 'start'] < soft_min_length: 100 | new_utt_list[-1]['stop'] = utt['stop'] 101 | new_utt_list[-1]['text'] += ' {}'.format(utt['text']) 102 | 103 | else: 104 | new_utt_list.append(utt_list[i].copy()) 105 | return new_utt_list 106 | 107 | 108 | def split_up_single_utterance(utt, 109 | target_utt_length=10.0, 110 | min_utt_length=4.0): 111 | duration = utt['stop'] - utt['start'] 112 | if duration < min_utt_length + target_utt_length: 113 | return [utt] 114 | else: 115 | remaining_duration = copy(duration) 116 | new_utt_list = [] 117 | iterations = 0 118 | while remaining_duration >= min_utt_length + target_utt_length: 119 | new_utt = OrderedDict({}) 120 | new_utt['start'] = utt['start'] + (iterations * target_utt_length) 121 | new_utt['stop'] = new_utt['start'] + target_utt_length 122 | new_utt['text'] = utt['text'] 123 | new_utt['speaker_id'] = utt['speaker_id'] 124 | 125 | new_utt_list.append(new_utt) 126 | 127 | remaining_duration -= target_utt_length 128 | iterations += 1 129 | 130 | new_utt = OrderedDict({}) 131 | new_utt['start'] = new_utt_list[-1]['stop'] 132 | new_utt['stop'] = utt['stop'] 133 | new_utt['text'] = utt['text'] 134 | new_utt['speaker_id'] = utt['speaker_id'] 135 | new_utt_list.append(new_utt) 136 | 137 | return new_utt_list 138 | 139 | 140 | def split_up_long_utterances(utt_list, 141 | target_utt_length=10.0, 142 | min_utt_length=4.0): 143 | new_utt_list = [] 144 | for utt in utt_list: 145 | splitup = split_up_single_utterance(utt, 146 | target_utt_length=target_utt_length, 147 | min_utt_length=min_utt_length) 148 | new_utt_list.extend(splitup) 149 | return new_utt_list 150 | 151 | 152 | def make_segments(utts, recname, min_utt_len=2.0): 153 | """ 154 | Make kaldi friendly segments of the format recname-VWXYZ 155 | 156 | Input: uttlist imported from transcription json, recname 157 | 158 | """ 159 | seglines = [] 160 | speakers = [] 161 | utt_ids = [] 162 | 163 | i = 0 164 | for utt in utts: 165 | utt_id = '{0}-{1:0=5d}'.format(recname, i) 166 | start = float(utt['start']) 167 | stop = float(utt['stop']) 168 | 169 | if float(start) + min_utt_len > float(stop): 170 | # Discard too short snippets 171 | continue 172 | 173 | speaker_id = utt['speaker_id'] 174 | if not speaker_id.strip(): 175 | # Discard empty speaker_ids 176 | continue 177 | 178 | line = '{} {} {} {}\n'.format(utt_id, recname, start, stop) 179 | 180 | seglines.append(line) 181 | speakers.append(speaker_id) 182 | utt_ids.append(utt_id) 183 | i += 1 184 | 185 | return seglines, utt_ids, speakers 186 | 187 | 188 | def prep_utts(utts): 189 | # Converts start and stop to float 190 | for u in utts: 191 | u['start'] = float(u['start']) 192 | u['stop'] = float(u['stop']) 193 | return utts 194 | 195 | 196 | def calculate_speaker_ages(speakers, dobs, rec_date): 197 | """ 198 | Calculate the age in days of a list of speakers based on recording date and DoB 199 | 200 | inputs: (list of speakers, DoB dictionary of datetimes, recording datetime) 201 | output: list of speaker ages 202 | """ 203 | set_speakers = set(speakers) 204 | age_dict = {} 205 | for s in set_speakers: 206 | dob = dobs[s] 207 | delta_days = abs((rec_date - dob).days) 208 | age_dict[s] = delta_days 209 | speaker_ages = [age_dict[s] for s in speakers] 210 | return speaker_ages 211 | 212 | 213 | def make_verification_dataset(base_outfolder, caselist, orec_recid_mapping, dobs): 214 | """ 215 | Makes a verification/training dataset: Long utterances split up 216 | """ 217 | veri_data_path = os.path.join(base_outfolder, 'veri_data') 218 | os.makedirs(veri_data_path, exist_ok=True) 219 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp') 220 | shutil.copy(wavscp_path, veri_data_path) 221 | 222 | all_seglines = [] 223 | all_uttids = [] 224 | all_speakers = [] 225 | all_recs = [] 226 | all_ages = [] 227 | 228 | for case in tqdm(caselist): 229 | rec_name = os.path.splitext(os.path.basename(case))[0] 230 | new_recid = orec_recid_mapping[rec_name] 231 | js = json.load(open(case, encoding='utf-8'), object_pairs_hook=OrderedDict) 232 | 233 | utts = js['utts'] 234 | utts = process_utts_dob(utts, dobs) 235 | utts = prep_utts(utts) 236 | utts = join_up_utterances(utts) 237 | utts = split_up_long_utterances(utts, target_utt_length=10.0, min_utt_length=4.0) 238 | 239 | seglines, utt_ids, speakers = make_segments(utts, new_recid) 240 | 241 | case_date = datetime.datetime.strptime(js['case_date'], '%Y/%m/%d') 242 | utt_ages = calculate_speaker_ages(speakers, dobs, case_date) 243 | 244 | all_recs.extend([new_recid for _ in utt_ids]) 245 | all_seglines.extend(seglines) 246 | all_uttids.extend(utt_ids) 247 | all_speakers.extend(speakers) 248 | all_ages.extend(utt_ages) 249 | 250 | with open(os.path.join(veri_data_path, 'segments'), 'w+') as fp: 251 | for line in all_seglines: 252 | fp.write(line) 253 | 254 | with open(os.path.join(veri_data_path, 'real_utt2spk'), 'w+') as fp: 255 | for u, s in zip(all_uttids, all_speakers): 256 | line = '{} {}\n'.format(u, s) 257 | fp.write(line) 258 | 259 | with open(os.path.join(veri_data_path, 'utt2age'), 'w+') as fp: 260 | for u, a in zip(all_uttids, all_ages): 261 | line = '{} {}\n'.format(u, a) 262 | fp.write(line) 263 | 264 | with open(os.path.join(veri_data_path, 'utt2spk'), 'w+') as fp: 265 | for u, r in zip(all_uttids, all_recs): 266 | line = '{} {}\n'.format(u, r) 267 | fp.write(line) 268 | 269 | utt2spk_to_spk2utt(os.path.join(veri_data_path, 'utt2spk'), 270 | outfile=os.path.join(veri_data_path, 'spk2utt')) 271 | 272 | utt2spk_to_spk2utt(os.path.join(veri_data_path, 'real_utt2spk'), 273 | outfile=os.path.join(veri_data_path, 'real_spk2utt')) 274 | 275 | 276 | 277 | def split_up_single_utterance_subsegments(utt, 278 | target_utt_length=1.5, 279 | min_utt_length=1.4, 280 | shift=0.75): 281 | """ 282 | Split up a single utterance into subsegments based on input variables 283 | """ 284 | duration = utt['stop'] - utt['start'] 285 | if duration < min_utt_length + target_utt_length - shift: 286 | return [utt] 287 | else: 288 | new_utt_list = [] 289 | current_start = copy(utt['start']) 290 | while current_start <= utt['stop'] - min_utt_length: 291 | new_utt = OrderedDict({}) 292 | new_utt['start'] = current_start 293 | new_utt['stop'] = new_utt['start'] + target_utt_length 294 | new_utt['byte_start'] = utt['byte_start'] # todo fix 295 | new_utt['byte_stop'] = utt['byte_stop'] # todo fix 296 | new_utt['text'] = 'n/a' 297 | new_utt['speaker_id'] = utt['speaker_id'] 298 | 299 | new_utt_list.append(new_utt) 300 | 301 | current_start += shift 302 | 303 | new_utt = OrderedDict({}) 304 | new_utt['start'] = current_start 305 | new_utt['stop'] = utt['stop'] 306 | new_utt['byte_start'] = utt['byte_start'] 307 | new_utt['byte_stop'] = utt['byte_stop'] 308 | new_utt['text'] = 'n/a' 309 | new_utt['speaker_id'] = utt['speaker_id'] 310 | new_utt_list.append(new_utt) 311 | 312 | return new_utt_list 313 | 314 | 315 | def split_up_uttlist_subsegments(utt_list, 316 | target_utt_length=1.5, 317 | min_utt_length=1.4, 318 | shift=0.75): 319 | """ 320 | Split up a list of utterances into subsegments based in input variables 321 | """ 322 | new_utt_list = [] 323 | for utt in utt_list: 324 | splitup = split_up_single_utterance_subsegments(utt, 325 | target_utt_length=target_utt_length, 326 | min_utt_length=min_utt_length, 327 | shift=shift) 328 | new_utt_list.extend(splitup) 329 | return new_utt_list 330 | 331 | def rttm_lines_from_uttlist(utts, rec_id): 332 | utts = join_up_utterances(utts) 333 | rttm_lines = [] 334 | 335 | rttm_line = 'SPEAKER {} 0 {:.2f} {:.2f} {} \n' 336 | 337 | for u in utts: 338 | spk = u['speaker_id'] 339 | start = float(u['start']) 340 | stop = float(u['stop']) 341 | offset = stop - start 342 | if offset < 0.0: 343 | continue 344 | rttm_lines.append(rttm_line.format(rec_id, start, offset, spk)) 345 | 346 | return rttm_lines 347 | 348 | 349 | 350 | def make_diarization_dataset(base_outfolder, caselist, orec_recid_mapping, dobs, 351 | target_utt_length=1.5, min_utt_length=1.4, shift=0.75): 352 | """ 353 | Makes a diarization dataset, utterances split into overlapping segments 354 | Makes rttms 355 | """ 356 | diar_data_path = os.path.join(base_outfolder, 'diar_data') 357 | os.makedirs(diar_data_path, exist_ok=True) 358 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp') 359 | shutil.copy(wavscp_path, diar_data_path) 360 | 361 | all_seglines = [] 362 | all_uttids = [] 363 | all_speakers = [] 364 | all_recs = [] 365 | 366 | all_rttm_lines = [] 367 | 368 | for case in tqdm(caselist): 369 | rec_name = os.path.splitext(os.path.basename(case))[0] 370 | new_recid = orec_recid_mapping[rec_name] 371 | js = json.load(open(case, encoding='utf-8'), object_pairs_hook=OrderedDict) 372 | 373 | utts = js['utts'] 374 | utts = prep_utts(utts) 375 | utts = join_up_utterances(utts, cutoff_dur=1e20) 376 | 377 | rec_rttm_lines = rttm_lines_from_uttlist(utts, new_recid) 378 | all_rttm_lines.extend(rec_rttm_lines) 379 | 380 | utts = split_up_uttlist_subsegments(utts, 381 | target_utt_length=1.5, 382 | min_utt_length=1.4, 383 | shift=0.75) 384 | 385 | seglines, utt_ids, speakers = make_segments(utts, new_recid, min_utt_len=min_utt_length) 386 | 387 | all_recs.extend([new_recid for _ in utt_ids]) 388 | all_seglines.extend(seglines) 389 | all_uttids.extend(utt_ids) 390 | all_speakers.extend(speakers) 391 | 392 | with open(os.path.join(diar_data_path, 'segments'), 'w+') as fp: 393 | for line in all_seglines: 394 | fp.write(line) 395 | 396 | with open(os.path.join(diar_data_path, 'ref.rttm'), 'w+') as fp: 397 | for line in all_rttm_lines: 398 | fp.write(line) 399 | 400 | with open(os.path.join(diar_data_path, 'real_utt2spk'), 'w+') as fp: 401 | for u, s in zip(all_uttids, all_speakers): 402 | line = '{} {}\n'.format(u, s) 403 | fp.write(line) 404 | 405 | with open(os.path.join(diar_data_path, 'utt2spk'), 'w+') as fp: 406 | for u, r in zip(all_uttids, all_recs): 407 | line = '{} {}\n'.format(u, r) 408 | fp.write(line) 409 | 410 | utt2spk_to_spk2utt(os.path.join(diar_data_path, 'utt2spk'), 411 | outfile=os.path.join(diar_data_path, 'spk2utt')) 412 | 413 | utt2spk_to_spk2utt(os.path.join(diar_data_path, 'real_utt2spk'), 414 | outfile=os.path.join(diar_data_path, 'real_spk2utt')) 415 | 416 | 417 | def utt2spk_to_spk2utt(utt2spk_path, outfile=None): 418 | utts = [] 419 | spks = [] 420 | with open(utt2spk_path) as fp: 421 | for line in fp: 422 | splitup = line.strip().split(' ') 423 | assert len(splitup) == 2, 'Got more or less columns that was expected: (Got {}, expected 2)'.format( 424 | len(splitup)) 425 | utts.append(splitup[0]) 426 | spks.append(splitup[1]) 427 | set_spks = sorted(list(set(spks))) 428 | spk2utt_dict = OrderedDict({k: [] for k in set_spks}) 429 | for u, s in zip(utts, spks): 430 | spk2utt_dict[s].append(u) 431 | 432 | if outfile: 433 | with open(outfile, 'w+') as fp: 434 | for spk in spk2utt_dict: 435 | line = '{} {}\n'.format(spk, ' '.join(spk2utt_dict[spk])) 436 | fp.write(line) 437 | return spk2utt_dict 438 | 439 | 440 | 441 | if __name__ == "__main__": 442 | args = parse_args() 443 | base_outfolder = args.base_outfolder 444 | assert os.path.isdir(base_outfolder), 'Outfolder does not exist' 445 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp') 446 | assert os.path.isfile(wavscp_path), "Can't find {}".format(wavscp_path) 447 | 448 | dobs_pkl_path = os.path.join(args.base_outfolder, 'dob.p') 449 | assert os.path.isfile(dobs_pkl_path), "Couldn't find {}".format(dobs_pkl_path) 450 | dobs = pickle.load(open(dobs_pkl_path, 'rb')) 451 | 452 | audio_folder = os.path.join(base_outfolder, 'audio') 453 | transcript_folder = os.path.join(base_outfolder, 'transcripts') 454 | 455 | caselist = glob(os.path.join(transcript_folder, '*.json')) 456 | caselist = sorted(caselist) 457 | 458 | print('Assigning new recording names...') 459 | orec_recid_mapping, recid_orec_mapping = assign_newrecnames(caselist) 460 | 461 | write_json(os.path.join(base_outfolder, 'orec_recid_mapping.json'), orec_recid_mapping) 462 | write_json(os.path.join(base_outfolder, 'recid_orec_mapping.json'), recid_orec_mapping) 463 | 464 | print('Making wav.scp in {}'.format(os.path.join(base_outfolder, 'wav.scp'))) 465 | make_wavscp(base_outfolder, caselist, orec_recid_mapping) 466 | 467 | print('Making base verification/training dataset...') 468 | make_verification_dataset(base_outfolder, caselist, orec_recid_mapping, dobs) 469 | 470 | print('Making diarization dataset...') 471 | make_diarization_dataset(base_outfolder, caselist, orec_recid_mapping, dobs, 472 | target_utt_length=args.subsegment_length, min_utt_length=args.subsegment_length-0.1, 473 | shift=args.subsegment_shift) 474 | 475 | -------------------------------------------------------------------------------- /scotus_data_prep/step4_extract_feats.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | step=0 4 | nj=20 5 | base_outfolder=/PATH/TO/BASE_OUTFOLDER 6 | scotus_veri_dir=$base_outfolder/veri_data 7 | scotus_diar_dir=$base_outfolder/diar_data 8 | 9 | if [ $step -le 0 ]; then 10 | #make the feats 11 | utils/fix_data_dir.sh $scotus_veri_dir 12 | steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj $nj \ 13 | --cmd run.pl $scotus_veri_dir 14 | 15 | utils/fix_data_dir.sh $scotus_veri_dir 16 | sid/compute_vad_decision.sh --nj $nj --cmd run.pl $scotus_veri_dir 17 | utils/fix_data_dir.sh $scotus_veri_dir 18 | fi 19 | 20 | if [ $step -le 1 ]; then 21 | local/nnet3/xvector/prepare_feats_for_egs.sh --nj $nj --cmd run.pl \ 22 | $scotus_veri_dir ${scotus_veri_dir}_nosil $scotus_veri_dir/nosil_feats 23 | utils/fix_data_dir.sh ${scotus_veri_dir}_nosil 24 | fi 25 | 26 | if [ $step -le 2 ]; then 27 | #make the feats 28 | utils/fix_data_dir.sh $scotus_diar_dir 29 | steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj $nj \ 30 | --cmd run.pl $scotus_diar_dir 31 | 32 | utils/fix_data_dir.sh $scotus_diar_dir 33 | sid/compute_vad_decision.sh --nj $nj --cmd run.pl $scotus_diar_dir 34 | utils/fix_data_dir.sh $scotus_diar_dir 35 | fi 36 | 37 | if [ $step -le 3 ]; then 38 | local/nnet3/xvector/prepare_feats_for_egs.sh --nj $nj --cmd run.pl \ 39 | $scotus_diar_dir ${scotus_diar_dir}_nosil $scotus_diar_dir/nosil_feats 40 | utils/fix_data_dir.sh ${scotus_diar_dir}_nosil 41 | fi -------------------------------------------------------------------------------- /scotus_data_prep/step5_trim_split_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import pickle 6 | import random 7 | import shutil 8 | import sys 9 | from collections import OrderedDict 10 | from itertools import combinations 11 | 12 | import numpy as np 13 | from scipy.special import comb 14 | from tqdm import tqdm 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='Prep the data for verification, diarization and feature extraction') 19 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder') 20 | parser.add_argument('--train-proportion', type=float, default=0.8, help='Train proportion (default: 0.8)') 21 | parser.add_argument('--pos-per-spk', type=int, default=10, help='Positive trials per speaker') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def load_n_col(file): 27 | data = [] 28 | with open(file) as fp: 29 | for line in fp: 30 | data.append(line.strip().split(' ')) 31 | columns = list(zip(*data)) 32 | columns = [list(i) for i in columns] 33 | return columns 34 | 35 | 36 | def write_lines(lines, file): 37 | with open(file, 'w+') as fp: 38 | for line in lines: 39 | fp.write(line) 40 | 41 | 42 | def fix_data_dir(data_dir): 43 | """ 44 | Files: real_utt2spk, real_spk2utt, utt2age, wav.scp, utt2spk, spk2utt, segments 45 | Cleans and fixes all the right files to agree with kaldi data processing 46 | """ 47 | backup_data_dir = os.path.join(data_dir, '.mybackup') 48 | os.makedirs(backup_data_dir, exist_ok=True) 49 | 50 | files = glob.glob(os.path.join(data_dir, '*')) 51 | files = [f for f in files if os.path.isfile(f)] 52 | _ = [shutil.copy(f, backup_data_dir) for f in files] 53 | 54 | utt2spk_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'real_utt2spk')))}) 55 | utt2fpath_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'feats.scp')))}) 56 | 57 | utt2age = os.path.isfile(os.path.join(data_dir, 'utt2age')) 58 | if utt2age: 59 | utt2age_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'utt2age')))}) 60 | 61 | complete_utts = set(utt2fpath_dict.keys()).intersection(set(utt2spk_dict.keys())) 62 | complete_utts = sorted(list(complete_utts)) 63 | 64 | print('Reducing real utt2spk ({}) to {} utts...'.format(len(utt2spk_dict), 65 | len(complete_utts))) 66 | 67 | blank_utts = os.path.join(data_dir, 'utts') 68 | with open(blank_utts, 'w+') as fp: 69 | for u in complete_utts: 70 | fp.write('{}\n'.format(u)) 71 | 72 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'real_utt2spk'), 73 | os.path.join(data_dir, 'utt2spk'))) 74 | 75 | os.rename(os.path.join(data_dir, 'segments'), os.path.join(data_dir, 'segments_old')) 76 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'segments_old'), 77 | os.path.join(data_dir, 'segments'))) 78 | 79 | if utt2age: 80 | os.rename(os.path.join(data_dir, 'utt2age'), os.path.join(data_dir, 'utt2age_old')) 81 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'utt2age_old'), 82 | os.path.join(data_dir, 'utt2age'))) 83 | 84 | 85 | set_spks = sorted(list(set(utt2spk_dict.values()))) 86 | spk2utt_dict = OrderedDict({k: [] for k in set_spks}) 87 | for u in complete_utts: 88 | spk2utt_dict[utt2spk_dict[u]].append(u) 89 | 90 | with open(os.path.join(data_dir, 'spk2utt'), 'w+') as fp: 91 | for s in spk2utt_dict: 92 | if spk2utt_dict[s]: 93 | line = '{} {}\n'.format(s, ' '.join(spk2utt_dict[s])) 94 | fp.write(line) 95 | 96 | 97 | def load_one_tomany(file): 98 | one = [] 99 | many = [] 100 | with open(file) as fp: 101 | for line in fp: 102 | line = line.strip().split(' ', 1) 103 | one.append(line[0]) 104 | m = line[1].split(' ') 105 | many.append(m) 106 | return one, many 107 | 108 | 109 | def utt2spk_to_spk2utt(utt2spk_path, outfile=None): 110 | utts = [] 111 | spks = [] 112 | with open(utt2spk_path) as fp: 113 | for line in fp: 114 | splitup = line.strip().split(' ') 115 | assert len(splitup) == 2, 'Got more or less columns that was expected: (Got {}, expected 2)'.format( 116 | len(splitup)) 117 | utts.append(splitup[0]) 118 | spks.append(splitup[1]) 119 | set_spks = sorted(list(set(spks))) 120 | spk2utt_dict = OrderedDict({k: [] for k in set_spks}) 121 | for u, s in zip(utts, spks): 122 | spk2utt_dict[s].append(u) 123 | 124 | if outfile: 125 | with open(outfile, 'w+') as fp: 126 | for spk in spk2utt_dict: 127 | line = '{} {}\n'.format(spk, ' '.join(spk2utt_dict[spk])) 128 | fp.write(line) 129 | return spk2utt_dict 130 | 131 | 132 | def split_recordings(data_dir, train_proportion=0.8): 133 | """ 134 | Split the recordings based on train proportion 135 | 136 | returns train_recordings, test_recordings 137 | """ 138 | np.random.seed(1234) 139 | random.seed(1234) 140 | segments_path = os.path.join(data_dir, 'segments') 141 | assert os.path.isfile(segments_path), "Couldn't find {}".format(segments_path) 142 | 143 | utts, urecs, _, _ = load_n_col(segments_path) 144 | 145 | setrecs = sorted(list(set(urecs))) 146 | num_train_recs = int(np.floor(train_proportion * len(setrecs))) 147 | 148 | train_recs = np.random.choice(setrecs, size=num_train_recs, replace=False) 149 | test_recs = [r for r in setrecs if r not in train_recs] 150 | return train_recs, test_recs 151 | 152 | 153 | def split_data_dir(data_dir, train_recs, test_recs): 154 | """ 155 | Split recordings and files into train and test subfolders based on train_recs, test_recs 156 | 157 | Filters feats, utt2spk, spk2utt, segments, (rttm, utt2age) 158 | """ 159 | segments_path = os.path.join(data_dir, 'segments') 160 | assert os.path.isfile(segments_path), "Couldn't find {}".format(segments_path) 161 | 162 | train_dir = os.path.join(data_dir, 'train') 163 | test_dir = os.path.join(data_dir, 'test') 164 | 165 | os.makedirs(train_dir, exist_ok=True) 166 | os.makedirs(test_dir, exist_ok=True) 167 | 168 | utts, urecs, _, _ = load_n_col(segments_path) 169 | utt2rec_dict = OrderedDict({k:v for k,v in zip(utts, urecs)}) 170 | 171 | train_utts = [u for u in utts if utt2rec_dict[u] in train_recs] 172 | test_utts = [u for u in utts if utt2rec_dict[u] in test_recs] 173 | 174 | tr_u = os.path.join(train_dir, 'utts') 175 | with open(tr_u, 'w+') as fp: 176 | for u in train_utts: 177 | fp.write('{}\n'.format(u)) 178 | 179 | te_u = os.path.join(test_dir, 'utts') 180 | with open(te_u, 'w+') as fp: 181 | for u in test_utts: 182 | fp.write('{}\n'.format(u)) 183 | 184 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'utt2spk'), 185 | os.path.join(train_dir, 'utt2spk'))) 186 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'utt2spk'), 187 | os.path.join(test_dir, 'utt2spk'))) 188 | 189 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'feats.scp'), 190 | os.path.join(train_dir, 'feats.scp'))) 191 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'feats.scp'), 192 | os.path.join(test_dir, 'feats.scp'))) 193 | 194 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'segments'), 195 | os.path.join(train_dir, 'segments'))) 196 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'segments'), 197 | os.path.join(test_dir, 'segments'))) 198 | 199 | utt2spk_to_spk2utt(os.path.join(train_dir, 'utt2spk'), outfile=os.path.join(train_dir, 'spk2utt')) 200 | utt2spk_to_spk2utt(os.path.join(test_dir, 'utt2spk'), outfile=os.path.join(test_dir, 'spk2utt')) 201 | 202 | if os.path.isfile(os.path.join(data_dir, 'utt2age')): 203 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'utt2age'), 204 | os.path.join(train_dir, 'utt2age'))) 205 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'utt2age'), 206 | os.path.join(test_dir, 'utt2age'))) 207 | 208 | rttm_path = os.path.join(data_dir, 'ref.rttm') 209 | if os.path.isfile(rttm_path): 210 | with open(rttm_path) as fp: 211 | rttm_lines = [line for line in fp] 212 | 213 | rttm_recs = [line.split()[1].strip() for line in rttm_lines] 214 | 215 | tr_rttmlines = [l for l, r in zip(rttm_lines, rttm_recs) if r in train_recs] 216 | te_rttmlines = [l for l, r in zip(rttm_lines, rttm_recs) if r in test_recs] 217 | 218 | write_lines(tr_rttmlines, os.path.join(train_dir, 'ref.rttm')) 219 | write_lines(te_rttmlines, os.path.join(test_dir, 'ref.rttm')) 220 | 221 | 222 | def nonoverlapped_utts(veri_data_dir): 223 | ''' 224 | Retrieve only the utterances that belong to speakers not seen in the training set 225 | ''' 226 | train_dir = os.path.join(veri_data_dir, 'train') 227 | test_dir = os.path.join(veri_data_dir, 'test') 228 | 229 | train_utts, train_spkrs = load_n_col(os.path.join(train_dir, 'utt2spk')) 230 | test_utts, test_spkrs = load_n_col(os.path.join(test_dir, 'utt2spk')) 231 | 232 | set_tr_spkrs = set(train_spkrs) 233 | set_te_spkrs = set(test_spkrs) 234 | 235 | non_overlapped_speakers = set_te_spkrs - set_tr_spkrs 236 | assert len(non_overlapped_speakers) >= 10, "Something has gone wrong if less than 10 speakers are left" 237 | 238 | valid_utts, valid_spkrs = [], [] 239 | for u, s in zip(test_utts, test_spkrs): 240 | if s in non_overlapped_speakers: 241 | valid_utts.append(u) 242 | valid_spkrs.append(s) 243 | 244 | return valid_utts, valid_spkrs 245 | 246 | 247 | def generate_veri_pairs(utts, spkrs, pos_per_spk=15): 248 | # Randomly creates pairs for a verification list 249 | # pos_per_spk determines the number of same-speaker pairs, which is always paired with an equal number of negatives 250 | np.random.seed(1234) 251 | random.seed(1234) 252 | setspkrs = sorted(list(set(spkrs))) 253 | spk2utt_dict = OrderedDict({s: [] for s in setspkrs}) 254 | for utt, s in zip(utts, spkrs): 255 | spk2utt_dict[s].append(utt) 256 | 257 | u0 = [] 258 | u1 = [] 259 | labs = [] 260 | 261 | for s in setspkrs: 262 | random.shuffle(spk2utt_dict[s]) 263 | 264 | for s in setspkrs: 265 | positives = [spk2utt_dict[s].pop() for _ in range(pos_per_spk)] 266 | num_neg_trials = int(comb(pos_per_spk, 2)) 267 | negatives = [spk2utt_dict[np.random.choice(list(set(setspkrs) - set([s])))].pop() for _ in 268 | range(num_neg_trials)] 269 | for a, b in combinations(positives, 2): 270 | u0.append(a) 271 | u1.append(b) 272 | labs.append(1) 273 | for a, b in zip(np.random.choice(positives, size=num_neg_trials, replace=True), negatives): 274 | u0.append(a) 275 | u1.append(b) 276 | labs.append(0) 277 | 278 | return u0, u1, labs 279 | 280 | 281 | if __name__ == "__main__": 282 | args = parse_args() 283 | 284 | veri_data_dir = os.path.join(args.base_outfolder, 'veri_data_nosil') 285 | assert os.path.isdir(veri_data_dir), "Couldn't find {}".format(veri_data_dir) 286 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/segments'), veri_data_dir) 287 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/utt2age'), veri_data_dir) 288 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/real_utt2spk'), veri_data_dir) 289 | 290 | print('Fixing data dir: {}'.format(veri_data_dir)) 291 | fix_data_dir(veri_data_dir) 292 | train_recs, test_recs = split_recordings(veri_data_dir, train_proportion=args.train_proportion) 293 | 294 | print('Making recording split...') 295 | with open(os.path.join(args.base_outfolder, 'recording_split.json'), 'w+', encoding='utf-8') as fp: 296 | recdict = {'train': list(train_recs), 'test': list(test_recs)} 297 | json.dump(recdict, fp) 298 | 299 | print('Splitting verification data...') 300 | split_data_dir(veri_data_dir, train_recs, test_recs) 301 | 302 | print('Making verification pairs for test portion of verification data...') 303 | valid_utts, valid_spkrs = nonoverlapped_utts(veri_data_dir) 304 | u0, u1, labs = generate_veri_pairs(valid_utts, valid_spkrs, pos_per_spk=args.pos_per_spk) 305 | veri_lines = ['{} {} {}\n'.format(l, a, b) for l, a, b in zip(labs, u0, u1)] 306 | write_lines(veri_lines, os.path.join(veri_data_dir, 'test/veri_pairs')) 307 | 308 | print('Now fixing diarization data...') 309 | diar_data_dir = os.path.join(args.base_outfolder, 'diar_data_nosil') 310 | assert os.path.isdir(diar_data_dir), "Couldn't find {}".format(diar_data_dir) 311 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/real_utt2spk'), diar_data_dir) 312 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/segments'), diar_data_dir) 313 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/ref.rttm'), diar_data_dir) 314 | 315 | fix_data_dir(diar_data_dir) 316 | 317 | print('Splitting diarization data...') 318 | split_data_dir(diar_data_dir, train_recs, test_recs) 319 | 320 | 321 | print('Done!!') 322 | 323 | 324 | 325 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import configparser 3 | import glob 4 | import json 5 | import os 6 | import pickle 7 | import random 8 | import shutil 9 | import time 10 | from collections import OrderedDict 11 | from pprint import pprint 12 | from inference import test, test_all_factors, test_enrollment_models, test_nosil 13 | 14 | import numpy as np 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import uvloop 20 | from data_io import SpeakerDataset, SpeakerModelTestDataset, SpeakerTestDataset 21 | from models.classifiers import (AdaCos, AMSMLoss, ArcFace, L2SoftMax, SoftMax, 22 | SphereFace, XVecHead, XVecHeadUncertain, GradientReversal) 23 | from models.criteria import (DisturbLabelLoss, LabelSmoothingLoss, 24 | TwoNeighbourSmoothingLoss, MultiTaskUncertaintyLossKendall, 25 | MultiTaskUncertaintyLossLiebel) 26 | from models.extractors import ETDNN, FTDNN, XTDNN 27 | from torch.utils.data import DataLoader 28 | from torch.utils.tensorboard import SummaryWriter 29 | from tqdm import tqdm 30 | from utils import SpeakerRecognitionMetrics, schedule_lr 31 | 32 | 33 | def get_lr(optimizer): 34 | for param_group in optimizer.param_groups: 35 | return param_group['lr'] 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description='Train SV model') 40 | parser.add_argument('--cfg', type=str, default='./configs/example_speaker.cfg') 41 | parser.add_argument('--transfer-learning', action='store_true', default=False, 42 | help='Start from g_start.pt in exp folder') 43 | parser.add_argument('--resume-checkpoint', type=int, default=0) 44 | args = parser.parse_args() 45 | assert os.path.isfile(args.cfg) 46 | args._start_time = time.ctime() 47 | return args 48 | 49 | 50 | def parse_config(args): 51 | assert os.path.isfile(args.cfg) 52 | config = configparser.ConfigParser() 53 | config.read(args.cfg) 54 | 55 | args.train_data = config['Datasets'].get('train') 56 | assert args.train_data 57 | args.test_data = config['Datasets'].get('test') 58 | 59 | args.model_type = config['Model'].get('model_type', fallback='XTDNN') 60 | assert args.model_type in ['XTDNN', 'ETDNN', 'FTDNN'] 61 | 62 | args.classifier_heads = config['Model'].get('classifier_heads').split(',') 63 | assert len(args.classifier_heads) <= 3, 'Three options available' 64 | assert len(args.classifier_heads) == len(set(args.classifier_heads)) 65 | for clf in args.classifier_heads: 66 | assert clf in ['speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec'] 67 | 68 | args.classifier_types = config['Optim'].get('classifier_types').split(',') 69 | assert len(args.classifier_heads) == len(args.classifier_types) 70 | 71 | args.classifier_loss_weighting_type = config['Optim'].get('classifier_loss_weighting_type', fallback='none') 72 | assert args.classifier_loss_weighting_type in ['none', 'uncertainty_kendall', 'uncertainty_liebel', 'dwa'] 73 | 74 | args.dwa_temperature = config['Optim'].getfloat('dwa_temperature', fallback=2.) 75 | 76 | args.classifier_loss_weights = np.array(json.loads(config['Optim'].get('classifier_loss_weights'))).astype(float) 77 | assert len(args.classifier_heads) == len(args.classifier_loss_weights) 78 | 79 | args.classifier_lr_mults = np.array(json.loads(config['Optim'].get('classifier_lr_mults'))).astype(float) 80 | assert len(args.classifier_heads) == len(args.classifier_lr_mults) 81 | 82 | # assert clf_type in ['l2softmax', 'adm', 'adacos', 'xvec', 'arcface', 'sphereface', 'softmax'] 83 | 84 | args.classifier_smooth_types = config['Optim'].get('classifier_smooth_types').split(',') 85 | assert len(args.classifier_smooth_types) == len(args.classifier_heads) 86 | args.classifier_smooth_types = [s.strip() for s in args.classifier_smooth_types] 87 | 88 | args.label_smooth_type = config['Optim'].get('label_smooth_type', fallback='None') 89 | assert args.label_smooth_type in ['None', 'disturb', 'uniform'] 90 | args.label_smooth_prob = config['Optim'].getfloat('label_smooth_prob', fallback=0.1) 91 | 92 | args.input_dim = config['Hyperparams'].getint('input_dim', fallback=30) 93 | args.embedding_dim = config['Hyperparams'].getint('embedding_dim', fallback=512) 94 | args.lr = config['Hyperparams'].getfloat('lr', fallback=0.2) 95 | args.batch_size = config['Hyperparams'].getint('batch_size', fallback=400) 96 | args.max_seq_len = config['Hyperparams'].getint('max_seq_len', fallback=400) 97 | args.no_cuda = config['Hyperparams'].getboolean('no_cuda', fallback=False) 98 | args.seed = config['Hyperparams'].getint('seed', fallback=123) 99 | args.num_iterations = config['Hyperparams'].getint('num_iterations', fallback=50000) 100 | args.momentum = config['Hyperparams'].getfloat('momentum', fallback=0.9) 101 | args.scheduler_steps = np.array(json.loads(config.get('Hyperparams', 'scheduler_steps'))).astype(int) 102 | args.scheduler_lambda = config['Hyperparams'].getfloat('scheduler_lambda', fallback=0.5) 103 | args.multi_gpu = config['Hyperparams'].getboolean('multi_gpu', fallback=False) 104 | args.classifier_lr_mult = config['Hyperparams'].getfloat('classifier_lr_mult', fallback=1.) 105 | args.dropout = config['Hyperparams'].getboolean('dropout', fallback=True) 106 | 107 | args.model_dir = config['Outputs']['model_dir'] 108 | if not hasattr(args, 'basefolder'): 109 | args.basefolder = config['Outputs'].get('basefolder', fallback=None) 110 | args.log_file = os.path.join(args.model_dir, 'train.log') 111 | args.checkpoint_interval = config['Outputs'].getint('checkpoint_interval') 112 | args.results_pkl = os.path.join(args.model_dir, 'results.p') 113 | 114 | args.num_age_bins = config['Misc'].getint('num_age_bins', fallback=10) 115 | args.age_label_smoothing = config['Misc'].getboolean('age_label_smoothing', fallback=False) 116 | return args 117 | 118 | 119 | def train(ds_train): 120 | use_cuda = not args.no_cuda and torch.cuda.is_available() 121 | print('=' * 30) 122 | print('USE_CUDA SET TO: {}'.format(use_cuda)) 123 | print('CUDA AVAILABLE?: {}'.format(torch.cuda.is_available())) 124 | print('=' * 30) 125 | device = torch.device("cuda" if use_cuda else "cpu") 126 | 127 | writer = SummaryWriter(comment=os.path.basename(args.cfg)) 128 | 129 | if args.model_type == 'XTDNN': 130 | generator = XTDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim) 131 | if args.model_type == 'ETDNN': 132 | generator = ETDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim) 133 | if args.model_type == 'FTDNN': 134 | generator = FTDNN(in_dim=args.input_dim, embedding_dim=args.embedding_dim) 135 | 136 | generator.train() 137 | generator = generator.to(device) 138 | 139 | model_dict = {'generator': {'model': generator, 'lr_mult': 1., 'loss_weight': None}} 140 | clf_head_dict = {k: {'model': None, 'lr_mult': lr_mult, 'loss_weight': loss_weight} for k, lr_mult, loss_weight in 141 | zip(args.classifier_heads, args.classifier_lr_mults, args.classifier_loss_weights)} 142 | 143 | num_cls_per_task = [ds_train.num_classes[t] for t in args.classifier_heads] 144 | 145 | for clf_target, clf_type, num_classes, clf_smooth_type in zip(args.classifier_heads, 146 | args.classifier_types, 147 | num_cls_per_task, 148 | args.classifier_smooth_types): 149 | if clf_type == 'adm': 150 | clf = AMSMLoss(args.embedding_dim, num_classes) 151 | elif clf_type == 'adacos': 152 | clf = AdaCos(args.embedding_dim, num_classes) 153 | elif clf_type == 'l2softmax': 154 | clf = L2SoftMax(args.embedding_dim, num_classes) 155 | elif clf_type == 'softmax': 156 | clf = SoftMax(args.embedding_dim, num_classes) 157 | elif clf_type == 'xvec': 158 | clf = XVecHead(args.embedding_dim, num_classes) 159 | elif clf_type == 'xvec_regression': 160 | clf = XVecHead(args.embedding_dim, 1) 161 | elif clf_type == 'xvec_uncertain': 162 | clf = XVecHeadUncertain(args.embedding_dim, num_classes) 163 | elif clf_type == 'arcface': 164 | clf = ArcFace(args.embedding_dim, num_classes) 165 | elif clf_type == 'sphereface': 166 | clf = SphereFace(args.embedding_dim, num_classes) 167 | else: 168 | assert None, 'Classifier type {} not found'.format(clf_type) 169 | 170 | if clf_head_dict[clf_target]['loss_weight'] >= 0.0: 171 | clf_head_dict[clf_target]['model'] = clf.train().to(device) 172 | else: 173 | # GRL for negative loss weight 174 | abs_lw = np.abs(clf_head_dict[clf_target]['loss_weight']) 175 | clf_head_dict[clf_target]['model'] = nn.Sequential( 176 | GradientReversal(lambda_=abs_lw), 177 | clf 178 | ).train().to(device) 179 | clf_head_dict[clf_target]['loss_weight'] = 1.0 # this is lambda_ in the GRL 180 | 181 | if clf_smooth_type == 'none': 182 | if clf_target.endswith('regression'): 183 | clf_smooth = nn.SmoothL1Loss() 184 | else: 185 | clf_smooth = nn.CrossEntropyLoss() 186 | elif clf_smooth_type == 'twoneighbour': 187 | clf_smooth = TwoNeighbourSmoothingLoss(smoothing=args.label_smooth_prob) 188 | elif clf_smooth_type == 'uniform': 189 | clf_smooth = LabelSmoothingLoss(smoothing=args.label_smooth_prob) 190 | elif clf_smooth_type == 'disturb': 191 | clf_smooth = DisturbLabelLoss(device, disturb_prob=args.label_smooth_prob) 192 | else: 193 | assert None, 'Smooth type not found: {}'.format(clf_smooth_type) 194 | 195 | clf_head_dict[clf_target]['criterion'] = clf_smooth 196 | 197 | model_dict.update(clf_head_dict) 198 | 199 | if args.classifier_loss_weighting_type == 'uncertainty_kendall': 200 | model_dict['loss_aggregator'] = { 201 | 'model': MultiTaskUncertaintyLossKendall(len(args.classifier_heads)).to(device), 202 | 'lr_mult': 1., 203 | 'loss_weight': None 204 | } 205 | if args.classifier_loss_weighting_type == 'uncertainty_liebel': 206 | model_dict['loss_aggregator'] = { 207 | 'model': MultiTaskUncertaintyLossLiebel(len(args.classifier_heads)).to(device), 208 | 'lr_mult': 1., 209 | 'loss_weight': None 210 | } 211 | 212 | if args.resume_checkpoint != 0: 213 | model_str = os.path.join(args.model_dir, '{}_{}.pt') 214 | for m in model_dict: 215 | model_dict[m]['model'].load_state_dict(torch.load(model_str.format(m, args.resume_checkpoint))) 216 | 217 | optimizer = torch.optim.SGD( 218 | [{'params': model_dict[m]['model'].parameters(), 'lr': args.lr * model_dict[m]['lr_mult']} for m in model_dict], 219 | momentum=args.momentum) 220 | 221 | 222 | iterations = 0 223 | 224 | total_loss = 0 225 | running_loss = [np.nan for _ in range(500)] 226 | 227 | non_spk_clf_heads = [a for a in args.classifier_heads if a != 'speaker'] 228 | 229 | best_test_eer = (-1, 1.0) 230 | best_test_dcf = (-1, 1.0) 231 | best_acc = {k: (-1, 0.0) for k in non_spk_clf_heads} 232 | 233 | if os.path.isfile(args.results_pkl) and args.resume_checkpoint != 0: 234 | rpkl = pickle.load(open(args.results_pkl, "rb")) 235 | keylist = list(rpkl.keys()) 236 | 237 | if args.test_data: 238 | test_eers = [(rpkl[key]['test_eer'], key) for i, key in enumerate(rpkl)] 239 | best_teer = min(test_eers) 240 | best_test_eer = (best_teer[1], best_teer[0]) 241 | 242 | test_dcfs = [(rpkl[key]['test_dcf'], key) for i, key in enumerate(rpkl)] 243 | besttest_dcf = min(test_dcfs) 244 | best_test_dcf = (besttest_dcf[1], besttest_dcf[0]) 245 | 246 | else: 247 | rpkl = OrderedDict({}) 248 | 249 | if args.multi_gpu: 250 | dpp_generator = nn.DataParallel(generator).to(device) 251 | 252 | data_generator = ds_train.get_batches(batch_size=args.batch_size, max_seq_len=args.max_seq_len) 253 | 254 | if args.model_type == 'FTDNN': 255 | drop_indexes = np.linspace(0, 1, args.num_iterations) 256 | drop_sch = ([0, 0.5, 1], [0, 0.5, 0]) 257 | drop_schedule = np.interp(drop_indexes, drop_sch[0], drop_sch[1]) 258 | 259 | for iterations in range(1, args.num_iterations + 1): 260 | if iterations > args.num_iterations: 261 | break 262 | if iterations in args.scheduler_steps: 263 | schedule_lr(optimizer, factor=args.scheduler_lambda) 264 | if iterations <= args.resume_checkpoint: 265 | print('Skipping iteration {}'.format(iterations), file=open(args.log_file, "a")) 266 | continue 267 | 268 | if args.model_type == 'FTDNN': 269 | if args.dropout: 270 | generator.set_dropout_alpha(drop_schedule[iterations - 1]) 271 | 272 | feats, labels = next(data_generator) 273 | feats = feats.to(device) 274 | 275 | if args.multi_gpu: 276 | embeds = dpp_generator(feats) 277 | else: 278 | embeds = generator(feats) 279 | 280 | total_loss = 0 281 | losses = [] 282 | 283 | loss_tensors = [] 284 | 285 | for m in args.classifier_heads: 286 | lab = labels[m].to(device) 287 | if m == 'rec': 288 | preds = model_dict[m]['model'](embeds) 289 | else: 290 | preds = model_dict[m]['model'](embeds, lab) 291 | loss = model_dict[m]['criterion'](preds, lab) 292 | if args.classifier_loss_weighting_type == 'none': 293 | total_loss += loss * model_dict[m]['loss_weight'] 294 | else: 295 | loss_tensors.append(loss) 296 | losses.append(round(loss.item(), 4)) 297 | 298 | if args.classifier_loss_weighting_type.startswith('uncertainty'): 299 | loss_tensors = torch.FloatTensor(loss_tensors).to(device) 300 | total_loss = model_dict['loss_aggregator']['model'](loss_tensors) 301 | 302 | if args.classifier_loss_weighting_type == 'dwa': 303 | loss_tensors = loss_tensors 304 | if iterations < 4: 305 | loss_t_1 = np.ones(len(loss_tensors)) 306 | for l in loss_tensors: 307 | total_loss += l 308 | else: 309 | dwa_w = loss_t_1/loss_t_2 310 | K = len(loss_tensors) 311 | per_task_weight = torch.FloatTensor(dwa_w/args.dwa_temperature) #lambda_k 312 | per_task_weight = torch.nn.functional.softmax(per_task_weight, dim=0) * K 313 | per_task_weight = per_task_weight.numpy() 314 | for l, w in zip(loss_tensors, per_task_weight): 315 | total_loss += l * w 316 | 317 | loss_t_2 = loss_t_1.copy() 318 | loss_t_1 = torch.FloatTensor(loss_tensors).detach().cpu().numpy() 319 | 320 | 321 | optimizer.zero_grad() 322 | total_loss.backward() 323 | optimizer.step() 324 | 325 | if args.model_type == 'FTDNN': 326 | generator.step_ftdnn_layers() 327 | 328 | running_loss.pop(0) 329 | running_loss.append(total_loss.item()) 330 | rmean_loss = np.nanmean(np.array(running_loss)) 331 | 332 | if iterations % 10 == 0: 333 | msg = "{}: {}: [{}/{}] \t C-Loss:{:.4f}, AvgLoss:{:.4f}, losses: {}, lr: {}, bs: {}".format( 334 | args.model_dir, 335 | time.ctime(), 336 | iterations, 337 | args.num_iterations, 338 | total_loss.item(), 339 | rmean_loss, 340 | losses, 341 | get_lr(optimizer), 342 | len(feats)) 343 | print(msg) 344 | print(msg, file=open(args.log_file, "a")) 345 | 346 | writer.add_scalar('combined loss', total_loss.item(), iterations) 347 | writer.add_scalar('Avg loss', rmean_loss, iterations) 348 | 349 | if iterations % args.checkpoint_interval == 0: 350 | for m in model_dict: 351 | model_dict[m]['model'].eval().cpu() 352 | cp_filename = "{}_{}.pt".format(m, iterations) 353 | cp_model_path = os.path.join(args.model_dir, cp_filename) 354 | torch.save(model_dict[m]['model'].state_dict(), cp_model_path) 355 | model_dict[m]['model'].to(device).train() 356 | 357 | if args.test_data: 358 | rpkl, best_test_eer, best_test_dcf = eval_step(model_dict, device, ds_test, iterations, rpkl, writer, 359 | best_test_eer, best_test_dcf, best_acc) 360 | 361 | # ---- Final model saving ----- 362 | for m in model_dict: 363 | model_dict[m]['model'].eval().cpu() 364 | cp_filename = "final_{}_{}.pt".format(m, iterations) 365 | cp_model_path = os.path.join(args.model_dir, cp_filename) 366 | torch.save(model_dict[m]['model'].state_dict(), cp_model_path) 367 | 368 | 369 | def eval_step(model_dict, device, ds_test, iterations, rpkl, writer, best_test_eer, best_test_dcf, best_acc): 370 | rpkl[iterations] = {} 371 | print('Evaluating on test/validation for {} iterations'.format(iterations)) 372 | if args.test_data: 373 | test_eer, test_dcf, acc_dict = test_all_factors(model_dict, ds_test, device) 374 | print(acc_dict) 375 | for att in acc_dict: 376 | writer.add_scalar(att, acc_dict[att], iterations) 377 | if acc_dict[att] > best_acc[att][1]: 378 | best_acc[att] = (iterations, acc_dict[att]) 379 | print('{} accuracy on Test Set: {}'.format(att, acc_dict[att])) 380 | print('{} accuracy on Test Set: {}'.format(att, acc_dict[att]), file=open(args.log_file, "a")) 381 | print('Best test {} acc: {}'.format(att, best_acc[att])) 382 | print('Best test {} acc: {}'.format(att, best_acc[att]), file=open(args.log_file, "a")) 383 | rpkl[iterations][att] = acc_dict[att] 384 | 385 | print('EER on Test Set: {} ({})'.format(test_eer, args.test_data)) 386 | print('EER on Test Set: {} ({})'.format(test_eer, args.test_data), file=open(args.log_file, "a")) 387 | writer.add_scalar('test_eer', test_eer, iterations) 388 | if test_eer < best_test_eer[1]: 389 | best_test_eer = (iterations, test_eer) 390 | print('Best test EER: {}'.format(best_test_eer)) 391 | print('Best test EER: {}'.format(best_test_eer), file=open(args.log_file, "a")) 392 | rpkl[iterations]['test_eer'] = test_eer 393 | 394 | print('minDCF on Test Set: {} ({})'.format(test_dcf, args.test_data)) 395 | print('minDCF on Test Set: {} ({})'.format(test_dcf, args.test_data), file=open(args.log_file, "a")) 396 | writer.add_scalar('test_dcf', test_dcf, iterations) 397 | if test_dcf < best_test_dcf[1]: 398 | best_test_dcf = (iterations, test_dcf) 399 | print('Best test minDCF: {}'.format(best_test_dcf)) 400 | print('Best test minDCF: {}'.format(best_test_dcf), file=open(args.log_file, "a")) 401 | rpkl[iterations]['test_dcf'] = test_dcf 402 | 403 | pickle.dump(rpkl, open(args.results_pkl, "wb")) 404 | return rpkl, best_test_eer, best_test_dcf 405 | 406 | 407 | if __name__ == "__main__": 408 | args = parse_args() 409 | args = parse_config(args) 410 | os.makedirs(args.model_dir, exist_ok=True) 411 | if args.resume_checkpoint == 0: 412 | shutil.copy(args.cfg, os.path.join(args.model_dir, 'experiment_settings.cfg')) 413 | else: 414 | shutil.copy(args.cfg, os.path.join(args.model_dir, 'experiment_settings_resume.cfg')) 415 | if os.path.exists(args.log_file): 416 | os.remove(args.log_file) 417 | pprint(vars(args)) 418 | torch.manual_seed(args.seed) 419 | np.random.seed(seed=args.seed) 420 | random.seed(args.seed) 421 | uvloop.install() 422 | ds_train = SpeakerDataset(args.train_data, num_age_bins=args.num_age_bins) 423 | class_enc_dict = ds_train.get_class_encs() 424 | if args.test_data: 425 | ds_test = SpeakerDataset(args.test_data, test_mode=True, 426 | class_enc_dict=class_enc_dict, num_age_bins=args.num_age_bins) 427 | train(ds_train) 428 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from operator import itemgetter 2 | from scipy.interpolate import interp1d 3 | from scipy.optimize import brentq 4 | from sklearn.metrics import pairwise_distances, roc_curve, accuracy_score 5 | from sklearn.metrics.pairwise import paired_distances 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def mtd(stuff, device): 13 | if isinstance(stuff, torch.Tensor): 14 | return stuff.to(device) 15 | else: 16 | return [mtd(s, device) for s in stuff] 17 | 18 | class SpeakerRecognitionMetrics: 19 | ''' 20 | This doesn't need to be a class [remnant of old structuring]. 21 | To be reworked 22 | ''' 23 | 24 | def __init__(self, distance_measure=None): 25 | if not distance_measure: 26 | distance_measure = 'cosine' 27 | self.distance_measure = distance_measure 28 | 29 | def get_labels_scores(self, vectors, labels): 30 | labels = labels[:, np.newaxis] 31 | pair_labels = pairwise_distances(labels, metric='hamming').astype(int).flatten() 32 | pair_scores = pairwise_distances(vectors, metric=self.distance_measure).flatten() 33 | return pair_labels, pair_scores 34 | 35 | def get_roc(self, vectors, labels): 36 | pair_labels, pair_scores = self.get_labels_scores(vectors, labels) 37 | fpr, tpr, threshold = roc_curve(pair_labels, pair_scores, pos_label=1, drop_intermediate=False) 38 | # fnr = 1. - tpr 39 | return fpr, tpr, threshold 40 | 41 | def get_eer(self, vectors, labels): 42 | fpr, tpr, _ = self.get_roc(vectors, labels) 43 | # fnr = 1 - self.tpr 44 | # eer = self.fpr[np.nanargmin(np.absolute((fnr - self.fpr)))] 45 | eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 46 | return eer 47 | 48 | def eer_from_pairs(self, pair_labels, pair_scores): 49 | self.fpr, self.tpr, self.thresholds = roc_curve(pair_labels, pair_scores, pos_label=1, drop_intermediate=False) 50 | fnr = 1 - self.tpr 51 | eer = self.fpr[np.nanargmin(np.absolute((fnr - self.fpr)))] 52 | return eer 53 | 54 | def eer_from_ers(self, fpr, tpr): 55 | fnr = 1 - tpr 56 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 57 | return eer 58 | 59 | def scores_from_pairs(self, vecs0, vecs1): 60 | return paired_distances(vecs0, vecs1, metric=self.distance_measure) 61 | 62 | def compute_min_dcf(self, fpr, tpr, thresholds, p_target=0.01, c_miss=10, c_fa=1): 63 | #adapted from compute_min_dcf.py in kaldi sid 64 | # thresholds, fpr, tpr = list(zip(*sorted(zip(thresholds, fpr, tpr)))) 65 | incr_score_indices = np.argsort(thresholds, kind="mergesort") 66 | thresholds = thresholds[incr_score_indices] 67 | fpr = fpr[incr_score_indices] 68 | tpr = tpr[incr_score_indices] 69 | 70 | fnr = 1. - tpr 71 | min_c_det = float("inf") 72 | for i in range(0, len(fnr)): 73 | c_det = c_miss * fnr[i] * p_target + c_fa * fpr[i] * (1 - p_target) 74 | if c_det < min_c_det: 75 | min_c_det = c_det 76 | 77 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 78 | min_dcf = min_c_det / c_def 79 | return min_dcf 80 | 81 | def compute_eer(self, fnr, fpr): 82 | """ computes the equal error rate (EER) given FNR and FPR values calculated 83 | for a range of operating points on the DET curve 84 | """ 85 | 86 | diff_pm_fa = fnr - fpr 87 | x1 = np.flatnonzero(diff_pm_fa >= 0)[0] 88 | x2 = np.flatnonzero(diff_pm_fa < 0)[-1] 89 | a = (fnr[x1] - fpr[x1]) / (fpr[x2] - fpr[x1] - (fnr[x2] - fnr[x1])) 90 | return fnr[x1] + a * (fnr[x2] - fnr[x1]) 91 | 92 | def compute_pmiss_pfa(self, scores, labels): 93 | """ computes false positive rate (FPR) and false negative rate (FNR) 94 | given trial scores and their labels. A weights option is also provided 95 | to equalize the counts over score partitions (if there is such 96 | partitioning). 97 | """ 98 | 99 | sorted_ndx = np.argsort(scores) 100 | labels = labels[sorted_ndx] 101 | 102 | tgt = (labels == 1).astype('f8') 103 | imp = (labels == 0).astype('f8') 104 | 105 | fnr = np.cumsum(tgt) / np.sum(tgt) 106 | fpr = 1 - np.cumsum(imp) / np.sum(imp) 107 | return fnr, fpr 108 | 109 | 110 | def compute_min_cost(self, scores, labels, p_target=0.01): 111 | fnr, fpr = self.compute_pmiss_pfa(scores, labels) 112 | eer = self.compute_eer(fnr, fpr) 113 | min_c = self.compute_c_norm(fnr, fpr, p_target) 114 | return eer, min_c 115 | 116 | def compute_c_norm(self, fnr, fpr, p_target, c_miss=10, c_fa=1): 117 | """ computes normalized minimum detection cost function (DCF) given 118 | the costs for false accepts and false rejects as well as a priori 119 | probability for target speakers 120 | """ 121 | 122 | dcf = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target) 123 | c_det = np.min(dcf) 124 | c_def = min(c_miss * p_target, c_fa * (1 - p_target)) 125 | 126 | return c_det/c_def 127 | 128 | 129 | 130 | def warm_up_lr(batch, num_batch_warm_up, init_lr, optimizer): 131 | for params in optimizer.param_groups: 132 | params['lr'] = batch * init_lr / num_batch_warm_up 133 | 134 | def schedule_lr(optimizer, factor=0.1): 135 | for params in optimizer.param_groups: 136 | params['lr'] *= factor 137 | print(optimizer) 138 | 139 | def set_lr(optimizer, lr): 140 | for params in optimizer.param_groups: 141 | params['lr'] = lr 142 | print(optimizer) 143 | 144 | 145 | -------------------------------------------------------------------------------- /voxceleb_data_prep/data/nationality_to_country.tsv: -------------------------------------------------------------------------------- 1 | Nationality Country 2 | Abkhazia Abkhazia 3 | Abkhaz Abkhazia 4 | Abkhazian Abkhazia 5 | Afghanistan Afghanistan 6 | Afghan Afghanistan 7 | Åland Islands Finland 8 | Åland Island Finland 9 | Albania Albania 10 | Albanian Albania 11 | Algeria Algeria 12 | Algerian Algeria 13 | American Samoa American Samoa 14 | American Samoan American Samoa 15 | Andorra Andorra 16 | Andorran Andorra 17 | Angola Angola 18 | Angolan Angola 19 | Anguilla Anguilla 20 | Anguillan Anguilla 21 | Antigua and Barbuda Antigua and Barbuda 22 | Antiguan Antigua and Barbuda 23 | Barbudan Antigua and Barbuda 24 | Argentina Argentina 25 | Argentine Argentina 26 | Argentinian Argentina 27 | Armenia Armenia 28 | Armenian Armenia 29 | Aruba Aruba 30 | Aruban Aruba 31 | Australia Australia 32 | Australian Australia 33 | Austria Austria 34 | Austrian Austria 35 | Azerbaijan Azerbaijan 36 | Azerbaijani Azerbaijan 37 | Azeri Azerbaijan 38 | Bahamas Bahamas 39 | Bahamian Bahamas 40 | Bahrain Bahrain 41 | Bahraini Bahrain 42 | Bangladesh Bangladesh 43 | Bangladeshi Bangladesh 44 | Barbados Barbados 45 | Barbadian Barbados 46 | Belarus Belarus 47 | Belarusian Belarus 48 | Belgium Belgium 49 | Belgian Belgium 50 | Belize Belize 51 | Belizean Belize 52 | Benin Benin 53 | Beninese Benin 54 | Beninois Benin 55 | Bermuda Bermuda 56 | Bermudian Bermuda 57 | Bermudan Bermuda 58 | Bhutan Bhutan 59 | Bhutanese Bhutan 60 | Bolivia Bolivia 61 | Bolivian Bolivia 62 | Bonaire Curaçao 63 | Bosnia and Herzegovina Bosnia and Herzegovina 64 | Bosnian Bosnia and Herzegovina 65 | Herzegovinian Bosnia and Herzegovina 66 | Botswana Botswana 67 | Motswana Botswana 68 | Botswanan Botswana 69 | Batswana Botswana 70 | Bouvet Island Norway 71 | Bouvet Islander Norway 72 | Brazil Brazil 73 | Brazilian Brazil 74 | British Indian Ocean Territory United Kingdom 75 | BIOT United Kingdom 76 | Brunei Brunei 77 | Bruneian Brunei 78 | Bulgaria Bulgaria 79 | Bulgarian Bulgaria 80 | Burkina Faso Burkina Faso 81 | Burkinabé Burkina Faso 82 | Burkinabè/Burkinabé Burkina Faso 83 | Burma Myanmar 84 | Burmese Myanmar 85 | Bamar Myanmar 86 | Burundi Burundi 87 | Burundian Burundi 88 | Barundi Burundi 89 | Cabo Verde Cape Verde 90 | Cabo Verdean Cape Verde 91 | Cambodia Cambodia 92 | Cambodian Cambodia 93 | Cameroon Cameroon 94 | Cameroonian Cameroon 95 | Canada Canada 96 | Canadian Canada 97 | Cayman Islands Cayman Islands 98 | Caymanian Cayman Islands 99 | Central African Republic Central African Republic 100 | Central African Central African Republic 101 | Chad Chad 102 | Chadian Chad 103 | Chile Chile 104 | Chilean Chile 105 | China China 106 | Chinese China 107 | Taiwanese Taiwan 108 | Formosan Taiwan 109 | Christmas Island Christmas Island 110 | Christmas Islanders Christmas Island 111 | Cocos (Keeling) Islands Cocos Islands 112 | Cocos Island Cocos Islands 113 | Colombia Colombia 114 | Colombian Colombia 115 | Comoros Comoros 116 | Comoran Comoros 117 | Comorian Comoros 118 | Democratic Republic of the Congo Democratic Republic of the Congo 119 | Congolese Congo 120 | Congo Congo 121 | Republic of the Congo Congo 122 | Cook Islands Cook Islands 123 | Cook Island Cook Islands 124 | Costa Rica Costa Rica 125 | Costa Rican Costa Rica 126 | Croatia Croatia 127 | Croatian Croatia 128 | Croats Croatia 129 | Cuba Cuba 130 | Cuban Cuba 131 | Curaçao Curaçao 132 | Curaçaoan Curaçao 133 | Cyprus Cyprus 134 | Cypriot Cyprus 135 | Czech Republic Czech Republic 136 | Czech Czech Republic 137 | Denmark Denmark 138 | Danish Denmark 139 | Djibouti Djibouti 140 | Djiboutian Djibouti 141 | Dominica Dominica 142 | Dominican Dominican Republic 143 | Dominicans Dominican Republic 144 | Dominican Republic Dominican Republic 145 | East Timor East Timor 146 | Timorese East Timor 147 | Ecuador Ecuador 148 | Ecuadorian Ecuador 149 | Egypt Egypt 150 | Egyptian Egypt 151 | El Salvador El Salvador 152 | Salvadoran El Salvador 153 | Salvadorean El Salvador 154 | England United Kingdom 155 | English United Kingdom 156 | Equatorial Guinea Equatorial Guinea 157 | Equatorial Guinean Equatorial Guinea 158 | Equatoguinean Equatorial Guinea 159 | Eritrea Eritrea 160 | Eritrean Eritrea 161 | Estonia Estonia 162 | Estonian Estonia 163 | Eswatini (Swaziland) Swaziland 164 | Swazi Swaziland 165 | Swati Swaziland 166 | Ethiopia Ethiopia 167 | Ethiopian Ethiopia 168 | Habesha Ethiopia 169 | Falkland Islands Falkland Islands 170 | Falkland Island Falkland Islands 171 | Faroe Islands Faroe Islands 172 | Faroese Faroe Islands 173 | Fiji Fiji 174 | Fijian Fiji 175 | Finland Finland 176 | Finnish Finland 177 | France France 178 | French Guiana French Guiana 179 | French Guianese French Guiana 180 | French Polynesia French Polynesia 181 | French Polynesian French Polynesia 182 | French Southern Territories France 183 | French France 184 | Gabon Gabon 185 | Gabonese Gabon 186 | Gabonaise Gabon 187 | Gambia Gambia 188 | Gambian Gambia 189 | Gambians Gambia 190 | Georgia Georgia 191 | Georgian Georgia 192 | Germany Germany 193 | German Germany 194 | Ghana Ghana 195 | Ghanaian Ghana 196 | Gibraltar Gibraltar 197 | Gibraltarian Gibraltar 198 | Greece Greece 199 | Greek Greece 200 | Greenland Greenland 201 | Greenlandic Greenland 202 | Grenada Grenada 203 | Grenadian Grenada 204 | Guadeloupe Dominica 205 | Guadeloupians Dominica 206 | Guadeloupeans Dominica 207 | Guam Guam 208 | Guamanian Guam 209 | Guatemala Guatemala 210 | Guatemalan Guatemala 211 | Guernsey Guernsey 212 | Channel Island Guernsey 213 | Channel Islanders Guernsey 214 | Guinea Guinea 215 | Guinean Guinea 216 | Guinea-Bissau Guinea-Bissau 217 | Bissau-Guinean Guinea-Bissau 218 | Guyana Guyana 219 | Guyanese Guyana 220 | Haiti Haiti 221 | Haitian Haiti 222 | Honduras Honduras 223 | Honduran Honduras 224 | Hondurans Honduras 225 | Hong Kong Hong Kong 226 | Cantonese China 227 | Hong Konger Hong Kong 228 | Hongkongers Hong Kong 229 | Hong Kongese Hong Kong 230 | Hungary Hungary 231 | Hungarian Hungary 232 | Magyar Hungary 233 | Iceland Iceland 234 | Icelandic Iceland 235 | Icelander Iceland 236 | India India 237 | Indian India 238 | Indonesia Indonesia 239 | Indonesian Indonesia 240 | Iran Iran 241 | Iranian Iran 242 | Persian Iran 243 | Iraq Iraq 244 | Iraqi Iraq 245 | Ireland Ireland 246 | Irish Ireland 247 | Isle of Man Isle of Man 248 | Manx Isle of Man 249 | Israel Israel 250 | Israeli Israel 251 | Italy Italy 252 | Italian Italy 253 | Ivory Coast Ivory Coast 254 | Ivorian Ivory Coast 255 | Jamaica Jamaica 256 | Jamaican Jamaica 257 | Jan Mayen Norway 258 | Jan Mayen Norway 259 | Jan Mayen residents Norway 260 | Japan Japan 261 | Japanese Japan 262 | Jersey Jersey 263 | Channel Island Jersey 264 | Channel Islanders Jersey 265 | Jordan Jordan 266 | Jordanian Jordan 267 | Kazakhstan Kazakhstan 268 | Kazakhstani Kazakhstan 269 | Kazakh Kazakhstan 270 | Kenya Kenya 271 | Kenyan Kenya 272 | Kiribati Kiribati 273 | I-Kiribati Kiribati 274 | Korean South Korea 275 | North Korea North Korea 276 | South Korea South Korea 277 | North Korean North Korea 278 | South Korean South Korea 279 | Kosovo Kosovo 280 | Kosovar Kosovo 281 | Kosovan Kosovo 282 | Kuwait Kuwait 283 | Kuwaiti Kuwait 284 | Kyrgyzstan Kyrgyzstan 285 | Kyrgyzstani Kyrgyzstan 286 | Kyrgyz Kyrgyzstan 287 | Kirgiz Kyrgyzstan 288 | Kirghiz Kyrgyzstan 289 | Lao People's Democratic Republic Laos 290 | Lao Laos 291 | Laotian Laos 292 | Laos Laos 293 | Latvia Latvia 294 | Latvian Latvia 295 | Lettish Latvia 296 | Lebanon Lebanon 297 | Lebanese Lebanon 298 | Lesotho Lesotho 299 | Basotho Lesotho 300 | Liberia Liberia 301 | Liberian Liberia 302 | Libya Libya 303 | Libyan Libya 304 | Liechtenstein Liechtenstein 305 | Liechtensteiner Liechtenstein 306 | Lithuania Lithuania 307 | Lithuanian Lithuania 308 | Luxembourg Luxembourg 309 | Luxembourgish Luxembourg 310 | Luxembourger Luxembourg 311 | Macau Macau 312 | Macanese Macau 313 | Macedonia North Macedonia 314 | Macedonian North Macedonia 315 | Madagascar Madagascar 316 | Malagasy Madagascar 317 | Malawi Malawi 318 | Malawian Malawi 319 | Malaysia Malaysia 320 | Malaysian Malaysia 321 | Maldives Maldives 322 | Maldivian Maldives 323 | Mali Mali 324 | Malian Mali 325 | Malinese Mali 326 | Malta Malta 327 | Maltese Malta 328 | Marshall Islands Marshall Islands 329 | Marshallese Marshall Islands 330 | Martinique Saint Lucia 331 | Martiniquais Saint Lucia 332 | Martinican Saint Lucia 333 | Martiniquaises Saint Lucia 334 | Mauritania Mauritania 335 | Mauritanian Mauritania 336 | Mauritius Mauritius 337 | Mauritian Mauritius 338 | Mayotte Comoros 339 | Mahoran Comoros 340 | Mexico Mexico 341 | Mexican Mexico 342 | Micronesia Federated States of Micronesia 343 | Micronesian Federated States of Micronesia 344 | Moldova Moldova 345 | Moldovan Moldova 346 | Monaco Monaco 347 | Monégasque Monaco 348 | Monacan Monaco 349 | Mongolia Mongolia 350 | Mongolian Mongolia 351 | Mongol Mongolia 352 | Montenegro Montenegro 353 | Montenegrin Montenegro 354 | Montserrat Montserrat 355 | Montserratian Montserrat 356 | Morocco Morocco 357 | Moroccan Morocco 358 | Mozambique Mozambique 359 | Mozambican Mozambique 360 | Myanmar Myanmar 361 | Burmese Myanmar 362 | Bamar Myanmar 363 | Namibia Namibia 364 | Namibian Namibia 365 | Nauru Nauru 366 | Nauruan Nauru 367 | Nepal Nepal 368 | Nepali Nepal 369 | Nepalese Nepal 370 | Netherlands Netherlands 371 | Dutch Netherlands 372 | Netherlandic Netherlands 373 | New Caledonia New Caledonia 374 | New Caledonian New Caledonia 375 | New Zealand New Zealand 376 | New Zealandic New Zealand 377 | New Zealander New Zealand 378 | Nicaragua Nicaragua 379 | Nicaraguan Nicaragua 380 | Niger Niger 381 | Nigerien Niger 382 | Nigeria Nigeria 383 | Nigerian Nigeria 384 | Niue Niue 385 | Niuean Niue 386 | Norfolk Island Norfolk Island 387 | Norfolk Islanders Norfolk Island 388 | Northern Ireland United Kingdom 389 | Northern Irish United Kingdom 390 | Northern Mariana Islands Northern Mariana Islands 391 | Northern Marianan Northern Mariana Islands 392 | Norway Norway 393 | Norwegian Norway 394 | Oman Oman 395 | Omani Oman 396 | Pakistan Pakistan 397 | Pakistani Pakistan 398 | Palau Palau 399 | Palauan Palau 400 | Palestine Palestine 401 | Palestinian Palestine 402 | Panama Panama 403 | Panamanian Panama 404 | Papua New Guinea Papua New Guinea 405 | Papua New Guinean Papua New Guinea 406 | Paraguay Paraguay 407 | Paraguayan Paraguay 408 | Peru Peru 409 | Peruvian Peru 410 | Philippines Philippines 411 | Philippine Philippines 412 | Filipino Philippines 413 | Filipina Philippines 414 | Pitcairn Islands Pitcairn Islands 415 | Pitcairn Island Pitcairn Islands 416 | Pitcairn Islander Pitcairn Islands 417 | Poland Poland 418 | Polish Poland 419 | Portugal Portugal 420 | Portuguese Portugal 421 | Puerto Rico Puerto Rico 422 | Puerto Rican Puerto Rico 423 | Qatar Qatar 424 | Qatari Qatar 425 | Réunion Mauritius 426 | Réunionese Mauritius 427 | Réunionnais Mauritius 428 | Romania Romania 429 | Romanian Romania 430 | Russia Russia 431 | Russian Russia 432 | Rwanda Rwanda 433 | Rwandan Rwanda 434 | Banyarwanda Rwanda 435 | Saba Saint Kitts and Nevis 436 | Saba Dutch Saint Kitts and Nevis 437 | Saint Barthélemy Saint Barthélemy 438 | Barthélemois Saint Barthélemy 439 | Barthélemois/Barthélemoises Saint Barthélemy 440 | Saint Kitts and Nevis Saint Kitts and Nevis 441 | Kittitian Saint Kitts and Nevis 442 | Nevisian Saint Kitts and Nevis 443 | Saint Lucia Saint Lucia 444 | Saint Lucian Saint Lucia 445 | Saint Martin Saint Martin 446 | Saint-Martinoise Saint Martin 447 | Saint-Martinois/Saint-Martinoises Saint Martin 448 | Saint Pierre and Miquelon Saint Pierre and Miquelon 449 | Saint-Pierrais Saint Pierre and Miquelon 450 | Miquelonnais Saint Pierre and Miquelon 451 | Saint-Pierraises Saint Pierre and Miquelon 452 | Miquelonnaises Saint Pierre and Miquelon 453 | Saint Vincent and the Grenadines Saint Vincent and the Grenadines 454 | Saint Vincentian Saint Vincent and the Grenadines 455 | Vincentian Saint Vincent and the Grenadines 456 | Samoa Samoa 457 | Samoan Samoa 458 | San Marino San Marino 459 | Sammarinese San Marino 460 | São Tomé and Príncipe São Tomé and Príncipe 461 | São Toméan São Tomé and Príncipe 462 | Saudi Arabia Saudi Arabia 463 | Saudi Saudi Arabia 464 | Saudi Arabian Saudi Arabia 465 | Scotland United Kingdom 466 | Scottish United Kingdom 467 | Senegal Senegal 468 | Senegalese Senegal 469 | Serbia Serbia 470 | Serbian Serbia 471 | Seychelles Seychelles 472 | Seychellois Seychelles 473 | Seychellois/Seychelloises Seychelles 474 | Sierra Leone Sierra Leone 475 | Sierra Leonean Sierra Leone 476 | Singapore Singapore 477 | Singaporean Singapore 478 | Sint Eustatius Saint Kitts and Nevis 479 | Statian Saint Kitts and Nevis 480 | Sint Maarten Sint Maarten 481 | Sint Maartener Sint Maarten 482 | Slovakia Slovakia 483 | Slovak Slovakia 484 | Slovakian Slovakia 485 | Slovenia Slovenia 486 | Slovenian Slovenia 487 | Slovene Slovenia 488 | Solomon Islands Solomon Islands 489 | Solomon Island Solomon Islands 490 | Solomon Islander Solomon Islands 491 | Somalia Somalia 492 | Somali Somalia 493 | Somaliland Somalia 494 | Somalilander Somalia 495 | South Africa South Africa 496 | South Africa South Africa 497 | South Georgia and the South Sandwich Islands United Kingdom 498 | South Georgia Island United Kingdom 499 | South Sandwich Island United Kingdom 500 | South Ossetia South Ossetia 501 | South Ossetian South Ossetia 502 | South Sudan South Sudan 503 | South Sudanese South Sudan 504 | Spain Spain 505 | Spanish Spain 506 | Spaniard Spain 507 | Sri Lanka Sri Lanka 508 | Sri Lankan Sri Lanka 509 | Sudan Sudan 510 | Sudanese Sudan 511 | Suriname Suriname 512 | Surinamese Suriname 513 | Surinamer Suriname 514 | Svalbard Norway 515 | Svalbard resident Norway 516 | Swaziland Eswatini 517 | Swazi Eswatini 518 | Swati Eswatini 519 | Swazis Eswatini 520 | Sweden Sweden 521 | Swedish Sweden 522 | Switzerland Switzerland 523 | Swiss Switzerland 524 | Syria Syria 525 | Syrian Syria 526 | Tajikistan Tajikistan 527 | Tajikistani Tajikistan 528 | Tajiks Tajikistan 529 | Tanzania Tanzania 530 | Tanzanian Tanzania 531 | Thailand Thailand 532 | Thai Thailand 533 | Timor-Leste East Timor 534 | Timorese East Timor 535 | Togo Togo 536 | Togolese Togo 537 | Tokelau Tokelau 538 | Tokelauan Tokelau 539 | Tokelauans Tokelau 540 | Tonga Tonga 541 | Tongan Tonga 542 | Trinidad and Tobago Trinidad and Tobago 543 | Trinidadian Trinidad and Tobago 544 | Tobagonian Trinidad and Tobago 545 | Tunisia Tunisia 546 | Tunisian Tunisia 547 | Turkey Turkey 548 | Turkish Turkey 549 | Turkmenistan Turkmenistan 550 | Turkmen Turkmenistan 551 | Turks and Caicos Islands Turks and Caicos Islands 552 | Turks and Caicos Island Turks and Caicos Islands 553 | Turks and Caicos Islander Turks and Caicos Islands 554 | Tuvalu Tuvalu 555 | Tuvaluan Tuvalu 556 | Uganda Uganda 557 | Ugandan Uganda 558 | Ukraine Ukraine 559 | Ukrainian Ukraine 560 | United Arab Emirates United Arab Emirates 561 | Emirati United Arab Emirates 562 | Emirian United Arab Emirates 563 | Emiri United Arab Emirates 564 | Great Britain United Kingdom 565 | United Kingdom United Kingdom 566 | British United Kingdom 567 | UK United Kingdom 568 | U.K. United Kingdom 569 | U.K United Kingdom 570 | United States of America United States 571 | United States United States 572 | U.S. United States 573 | American United States 574 | U.S.A United States 575 | U.S.A. United States 576 | USA United States 577 | Uruguay Uruguay 578 | Uruguayan Uruguay 579 | Uzbekistan Uzbekistan 580 | Uzbekistani Uzbekistan 581 | Uzbek Uzbekistan 582 | Vanuatu Vanuatu 583 | Vanuatuan Vanuatu 584 | Ni-Vanuatu Vanuatu 585 | Vatican City Vatican City 586 | Vatican Vatican City 587 | Vatican citizen Vatican City 588 | Vatican City State Vatican City 589 | Venezuelan Venezuela 590 | Venezuela Venezuela 591 | Vietnam Vietnam 592 | Vietnamese Vietnam 593 | British Virgin Islands British Virgin Islands 594 | British Virgin Island British Virgin Islands 595 | British Virgin Islanders British Virgin Islands 596 | Virgin Islands U.S. Virgin Islands 597 | U.S. Virgin Island U.S. Virgin Islands 598 | U.S. Virgin Islanders U.S. Virgin Islands 599 | Wales United Kingdom 600 | Welsh United Kingdom 601 | Wallis and Futuna Wallis and Futuna 602 | Wallisian Wallis and Futuna 603 | Futunan Wallis and Futuna 604 | Western Sahara Western Sahara 605 | Sahrawi Western Sahara 606 | Sahrawian Western Sahara 607 | Sahraouian Western Sahara 608 | Yemen Yemen 609 | Yemeni Yemen 610 | Zambia Zambia 611 | Zambian Zambia 612 | Zimbabwe Zimbabwe 613 | Zimbabwean Zimbabwe 614 | -------------------------------------------------------------------------------- /voxceleb_data_prep/data/us_states.csv: -------------------------------------------------------------------------------- 1 | States 2 | Alabama 3 | Alaska 4 | Arizona 5 | Arkansas 6 | California 7 | Colorado 8 | Connecticut 9 | Delaware 10 | Florida 11 | Georgia 12 | Hawaii 13 | Idaho 14 | Illinois 15 | Indiana 16 | Iowa 17 | Kansas 18 | Kentucky 19 | Louisiana 20 | Maine 21 | Maryland 22 | Massachusetts 23 | Michigan 24 | Minnesota 25 | Mississippi 26 | Missouri 27 | Montana 28 | Nebraska 29 | Nevada 30 | New Hampshire 31 | New Jersey 32 | New Mexico 33 | New York 34 | North Carolina 35 | North Dakota 36 | Ohio 37 | Oklahoma 38 | Oregon 39 | Pennsylvania 40 | Rhode Island 41 | South Carolina 42 | South Dakota 43 | Tennessee 44 | Texas 45 | Utah 46 | Vermont 47 | Virginia 48 | Washington 49 | West Virginia 50 | Wisconsin 51 | Wyoming 52 | -------------------------------------------------------------------------------- /voxceleb_data_prep/scrape_nationalities.py: -------------------------------------------------------------------------------- 1 | from collections import Counter, OrderedDict 2 | from pprint import pprint 3 | 4 | import pandas as pd 5 | import spacy 6 | import wikipedia 7 | import wptools 8 | from tqdm import tqdm 9 | from wikipedia import DisambiguationError 10 | 11 | 12 | def wiki_infobox(text): 13 | try: 14 | page = wptools.page(text, silent=True).get_parse() 15 | infobox = page.data['infobox'] 16 | except: 17 | infobox = {} 18 | return infobox 19 | 20 | 21 | if __name__ == "__main__": 22 | # Meta information of vox1, same as original filename 23 | vox1 = pd.read_csv('./data/vox1_meta.csv', delimiter='\t') 24 | 25 | # Meta information of vox2, same as original filename 26 | vox2 = pd.read_csv('./data/vox2_meta.csv') 27 | 28 | # Meta information of vggface2, original filename meta/identity_meta.csv 29 | vgg2 = pd.read_csv('./data/vggface2_meta.csv', quotechar='"', skipinitialspace=True) 30 | 31 | us_states = set(pd.read_csv('./data/us_states.csv')['States'].str.lower().values) 32 | 33 | vgg_id_to_name = {k:v.strip() for k,v in zip(vgg2['Class_ID'].values, vgg2['Name'].values)} 34 | vox2_ids_dict = {k:v.strip() for k,v in zip(vox2['VoxCeleb2 ID '].values, vox2['VGGFace2 ID '])} 35 | vox2_id_to_name = {k:vgg_id_to_name[vox2_ids_dict[k]] for k in vox2_ids_dict} 36 | vox2_name_to_id = {k:v for v, k in vox2_id_to_name.items()} 37 | 38 | natcountry = pd.read_csv('./data/nationality_to_country.tsv', delimiter='\t') 39 | country_set = set(natcountry.Country.values) 40 | country_nat_dict = {k.lower():[] for k in country_set} 41 | 42 | for c, n in zip(natcountry.Country.values, natcountry.Nationality.values): 43 | country_nat_dict[c.lower()].append(n.lower()) 44 | 45 | nat_country_dict = {n.lower():c.lower() for c, n in zip(natcountry.Country.values, natcountry.Nationality.values)} 46 | 47 | # Reorder country nat_dict based on demographics of vox1 48 | # Most populous first 49 | # Then by length of country name 50 | # This is to make sure the country names which are substrings are not checked first 51 | common_nats = vox1.Nationality.str.lower().value_counts().keys() 52 | 53 | common_nats_keys = [] 54 | for c in common_nats: 55 | common_nats_keys.append(nat_country_dict[c]) 56 | 57 | ordered_country_nat_dict = OrderedDict({}) 58 | for c in common_nats_keys: 59 | ordered_country_nat_dict[c] = country_nat_dict[c] 60 | 61 | countries_by_len = sorted(country_nat_dict.keys(), key=len, reverse=True) 62 | 63 | for c in countries_by_len: 64 | if c not in ordered_country_nat_dict: 65 | ordered_country_nat_dict[c] = country_nat_dict[c] 66 | 67 | vox2_names = list(vox2_id_to_name.values()) 68 | vox2_nationalities = OrderedDict({k:[] for k in vox2_names}) 69 | 70 | nlp = spacy.load("en_core_web_sm") 71 | 72 | for i, name in enumerate(tqdm(vox2_nationalities)): 73 | if vox2_nationalities[name]: 74 | continue 75 | 76 | # Get the wikipedia page and summary text 77 | qname = ' '.join(name.split('_')) 78 | try: 79 | text = wikipedia.summary(qname) 80 | except: 81 | search = wikipedia.search(qname, results=3) 82 | if len(search) == 0: 83 | print(name) 84 | continue 85 | else: 86 | index = 0 87 | while True: 88 | if index >= len(search): 89 | qname = 'nan' 90 | text = '' 91 | break 92 | try: 93 | qname = search[index] 94 | text = wikipedia.summary(qname, auto_suggest=False) 95 | break 96 | except DisambiguationError: 97 | index += 1 98 | 99 | if qname == 'nan': 100 | #Couldn't find a good wikipedia page 101 | print(name) 102 | continue 103 | 104 | #Try the infobox first 105 | try: 106 | person_infobox = wiki_infobox(qname) 107 | if 'birth_place' in person_infobox: 108 | place = person_infobox['birth_place'].replace('[', '').replace(']', '').lower() 109 | 110 | for s in us_states: 111 | if s in place: 112 | vox2_nationalities[name] = ['united states'] 113 | break 114 | 115 | if vox2_nationalities[name]: 116 | continue 117 | 118 | for c in ordered_country_nat_dict: 119 | if c in place: 120 | vox2_nationalities[name] = [c] 121 | break 122 | 123 | if vox2_nationalities[name]: 124 | continue 125 | 126 | place_infobox = wiki_infobox(place) 127 | if 'subdivision_name' in place_infobox: 128 | subd_country = place_infobox['subdivision_name'].lower() 129 | if subd_country in country_nat_dict: 130 | vox2_nationalities[name] = [subd_country] 131 | continue 132 | except: 133 | pass 134 | 135 | 136 | #Otherwise try the summary text 137 | doc = nlp(text) 138 | all_stopwords = nlp.Defaults.stop_words 139 | 140 | nat_candidates = [] 141 | for j, tok in enumerate(doc): 142 | if tok.text.lower() in nat_country_dict: 143 | if doc[j-1].text.lower() == 'new': 144 | continue 145 | if tok.text.lower() == 'american' and doc[j+1].text.lower() == 'football': 146 | continue 147 | nat_candidates.append(nat_country_dict[tok.text.lower()]) 148 | if doc[j+1].text.lower() == '-': 149 | if doc[j+2].text.lower() == 'born': 150 | c = doc[j+3].text.lower() 151 | if c in nat_country_dict: 152 | nat_candidates = [nat_country_dict[c]] 153 | break 154 | vox2_nationalities[name] = nat_candidates 155 | 156 | vox2_nats_final = OrderedDict({}) 157 | for name in vox2_nationalities: 158 | # Take the most frequent nat in the list 159 | # In case of ties, takes the first occuring one 160 | if vox2_nationalities[name]: 161 | best_nat = Counter(vox2_nationalities[name]).most_common(1)[0][0] 162 | vox2_nats_final[name] = best_nat 163 | else: 164 | vox2_nats_final[name] = 'unk' 165 | 166 | 167 | with open('./spk2nat', 'w') as fp: 168 | for name in vox2_nats_final: 169 | id = vox2_name_to_id[name] 170 | nat = '_'.join(vox2_nats_final[name].split()) 171 | line = '{} {}\n'.format(id, nat) 172 | fp.write(line) 173 | 174 | --------------------------------------------------------------------------------