├── .gitignore
├── LICENSE
├── README.md
├── configs
└── example.cfg
├── data_io.py
├── diarize.py
├── figures
└── multitask.png
├── inference.py
├── models
├── classifiers.py
├── criteria.py
├── extractors.py
└── sim_predictors.py
├── scotus_data_prep
├── filter_scp.pl
├── local_info.py
├── step1_downloadmp3.py
├── step2_scrape_dob.py
├── step3_prepdata.py
├── step4_extract_feats.sh
└── step5_trim_split_data.py
├── train.py
├── utils.py
└── voxceleb_data_prep
├── data
├── nationality_to_country.tsv
└── us_states.csv
└── scrape_nationalities.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 |
141 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Chau
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi Task Learning Speaker Embeddings
2 | Code for the paper: ["Leveraging speaker attribute information using multi task learning for speaker verification and diarization"](https://arxiv.org/abs/2010.14269)
3 |
4 | The overall concept of this paper is that training speaker embedding extractors on auxiliary attributes (such as age or nationality) alongside speaker classification can lead to increased performance for verification and diarization. Training the embeddings in this multi-task fashion can improve the descriptiveness of the embedding space.
5 |
6 | This is implemented by having multiple task-specific "heads" acting on the embeddings, such as x-vectors. Alongside speaker classification, one might employ age classification/regression, or nationality classification. A general diagram for this can be seen below:
7 |
8 |
9 |
10 |
11 |
12 | Along with experimental code, this repository covers other data preparation tasks used in the paper, such as webscraping age information for lawyers in SCOTUS, and nationality for Wikipedia celebrities.
13 |
14 | # Contents
15 |
16 | - Requirements
17 | - SCOTUS Data Preparation
18 | - VoxCeleb Data Preparation
19 | - Experiments
20 |
21 | # Requirements
22 |
23 | This was tested on python 3.6.8
24 |
25 | - General requirements:
26 | - numpy
27 | - tqdm
28 | - sklearn
29 | - scipy
30 | - Web-scraping:
31 | - beautifulsoup4
32 | - wikipedia
33 | - wptools
34 | - Levenshtein
35 | - python-dateutil
36 | - Experiments:
37 | - Kaldi
38 | - torch
39 | - kaldi_io
40 |
41 | You will also need access to the avvo.com API, along with a custom google search API key, with instructions of how to obtain/set these up in the SCOTUS Data Preparation section.
42 |
43 | # SCOTUS Data Preparation
44 |
45 | ## Obtaining case information
46 |
47 | This work utilizes the information available via the [Oyez project](https://www.oyez.org/about-audio), which provides audio files and transcriptions for US Supreme Court oral arguments.
48 |
49 | The Oyez project has an API, an example of which can be seen here: https://api.oyez.org/cases?per_page=100
50 |
51 | Thankfully, another GitHub repo exists which has already scraped the information provided by the API for each case and auto-updates every week: https://github.com/walkerdb/supreme_court_transcripts. This ~3.5GB repo will be used to extract the cases that we want audio and transcripts for.
52 |
53 | The following command will clone this repo (specifically my fork, which doesn't update with new cases, to ensure consistency with the number of cases that I have tested my code with):
54 |
55 | ```sh
56 | git clone https://github.com/cvqluu/supreme_court_transcripts
57 | ```
58 |
59 | The location of this repository which contains the case information will be referred to `$SCOTUS_CASE_REPO` throughout the rest of the SCOTUS section. Of course, if you want to obtain the most up to date recordings, then you should clone the original repo by walkerdb.
60 |
61 | ## Downloading Audio and Filtering Cases
62 |
63 | Now that `$SCOTUS_CASE_REPO` exists and has been cloned from GitHub, we can now make us of the MTL-Speaker-Embeddings repo:
64 |
65 | ```sh
66 | git clone https://github.com/cvqluu/MTL-Speaker-Embeddings
67 | cd MTL-Speaker-Embeddings/scotus_data_prep
68 | ```
69 |
70 | which puts us inside the `scotus_data_prep` folder, which is where the data prep scripts for SCOTUS are all present.
71 |
72 | Next, we want to run the first step, which downloads mp3 files and filters the cases which we can use. This takes in as input the location of the case JSONs in `$SCOTUS_CASE_REPO` and then produces a data folder (of which the location is up to you) which we will call `$BASE_OUTFOLDER`, which will store all the SCOTUS data to be used later on.
73 |
74 | ```sh
75 | python step1_downloadmp3.py --case-folder $SCOTUS_CASE_REPO/oyez/cases --base-outfolder $BASE_OUTFOLDER
76 | ```
77 |
78 | This parses through the cases in `$SCOTUS_CASE_REPO` and eliminates cases argued before October 2005, as this is when the Supreme Court started recording digitally, instead of using a reel-to-reel taping system. The taped recordings were excluded as there were a number of problems and defects with these recordings as detailed here: https://www.oyez.org/about-audio. This script also eliminates cases where audio can't be found, and also ones where the transcription has a lot of invalid speaker turns (0 or less duration speaker turns).
79 |
80 | The outcome should be a file structure in `$BASE_OUTFOLDER` something like so:
81 |
82 | ```
83 | $BASE_OUTFOLDER
84 | ├── speaker_ids.json
85 | ├── audio
86 | | ├── 2007.07-330.mp3
87 | | └── ...
88 | └── transcripts
89 | ├── 2007.07-330.json
90 | └── ...
91 | ```
92 |
93 | Sometimes, downloading audio files may fail, but succeed on subsequent attempts. This script is safe to run multiple times, and will try and re-download files that are missing (while skipping ones that already have been processed).
94 |
95 | If you have used my fork for `$SCOTUS_CASE_REPO` then you should end up with 2035 mp3s/jsons in the audio and transcripts folders respectively, although this may not be consistent, depending on the availability of every mp3.
96 |
97 | ## Web scraping
98 |
99 | Next, we want to scrape the approximate DoB of each speaker found in the recordings downloaded, which is stored in `$BASE_OUTFOLDER/speaker_ids.json`.
100 |
101 | First of all, we need to set up and obtain a few things in order to fill out `local_info.py` inside of `scotus_data_prep`.
102 |
103 | ### Avvo API
104 |
105 | Instructions on how to obtain access to the Avvo API are here: http://avvo.github.io/api-doc/
106 |
107 | Once this is set up, you can fill out `local_info.py` with your own `AVVO_API_ACCESS_TOKEN`. (Note: `AVVO_CSE_ID` is covered in the next section)
108 |
109 | ### Custom Google Search
110 |
111 | A custom google search will be needed, which can be set up here:
112 |
113 | https://cse.google.com/cse/all
114 |
115 | You will want to click `Add` to create a new search engine, and under `Sites to search`, you will want to enter `avvo.com`, with the Language specified as English.
116 |
117 | After clicking Create, this should succesfully create this custom search. The control panel will take you to where you can find the `Search Engine ID`, which you can fill into `AVVO_CSE_ID` in `local_info.py`.
118 |
119 | You will also need a google API key, which can be set up here: https://developers.google.com/maps/documentation/javascript/get-api-key
120 |
121 | Once this is obtained, fill out `GOOGLE_API_KEY` in `local_info.py`.
122 |
123 | ### Running the webscraper
124 |
125 | ```sh
126 | python step2_scrape_dob.py --base_outfolder $BASE_OUTFOLDER
127 | ```
128 |
129 | This will go through the names in `speaker_ids.json` and try and scrape their dates of birth (DoB) from the following sources:
130 |
131 | - [Wikipedia](https://en.wikipedia.org/)
132 | - [Supreme court clerkship graduation dates](https://en.wikipedia.org/wiki/List_of_law_clerks_of_the_Supreme_Court_of_the_United_States_(Chief_Justice))
133 | - JUSTIA.com
134 | - Avvo.com
135 |
136 | For the final three sources, the year discovered for graduation/admission to practice law will be subtracted by 25 to obtain an approximate date of birth. The dates of birth, and additional information about the source of each DoB are stored in pickled dictionaries in `$BASE_OUTFOLDER`:
137 |
138 | ```
139 | $BASE_OUTFOLDER
140 | ├── speaker_ids.json
141 | ├── dobs.p
142 | ├── dobs_info.p
143 | ├── audio
144 | | └── ...
145 | └── transcripts
146 | └── ...
147 | ```
148 |
149 | Like `step1_downloadmp3.py`, this script is also safe to re-run, and will try to re-scrape names for which no DoB has been found. If you wish to skip the names which have been attemped, and are present in the `dobs.p` dictionary, the step2 script can be run with a `--skip-attempted` flag.
150 |
151 | ## Prepping data for feature extraction
152 |
153 | ```sh
154 | python step3_prepdata.py --base_outfolder $BASE_OUTFOLDER
155 | ```
156 |
157 | This prepares verification and diarization data folders that are ready for Kaldi to extract features from. As a result, the recordings are changed into the consistent naming format `YEAR-XYZ`, and this mapping from orignal names to new names is stored in a JSON file.
158 |
159 | Verification/training data is made by splitting up long utterances into non-overlapping 10s segments, with minimum length 4s. Utterances are named in a consistent fashion, and the age (in days) of each speaker at the time of each utterance is calculated, and placed into an `utt2age` file.
160 |
161 | Diarization data is made by splitting into 1.5s segments with 0.75s shift. A `ref.rttm` file is also created.
162 |
163 | This should result in the following file structure (in addition to what was previously shown):
164 |
165 | ```
166 | $BASE_OUTFOLDER
167 | ├── ...
168 | ├── orec_recid_mapping.json
169 | ├── recid_orec_mapping.json
170 | ├── veri_data
171 | | ├── utt2spk
172 | | ├── spk2utt
173 | | ├── utt2age
174 | | ├── wav.scp
175 | | ├── segments
176 | | ├── real_utt2spk
177 | | └── real_spk2utt
178 | └── diar_data
179 | ├── utt2spk
180 | ├── spk2utt
181 | ├── ref.rttm
182 | ├── wav.scp
183 | ├── segments
184 | ├── real_utt2spk
185 | └── real_spk2utt
186 |
187 | ```
188 |
189 | The `real_{utt2spk|spk2utt}` are there as Kaldi feature extraction insists on speaker ids being the prefix to an utterance name - these will be sorted out later on.
190 |
191 | ## Feature extraction
192 |
193 | We will use Kaldi to extract the features. This kaldi feature extraction script is found in `step4_extract_feats.sh`.
194 |
195 | You will need to edit at the top of this script and fill in your own $BASE_OUTFOLDER
196 |
197 | ```sh
198 | #!/bin/bash
199 |
200 | step=0
201 | nj=20
202 | base_outfolder=/PATH/TO/BASE_OUTFOLDER <--- Edit here
203 |
204 | ...
205 | ```
206 |
207 | We used the `egs/voxceleb/v2` recipe folder to carry out this script (although the only thing specific to this recipe is the `conf/mfcc.conf`).
208 |
209 | To carry this out, we will assume you have Kaldi installed at `$KALDI_ROOT`
210 |
211 | ```sh
212 | cp step4_extract_feats.sh $KALDI_ROOT/egs/voxceleb/v2/
213 | cd $KALDI_ROOT/egs/voxceleb/v2
214 | source path.sh
215 | bash step4_extract_feats.sh
216 | ```
217 |
218 |
219 | ## Making train/test splits
220 |
221 | After changing back your working directory to `scotus_data_prep`, we can run the final stage script.
222 |
223 | ```sh
224 | python step5_trim_split_data.py --base_outfolder $BASE_OUTFOLDER --train-proportion 0.8 --pos-per-spk 12
225 | ```
226 |
227 | This trims down the data folders to match what features have been extracted, and splits the recordings into train and test according to `--train-proportion`. It also makes a verification task for test set utterances, excluding speakers seen in the training set, generating trials of pairs of utterances to compare.
228 |
229 | The `--pos-per-speaker` option determines how many positive trials there are per test set speaker. If too high a value is selected, this may error out as it cannot select enough utterances, so if that occurs, try lowering this value.
230 |
231 | This should yield the following (once again in addition to what was produced before):
232 |
233 | ```
234 | $BASE_OUTFOLDER
235 | ├── ...
236 | ├── veri_data_nosil
237 | | ├── train
238 | | | ├── utt2spk
239 | | | ├── spk2utt
240 | | | ├── utt2age
241 | | | ├── feats.scp
242 | | | └── utts
243 | | ├── test
244 | | | ├── veri_pairs
245 | | | └── ...
246 | | └── ...
247 | └── diar_data_nosil
248 | ├── train
249 | | ├── ref.rttm
250 | | └── ...
251 | ├── test
252 | | ├── ref.rttm
253 | | └── ...
254 | └── ...
255 | ```
256 |
257 | # VoxCeleb Data Preparation
258 |
259 | TODO
260 |
261 |
262 | # Experiments
263 |
264 | Training an embedding extractor is done via `train.py`:
265 |
266 | ```sh
267 | python train.py --cfg configs/example.cfg
268 | ```
269 |
270 | This runs according to the config file `configs/example.cfg` which we will detail and explain below. The following config file trains an xvector architecture embedding extractor with an age classification head with 10 classes and 0.1 weighting of the age loss function.
271 |
272 |
273 | ## Config File
274 |
275 | ```ini
276 | [Datasets]
277 | # Path to the datasets
278 | train = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/train
279 | test = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/test
280 |
281 | [Model]
282 | # Allowed model_type : ['XTDNN', 'ETDNN']
283 | model_type = XTDNN
284 | # Allowed classifier_heads types:
285 | # ['speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec']
286 | classifier_heads = speaker,age
287 |
288 | [Optim]
289 | # Allowed classifier_types:
290 | # ['xvec', 'adm', 'adacos', 'l2softmax', 'xvec_regression', 'arcface', 'sphereface']
291 | classifier_types = xvec,xvec
292 | classifier_lr_mults = [1.0, 1.0]
293 | classifier_loss_weights = [1.0, 0.1]
294 | # Allowed smooth_types:
295 | # ['twoneighbour', 'uniform']
296 | classifier_smooth_types = none,none
297 |
298 | [Hyperparams]
299 | input_dim = 30
300 | lr = 0.2
301 | batch_size = 500
302 | max_seq_len = 350
303 | no_cuda = False
304 | seed = 1234
305 | num_iterations = 50000
306 | momentum = 0.5
307 | scheduler_steps = [40000]
308 | scheduler_lambda = 0.5
309 | multi_gpu = False
310 | classifier_lr_mult = 1.
311 | embedding_dim = 256
312 |
313 | [Outputs]
314 | model_dir = exp/example_exp
315 | # At each checkpoint, the model will be evaluated on the test data, and the model will be saved
316 | checkpoint_interval = 500
317 |
318 | [Misc]
319 | num_age_bins = 10
320 | ```
321 |
322 | Most of the parameters in this configuration file are fairly self explanatory.
323 |
324 | The most important setup is the `classifier_heads` field, which determines what tasks are being applied to the embedding. This also determines the number of parameters in each field in `[Optim]`, which have to match the number of classifier heads.
325 |
326 |
327 | ## Supported Auxiliary Tasks
328 |
329 | The currently supported tasks are `'speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec'`, and each has a requirement for the data that must be present in both train and test folders order for this to be evaluated:
330 |
331 | - `age` and `age_regression`
332 | - Age classification or Age regression
333 | - Requires: `utt2age` file
334 | - 2 column file of format:
335 | - ` `
336 | - `gender`
337 | - Gender classification
338 | - Requires: `spk2gender` file
339 | - 2 column file of format:
340 | - ` `
341 | - `nationality`
342 | - Nationality classification
343 | - Requires `spk2nat` file
344 | - 2 column file of format:
345 | - ` `
346 | - `rec`
347 | - Recording ID classification
348 | - This is typically accompanied by a negative loss weight. When a negative loss weight is given, this automatically applies a Gradient Reversal Layer (GRL) to that classification head.
349 | - Requires `utt2rec` file (but not in test folder):
350 | - 2 column file of format:
351 | - ` `
352 |
353 |
354 | ## Resuming training
355 |
356 | The following command will resume an experiment from the checkpoint at 25000 iterations:
357 |
358 | ```sh
359 | python train.py --cfg configs/example.cfg --resume-checkpoint 25000
360 | ```
361 |
--------------------------------------------------------------------------------
/configs/example.cfg:
--------------------------------------------------------------------------------
1 | [Datasets]
2 | train = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/train
3 | test = /PATH_TO_BASE_OUTFOLDER/veri_data_nosil/test
4 |
5 | [Model]
6 | #allowed model_type : ['XTDNN', 'ETDNN']
7 | model_type = XTDNN
8 | classifier_heads = speaker,age
9 |
10 | [Optim]
11 | classifier_types = xvec,xvec
12 | classifier_lr_mults = [1.0, 1.0]
13 | classifier_loss_weights = [1.0, 0.1]
14 | classifier_smooth_types = none,none
15 |
16 | [Hyperparams]
17 | input_dim = 30
18 | lr = 0.2
19 | batch_size = 500
20 | max_seq_len = 350
21 | no_cuda = False
22 | seed = 1234
23 | num_iterations = 50000
24 | momentum = 0.5
25 | scheduler_steps = [40000]
26 | scheduler_lambda = 0.5
27 | multi_gpu = False
28 | classifier_lr_mult = 1.
29 | embedding_dim = 256
30 |
31 | [Outputs]
32 | model_dir = exp/example_exp
33 | checkpoint_interval = 500
34 |
35 | [Misc]
36 | num_age_bins = 10
--------------------------------------------------------------------------------
/diarize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import concurrent.futures
3 | import configparser
4 | import glob
5 | import json
6 | import os
7 | import pickle
8 | import random
9 | import re
10 | import shutil
11 | import subprocess
12 | import time
13 | from collections import OrderedDict
14 | from pprint import pprint
15 |
16 | import numpy as np
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import uvloop
21 | from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering
22 | from sklearn.metrics import pairwise_distances
23 | from sklearn.preprocessing import normalize
24 | from torch.utils.data import DataLoader
25 | from tqdm import tqdm
26 |
27 | from data_io import DiarizationDataset
28 | from models.extractors import ETDNN, FTDNN, XTDNN
29 | from train import parse_config
30 |
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser(description='Diarize per recording')
34 | parser.add_argument('--cfg', type=str,
35 | default='./configs/example_speaker.cfg')
36 | parser.add_argument('--checkpoint', type=int,
37 | default=0, help='Choose a specific iteration to evaluate on instead of the best eer')
38 | parser.add_argument('--diar-data', type=str,
39 | default='/disk/scratch2/s1786813/repos/supreme_court_transcripts/oyez/scotus_diarization_nosil/test')
40 | args = parser.parse_args()
41 | assert os.path.isfile(args.cfg)
42 | args._start_time = time.ctime()
43 | return args
44 |
45 |
46 | def lines_to_file(lines, filename, wmode="w+"):
47 | with open(filename, wmode) as fp:
48 | for line in lines:
49 | fp.write(line)
50 |
51 |
52 | def get_eer_metrics(folder):
53 | rpkl_path = os.path.join(folder, 'results.p')
54 | rpkl = pickle.load(open(rpkl_path, 'rb'))
55 | iterations = list(rpkl.keys())
56 | eers = [rpkl[k]['test_eer'] for k in rpkl]
57 | return iterations, eers, np.min(eers)
58 |
59 |
60 | def setup():
61 | if args.model_type == 'XTDNN':
62 | generator = XTDNN(features_per_frame=args.input_dim,
63 | embed_features=args.embedding_dim)
64 | if args.model_type == 'ETDNN':
65 | generator = ETDNN(features_per_frame=args.input_dim,
66 | embed_features=args.embedding_dim)
67 | if args.model_type == 'FTDNN':
68 | generator = FTDNN(in_dim=args.input_dim,
69 | embedding_dim=args.embedding_dim)
70 |
71 | generator.eval()
72 | generator = generator
73 | return generator
74 |
75 |
76 | def agg_clustering_oracle(S, num_clusters):
77 | ahc = AgglomerativeClustering(
78 | n_clusters=num_clusters, affinity='precomputed', linkage='average', compute_full_tree=True)
79 | return ahc.fit_predict(S)
80 |
81 |
82 | def score_der(hyp=None, ref=None, outfile=None, collar=0.25):
83 | '''
84 | Takes in hypothesis rttm and reference rttm and returns the diarization error rate
85 | Calls md-eval.pl -> writes output to file -> greps for DER value
86 | '''
87 | assert os.path.isfile(hyp)
88 | assert os.path.isfile(ref)
89 | assert outfile
90 | cmd = 'perl md-eval.pl -1 -c {} -s {} -r {} > {}'.format(
91 | collar, hyp, ref, outfile)
92 | subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL,
93 | stderr=subprocess.DEVNULL)
94 | assert os.path.isfile(outfile)
95 | with open(outfile, 'r') as file:
96 | data = file.read().replace('\n', '')
97 | der_str = re.search(
98 | r'DIARIZATION\ ERROR\ =\ [0-9]+([.][0-9]+)?', data).group()
99 | der = float(der_str.split()[-1])
100 | return der
101 |
102 |
103 | def score_der_uem(hyp=None, ref=None, outfile=None, uem=None, collar=0.25):
104 | '''
105 | takes in hypothesis rttm and reference rttm and returns the diarization error rate
106 | calls md-eval.pl -> writes output to file -> greps for der value
107 | '''
108 | assert os.path.isfile(hyp)
109 | assert os.path.isfile(ref)
110 | assert os.path.isfile(uem)
111 | assert outfile
112 | cmd = 'perl md-eval.pl -1 -c {} -s {} -r {} -u {} > {}'.format(
113 | collar, hyp, ref, uem, outfile)
114 | subprocess.call(cmd, shell=True, stdout=subprocess.DEVNULL,
115 | stderr=subprocess.DEVNULL)
116 | assert os.path.isfile(outfile)
117 | with open(outfile, 'r') as file:
118 | data = file.read().replace('\n', '')
119 | der_str = re.search(
120 | r'DIARIZATION\ ERROR\ =\ [0-9]+([.][0-9]+)?', data).group()
121 | der = float(der_str.split()[-1])
122 | return der
123 |
124 | def make_rttm_lines(segcols, cluster_labels):
125 | # Make the rtttm from segments and cluster labels, resolve overlaps etc
126 | assert len(segcols[0]) == len(cluster_labels)
127 | assert len(set(segcols[1])) == 1, 'Must be from a single recording'
128 | rec_id = list(set(segcols[1]))[0]
129 |
130 | starts = segcols[2].astype(float)
131 | ends = segcols[3].astype(float)
132 |
133 | events = [{'start': starts[0], 'end': ends[0], 'label': cluster_labels[0]}]
134 |
135 | for t0, t1, lab in zip(starts, ends, cluster_labels):
136 | # TODO: Warning this only considers overlap with a single adjacent neighbour
137 | if t0 <= events[-1]['end']:
138 | if lab == events[-1]['label']:
139 | events[-1]['end'] = t1
140 | continue
141 | else:
142 | overlap = events[-1]['end'] - t0
143 | events[-1]['end'] -= overlap/2
144 | newevent = {'start': t0+overlap/2, 'end': t1, 'label': lab}
145 | events.append(newevent)
146 | else:
147 | newevent = {'start': t0, 'end': t1, 'label': lab}
148 | events.append(newevent)
149 |
150 | line_str = 'SPEAKER {} 0 {:.3f} {:.3f} {} \n'
151 | lines = []
152 | for ev in events:
153 | offset = ev['end'] - ev['start']
154 | if offset < 0.0:
155 | continue
156 | lines.append(line_str.format(rec_id, ev['start'], offset, ev['label']))
157 | return lines
158 |
159 |
160 | def extract_and_diarize(generator, test_data_dir, device):
161 | all_hyp_rttms = []
162 | all_ref_rttms = []
163 |
164 | # if args.checkpoint == 0:
165 | # results_pkl = os.path.join(args.model_dir, 'diar_results.p')
166 | # else:
167 | results_pkl = os.path.join(args.model_dir, 'diar_results_{}.p'.format(args.checkpoint))
168 | rttm_dir = os.path.join(args.model_dir, 'hyp_rttms')
169 | os.makedirs(rttm_dir, exist_ok=True)
170 |
171 | if os.path.isfile(results_pkl):
172 | rpkl = pickle.load(open(results_pkl, 'rb'))
173 | if rpkl['test_data'] != TEST_DATA_PATH:
174 | moved_rpkl = os.path.join(
175 | args.model_dir, rpkl['test_data'].replace(os.sep, '-') + '.p')
176 | shutil.copy(results_pkl, moved_rpkl)
177 | rpkl = OrderedDict({'test_data': TEST_DATA_PATH})
178 | else:
179 | rpkl = OrderedDict({'test_data': TEST_DATA_PATH})
180 |
181 | if 'full_der' in rpkl:
182 | print('({}) Full test DER: {}'.format(args.cfg, rpkl['full_der']))
183 | if 'full_der_uem' not in rpkl:
184 | uem_file = os.path.join(TEST_DATA_PATH, 'uem')
185 | der_uem = score_der_uem(hyp=os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint)),
186 | ref=os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint)),
187 | outfile=os.path.join(args.model_dir, 'final_hyp_uem.derlog'),
188 | uem=uem_file,
189 | collar=0.25)
190 | rpkl['full_der_uem'] = der_uem
191 | pickle.dump(rpkl, open(results_pkl, 'wb'))
192 | print('({}) Full test DER (uem): {}'.format(args.cfg, rpkl['full_der_uem']))
193 |
194 | else:
195 | ds_test = DiarizationDataset(test_data_dir)
196 | recs = ds_test.recs
197 | generator = setup()
198 | generator.eval().to(device)
199 |
200 | if args.checkpoint == 0:
201 | generator_its, eers, _ = get_eer_metrics(args.model_dir)
202 | g_path = os.path.join(args.model_dir, 'generator_{}.pt'.format(
203 | generator_its[np.argmin(eers)]))
204 | else:
205 | g_path = os.path.join(args.model_dir, 'generator_{}.pt'.format(args.checkpoint))
206 | assert os.path.isfile(g_path), "Couldn't find {}".format(g_path)
207 |
208 | generator.load_state_dict(torch.load(g_path))
209 |
210 | with torch.no_grad():
211 | for i, r in tqdm(enumerate(recs), total=len(recs)):
212 | ref_rec_rttm = os.path.join(rttm_dir, '{}_{}_ref.rttm'.format(r, args.checkpoint))
213 | hyp_rec_rttm = os.path.join(rttm_dir, '{}_{}_hyp.rttm'.format(r, args.checkpoint))
214 | if r in rpkl and (os.path.isfile(ref_rec_rttm) and os.path.isfile(hyp_rec_rttm)):
215 | all_ref_rttms.append(ref_rec_rttm)
216 | all_hyp_rttms.append(hyp_rec_rttm)
217 | continue
218 |
219 | feats, spkrs, ref_rttm_lines, segcols, rec = ds_test.__getitem__(i)
220 | num_spkrs = len(set(spkrs))
221 | assert r == rec
222 |
223 | # Extract embeds
224 | embeds = []
225 | for feat in feats:
226 | if len(feat) <= 15:
227 | embeds.append(embed.cpu().numpy())
228 | else:
229 | feat = feat.unsqueeze(0).to(device)
230 | embed = generator(feat)
231 | embeds.append(embed.cpu().numpy())
232 | embeds = np.vstack(embeds)
233 | embeds = normalize(embeds, axis=1)
234 |
235 | # Compute similarity matrix
236 | sim_matrix = pairwise_distances(embeds, metric='cosine')
237 | cluster_labels = agg_clustering_oracle(sim_matrix, num_spkrs)
238 | # TODO: Consider overlapped dataset, prototype for now
239 |
240 | # Write to rttm
241 | hyp_rttm_lines = make_rttm_lines(segcols, cluster_labels)
242 |
243 |
244 | lines_to_file(ref_rttm_lines, ref_rec_rttm)
245 | lines_to_file(hyp_rttm_lines, hyp_rec_rttm)
246 |
247 | # Eval based on recording level rttm
248 | der = score_der(hyp=hyp_rec_rttm, ref=ref_rec_rttm,
249 | outfile='/tmp/{}.derlog'.format(rec), collar=0.25)
250 | print('({}) DER for {}: {}'.format(args.cfg, rec, der))
251 |
252 | rpkl[rec] = der
253 | pickle.dump(rpkl, open(results_pkl, 'wb'))
254 |
255 | all_ref_rttms.append(ref_rec_rttm)
256 | all_hyp_rttms.append(hyp_rec_rttm)
257 |
258 | final_hyp_rttm = os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint))
259 | final_ref_rttm = os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint))
260 |
261 | os.system('cat {} > {}'.format(' '.join(all_ref_rttms), final_ref_rttm))
262 | os.system('cat {} > {}'.format(' '.join(all_hyp_rttms), final_hyp_rttm))
263 |
264 | time.sleep(4)
265 |
266 | full_der = score_der(hyp=final_hyp_rttm, ref=final_ref_rttm,
267 | outfile=os.path.join(args.model_dir, 'final_{}_hyp.derlog'.format(args.checkpoint)), collar=0.25)
268 | print('({}) Full test DER: {}'.format(args.cfg, full_der))
269 |
270 | rpkl['full_der'] = full_der
271 | pickle.dump(rpkl, open(results_pkl, 'wb'))
272 |
273 | uem_file = os.path.join(TEST_DATA_PATH, 'uem')
274 | der_uem = score_der_uem(hyp=os.path.join(args.model_dir, 'final_{}_hyp.rttm'.format(args.checkpoint)),
275 | ref=os.path.join(args.model_dir, 'final_{}_ref.rttm'.format(args.checkpoint)),
276 | outfile=os.path.join(args.model_dir, 'final_hyp_uem.derlog'),
277 | uem=uem_file,
278 | collar=0.25)
279 | rpkl['full_der_uem'] = der_uem
280 | print('({}) Full test DER (uem): {}'.format(args.cfg, rpkl['full_der_uem']))
281 | pickle.dump(rpkl, open(results_pkl, 'wb'))
282 |
283 |
284 | if __name__ == "__main__":
285 | args = parse_args()
286 | args = parse_config(args)
287 | uvloop.install()
288 | rpkl_path = os.path.join(args.model_dir, 'results.p')
289 | if not os.path.isfile(rpkl_path):
290 | print('No results.p found')
291 | else:
292 | device = torch.device('cuda')
293 |
294 | TEST_DATA_PATH = args.diar_data
295 |
296 | generator = setup()
297 | extract_and_diarize(generator, TEST_DATA_PATH, device)
298 |
--------------------------------------------------------------------------------
/figures/multitask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvqluu/MTL-Speaker-Embeddings/a1b892f788eecc3bb458bba547c62ec2dac3c661/figures/multitask.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import configparser
3 | import os
4 | import pickle
5 | import sys
6 | import h5py
7 | import json
8 |
9 | from collections import OrderedDict
10 | from glob import glob
11 | from math import floor, log10
12 |
13 | import numpy as np
14 |
15 | import kaldi_io
16 | import torch
17 | import uvloop
18 | from data_io import SpeakerTestDataset, odict_from_2_col, SpeakerDataset
19 | from kaldi_io import read_vec_flt
20 | from kaldiio import ReadHelper
21 | from models.extractors import ETDNN, FTDNN, XTDNN
22 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
23 | from sklearn.linear_model import LogisticRegression
24 | from sklearn.metrics import roc_curve
25 | from sklearn.metrics.pairwise import cosine_distances, cosine_similarity
26 | from sklearn.preprocessing import normalize
27 | from tqdm import tqdm
28 | from utils import SpeakerRecognitionMetrics
29 |
30 |
31 |
32 | def mtd(stuff, device):
33 | if isinstance(stuff, torch.Tensor):
34 | return stuff.to(device)
35 | else:
36 | return [mtd(s, device) for s in stuff]
37 |
38 | def parse_args():
39 | parser = argparse.ArgumentParser(description='Test SV Model')
40 | parser.add_argument('--cfg', type=str, default='./configs/example_speaker.cfg')
41 | parser.add_argument('--best', action='store_true', default=False, help='Use best model')
42 | parser.add_argument('--checkpoint', type=int, default=-1, # which model to use, overidden by 'best'
43 | help='Use model checkpoint, default -1 uses final model')
44 | args = parser.parse_args()
45 | assert os.path.isfile(args.cfg)
46 | return args
47 |
48 | def parse_config(args):
49 | assert os.path.isfile(args.cfg)
50 | config = configparser.ConfigParser()
51 | config.read(args.cfg)
52 |
53 | args.train_data = config['Datasets'].get('train')
54 | assert args.train_data
55 | args.test_data = config['Datasets'].get('test')
56 |
57 | args.model_type = config['Model'].get('model_type', fallback='XTDNN')
58 | assert args.model_type in ['XTDNN', 'ETDNN', 'FTDNN']
59 |
60 | args.classifier_heads = config['Model'].get('classifier_heads').split(',')
61 | assert len(args.classifier_heads) <= 3, 'Three options available'
62 | assert len(args.classifier_heads) == len(set(args.classifier_heads))
63 | for clf in args.classifier_heads:
64 | assert clf in ['speaker', 'nationality', 'gender', 'age', 'age_regression']
65 |
66 | args.classifier_types = config['Optim'].get('classifier_types').split(',')
67 | assert len(args.classifier_heads) == len(args.classifier_types)
68 |
69 | args.classifier_loss_weighting_type = config['Optim'].get('classifier_loss_weighting_type', fallback='none')
70 | assert args.classifier_loss_weighting_type in ['none', 'uncertainty_kendall', 'uncertainty_liebel', 'dwa']
71 |
72 | args.dwa_temperature = config['Optim'].getfloat('dwa_temperature', fallback=2.)
73 |
74 | args.classifier_loss_weights = np.array(json.loads(config['Optim'].get('classifier_loss_weights'))).astype(float)
75 | assert len(args.classifier_heads) == len(args.classifier_loss_weights)
76 |
77 | args.classifier_lr_mults = np.array(json.loads(config['Optim'].get('classifier_lr_mults'))).astype(float)
78 | assert len(args.classifier_heads) == len(args.classifier_lr_mults)
79 |
80 | # assert clf_type in ['l2softmax', 'adm', 'adacos', 'xvec', 'arcface', 'sphereface', 'softmax']
81 |
82 | args.classifier_smooth_types = config['Optim'].get('classifier_smooth_types').split(',')
83 | assert len(args.classifier_smooth_types) == len(args.classifier_heads)
84 | args.classifier_smooth_types = [s.strip() for s in args.classifier_smooth_types]
85 |
86 | args.label_smooth_type = config['Optim'].get('label_smooth_type', fallback='None')
87 | assert args.label_smooth_type in ['None', 'disturb', 'uniform']
88 | args.label_smooth_prob = config['Optim'].getfloat('label_smooth_prob', fallback=0.1)
89 |
90 | args.input_dim = config['Hyperparams'].getint('input_dim', fallback=30)
91 | args.embedding_dim = config['Hyperparams'].getint('embedding_dim', fallback=512)
92 | args.num_iterations = config['Hyperparams'].getint('num_iterations', fallback=50000)
93 |
94 |
95 | args.model_dir = config['Outputs']['model_dir']
96 | if not hasattr(args, 'basefolder'):
97 | args.basefolder = config['Outputs'].get('basefolder', fallback=None)
98 | args.log_file = os.path.join(args.model_dir, 'train.log')
99 | args.results_pkl = os.path.join(args.model_dir, 'results.p')
100 |
101 | args.num_age_bins = config['Misc'].getint('num_age_bins', fallback=10)
102 | args.age_label_smoothing = config['Misc'].getboolean('age_label_smoothing', fallback=False)
103 | return args
104 |
105 |
106 | def test(generator, ds_test, device, mindcf=False):
107 | generator.eval()
108 | all_embeds = []
109 | all_utts = []
110 | num_examples = len(ds_test.veri_utts)
111 |
112 | with torch.no_grad():
113 | for i in range(num_examples):
114 | feats, utt = ds_test.__getitem__(i)
115 | feats = feats.unsqueeze(0).to(device)
116 | embeds = generator(feats)
117 | all_embeds.append(embeds.cpu().numpy())
118 | all_utts.append(utt)
119 |
120 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
121 | all_embeds = np.vstack(all_embeds)
122 | all_embeds = normalize(all_embeds, axis=1)
123 | all_utts = np.array(all_utts)
124 |
125 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
126 |
127 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
128 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
129 |
130 | scores = metric.scores_from_pairs(emb0, emb1)
131 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs)
132 | generator.train()
133 | if mindcf:
134 | return eer, mindcf1, None
135 | else:
136 | return eer
137 |
138 | def test_fromdata(generator, all_feats, all_utts, ds_test, device, mindcf=False):
139 | generator.eval()
140 | all_embeds = []
141 |
142 | with torch.no_grad():
143 | for feats in all_feats:
144 | feats = feats.unsqueeze(0).to(device)
145 | embeds = generator(feats)
146 | all_embeds.append(embeds.cpu().numpy())
147 |
148 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
149 | all_embeds = np.vstack(all_embeds)
150 | all_embeds = normalize(all_embeds, axis=1)
151 | all_utts = np.array(all_utts)
152 |
153 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
154 |
155 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
156 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
157 |
158 | scores = metric.scores_from_pairs(emb0, emb1)
159 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs)
160 | generator.train()
161 | if mindcf:
162 | return eer, mindcf1, None
163 | else:
164 | return eer
165 |
166 |
167 | def test_all_factors_multids(model_dict, ds_test, dsl, label_types, device):
168 | assert ds_test.test_mode
169 |
170 | for m in model_dict:
171 | model_dict[m]['model'].eval()
172 |
173 | label_types = [l for l in label_types if l not in ['speaker', 'rec']]
174 |
175 | with torch.no_grad():
176 | feats, label_dict, all_utts = ds_test.get_test_items()
177 | all_embeds = []
178 | pred_dict = {m: [] for m in label_types}
179 | for feat in tqdm(feats):
180 | feat = feat.unsqueeze(0).to(device)
181 | embed = model_dict['generator']['model'](model_dict['{}_ilayer'.format(dsl)]['model'](feat))
182 | for m in label_types:
183 | dictkey = '{}_{}'.format(dsl, m)
184 | pred = torch.argmax(model_dict[dictkey]['model'](embed, label=None), dim=1)
185 | pred_dict[m].append(pred.cpu().numpy()[0])
186 | all_embeds.append(embed.cpu().numpy())
187 |
188 | accuracy_dict = {m: np.equal(label_dict[m], pred_dict[m]).sum() / len(all_utts) for m in label_types}
189 |
190 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
191 | all_embeds = np.vstack(all_embeds)
192 | all_embeds = normalize(all_embeds, axis=1)
193 | all_utts = np.array(all_utts)
194 |
195 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
196 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
197 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
198 |
199 | scores = metric.scores_from_pairs(emb0, emb1)
200 | print('Min score: {}, max score {}'.format(min(scores), max(scores)))
201 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs)
202 |
203 | for m in model_dict:
204 | model_dict[m]['model'].train()
205 |
206 | return eer, mindcf1, accuracy_dict
207 |
208 |
209 | def test_all_factors(model_dict, ds_test, device):
210 | assert ds_test.test_mode
211 |
212 | for m in model_dict:
213 | model_dict[m]['model'].eval()
214 |
215 | label_types = [l for l in ds_test.label_types if l in model_dict]
216 |
217 | with torch.no_grad():
218 | feats, label_dict, all_utts = ds_test.get_test_items()
219 | all_embeds = []
220 | pred_dict = {m: [] for m in label_types}
221 | for feat in tqdm(feats):
222 | feat = feat.unsqueeze(0).to(device)
223 | embed = model_dict['generator']['model'](feat)
224 | for m in label_types:
225 | if m.endswith('regression'):
226 | pred = model_dict[m]['model'](embed, label=None)
227 | else:
228 | pred = torch.argmax(model_dict[m]['model'](embed, label=None), dim=1)
229 | pred_dict[m].append(pred.cpu().numpy()[0])
230 | all_embeds.append(embed.cpu().numpy())
231 |
232 | accuracy_dict = {m: np.equal(label_dict[m], pred_dict[m]).sum() / len(all_utts) for m in label_types if not m.endswith('regression')}
233 |
234 | for m in label_types:
235 | if m.endswith('regression'):
236 | accuracy_dict[m] = np.mean((label_dict[m] - pred_dict[m])**2)
237 |
238 | if ds_test.veripairs:
239 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
240 | all_embeds = np.vstack(all_embeds)
241 | all_embeds = normalize(all_embeds, axis=1)
242 | all_utts = np.array(all_utts)
243 |
244 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
245 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
246 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
247 |
248 | scores = metric.scores_from_pairs(emb0, emb1)
249 | print('Min score: {}, max score {}'.format(min(scores), max(scores)))
250 | eer, mindcf1 = metric.compute_min_cost(scores, 1 - ds_test.veri_labs)
251 | else:
252 | eer, mindcf1 = 0.0, 0.0
253 |
254 | for m in model_dict:
255 | model_dict[m]['model'].train()
256 |
257 | return eer, mindcf1, accuracy_dict
258 |
259 |
260 | def test_all_factors_ensemble(model_dict, ds_test, device, feats, all_utts, exclude=[], combine='sum'):
261 | assert ds_test.test_mode
262 |
263 | for m in model_dict:
264 | model_dict[m]['model'].eval()
265 |
266 | label_types = [l for l in ds_test.label_types if l in model_dict]
267 | set_veri_utts = set(list(ds_test.veri_0) + list(ds_test.veri_1))
268 | aux_embeds_dict = {m: [] for m in label_types}
269 |
270 | with torch.no_grad():
271 | veri_embeds = []
272 | veri_utts = []
273 | for feat, utt in tqdm(zip(feats, all_utts)):
274 | if utt in set_veri_utts:
275 | feat = feat.unsqueeze(0).to(device)
276 | embed = model_dict['generator']['model'](feat)
277 | veri_embeds.append(embed)
278 | veri_utts.append(utt)
279 |
280 | veri_embeds = torch.cat(veri_embeds)
281 | for m in label_types:
282 | task_embeds = model_dict[m]['model'](veri_embeds, label=None, transform=True)
283 | aux_embeds_dict[m] = normalize(task_embeds.cpu().numpy(), axis=1)
284 |
285 | veri_embeds = normalize(veri_embeds.cpu().numpy(), axis=1)
286 |
287 | aux_embeds_dict['base'] = veri_embeds
288 |
289 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
290 |
291 | total_scores = []
292 |
293 | for key in aux_embeds_dict:
294 | if key not in exclude:
295 | utt_embed = OrderedDict({k: v for k, v in zip(veri_utts, aux_embeds_dict[key])})
296 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
297 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
298 | scores = metric.scores_from_pairs(emb0, emb1)
299 | total_scores.append(scores)
300 |
301 | if combine == 'sum':
302 | total_scores = np.sum(np.array(total_scores), axis=0)
303 | eer, mindcf1 = metric.compute_min_cost(total_scores, 1. - ds_test.veri_labs)
304 | else:
305 | total_scores = np.array(total_scores).T
306 | lr_clf = LogisticRegression(solver='lbfgs')
307 | lr_clf.fit(total_scores, 1. - ds_test.veri_labs)
308 | weighted_scores = lr_clf.predict(total_scores)
309 | eer, mindcf1 = metric.compute_min_cost(weighted_scores, 1. - ds_test.veri_labs)
310 |
311 | for m in model_dict:
312 | model_dict[m]['model'].train()
313 |
314 | return eer, mindcf1
315 |
316 |
317 | def test_enrollment_models(generator, ds_test, device, return_scores=False, reduce_method='mean_embed'):
318 | assert reduce_method in ['mean_embed', 'mean_dist', 'max', 'min']
319 | generator.eval()
320 | all_embeds = []
321 | all_utts = []
322 | num_examples = len(ds_test)
323 |
324 | with torch.no_grad():
325 | for i in tqdm(range(num_examples)):
326 | feats, utt = ds_test.__getitem__(i)
327 | feats = feats.unsqueeze(0).to(device)
328 | embeds = generator(feats)
329 | all_embeds.append(embeds.cpu().numpy())
330 | all_utts.append(utt)
331 |
332 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
333 | all_embeds = np.vstack(all_embeds)
334 | all_embeds = normalize(all_embeds, axis=1)
335 | all_utts = np.array(all_utts)
336 |
337 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
338 | model_embeds = OrderedDict({})
339 | for i, model in enumerate(ds_test.models):
340 | model_embeds[model] = np.array([utt_embed[utt] for utt in ds_test.m_utts[i]])
341 |
342 | emb0s = np.array([model_embeds[u] for u in ds_test.models_eval])
343 | emb1 = np.array([utt_embed[u] for u in ds_test.eval_utts])
344 |
345 | mscore_means = []
346 | mscore_stds = []
347 | scores = []
348 | for model_utts, test_utt in zip(emb0s, emb1):
349 | if reduce_method == 'mean_embed':
350 | model_std = model_utts.std(0)
351 | model_mean = model_utts.mean(0)
352 | scores.append(np.linalg.norm(test_utt - model_mean))
353 | elif reduce_method == 'mean_dist':
354 | dist_means = np.mean(np.array([np.linalg.norm(test_utt - e) for e in model_utts]))
355 | scores.append(dist_means)
356 | elif reduce_method == 'min':
357 | scores.append(np.min(np.array([np.linalg.norm(test_utt - e) for e in model_utts])))
358 | elif reduce_method == 'max':
359 | scores.append(np.max(np.array([np.linalg.norm(test_utt - e) for e in model_utts])))
360 | else:
361 | print('do nothing')
362 |
363 | scores = np.array(scores)
364 | if return_scores:
365 | return scores
366 |
367 | eer, mindcf1 = metric.compute_min_cost(scores,
368 | 1 - ds_test.veri_labs)
369 | generator.train()
370 | return eer, mindcf1, scores
371 |
372 |
373 | def test_nosil(generator, ds_test, device, mindcf=False):
374 | generator.eval()
375 | all_embeds = []
376 | all_utts = []
377 | num_examples = len(ds_test.veri_utts)
378 |
379 | with torch.no_grad():
380 | with ReadHelper(
381 | 'ark:apply-cmvn-sliding --norm-vars=false --center=true --cmn-window=300 scp:{0}/feats_trimmed.scp '
382 | 'ark:- | select-voiced-frames ark:- scp:{0}/vad_trimmed.scp ark:- |'.format(
383 | ds_test.data_base_path)) as reader:
384 | for key, feat in tqdm(reader, total=num_examples):
385 | if key in ds_test.veri_utts:
386 | all_utts.append(key)
387 | feats = torch.FloatTensor(feat).unsqueeze(0).to(device)
388 | embeds = generator(feats)
389 | all_embeds.append(embeds.cpu().numpy())
390 |
391 | metric = SpeakerRecognitionMetrics(distance_measure='cosine')
392 | all_embeds = np.vstack(all_embeds)
393 | all_embeds = normalize(all_embeds, axis=1)
394 | all_utts = np.array(all_utts)
395 |
396 | print(all_embeds.shape, len(ds_test.veri_utts))
397 | utt_embed = OrderedDict({k: v for k, v in zip(all_utts, all_embeds)})
398 |
399 | emb0 = np.array([utt_embed[utt] for utt in ds_test.veri_0])
400 | emb1 = np.array([utt_embed[utt] for utt in ds_test.veri_1])
401 |
402 | scores = metric.scores_from_pairs(emb0, emb1)
403 | fpr, tpr, thresholds = roc_curve(1 - ds_test.veri_labs, scores, pos_label=1, drop_intermediate=False)
404 | eer = metric.eer_from_ers(fpr, tpr)
405 | generator.train()
406 | if mindcf:
407 | mindcf1 = metric.compute_min_dcf(fpr, tpr, thresholds, p_target=0.01)
408 | mindcf2 = metric.compute_min_dcf(fpr, tpr, thresholds, p_target=0.001)
409 | return eer, mindcf1, mindcf2
410 | else:
411 | return eer
412 |
413 |
414 | def evaluate_deepmine(generator, ds_eval, device, outfile_path='./exp'):
415 | os.makedirs(outfile_path, exist_ok=True)
416 | generator.eval()
417 |
418 | answer_col0 = []
419 | answer_col1 = []
420 | answer_col2 = []
421 |
422 | with torch.no_grad():
423 | for i in tqdm(range(len(ds_eval))):
424 | model, enrol_utts, enrol_feats, eval_utts, eval_feats = ds_eval.__getitem__(i)
425 | answer_col0.append([model for _ in range(len(eval_utts))])
426 | answer_col1.append(eval_utts)
427 |
428 | enrol_feats = mtd(enrol_feats, device)
429 | model_embed = torch.cat([generator(x.unsqueeze(0)) for x in enrol_feats]).cpu().numpy()
430 | model_embed = np.mean(normalize(model_embed, axis=1), axis=0).reshape(1, -1)
431 |
432 | del enrol_feats
433 | eval_feats = mtd(eval_feats, device)
434 | eval_embeds = torch.cat([generator(x.unsqueeze(0)) for x in eval_feats]).cpu().numpy()
435 | eval_embeds = normalize(eval_embeds, axis=1)
436 |
437 | scores = cosine_similarity(model_embed, eval_embeds).squeeze(0)
438 | assert len(scores) == len(eval_utts)
439 | answer_col2.append(scores)
440 | del eval_feats
441 |
442 | answer_col0 = np.concatenate(answer_col0)
443 | answer_col1 = np.concatenate(answer_col1)
444 | answer_col2 = np.concatenate(answer_col2)
445 |
446 | with open(os.path.join(outfile_path, 'answer_full.txt'), 'w+') as fp:
447 | for m, ev, s in zip(answer_col0, answer_col1, answer_col2):
448 | line = '{} {} {}\n'.format(m, ev, s)
449 | fp.write(line)
450 |
451 | with open(os.path.join(outfile_path, 'answer.txt'), 'w+') as fp:
452 | for s in answer_col2:
453 | line = '{}\n'.format(s)
454 | fp.write(line)
455 |
456 | if (answer_col0 == np.array(ds_eval.models_eval)).all():
457 | print('model ordering matched')
458 | else:
459 | print('model ordering was not correct, need to fix before submission')
460 |
461 | if (answer_col1 == np.array(ds_eval.eval_utts)).all():
462 | print('eval utt ordering matched')
463 | else:
464 | print('eval utt ordering was not correct, need to fix before submission')
465 |
466 |
467 | def dvec_compute(generator, ds_eval, device, num_jobs=20, outfolder='./exp/example_dvecs'):
468 | # naively compute the embeddings for each window
469 | # ds_len = len(ds_feats)
470 | all_utts = ds_eval.all_utts
471 | ds_len = len(all_utts)
472 | indices = np.arange(ds_len)
473 | job_split = np.array_split(indices, num_jobs)
474 | generator.eval().to(device)
475 | for job_num, job in enumerate(tqdm(job_split)):
476 | print('Starting job {}'.format(job_num))
477 | ark_scp_output = 'ark:| copy-vector ark:- ark,scp:{0}/xvector.{1}.ark,{0}/xvector.{1}.scp'.format(outfolder,
478 | job_num + 1)
479 | job_utts = all_utts[job]
480 | job_feats = ds_eval.get_batches(job_utts)
481 | job_feats = mtd(job_feats, device)
482 | with torch.no_grad():
483 | job_embeds = torch.cat([generator(x.unsqueeze(0)) for x in tqdm(job_feats)]).cpu().numpy()
484 | with kaldi_io.open_or_fd(ark_scp_output, 'wb') as f:
485 | for xvec, key in zip(job_embeds, job_utts):
486 | kaldi_io.write_vec_flt(f, xvec, key=key)
487 |
488 |
489 | def evaluate_deepmine_from_xvecs(ds_eval, outfolder='./exp/example_xvecs'):
490 | if not os.path.isfile(os.path.join(outfolder, 'xvector.scp')):
491 | xvec_scps = glob(os.path.join(outfolder, '*.scp'))
492 | assert len(xvec_scps) != 0, 'No xvector scps found'
493 | with open(os.path.join(outfolder, 'xvector.scp'), 'w+') as outfile:
494 | for fname in xvec_scps:
495 | with open(fname) as infile:
496 | for line in infile:
497 | outfile.write(line)
498 |
499 | xvec_dict = odict_from_2_col(os.path.join(outfolder, 'xvector.scp'))
500 | answer_col0 = []
501 | answer_col1 = []
502 | answer_col2 = []
503 |
504 | for i in tqdm(range(len(ds_eval))):
505 | model, enrol_utts, eval_utts, = ds_eval.get_item_utts(i)
506 | answer_col0.append([model for _ in range(len(eval_utts))])
507 | answer_col1.append(eval_utts)
508 |
509 | model_embeds = np.array([read_vec_flt(xvec_dict[u]) for u in enrol_utts])
510 | model_embed = np.mean(normalize(model_embeds, axis=1), axis=0).reshape(1, -1)
511 |
512 | eval_embeds = np.array([read_vec_flt(xvec_dict[u]) for u in eval_utts])
513 | eval_embeds = normalize(eval_embeds, axis=1)
514 |
515 | scores = cosine_similarity(model_embed, eval_embeds).squeeze(0)
516 | assert len(scores) == len(eval_utts)
517 | answer_col2.append(scores)
518 |
519 | answer_col0 = np.concatenate(answer_col0)
520 | answer_col1 = np.concatenate(answer_col1)
521 | answer_col2 = np.concatenate(answer_col2)
522 |
523 | print('Writing results to file...')
524 | with open(os.path.join(outfolder, 'answer_full.txt'), 'w+') as fp:
525 | for m, ev, s in tqdm(zip(answer_col0, answer_col1, answer_col2)):
526 | line = '{} {} {}\n'.format(m, ev, s)
527 | fp.write(line)
528 |
529 | with open(os.path.join(outfolder, 'answer.txt'), 'w+') as fp:
530 | for s in tqdm(answer_col2):
531 | line = '{}\n'.format(s)
532 | fp.write(line)
533 |
534 | if (answer_col0 == np.array(ds_eval.models_eval)).all():
535 | print('model ordering matched')
536 | else:
537 | print('model ordering was not correct, need to fix before submission')
538 |
539 | if (answer_col1 == np.array(ds_eval.eval_utts)).all():
540 | print('eval utt ordering matched')
541 | else:
542 | print('eval utt ordering was not correct, need to fix before submission')
543 |
544 |
545 | def round_sig(x, sig=2):
546 | return round(x, sig - int(floor(log10(abs(x)))) - 1)
547 |
548 |
549 | def get_eer_metrics(folder):
550 | rpkl_path = os.path.join(folder, 'wrec_results.p')
551 | rpkl = pickle.load(open(rpkl_path, 'rb'))
552 | iterations = list(rpkl.keys())
553 | eers = [rpkl[k]['test_eer'] for k in rpkl]
554 | return iterations, eers, np.min(eers)
555 |
556 |
557 | def extract_test_embeds(generator, ds_test, device):
558 | assert ds_test.test_mode
559 |
560 | with torch.no_grad():
561 | feats, label_dict, all_utts = ds_test.get_test_items()
562 | all_embeds = []
563 | for feat in tqdm(feats):
564 | feat = feat.unsqueeze(0).to(device)
565 | embed = generator(feat)
566 | all_embeds.append(embed.cpu().numpy())
567 |
568 | all_embeds = np.vstack(all_embeds)
569 | all_embeds = normalize(all_embeds, axis=1)
570 | all_utts = np.array(all_utts)
571 |
572 | return all_utts, all_embeds
573 |
574 | if __name__ == "__main__":
575 | args = parse_args()
576 | args = parse_config(args)
577 | uvloop.install()
578 |
579 | hf = h5py.File(os.path.join(args.model_dir, 'xvectors.h5'), 'w')
580 |
581 | use_cuda = torch.cuda.is_available()
582 | print('=' * 30)
583 | print('USE_CUDA SET TO: {}'.format(use_cuda))
584 | print('CUDA AVAILABLE?: {}'.format(torch.cuda.is_available()))
585 | print('=' * 30)
586 | device = torch.device("cuda" if use_cuda else "cpu")
587 |
588 | if args.checkpoint == -1:
589 | g_path = os.path.join(args.model_dir, "final_generator_{}.pt".format(args.num_iterations))
590 | else:
591 | g_path = os.path.join(args.model_dir, "generator_{}.pt".format(args.checkpoint))
592 |
593 | if args.model_type == 'XTDNN':
594 | generator = XTDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim)
595 | if args.model_type == 'ETDNN':
596 | generator = ETDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim)
597 | if args.model_type == 'FTDNN':
598 | generator = FTDNN(in_dim=args.input_dim, embedding_dim=args.embedding_dim)
599 |
600 | if args.best:
601 | iterations, eers, _ = get_eer_metrics(args.model_dir)
602 | best_it = iterations[np.argmin(eers)]
603 | g_path = os.path.join(args.model_dir, "generator_{}.pt".format(best_it))
604 |
605 | assert os.path.isfile(g_path), "Couldn't find {}".format(g_path)
606 |
607 | ds_train = SpeakerDataset(args.train_data, num_age_bins=args.num_age_bins)
608 | class_enc_dict = ds_train.get_class_encs()
609 | if args.test_data:
610 | ds_test = SpeakerDataset(args.test_data, test_mode=True,
611 | class_enc_dict=class_enc_dict, num_age_bins=args.num_age_bins)
612 |
613 | generator.load_state_dict(torch.load(g_path))
614 | model = generator
615 | model.eval().to(device)
616 |
617 | utts, embeds = extract_test_embeds(generator, ds_test, device)
618 |
619 | hf.create_dataset('embeds', data=embeds)
620 | dt = h5py.string_dtype()
621 | hf.create_dataset('utts', data=np.string_(np.array(utts)), dtype=dt)
622 | hf.close()
623 |
624 |
--------------------------------------------------------------------------------
/models/classifiers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.autograd import Function
10 |
11 | '''
12 | AdaCos and Ad margin loss taken from https://github.com/4uiiurz1/pytorch-adacos
13 | '''
14 |
15 | class DropAffine(nn.Module):
16 |
17 | def __init__(self, num_features, num_classes):
18 | super(DropAffine, self).__init__()
19 | self.fc = nn.Linear(num_features, num_classes)
20 | self.reset_parameters()
21 |
22 | def reset_parameters(self):
23 | self.fc.reset_parameters()
24 |
25 | def forward(self, input, label=None):
26 | W = self.fc.weight
27 | b = self.fc.bias
28 | logits = F.linear(input, W, b)
29 | return logits
30 |
31 |
32 | class GradientReversalFunction(Function):
33 | """
34 | Gradient Reversal Layer from:
35 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
36 | Forward pass is the identity function. In the backward pass,
37 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
38 | """
39 |
40 | @staticmethod
41 | def forward(ctx, x, lambda_):
42 | ctx.lambda_ = lambda_
43 | return x.clone()
44 |
45 | @staticmethod
46 | def backward(ctx, grads):
47 | lambda_ = ctx.lambda_
48 | lambda_ = grads.new_tensor(lambda_)
49 | dx = -lambda_ * grads
50 | return dx, None
51 |
52 |
53 | class GradientReversal(nn.Module):
54 | def __init__(self, lambda_=1):
55 | super(GradientReversal, self).__init__()
56 | self.lambda_ = lambda_
57 |
58 | def forward(self, x, **kwargs):
59 | return GradientReversalFunction.apply(x, self.lambda_)
60 |
61 |
62 |
63 | class L2SoftMax(nn.Module):
64 |
65 | def __init__(self, num_features, num_classes):
66 | super(L2SoftMax, self).__init__()
67 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
68 | self.reset_parameters()
69 |
70 | def reset_parameters(self):
71 | nn.init.xavier_uniform_(self.W)
72 |
73 | def forward(self, input, label=None):
74 | x = F.normalize(input)
75 | W = F.normalize(self.W)
76 | logits = F.linear(x, W)
77 | return logits
78 |
79 | class SoftMax(nn.Module):
80 |
81 | def __init__(self, num_features, num_classes):
82 | super(SoftMax, self).__init__()
83 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
84 | self.reset_parameters()
85 |
86 | def reset_parameters(self):
87 | nn.init.xavier_uniform_(self.W)
88 |
89 | def forward(self, input, label=None):
90 | x = input
91 | W = self.W
92 | logits = F.linear(x, W)
93 | return logits
94 |
95 |
96 | class LinearUncertain(nn.Module):
97 |
98 | def __init__(self, in_features, out_features, bias=True):
99 | super(LinearUncertain, self).__init__()
100 | self.in_features = in_features
101 | self.out_features = out_features
102 | self.bias = bias
103 |
104 | self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
105 | self.weight_beta = nn.Parameter(torch.Tensor(out_features, in_features))
106 |
107 | if bias:
108 | self.bias = nn.Parameter(torch.Tensor(out_features))
109 | else:
110 | self.register_parameter('bias', None)
111 | self.reset_parameters()
112 |
113 | def reset_parameters(self):
114 | nn.init.xavier_uniform_(self.weight_mu, gain=1.0)
115 | init_beta = np.log(np.exp(0.5 * np.sqrt(6)/(self.in_features+self.out_features)) - 1)
116 | nn.init.constant_(self.weight_beta, init_beta)
117 | if self.bias is not None:
118 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_mu)
119 | bound = 1 / np.sqrt(fan_in)
120 | nn.init.uniform_(self.bias, -bound, bound)
121 |
122 |
123 | def forward(self, x):
124 | if self.training:
125 | eps = torch.randn(self.out_features, self.in_features).to(self.weight_mu.device)
126 | weights = self.weight_mu + torch.log(torch.exp(self.weight_beta) + 1) * eps
127 | else:
128 | weights = self.weight_mu
129 | return F.linear(x, weights, self.bias)
130 |
131 |
132 | class XVecHead(nn.Module):
133 |
134 | def __init__(self, num_features, num_classes, hidden_features=None):
135 | super(XVecHead, self).__init__()
136 | hidden_features = num_features if not hidden_features else hidden_features
137 | self.fc_hidden = nn.Linear(num_features, hidden_features)
138 | self.nl = nn.LeakyReLU()
139 | self.bn = nn.BatchNorm1d(hidden_features)
140 | self.fc = nn.Linear(hidden_features, num_classes)
141 | self.reset_parameters()
142 |
143 | def reset_parameters(self):
144 | self.fc.reset_parameters()
145 |
146 | def forward(self, input, label=None, transform=False):
147 | input = self.fc_hidden(input)
148 | input = self.nl(input)
149 | input = self.bn(input)
150 | if transform:
151 | return input
152 | W = self.fc.weight
153 | b = self.fc.bias
154 | logits = F.linear(input, W, b)
155 | if logits.shape[-1] == 1:
156 | logits = torch.squeeze(logits, dim=-1)
157 | return logits
158 |
159 | class XVecHeadUncertain(nn.Module):
160 |
161 | def __init__(self, num_features, num_classes, hidden_features=None):
162 | super().__init__()
163 | hidden_features = num_features if not hidden_features else hidden_features
164 | self.fc_hidden = LinearUncertain(num_features, hidden_features)
165 | self.nl = nn.LeakyReLU()
166 | self.bn = nn.BatchNorm1d(hidden_features)
167 | self.fc = LinearUncertain(hidden_features, num_classes)
168 | self.reset_parameters()
169 |
170 | def reset_parameters(self):
171 | self.fc.reset_parameters()
172 |
173 | def forward(self, input, label=None, transform=False):
174 | input = self.fc_hidden(input)
175 | input = self.nl(input)
176 | input = self.bn(input)
177 | if transform:
178 | return input
179 | logits = self.fc(input)
180 | return logits
181 |
182 |
183 | class AMSMLoss(nn.Module):
184 |
185 | def __init__(self, num_features, num_classes, s=30.0, m=0.4):
186 | super(AMSMLoss, self).__init__()
187 | self.num_features = num_features
188 | self.n_classes = num_classes
189 | self.s = s
190 | self.m = m
191 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
192 | self.reset_parameters()
193 |
194 | def reset_parameters(self):
195 | nn.init.xavier_uniform_(self.W)
196 |
197 | def forward(self, input, label=None):
198 | # normalize features
199 | x = F.normalize(input)
200 | # normalize weights
201 | W = F.normalize(self.W)
202 | # dot product
203 | logits = F.linear(x, W)
204 | if label is None:
205 | return logits
206 | # add margin
207 | target_logits = logits - self.m
208 | one_hot = torch.zeros_like(logits)
209 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
210 | output = logits * (1 - one_hot) + target_logits * one_hot
211 | # feature re-scale
212 | output *= self.s
213 |
214 | return output
215 |
216 |
217 |
218 | class SphereFace(nn.Module):
219 |
220 | def __init__(self, num_features, num_classes, s=30.0, m=1.35):
221 | super(SphereFace, self).__init__()
222 | self.num_features = num_features
223 | self.n_classes = num_classes
224 | self.s = s
225 | self.m = m
226 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
227 | self.reset_parameters()
228 |
229 | def reset_parameters(self):
230 | nn.init.xavier_uniform_(self.W)
231 |
232 | def forward(self, input, label=None):
233 | # normalize features
234 | x = F.normalize(input)
235 | # normalize weights
236 | W = F.normalize(self.W)
237 | # dot product
238 | logits = F.linear(x, W)
239 | if label is None:
240 | return logits
241 | # add margin
242 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
243 | target_logits = torch.cos(self.m * theta)
244 | one_hot = torch.zeros_like(logits)
245 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
246 | output = logits * (1 - one_hot) + target_logits * one_hot
247 | # feature re-scale
248 | output *= self.s
249 |
250 | return output
251 |
252 |
253 | class ArcFace(nn.Module):
254 |
255 | def __init__(self, num_features, num_classes, s=30.0, m=0.50):
256 | super(ArcFace, self).__init__()
257 | self.num_features = num_features
258 | self.n_classes = num_classes
259 | self.s = s
260 | self.m = m
261 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
262 | self.reset_parameters()
263 |
264 | def reset_parameters(self):
265 | nn.init.xavier_uniform_(self.W)
266 |
267 | def forward(self, input, label=None):
268 | # normalize features
269 | x = F.normalize(input)
270 | # normalize weights
271 | W = F.normalize(self.W)
272 | # dot product
273 | logits = F.linear(x, W)
274 | if label is None:
275 | return logits
276 | # add margin
277 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
278 | target_logits = torch.cos(theta + self.m)
279 | one_hot = torch.zeros_like(logits)
280 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
281 | output = logits * (1 - one_hot) + target_logits * one_hot
282 | # feature re-scale
283 | output *= self.s
284 |
285 | return output
286 |
287 | class AdaCos(nn.Module):
288 |
289 | def __init__(self, num_features, num_classes, m=0.50):
290 | super(AdaCos, self).__init__()
291 | self.num_features = num_features
292 | self.n_classes = num_classes
293 | self.s = math.sqrt(2) * math.log(num_classes - 1)
294 | self.m = m
295 | self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
296 | self.reset_parameters()
297 |
298 | def reset_parameters(self):
299 | nn.init.xavier_uniform_(self.W)
300 |
301 | def forward(self, input, label=None):
302 | # normalize features
303 | x = F.normalize(input)
304 | # normalize weights
305 | W = F.normalize(self.W)
306 | # dot product
307 | logits = F.linear(x, W)
308 | if label is None:
309 | return logits
310 | # feature re-scale
311 | theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
312 | one_hot = torch.zeros_like(logits)
313 | one_hot.scatter_(1, label.view(-1, 1).long(), 1)
314 | with torch.no_grad():
315 | B_avg = torch.where(one_hot < 1, torch.exp(self.s * logits), torch.zeros_like(logits))
316 | B_avg = torch.sum(B_avg) / input.size(0)
317 | theta_med = torch.median(theta[one_hot == 1])
318 | self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
319 | output = self.s * logits
320 |
321 | return output
322 |
323 |
--------------------------------------------------------------------------------
/models/criteria.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MultiTaskUncertaintyLossKendall(nn.Module):
5 |
6 | def __init__(self, num_tasks):
7 | """
8 | Multi task loss with uncertainty weighting (Kendall 2016)
9 | \eta = 2*log(\sigma)
10 | """
11 | super().__init__()
12 | self.num_tasks = num_tasks
13 | self.eta = nn.Parameter(torch.zeros(num_tasks))
14 |
15 | def forward(self, losses):
16 | """
17 | Input: 1-d tensor of scalar losses, shape (num_tasks,)
18 | Output: Total weighted loss
19 | """
20 | assert len(losses) == self.num_tasks, 'Expected {} losses to weight, got {}'.format(self.num_tasks, len(losses))
21 | # factor = torch.pow(2*torch.exp(self.eta) - 2, -1)
22 | factor = torch.exp(-self.eta)/2.
23 | total_loss = (losses * factor + self.eta).sum()
24 | return total_loss/self.num_tasks
25 |
26 |
27 | class MultiTaskUncertaintyLossLiebel(nn.Module):
28 |
29 | def __init__(self, num_tasks):
30 | """
31 | Multi task loss with uncertainty weighting
32 | Liebel (2018) version
33 | Regularisation term ln(1 + sigma^2)
34 | """
35 | super().__init__()
36 | self.num_tasks = num_tasks
37 | self.sigma2 = nn.Parameter(0.25*torch.ones(num_tasks))
38 |
39 | def forward(self, losses):
40 | """
41 | Input: 1-d tensor of scalar losses, shape (num_tasks,)
42 | Output: Total weighted loss
43 | """
44 | assert len(losses) == self.num_tasks, 'Expected {} losses to weight, got {}'.format(self.num_tasks, len(losses))
45 | factor = 1./(2*self.sigma2)
46 | reg = torch.log(1. + self.sigma2) #regularisation term
47 | total_loss = (losses * factor + reg).sum()
48 | return total_loss/self.num_tasks
49 |
50 | class DisturbLabelLoss(nn.Module):
51 |
52 | def __init__(self, device, disturb_prob=0.1):
53 | super(DisturbLabelLoss, self).__init__()
54 | self.disturb_prob = disturb_prob
55 | self.ce = nn.CrossEntropyLoss()
56 | self.device = device
57 |
58 | def forward(self, pred, target):
59 | with torch.no_grad():
60 | disturb_indexes = torch.rand(len(pred)) < self.disturb_prob
61 | target[disturb_indexes] = torch.randint(pred.shape[-1], (int(disturb_indexes.sum()),)).to(self.device)
62 | return self.ce(pred, target)
63 |
64 |
65 | class LabelSmoothingLoss(nn.Module):
66 |
67 | def __init__(self, smoothing=0.1, dim=-1):
68 | super(LabelSmoothingLoss, self).__init__()
69 | self.confidence = 1.0 - smoothing
70 | self.smoothing = smoothing
71 | self.dim = dim
72 |
73 | def forward(self, pred, target):
74 | pred = pred.log_softmax(dim=self.dim)
75 | with torch.no_grad():
76 | true_dist = torch.zeros_like(pred)
77 | true_dist.fill_(self.smoothing / (pred.shape[-1] - 1))
78 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
79 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
80 |
81 |
82 | class TwoNeighbourSmoothingLoss(nn.Module):
83 |
84 | def __init__(self, smoothing=0.1, dim=-1):
85 | super().__init__()
86 | self.dim = dim
87 | self.smoothing = smoothing
88 | self.confidence = 1.0 - smoothing
89 |
90 |
91 | def forward(self, pred, target):
92 | pred = pred.log_softmax(dim=self.dim)
93 | num_classes = pred.shape[self.dim]
94 | with torch.no_grad():
95 | targets = target.data.unsqueeze(1)
96 | true_dist = torch.zeros_like(pred)
97 |
98 | up_labels = targets.add(1)
99 | up_labels[up_labels >= num_classes] = num_classes - 2
100 | down_labels = targets.add(-1)
101 | down_labels[down_labels < 0] = 1
102 |
103 | smooth_values = torch.zeros_like(targets).float().add_(self.smoothing/2)
104 | true_dist.scatter_add_(1, up_labels, smooth_values)
105 | true_dist.scatter_add_(1, down_labels, smooth_values)
106 |
107 | true_dist.scatter_(1, targets, self.confidence)
108 |
109 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
--------------------------------------------------------------------------------
/models/extractors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Function
5 | from collections import OrderedDict
6 | import numpy as np
7 |
8 |
9 | class TDNN(nn.Module):
10 |
11 | def __init__(
12 | self,
13 | input_dim=23,
14 | output_dim=512,
15 | context_size=5,
16 | stride=1,
17 | dilation=1,
18 | batch_norm=True,
19 | dropout_p=0.0,
20 | padding=0
21 | ):
22 | super(TDNN, self).__init__()
23 | self.context_size = context_size
24 | self.stride = stride
25 | self.input_dim = input_dim
26 | self.output_dim = output_dim
27 | self.dilation = dilation
28 | self.dropout_p = dropout_p
29 | self.padding = padding
30 |
31 | self.kernel = nn.Conv1d(self.input_dim,
32 | self.output_dim,
33 | self.context_size,
34 | stride=self.stride,
35 | padding=self.padding,
36 | dilation=self.dilation)
37 |
38 | self.nonlinearity = nn.LeakyReLU()
39 | self.batch_norm = batch_norm
40 | if batch_norm:
41 | self.bn = nn.BatchNorm1d(output_dim)
42 | self.drop = nn.Dropout(p=self.dropout_p)
43 |
44 | def forward(self, x):
45 | '''
46 | input: size (batch, seq_len, input_features)
47 | outpu: size (batch, new_seq_len, output_features)
48 | '''
49 |
50 | _, _, d = x.shape
51 | assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d)
52 |
53 | x = self.kernel(x.transpose(1,2))
54 | x = self.nonlinearity(x)
55 | x = self.drop(x)
56 |
57 | if self.batch_norm:
58 | x = self.bn(x)
59 | return x.transpose(1,2)
60 |
61 |
62 | class StatsPool(nn.Module):
63 |
64 | def __init__(self, floor=1e-10, bessel=False):
65 | super(StatsPool, self).__init__()
66 | self.floor = floor
67 | self.bessel = bessel
68 |
69 | def forward(self, x):
70 | means = torch.mean(x, dim=1)
71 | _, t, _ = x.shape
72 | if self.bessel:
73 | t = t - 1
74 | residuals = x - means.unsqueeze(1)
75 | numerator = torch.sum(residuals**2, dim=1)
76 | stds = torch.sqrt(torch.clamp(numerator, min=self.floor)/t)
77 | x = torch.cat([means, stds], dim=1)
78 | return x
79 |
80 |
81 | class ETDNN(nn.Module):
82 |
83 | def __init__(
84 | self,
85 | features_per_frame=80,
86 | hidden_features=1024,
87 | embed_features=256,
88 | dropout_p=0.0,
89 | batch_norm=True
90 | ):
91 | super(ETDNN, self).__init__()
92 | self.features_per_frame = features_per_frame
93 | self.hidden_features = hidden_features
94 | self.embed_features = embed_features
95 |
96 | self.dropout_p = dropout_p
97 | self.batch_norm = batch_norm
98 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm}
99 | self.nl = nn.LeakyReLU()
100 |
101 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=self.hidden_features, context_size=5, dilation=1, **tdnn_kwargs)
102 | self.frame2 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs)
103 | self.frame3 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=2, **tdnn_kwargs)
104 | self.frame4 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs)
105 | self.frame5 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=3, **tdnn_kwargs)
106 | self.frame6 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs)
107 | self.frame7 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=3, dilation=4, **tdnn_kwargs)
108 | self.frame8 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features, context_size=1, dilation=1, **tdnn_kwargs)
109 | self.frame9 = TDNN(input_dim=self.hidden_features, output_dim=self.hidden_features*3, context_size=1, dilation=1, **tdnn_kwargs)
110 |
111 | self.tdnn_list = nn.Sequential(self.frame1, self.frame2, self.frame3, self.frame4, self.frame5, self.frame6, self.frame7, self.frame8, self.frame9)
112 | self.statspool = StatsPool()
113 |
114 | self.fc_embed = nn.Linear(self.hidden_features*6, self.embed_features)
115 |
116 | def forward(self, x):
117 | x = self.tdnn_list(x)
118 | x = self.statspool(x)
119 | x = self.fc_embed(x)
120 | return x
121 |
122 |
123 | class XTDNN(nn.Module):
124 |
125 | def __init__(
126 | self,
127 | features_per_frame=30,
128 | final_features=1500,
129 | embed_features=512,
130 | dropout_p=0.0,
131 | batch_norm=True
132 | ):
133 | super(XTDNN, self).__init__()
134 | self.features_per_frame = features_per_frame
135 | self.final_features = final_features
136 | self.embed_features = embed_features
137 | self.dropout_p = dropout_p
138 | self.batch_norm = batch_norm
139 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm}
140 |
141 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=512, context_size=5, dilation=1, **tdnn_kwargs)
142 | self.frame2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=2, **tdnn_kwargs)
143 | self.frame3 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=3, **tdnn_kwargs)
144 | self.frame4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1, **tdnn_kwargs)
145 | self.frame5 = TDNN(input_dim=512, output_dim=self.final_features, context_size=1, dilation=1, **tdnn_kwargs)
146 |
147 | self.tdnn_list = nn.Sequential(self.frame1, self.frame2, self.frame3, self.frame4, self.frame5)
148 | self.statspool = StatsPool()
149 |
150 | self.fc_embed = nn.Linear(self.final_features*2, self.embed_features)
151 | self.nl_embed = nn.LeakyReLU()
152 | self.bn_embed = nn.BatchNorm1d(self.embed_features)
153 | self.drop_embed = nn.Dropout(p=self.dropout_p)
154 |
155 | def forward(self, x):
156 | x = self.tdnn_list(x)
157 | x = self.statspool(x)
158 | x = self.fc_embed(x)
159 | x = self.nl_embed(x)
160 | x = self.bn_embed(x)
161 | x = self.drop_embed(x)
162 | return x
163 |
164 |
165 | class XTDNN_ILayer(nn.Module):
166 |
167 | def __init__(
168 | self,
169 | features_per_frame=30,
170 | dropout_p=0.0,
171 | batch_norm=True
172 | ):
173 | super().__init__()
174 | self.features_per_frame = features_per_frame
175 | self.dropout_p = dropout_p
176 | self.batch_norm = batch_norm
177 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm}
178 |
179 | self.frame1 = TDNN(input_dim=self.features_per_frame, output_dim=512, context_size=5, dilation=1, **tdnn_kwargs)
180 |
181 | def forward(self, x):
182 | x = self.frame1(x)
183 | return x
184 |
185 |
186 | class XTDNN_OLayer(nn.Module):
187 |
188 | def __init__(
189 | self,
190 | final_features=1500,
191 | embed_features=512,
192 | dropout_p=0.0,
193 | batch_norm=True
194 | ):
195 | super().__init__()
196 | self.final_features = final_features
197 | self.embed_features = embed_features
198 | self.dropout_p = dropout_p
199 | self.batch_norm = batch_norm
200 | tdnn_kwargs = {'dropout_p':dropout_p, 'batch_norm':self.batch_norm}
201 |
202 | self.frame2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=2, **tdnn_kwargs)
203 | self.frame3 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=3, **tdnn_kwargs)
204 | self.frame4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1, **tdnn_kwargs)
205 | self.frame5 = TDNN(input_dim=512, output_dim=self.final_features, context_size=1, dilation=1, **tdnn_kwargs)
206 |
207 | self.statspool = StatsPool()
208 |
209 | self.fc_embed = nn.Linear(self.final_features*2, self.embed_features)
210 | self.nl_embed = nn.LeakyReLU()
211 | self.bn_embed = nn.BatchNorm1d(self.embed_features)
212 | self.drop_embed = nn.Dropout(p=self.dropout_p)
213 |
214 | def forward(self, x):
215 | x = self.frame2(x)
216 | x = self.frame3(x)
217 | x = self.frame4(x)
218 | x = self.frame5(x)
219 | x = self.statspool(x)
220 | x = self.fc_embed(x)
221 | x = self.nl_embed(x)
222 | x = self.bn_embed(x)
223 | x = self.drop_embed(x)
224 | return x
225 |
226 |
227 |
228 | class DenseReLU(nn.Module):
229 |
230 | def __init__(self, in_dim, out_dim):
231 | super(DenseReLU, self).__init__()
232 | self.fc = nn.Linear(in_dim, out_dim)
233 | self.bn = nn.BatchNorm1d(out_dim)
234 | self.nl = nn.LeakyReLU()
235 |
236 | def forward(self, x):
237 | x = self.fc(x)
238 | x = self.nl(x)
239 | if len(x.shape) > 2:
240 | x = self.bn(x.transpose(1,2)).transpose(1,2)
241 | else:
242 | x = self.bn(x)
243 | return x
244 |
245 |
246 | class FTDNNLayer(nn.Module):
247 |
248 | def __init__(self, in_dim, out_dim, bottleneck_dim, context_size=2, dilations=None, paddings=None, alpha=0.0):
249 | '''
250 | 3 stage factorised TDNN http://danielpovey.com/files/2018_interspeech_tdnnf.pdf
251 | '''
252 | super(FTDNNLayer, self).__init__()
253 | paddings = [1,1,1] if not paddings else paddings
254 | dilations = [2,2,2] if not dilations else dilations
255 | kwargs = {'bias':False}
256 | self.factor1 = nn.Conv1d(in_dim, bottleneck_dim, context_size, padding=paddings[0], dilation=dilations[0], **kwargs)
257 | self.factor2 = nn.Conv1d(bottleneck_dim, bottleneck_dim, context_size, padding=paddings[1], dilation=dilations[1], **kwargs)
258 | self.factor3 = nn.Conv1d(bottleneck_dim, out_dim, context_size, padding=paddings[2], dilation=dilations[2], **kwargs)
259 | self.reset_parameters()
260 | self.nl = nn.ReLU()
261 | self.bn = nn.BatchNorm1d(out_dim)
262 | self.dropout = SharedDimScaleDropout(alpha=alpha, dim=1)
263 |
264 | def forward(self, x):
265 | ''' input (batch_size, seq_len, in_dim) '''
266 | assert (x.shape[-1] == self.factor1.weight.shape[1])
267 | x = self.factor1(x.transpose(1,2))
268 | x = self.factor2(x)
269 | x = self.factor3(x)
270 | x = self.nl(x)
271 | x = self.bn(x).transpose(1,2)
272 | x = self.dropout(x)
273 | return x
274 |
275 | def step_semi_orth(self):
276 | with torch.no_grad():
277 | factor1_M = self.get_semi_orth_weight(self.factor1)
278 | factor2_M = self.get_semi_orth_weight(self.factor2)
279 | self.factor1.weight.copy_(factor1_M)
280 | self.factor2.weight.copy_(factor2_M)
281 |
282 | def reset_parameters(self):
283 | # Standard dev of M init values is inverse of sqrt of num cols
284 | nn.init._no_grad_normal_(self.factor1.weight, 0., self.get_M_shape(self.factor1.weight)[1]**-0.5)
285 | nn.init._no_grad_normal_(self.factor2.weight, 0., self.get_M_shape(self.factor2.weight)[1]**-0.5)
286 |
287 | def orth_error(self):
288 | factor1_err = self.get_semi_orth_error(self.factor1).item()
289 | factor2_err = self.get_semi_orth_error(self.factor2).item()
290 | return factor1_err + factor2_err
291 |
292 | @staticmethod
293 | def get_semi_orth_weight(conv1dlayer):
294 | # updates conv1 weight M using update rule to make it more semi orthogonal
295 | # based off ConstrainOrthonormalInternal in nnet-utils.cc in Kaldi src/nnet3
296 | # includes the tweaks related to slowing the update speed
297 | # only an implementation of the 'floating scale' case
298 | with torch.no_grad():
299 | update_speed = 0.125
300 | orig_shape = conv1dlayer.weight.shape
301 | # a conv weight differs slightly from TDNN formulation:
302 | # Conv weight: (out_filters, in_filters, kernel_width)
303 | # TDNN weight M is of shape: (in_dim, out_dim) or [rows, cols]
304 | # the in_dim of the TDNN weight is equivalent to in_filters * kernel_width of the Conv
305 | M = conv1dlayer.weight.reshape(orig_shape[0], orig_shape[1]*orig_shape[2]).T
306 | # M now has shape (in_dim[rows], out_dim[cols])
307 | mshape = M.shape
308 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols
309 | M = M.T
310 | P = torch.mm(M, M.T)
311 | PP = torch.mm(P, P.T)
312 | trace_P = torch.trace(P)
313 | trace_PP = torch.trace(PP)
314 | ratio = trace_PP * P.shape[0] / (trace_P * trace_P)
315 |
316 | # the following is the tweak to avoid divergence (more info in Kaldi)
317 | assert ratio > 0.999
318 | if ratio > 1.02:
319 | update_speed *= 0.5
320 | if ratio > 1.1:
321 | update_speed *= 0.5
322 |
323 | scale2 = trace_PP/trace_P
324 | update = P - (torch.matrix_power(P, 0) * scale2)
325 | alpha = update_speed / scale2
326 | update = (-4.0 * alpha) * torch.mm(update, M)
327 | updated = M + update
328 | # updated has shape (cols, rows) if rows > cols, else has shape (rows, cols)
329 | # Transpose (or not) to shape (cols, rows) (IMPORTANT, s.t. correct dimensions are reshaped)
330 | # Then reshape to (cols, in_filters, kernel_width)
331 | return updated.reshape(*orig_shape) if mshape[0] > mshape[1] else updated.T.reshape(*orig_shape)
332 |
333 | @staticmethod
334 | def get_M_shape(conv_weight):
335 | orig_shape = conv_weight.shape
336 | return (orig_shape[1]*orig_shape[2], orig_shape[0])
337 |
338 | @staticmethod
339 | def get_semi_orth_error(conv1dlayer):
340 | with torch.no_grad():
341 | orig_shape = conv1dlayer.weight.shape
342 | M = conv1dlayer.weight.reshape(orig_shape[0], orig_shape[1]*orig_shape[2])
343 | mshape = M.shape
344 | if mshape[0] > mshape[1]: # semi orthogonal constraint for rows > cols
345 | M = M.T
346 | P = torch.mm(M, M.T)
347 | PP = torch.mm(P, P.T)
348 | trace_P = torch.trace(P)
349 | trace_PP = torch.trace(PP)
350 | ratio = trace_PP * P.shape[0] / (trace_P * trace_P)
351 | scale2 = torch.sqrt(trace_PP/trace_P) ** 2
352 | update = P - (torch.matrix_power(P, 0) * scale2)
353 | return torch.norm(update, p='fro')
354 |
355 |
356 | class FTDNN(nn.Module):
357 |
358 | def __init__(self, in_dim=30, embedding_dim=512):
359 | '''
360 | The FTDNN architecture from
361 | "State-of-the-art speaker recognition with neural network embeddings in
362 | NIST SRE18 and Speakers in the Wild evaluations"
363 | https://www.sciencedirect.com/science/article/pii/S0885230819302700
364 | '''
365 | super(FTDNN, self).__init__()
366 |
367 | self.layer01 = TDNN(input_dim=in_dim, output_dim=512, context_size=5, padding=2)
368 | self.layer02 = FTDNNLayer(512, 1024, 256, context_size=2, dilations=[2,2,2], paddings=[1,1,1])
369 | self.layer03 = FTDNNLayer(1024, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0])
370 | self.layer04 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1])
371 | self.layer05 = FTDNNLayer(2048, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0])
372 | self.layer06 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1])
373 | self.layer07 = FTDNNLayer(3072, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1])
374 | self.layer08 = FTDNNLayer(1024, 1024, 256, context_size=2, dilations=[3,3,2], paddings=[2,1,1])
375 | self.layer09 = FTDNNLayer(3072, 1024, 256, context_size=1, dilations=[1,1,1], paddings=[0,0,0])
376 | self.layer10 = DenseReLU(1024, 2048)
377 |
378 | self.layer11 = StatsPool()
379 |
380 | self.layer12 = DenseReLU(4096, embedding_dim)
381 |
382 |
383 | def forward(self, x):
384 | x = self.layer01(x)
385 | x_2 = self.layer02(x)
386 | x_3 = self.layer03(x_2)
387 | x_4 = self.layer04(x_3)
388 | skip_5 = torch.cat([x_4, x_3], dim=-1)
389 | x = self.layer05(skip_5)
390 | x_6 = self.layer06(x)
391 | skip_7 = torch.cat([x_6, x_4, x_2], dim=-1)
392 | x = self.layer07(skip_7)
393 | x_8 = self.layer08(x)
394 | skip_9 = torch.cat([x_8, x_6, x_4], dim=-1)
395 | x = self.layer09(skip_9)
396 | x = self.layer10(x)
397 | x = self.layer11(x)
398 | x = self.layer12(x)
399 | return x
400 |
401 | def step_ftdnn_layers(self):
402 | for layer in self.children():
403 | if isinstance(layer, FTDNNLayer):
404 | layer.step_semi_orth()
405 |
406 | def set_dropout_alpha(self, alpha):
407 | for layer in self.children():
408 | if isinstance(layer, FTDNNLayer):
409 | layer.dropout.alpha = alpha
410 |
411 | def get_orth_errors(self):
412 | errors = 0.
413 | with torch.no_grad():
414 | for layer in self.children():
415 | if isinstance(layer, FTDNNLayer):
416 | errors += layer.orth_error()
417 | return errors
418 |
419 |
420 | class SharedDimScaleDropout(nn.Module):
421 | def __init__(self, alpha: float = 0.5, dim=1):
422 | '''
423 | Continuous scaled dropout that is const over chosen dim (usually across time)
424 | Multiplies inputs by random mask taken from Uniform([1 - 2\alpha, 1 + 2\alpha])
425 | '''
426 | super(SharedDimScaleDropout, self).__init__()
427 | if alpha > 0.5 or alpha < 0:
428 | raise ValueError("alpha must be between 0 and 0.5")
429 | self.alpha = alpha
430 | self.dim = dim
431 | self.register_buffer('mask', torch.tensor(0.))
432 |
433 | def forward(self, X):
434 | if self.training:
435 | if self.alpha != 0.:
436 | # sample mask from uniform dist with dim of length 1 in self.dim and then repeat to match size
437 | tied_mask_shape = list(X.shape)
438 | tied_mask_shape[self.dim] = 1
439 | repeats = [1 if i != self.dim else X.shape[self.dim] for i in range(len(X.shape))]
440 | return X * self.mask.repeat(tied_mask_shape).uniform_(1 - 2*self.alpha, 1 + 2*self.alpha).repeat(repeats)
441 | # expected value of dropout mask is 1 so no need to scale outputs like vanilla dropout
442 | return X
443 |
--------------------------------------------------------------------------------
/models/sim_predictors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | class LSTMSimilarity(nn.Module):
7 |
8 | def __init__(self, input_size=256, hidden_size=256, num_layers=2):
9 | super(LSTMSimilarity, self).__init__()
10 | self.lstm = nn.LSTM(input_size,
11 | hidden_size,
12 | num_layers=num_layers,
13 | bidirectional=True,
14 | batch_first=True)
15 | self.fc1 = nn.Linear(hidden_size*2, 64)
16 | self.nl = nn.ReLU(inplace=True)
17 | self.fc2 = nn.Linear(64, 1)
18 |
19 | def forward(self, x):
20 | x, _ = self.lstm(x)
21 | x = self.fc1(x)
22 | x = self.nl(x)
23 | x = self.fc2(x).squeeze(2)
24 | return x
25 |
26 | class LSTMSimilarityCosRes(nn.Module):
27 |
28 | def __init__(self, input_size=256, hidden_size=256, num_layers=2):
29 | '''
30 | Like the LSTM Model but the LSTM only has to learn a modification to the cosine sim:
31 | y = LSTM(x) + pwise_cos_sim(x)
32 | '''
33 | super(LSTMSimilarityCosRes, self).__init__()
34 | self.lstm = nn.LSTM(input_size,
35 | hidden_size,
36 | num_layers=num_layers,
37 | bidirectional=True,
38 | batch_first=True)
39 | self.fc1 = nn.Linear(hidden_size*2, 64)
40 | self.nl = nn.ReLU(inplace=True)
41 | self.fc2 = nn.Linear(64, 1)
42 |
43 | self.pdistlayer = pCosineSiamese()
44 |
45 | def forward(self, x):
46 | cs = self.pdistlayer(x)
47 | x, _ = self.lstm(x)
48 | x = self.fc1(x)
49 | x = self.nl(x)
50 | x = self.fc2(x).squeeze(2)
51 | x += cs
52 | return x
53 |
54 | class LSTMSimilarityCosWS(nn.Module):
55 |
56 | def __init__(self, input_size=256, hidden_size=256, num_layers=2):
57 | '''
58 | Like the LSTM Model but the LSTM only has to learn a weighted sum of it and the cosine sim:
59 | y = A*LSTM(x) + B*pwise_cos_sim(x)
60 | '''
61 | super(LSTMSimilarityCosWS, self).__init__()
62 | self.lstm = nn.LSTM(input_size,
63 | hidden_size,
64 | num_layers=num_layers,
65 | bidirectional=True,
66 | batch_first=True)
67 | self.fc1 = nn.Linear(hidden_size*2, 64)
68 | self.nl = nn.ReLU()
69 | self.fc2 = nn.Linear(64, 1)
70 | self.weightsum = nn.Linear(2,1)
71 | self.pdistlayer = pCosineSiamese()
72 |
73 | def forward(self, x):
74 | cs = self.pdistlayer(x)
75 | x, _ = self.lstm(x)
76 | x = self.fc1(x)
77 | x = self.nl(x)
78 | x = torch.sigmoid(self.fc2(x).squeeze(2))
79 | x = torch.stack([x, cs], dim=-1)
80 | return self.weightsum(x).squeeze(-1)
81 |
82 |
83 | class pCosineSim(nn.Module):
84 |
85 | def __init__(self):
86 | super(pCosineSim, self).__init__()
87 |
88 | def forward(self, x):
89 | cs = []
90 | for j in range(x.shape[0]):
91 | cs_row = torch.clamp(torch.mm(x[j].unsqueeze(1).transpose(0,1), x.transpose(0,1)) / (torch.norm(x[j]) * torch.norm(x, dim=1)), 1e-6)
92 | cs.append(cs_row)
93 | return torch.cat(cs)
94 |
95 | class pbCosineSim(nn.Module):
96 |
97 | def __init__(self):
98 | super(pbCosineSim, self).__init__()
99 |
100 | def forward(self, x):
101 | '''
102 | Batch pairwise cosine similarity:
103 |
104 | input (batch_size, seq_len, d)
105 | output (batch_size, seq_len, seq_len)
106 | '''
107 | cs = []
108 | for j in range(x.shape[1]):
109 | cs_row = torch.clamp(torch.bmm(x[:, j, :].unsqueeze(1), x.transpose(1,2)) / (torch.norm(x[:, j, :], dim=-1).unsqueeze(1) * torch.norm(x, dim=-1)).unsqueeze(1), 1e-6)
110 | cs.append(cs_row)
111 | return torch.cat(cs, dim=1)
112 |
113 | class pCosineSiamese(nn.Module):
114 |
115 | def __init__(self):
116 | super(pCosineSiamese, self).__init__()
117 |
118 | def forward(self, x):
119 | '''
120 | split every element in last dimension/2 and take cosine distance
121 | '''
122 | x1, x2 = torch.split(x, x.shape[-1]//2, dim=-1)
123 | return F.cosine_similarity(x1, x2, dim=-1)
124 |
125 | def subsequent_mask(size):
126 | "Mask out subsequent positions."
127 | attn_shape = (1, size, size)
128 | mask = np.triu(np.ones(attn_shape), k=1)
129 | mask[mask==1.0] = float('-inf')
130 | return torch.FloatTensor(mask).squeeze(0)
131 |
132 | class TransformerSim(nn.Module):
133 |
134 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=2, dim_feedforward=1024):
135 | super(TransformerSim, self).__init__()
136 |
137 | self.tf = nn.Transformer(d_model=d_model,
138 | nhead=nhead,
139 | num_encoder_layers=num_encoder_layers,
140 | num_decoder_layers=num_encoder_layers,
141 | dim_feedforward=dim_feedforward)
142 | self.out_embed = nn.Embedding(3, d_model)
143 | self.generator = nn.Linear(d_model, 3)
144 |
145 | def forward(self, src, tgt, src_mask=None, tgt_mask=None):
146 | x = self.tf(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
147 | x = self.generator(x)
148 | return x
149 |
150 | def encode(self, src):
151 | x = self.tf.encoder(src)
152 | return x
153 |
154 | class XTransformerSim(nn.Module):
155 |
156 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=4, dim_feedforward=512):
157 | super(XTransformerSim, self).__init__()
158 |
159 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
160 | # self.pdistlayer = pCosineSiamese()
161 | self.fc1 = nn.Linear(d_model, 1)
162 | # self.weightsum = nn.Linear(2, 1)
163 |
164 | def forward(self, src):
165 | # cs = self.pdistlayer(src)
166 | x = self.tf(src)
167 | x = self.fc1(x).squeeze(-1)
168 | # x = torch.stack([x, cs], dim=-1)
169 | # x = self.weightsum(x).squeeze(-1)
170 | return x
171 |
172 |
173 | class XTransformerLSTMSim(nn.Module):
174 |
175 | def __init__(self, d_model=256, nhead=4, num_encoder_layers=2, dim_feedforward=1024):
176 | super(XTransformerLSTMSim, self).__init__()
177 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
178 | self.lstm = nn.LSTM(d_model,
179 | d_model,
180 | num_layers=2,
181 | bidirectional=True)
182 | self.fc1 = nn.Linear(d_model*2, 64)
183 | self.fc2 = nn.Linear(64, 1)
184 |
185 | def forward(self, src):
186 | out = self.tf(src)
187 | out, _ = self.lstm(out)
188 | out = F.relu(self.fc1(out))
189 | out = self.fc2(out)
190 | return out
191 |
192 | class AttnDecoderRNN(nn.Module):
193 | def __init__(self, hidden_size=256, output_size=1, dropout_p=0.1, max_length=250):
194 | super(AttnDecoderRNN, self).__init__()
195 | self.hidden_size = hidden_size
196 | self.output_size = output_size
197 | self.dropout_p = dropout_p
198 | self.max_length = max_length
199 |
200 | # self.embedding = nn.Embedding(self.output_size, self.hidden_size)
201 | self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
202 | self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
203 | self.dropout = nn.Dropout(self.dropout_p)
204 | self.gru = nn.GRU(self.hidden_size, self.hidden_size)
205 | self.out = nn.Linear(self.hidden_size, self.output_size)
206 |
207 | def forward(self, input, hidden, encoder_outputs):
208 | embedded = self.embedding(input).view(1, 1, -1)
209 | embedded = self.dropout(embedded)
210 |
211 | attn_weights = F.softmax(
212 | self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
213 | attn_applied = torch.bmm(attn_weights.unsqueeze(0),
214 | encoder_outputs.unsqueeze(0))
215 |
216 | output = torch.cat((embedded[0], attn_applied[0]), 1)
217 | output = self.attn_combine(output).unsqueeze(0)
218 |
219 | output = F.relu(output)
220 | output, hidden = self.gru(output, hidden)
221 | output = self.out(output[0])
222 | return output, hidden, attn_weights
223 |
224 | def initHidden(self, device):
225 | return torch.zeros(1, 1, self.hidden_size, device=device)
226 |
227 |
228 | class XTransformer(nn.Module):
229 |
230 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024):
231 | super(XTransformer, self).__init__()
232 |
233 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
234 | # self.fc1 = nn.Linear(d_model, d_model)
235 | # self.nl = nn.ReLU(inplace=True)
236 | # self.fc2 = nn.Linear(d_model, d_model)
237 |
238 | self.pdist = pCosineSim()
239 |
240 | def forward(self, src):
241 | x = self.tf(src)
242 | x = x.squeeze(1)
243 | # x = self.fc1(x)
244 | # x = self.nl(x)
245 | # x = self.fc2(x)
246 | x = F.normalize(x, p=2, dim=-1)
247 | sim = self.pdist(x)
248 | return sim
249 |
250 |
251 | class XTransformerRes(nn.Module):
252 |
253 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024):
254 | super(XTransformerRes, self).__init__()
255 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
256 | self.pdist = pCosineSim()
257 |
258 | def forward(self, src):
259 | cs = self.pdist(src.squeeze(1))
260 | x = self.tf(src)
261 | x = x.squeeze(1)
262 | x = F.normalize(x, p=2, dim=-1)
263 | sim = self.pdist(x)
264 | sim += cs
265 | return torch.clamp(sim/2, 1e-16, 1.-1e-16)
266 |
267 |
268 | class XTransformerMask(nn.Module):
269 |
270 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024):
271 |
272 | super(XTransformerMask, self).__init__()
273 |
274 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
275 |
276 | self.pdist = pCosineSim()
277 |
278 | def forward(self, src):
279 | mask = self.tf(src)
280 | mask = mask.squeeze(1)
281 | mask = torch.sigmoid(mask)
282 | out = F.normalize(mask * src.squeeze(1), p=2, dim=-1)
283 | sim = self.pdist(out)
284 | return sim
285 |
286 |
287 | class XTransformerMaskRes(nn.Module):
288 |
289 | def __init__(self, d_model=128, nhead=8, num_encoder_layers=6, dim_feedforward=1024):
290 |
291 | super(XTransformerMaskRes, self).__init__()
292 | self.tf = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward), num_encoder_layers)
293 | self.pdist = pbCosineSim()
294 |
295 | def forward(self, src, src_mask=None):
296 | cs = self.pdist(src.transpose(0, 1))
297 | mask = self.tf(src, src_key_padding_mask=src_mask)
298 | mask = torch.sigmoid(mask)
299 | out = F.normalize(mask * src, p=2, dim=-1)
300 | sim = self.pdist(out.transpose(0, 1))
301 | sim += cs
302 | return torch.clamp(sim/2, 1e-1, 1.-1e-1)
303 |
304 |
305 |
306 | class LSTMTransform(nn.Module):
307 |
308 | def __init__(self, input_size=128, hidden_size=256, num_layers=2):
309 | super(LSTMTransform, self).__init__()
310 | self.lstm = nn.LSTM(input_size,
311 | hidden_size,
312 | num_layers=num_layers,
313 | bidirectional=True,
314 | batch_first=True)
315 |
316 | self.fc1 = nn.Linear(512, 256)
317 | self.nl = nn.ReLU(inplace=True)
318 | self.fc2 = nn.Linear(256, 256)
319 |
320 | self.pdist = pCosineSim()
321 |
322 | def forward(self, x):
323 | x, _ = self.lstm(x)
324 | x = x.squeeze(0)
325 | x = self.fc1(x)
326 | x = self.nl(x)
327 | x = self.fc2(x)
328 | sim = self.pdist(x)
329 | return 1. - sim
330 |
331 |
332 | class ConvSim(nn.Module):
333 |
334 | def __init__(self, input_dim=256):
335 | super(ConvSim, self).__init__()
336 | self.input_dim = input_dim
337 | self.layer1 = nn.Sequential(
338 | nn.Conv1d(input_dim, 32, kernel_size=3, stride=1, padding=1),
339 | nn.ReLU(),
340 | nn.BatchNorm1d(32))
341 | self.layer2 = nn.Sequential(
342 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=1),
343 | nn.ReLU(),
344 | nn.BatchNorm1d(32))
345 | self.layer3 = nn.Sequential(
346 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=2),
347 | nn.ReLU(),
348 | nn.BatchNorm1d(32))
349 | self.layer4 = nn.Sequential(
350 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3),
351 | nn.ReLU(),
352 | nn.BatchNorm1d(32))
353 | self.layer5 = nn.Sequential(
354 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3),
355 | nn.ReLU(),
356 | nn.BatchNorm1d(32))
357 | self.layer6 = nn.Sequential(
358 | nn.Conv1d(32, 1, kernel_size=1, stride=1, padding=1),
359 | nn.ReLU(),
360 | nn.BatchNorm1d(1))
361 |
362 | def forward(self, x):
363 | if x.shape[-1] == self.input_dim:
364 | x = x.permute(0,2,1)
365 | x = self.layer1(x)
366 | x = self.layer2(x)
367 | x = self.layer3(x)
368 | x = self.layer4(x)
369 | x = self.layer5(x)
370 | x = self.layer6(x)
371 | return x.squeeze(1)
372 |
373 | class ConvCosResSim(nn.Module):
374 |
375 | def __init__(self, input_dim=256):
376 | super(ConvCosResSim, self).__init__()
377 | self.pdistlayer = pCosineSiamese()
378 | self.input_dim = input_dim
379 | # self.drop1 = nn.Dropout()
380 | # self.drop2 = nn.Dropout()
381 | # self.drop3 = nn.Dropout()
382 | # self.drop4 = nn.Dropout()
383 | # self.drop5 = nn.Dropout()
384 | self.layer1 = nn.Sequential(
385 | nn.Conv1d(input_dim, 32, kernel_size=3, stride=1, padding=1),
386 | nn.ReLU(),
387 | nn.BatchNorm1d(32))
388 | self.layer2 = nn.Sequential(
389 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=1),
390 | nn.ReLU(),
391 | nn.BatchNorm1d(32))
392 | self.layer3 = nn.Sequential(
393 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, dilation=2),
394 | nn.ReLU(),
395 | nn.BatchNorm1d(32))
396 | self.layer4 = nn.Sequential(
397 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3),
398 | nn.ReLU(),
399 | nn.BatchNorm1d(32))
400 | self.layer5 = nn.Sequential(
401 | nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=3, dilation=3),
402 | nn.ReLU(),
403 | nn.BatchNorm1d(32))
404 | self.layer6 = nn.Sequential(
405 | nn.Conv1d(32, 1, kernel_size=1, stride=1, padding=1),
406 | nn.ReLU(),
407 | nn.BatchNorm1d(1))
408 |
409 | def forward(self, x):
410 | cs = self.pdistlayer(x)
411 | if x.shape[-1] == self.input_dim:
412 | x = x.permute(0,2,1)
413 | x = self.layer1(x)
414 | # x = self.drop1(x)
415 | x = self.layer2(x)
416 | # x = self.drop2(x)
417 | x = self.layer3(x)
418 | # x = self.drop3(x)
419 | x = self.layer4(x)
420 | # x = self.drop4(x)
421 | x = self.layer5(x)
422 | # x = self.drop5(x)
423 | x = self.layer6(x).squeeze(1)
424 | x += cs
425 | return x
426 |
427 | def set_dropout_p(self, p):
428 | for layer in self.children():
429 | if isinstance(layer, nn.Dropout):
430 | layer.p = p
--------------------------------------------------------------------------------
/scotus_data_prep/filter_scp.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env perl
2 | # Copyright 2010-2012 Microsoft Corporation
3 | # Johns Hopkins University (author: Daniel Povey)
4 |
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
12 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
13 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
14 | # MERCHANTABLITY OR NON-INFRINGEMENT.
15 | # See the Apache 2 License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | # This script takes a list of utterance-ids or any file whose first field
20 | # of each line is an utterance-id, and filters an scp
21 | # file (or any file whose "n-th" field is an utterance id), printing
22 | # out only those lines whose "n-th" field is in id_list. The index of
23 | # the "n-th" field is 1, by default, but can be changed by using
24 | # the -f switch
25 |
26 | $exclude = 0;
27 | $field = 1;
28 | $shifted = 0;
29 |
30 | do {
31 | $shifted=0;
32 | if ($ARGV[0] eq "--exclude") {
33 | $exclude = 1;
34 | shift @ARGV;
35 | $shifted=1;
36 | }
37 | if ($ARGV[0] eq "-f") {
38 | $field = $ARGV[1];
39 | shift @ARGV; shift @ARGV;
40 | $shifted=1
41 | }
42 | } while ($shifted);
43 |
44 | if(@ARGV < 1 || @ARGV > 2) {
45 | die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" .
46 | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
47 | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
48 | "only the lines that were *not* in id_list.\n" .
49 | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
50 | "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
51 | "-f option, add 1 to the argument.\n" .
52 | "See also: utils/filter_scp.pl .\n";
53 | }
54 |
55 |
56 | $idlist = shift @ARGV;
57 | open(F, "<$idlist") || die "Could not open id-list file $idlist";
58 | while() {
59 | @A = split;
60 | @A>=1 || die "Invalid id-list file line $_";
61 | $seen{$A[0]} = 1;
62 | }
63 |
64 | if ($field == 1) { # Treat this as special case, since it is common.
65 | while(<>) {
66 | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
67 | # $1 is what we filter on.
68 | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
69 | print $_;
70 | }
71 | }
72 | } else {
73 | while(<>) {
74 | @A = split;
75 | @A > 0 || die "Invalid scp file line $_";
76 | @A >= $field || die "Invalid scp file line $_";
77 | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
78 | print $_;
79 | }
80 | }
81 | }
82 |
83 | # tests:
84 | # the following should print "foo 1"
85 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
86 | # the following should print "bar 2".
87 | # ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
88 |
--------------------------------------------------------------------------------
/scotus_data_prep/local_info.py:
--------------------------------------------------------------------------------
1 | AVVO_API_ACCESS_TOKEN=''
2 | AVVO_CSE_ID=''
3 | GOOGLE_API_KEY=''
4 |
--------------------------------------------------------------------------------
/scotus_data_prep/step1_downloadmp3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import ssl
5 | import sys
6 | import urllib
7 | from glob import glob
8 |
9 | import numpy as np
10 | import requests
11 | from tqdm import tqdm
12 |
13 | ssl._create_default_https_context = ssl._create_unverified_context
14 | from collections import OrderedDict
15 | from multiprocessing import Pool
16 |
17 | import dateutil.parser as dparser
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Download mp3s and process the case jsons into something usable')
22 | parser.add_argument('--case-folder', type=str, help='Location of the case json files')
23 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder')
24 | args = parser.parse_args()
25 | return args
26 |
27 | def get_mp3_and_transcript(c):
28 | '''
29 | Get the mp3 from case jsons
30 | '''
31 | js = json.load(open(c, encoding='utf-8'), object_pairs_hook=OrderedDict)
32 |
33 | rec_name = os.path.splitext(os.path.basename(c))[0]
34 | case_year = int(rec_name[:4])
35 |
36 | # Only want digital recordings, after October 2005
37 | if case_year <= 2004:
38 | return rec_name, False, False, None
39 |
40 | cutoff_date = dparser.parse('1 October 2005', fuzzy=True)
41 |
42 | if 'oral_argument_audio' in js:
43 | if js['oral_argument_audio']:
44 | if 'href' in js['oral_argument_audio']:
45 | t_url = js['oral_argument_audio']['href']
46 | resp = requests.get(t_url, timeout=20)
47 | if resp.ok:
48 | js = json.loads(resp.content, object_pairs_hook=OrderedDict)
49 | else:
50 | return rec_name, False, False, None
51 |
52 | elif len(js['oral_argument_audio']) == 1:
53 | if 'href' in js['oral_argument_audio'][0]:
54 | t_url = js['oral_argument_audio'][0]['href']
55 | resp = requests.get(t_url, timeout=20)
56 | if resp.ok:
57 | js = json.loads(resp.content, object_pairs_hook=OrderedDict)
58 | else:
59 | return rec_name, False, False, None
60 |
61 | if 'title' in js:
62 | dashind = js['title'].rfind('-')
63 | reduced_title = js['title'][dashind + 1:]
64 | start = reduced_title.find('(')
65 | end = reduced_title.find(')')
66 | if start != -1 and end != -1:
67 | reduced_title = reduced_title[:start - 1]
68 | try:
69 | case_date = dparser.parse(reduced_title, fuzzy=True)
70 | if case_date < cutoff_date:
71 | # print('Recording was too early')
72 | return rec_name, False, False, None
73 | except ValueError:
74 | # print('Couldnt figure out date {}'.format(reduced_title))
75 | return rec_name, False, js, None
76 |
77 | if 'media_file' in js:
78 | if 'transcript' in js:
79 | if js['media_file'] and js['transcript']:
80 | for obj in js['media_file']:
81 | if obj:
82 | if 'href' in obj:
83 | url = obj['href']
84 | if url.endswith('mp3'):
85 | return rec_name, url, js, case_date
86 |
87 | return rec_name, False, js, None
88 |
89 |
90 | def download_mp3(rec_name, url, outfile):
91 | filename = outfile
92 | if os.path.isfile(filename):
93 | return True
94 | try:
95 | request = urllib.request.urlopen(url, timeout=30)
96 | with open(filename+'_tmp', 'wb') as f:
97 | f.write(request.read())
98 | os.rename(filename+'_tmp', filename)
99 | return True
100 | except:
101 | os.remove(filename+'_tmp')
102 | return False
103 |
104 |
105 | def filter_and_process_case(js,
106 | valid_to_invalid_ratio=4,
107 | min_num_utterances=5):
108 | '''
109 | This goes through and removes invalid utterances from a case json
110 | If the invalid utterances outnumber the valid ones by a ratio greater than
111 | valid_to_invalid_ratio, then False is returned
112 | else, the transcription is returned
113 | '''
114 | utts = []
115 | utts_spkr = []
116 | rec_speakers = {}
117 |
118 | invalid_utts = 0
119 | # Iterate through each speaker turn
120 | for sec in js['transcript']['sections']:
121 | for turn in sec['turns']:
122 | if not turn['speaker']:
123 | # ignore turns where no speaker is labelled
124 | invalid_utts += 1
125 | continue
126 |
127 | speaker = turn['speaker']
128 | speaker_id = speaker['identifier']
129 |
130 | if not speaker_id.strip():
131 | # ignore turns where no speaker is labelled
132 | invalid_utts += 1
133 | continue
134 |
135 | if speaker_id not in rec_speakers:
136 | rec_speakers[speaker_id] = speaker
137 |
138 | for utt in turn['text_blocks']:
139 | utt['speaker_id'] = speaker_id
140 | utt['start'] = float(utt['start'])
141 | utt['stop'] = float(utt['stop'])
142 |
143 | if utt['start'] >= utt['stop']:
144 | invalid_utts += 1
145 | continue
146 | else:
147 | utts.append(utt)
148 |
149 | transcription = {'utts': utts, 'rec_speakers': rec_speakers}
150 |
151 | if len(utts) >= min_num_utterances:
152 | if invalid_utts > 0:
153 | if len(utts)/invalid_utts >= valid_to_invalid_ratio:
154 | return transcription
155 | else:
156 | return False
157 | else:
158 | return transcription
159 | else:
160 | return False
161 |
162 | def process_case(c):
163 | '''
164 | Processes from the case json
165 | '''
166 | rec_name, url, js, case_date = get_mp3_and_transcript(c)
167 | if url: # If mp3 was found
168 | transcription = filter_and_process_case(js)
169 | if transcription: # If transcription is valid
170 | mp3_outfile = os.path.join(base_audio_folder, '{}.mp3'.format(rec_name))
171 | transcript_outfile = os.path.join(base_transcript_folder, '{}.json'.format(rec_name))
172 |
173 | transcription['case_date'] = case_date.strftime('%Y/%m/%d')
174 | download_success = download_mp3(rec_name, url, mp3_outfile)
175 |
176 | if download_success:
177 | with open(transcript_outfile, 'w', encoding='utf-8') as outfile:
178 | json.dump(transcription, outfile)
179 |
180 |
181 | def collate_speakers():
182 | '''
183 | Goes inside base_transcript_folder and collates all the speakers
184 | Writes to a dictionary of all speakers:
185 |
186 | speaker_id_dict = {speaker_id_key:{DICT}, etc.}
187 |
188 | This will be useful later on for parsing their full names, instead of the speakeaker_id_key
189 | '''
190 | transcripts = glob(os.path.join(base_transcript_folder, '*.json'))
191 | assert len(transcripts) > 1, 'Could only find 1 or less transcription json files'
192 | all_speaker_dict = {}
193 |
194 | for t in transcripts:
195 | js = json.load(open(t, encoding='utf-8'), object_pairs_hook=OrderedDict)
196 | for s in js['rec_speakers']:
197 | if s not in all_speaker_dict:
198 | all_speaker_dict[s] = js['rec_speakers'][s]
199 |
200 | outfile = os.path.join(base_outfolder, 'speaker_ids.json')
201 | with open(outfile, 'w', encoding='utf-8') as outfile:
202 | json.dump(all_speaker_dict, outfile)
203 |
204 |
205 | def mp3_and_transcript_exist(c):
206 | mp3_exist = os.path.isfile(os.path.join(base_audio_folder, '{}.mp3'.format(os.path.splitext(os.path.basename(c))[0])))
207 | transcript_exist = os.path.isfile(os.path.join(base_transcript_folder, '{}.json'.format(os.path.splitext(os.path.basename(c))[0])))
208 | if mp3_exist and transcript_exist:
209 | return True
210 | else:
211 | False
212 |
213 |
214 | if __name__ == '__main__':
215 | args = parse_args()
216 | case_folder = args.case_folder
217 | base_outfolder = args.base_outfolder
218 | os.makedirs(base_outfolder, exist_ok=True)
219 |
220 | base_audio_folder = os.path.join(base_outfolder, 'audio')
221 | os.makedirs(base_audio_folder, exist_ok=True)
222 |
223 | base_transcript_folder = os.path.join(base_outfolder, 'transcripts')
224 | os.makedirs(base_transcript_folder, exist_ok=True)
225 |
226 | cases = glob(os.path.join(case_folder, '20*.json'))
227 | print('{} cases found'.format(len(cases)))
228 |
229 | assert len(cases) > 1, "Could only find 1 or less case json files"
230 |
231 | trimmed_cases = [c for c in cases if not mp3_and_transcript_exist(c)]
232 |
233 | print('Processing {} cases...'.format(len(trimmed_cases)))
234 |
235 | with Pool(20) as p:
236 | for _ in tqdm(p.imap_unordered(process_case, trimmed_cases), total=len(trimmed_cases)):
237 | pass
238 |
239 | collate_speakers()
240 |
241 |
--------------------------------------------------------------------------------
/scotus_data_prep/step2_scrape_dob.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import pickle
6 | import re
7 | import ssl
8 | import string
9 | import sys
10 | import time
11 | import warnings
12 | from collections import OrderedDict
13 |
14 | import numpy as np
15 | import requests
16 | import wikipedia
17 | import wptools
18 | from bs4 import BeautifulSoup
19 | from dateutil.relativedelta import relativedelta
20 | from Levenshtein import distance as lev_dist
21 | from local_info import (AVVO_API_ACCESS_TOKEN, AVVO_CSE_ID, GOOGLE_API_KEY)
22 | from tqdm import tqdm
23 |
24 | ssl._create_default_https_context = ssl._create_unverified_context
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser(description='Reads through speaker_ids.json and tries to webscrape the DOB of each lawyer')
29 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder')
30 | parser.add_argument('--overwrite', action='store_true', default=False, help='Overwrite dob pickle (default: False)')
31 | parser.add_argument('--skip-attempted', action='store_true', default=False, help='Skip names which have been attempted before')
32 | args = parser.parse_args()
33 | return args
34 |
35 | class LawyerDOBParser:
36 |
37 | def __init__(self, graduation_dob_offset=25, distance_threshold=4, minimum_age=18):
38 | self.graduation_dob_offset = graduation_dob_offset
39 | self.distance_threshold = distance_threshold
40 | self.scotus_clerks_populated = False
41 | self.minimum_age = minimum_age
42 | self.minimum_dob = datetime.datetime(2005, 10, 1) - relativedelta(years=minimum_age)
43 |
44 | self.google_api_key = GOOGLE_API_KEY
45 | self.avvo_cse_id = AVVO_CSE_ID
46 |
47 | def parse_name(self, name):
48 | '''
49 | Parse name looking at wikipedia, then SCOTUS clerks, then the JUSTIA website
50 |
51 | Input: name
52 | Output: datetime object for D.O.B
53 | '''
54 | if not self.scotus_clerks_populated:
55 | self.get_scotus_clerks()
56 |
57 | print('Searching for DOB of {}....'.format(name))
58 | # Search wikipedia for person
59 | wiki_dob, wiki_info = self.parse_wiki(name)
60 | if wiki_dob:
61 | if wiki_dob <= self.minimum_dob:
62 | return wiki_dob, wiki_info
63 |
64 | # Search through supreme court clerks for person
65 | scotus_dob, scotus_info = self.search_scotus_clerks(name)
66 | if scotus_dob:
67 | scotus_dob = datetime.datetime(scotus_dob, 7, 2)
68 | if scotus_dob <= self.minimum_dob:
69 | return scotus_dob, scotus_info
70 |
71 | # Search through JUSTIA website
72 | justia_dob, justia_info = self.parse_justia(name)
73 | if justia_dob:
74 | justia_dob = datetime.datetime(justia_dob, 7, 2)
75 | if justia_dob <= self.minimum_dob:
76 | return justia_dob, justia_info
77 |
78 | # Search through AVVO website
79 | avvo_dob, avvo_info = self.parse_avvo(name)
80 | if avvo_dob:
81 | time.sleep(1)
82 | avvo_dob = datetime.datetime(avvo_dob, 7, 2)
83 | if avvo_dob <= self.minimum_dob:
84 | return avvo_dob, avvo_info
85 |
86 | print("Couldn't find any age for {}".format(name))
87 | info_list = [wiki_info, scotus_info, justia_info, avvo_info]
88 | collated_info = {'info': {'type': None, 'error': 'no info found', 'collated_info': info_list}}
89 | return None, collated_info
90 |
91 | def parse_wiki(self, name):
92 | search = wikipedia.search(name)
93 | if search:
94 | if self.name_distance(name, search[0]) <= self.distance_threshold:
95 | name = search[0]
96 |
97 | wpage = wptools.page(name, silent=True)
98 | info = {'info': {'type': 'wiki', 'error': None, 'name': name}}
99 | try:
100 | page = wpage.get_parse()
101 | except:
102 | info['info']['error'] = 'page not found'
103 | return None, info
104 | try:
105 | if page.data:
106 | if 'infobox' in page.data:
107 | if 'birth_date' in page.data['infobox']:
108 | dob = page.data['infobox']['birth_date'].strip('{}').split('|')
109 | dinfo = []
110 | for d in dob:
111 | try:
112 | dinfo.append(int(d))
113 | except:
114 | continue
115 | if dinfo:
116 | if len(dinfo) > 3:
117 | dinfo = dinfo[-3:]
118 | if dinfo[0] > 1900: # simple check if 4-digit year recognised
119 | prelim_date = [1, 1, 1]
120 | for i, d in enumerate(dinfo):
121 | prelim_date[i] = d
122 | dob = datetime.datetime(*prelim_date)
123 | info['info']['links'] = page.data['iwlinks']
124 | return dob, info
125 | info['info']['error'] = 'page couldnt be parsed'
126 | return None, info
127 | except:
128 | info['info']['error'] = 'page couldnt be parsed'
129 | return None, info
130 |
131 | def parse_justia(self, name):
132 | searched_name, distance, justia_url = self.search_justia(name)
133 | info = {'info': {'type': 'justia', 'searched_name': searched_name, 'justia_url': justia_url, 'error': None}}
134 | if distance <= self.distance_threshold:
135 | grad_year = self.parse_justia_lawyer(justia_url)
136 | if grad_year:
137 | return grad_year - self.graduation_dob_offset, info
138 | else:
139 | info['info']['error'] = 'no year found'
140 | return None, info
141 | else:
142 | info['info']['error'] = 'distance threshold not met'
143 | return None, info
144 |
145 | def search_justia(self, name):
146 | """
147 | Input: Name to search, i.e. Anthony A. Yang (str,)
148 | Output: Matched name, Levenshtein distance to input, JUSTIA url
149 | """
150 | base_search_url = 'https://lawyers.justia.com/search?profile-id-field=&practice-id-field=&query={}&location='
151 | base_name = name.translate(str.maketrans('', '', string.punctuation)).lower()
152 | name_query = '+'.join(base_name.split())
153 | search_url = base_search_url.format(name_query)
154 |
155 | search_url = base_search_url.format(name_query)
156 | search_request = requests.get(search_url)
157 | soup = BeautifulSoup(search_request.content, 'lxml')
158 | lawyer_avatars = soup.findAll('a', attrs={'class': 'lawyer-avatar'})
159 |
160 | if lawyer_avatars:
161 | search_names = []
162 | search_urls = []
163 |
164 | for a in lawyer_avatars:
165 | search_names.append(a['title'])
166 | search_urls.append(a['href'])
167 |
168 | search_names = np.array(search_names)
169 | search_names_base = [n.translate(str.maketrans('', '', string.punctuation)).lower() for n in search_names]
170 |
171 | distances = np.array([self.name_distance(name, n) for n in search_names])
172 | search_urls = np.array(search_urls)
173 |
174 | dist_order = np.argsort(distances)
175 | distances = distances[dist_order]
176 | search_urls = search_urls[dist_order]
177 | search_names = search_names[dist_order]
178 |
179 | return search_names[0], distances[0], search_urls[0]
180 | else:
181 | return 'None', 100000, 'None'
182 |
183 | @staticmethod
184 | def parse_justia_lawyer(lawyer_url):
185 | """
186 | Input: Justia lawyer page url
187 | Output: Graduation year
188 | """
189 | r = requests.get(lawyer_url)
190 | soup = BeautifulSoup(r.content, 'lxml')
191 |
192 | jurisdictions = soup.find('div', attrs={'id': 'jurisdictions-block'})
193 | if jurisdictions:
194 | jd_admitted_year = []
195 | for j in jurisdictions:
196 | try:
197 | jd_admitted_year.append(int(j.find('time')['datetime']))
198 | except:
199 | continue
200 | if jd_admitted_year:
201 | return min(jd_admitted_year)
202 | else:
203 | # look for professional associations if jurisdictions is emtpy
204 | prof_assoc = None
205 | education = None
206 | blocks = soup.findAll('div', attrs={'class': 'block'})
207 | for block in blocks:
208 | subdivs = block.findAll('div')
209 | for subdiv in subdivs:
210 | if subdiv.text == 'Professional Associations':
211 | prof_assoc = block
212 | break
213 | if subdiv.text == 'Education':
214 | education = block
215 | break
216 |
217 | if prof_assoc:
218 | prof_assoc_year = []
219 | professional_associations = prof_assoc.findAll('time')
220 | for p in professional_associations:
221 | try:
222 | prof_assoc_year.append(int(p['datetime']))
223 | except:
224 | continue
225 | if prof_assoc_year:
226 | return min(prof_assoc_year)
227 |
228 | if education:
229 | education_years = []
230 | education_history = education.findAll('dl')
231 | for e in education_history:
232 | degree_type = e.find('dd').text
233 | if degree_type.strip().translate(str.maketrans('', '', string.punctuation)).lower() == 'jd':
234 | try:
235 | return int(e.find('time')['datetime'])
236 | except:
237 | continue
238 |
239 | def search_scotus_clerks(self, query_name):
240 | assert self.clerk_dob_dict, 'get_scotus_clerks must be called before this function'
241 | # query_name = query_name.translate(str.maketrans('', '', string.punctuation)).lower()
242 | distances = np.array([self.name_distance(query_name, k) for k in self.scotus_clerks])
243 | closest_match = np.argmin(distances)
244 | info = {'info': {'type': 'clerk', 'closest_match': closest_match, 'error': None}}
245 | if distances[closest_match] <= self.distance_threshold:
246 | return self.clerk_dob_dict[self.scotus_clerks[closest_match]], info
247 | else:
248 | info['info']['error'] = 'distance threshold not met'
249 | return None, info
250 |
251 | def get_scotus_clerks(self):
252 | """
253 | Populates self.clerk_dob_dict with dates of birth for SCOTUS clerks
254 | """
255 | base_url = 'https://en.wikipedia.org/wiki/List_of_law_clerks_of_the_Supreme_Court_of_the_United_States_({})'
256 | seats = ['Chief_Justice', 'Seat_1', 'Seat_2',
257 | 'Seat_3', 'Seat_4', 'Seat_6', 'Seat_8',
258 | 'Seat_9', 'Seat_10']
259 | urls = [base_url.format(s) for s in seats]
260 |
261 | self.all_cdicts = []
262 | self.clerk_dob_dict = OrderedDict({})
263 |
264 | for url in urls:
265 | mini_clerk_dict = self.parse_clerk_wiki(url)
266 | self.all_cdicts.append(mini_clerk_dict)
267 |
268 | for cdict in self.all_cdicts:
269 | self.clerk_dob_dict = {**self.clerk_dob_dict, **cdict}
270 |
271 | self.scotus_clerks = np.array(list(self.clerk_dob_dict.keys()))
272 | self.scotus_clerks_populated = True
273 |
274 | def parse_clerk_wiki(self, url):
275 | r = requests.get(url)
276 | soup = BeautifulSoup(r.content, 'lxml')
277 | tables = soup.findAll('table', attrs={'class': 'wikitable'})
278 | clerk_dict = {}
279 | for table in tables:
280 | for tr in table.findAll('tr'):
281 | row_entries = tr.findAll('td')
282 | if len(row_entries) != 5:
283 | continue
284 | else:
285 | name = row_entries[0].text
286 | u = row_entries[3].text
287 | year_candidates = re.findall(r'\d{4}', u)
288 |
289 | if year_candidates:
290 | year = int(year_candidates[0])
291 | else:
292 | continue
293 |
294 | cleaned_name = re.sub(r'\([^)]*\)', '', name)
295 | cleaned_name = re.sub(r'\[[^)]*\]', '', cleaned_name).strip()
296 | clerk_dict[cleaned_name] = year - self.graduation_dob_offset
297 |
298 | return clerk_dict
299 |
300 | def parse_avvo(self, name):
301 | avvo_ids, response = self.get_avvo_ids_google(name)
302 | info = {'info': {'type': 'avvo', 'google_response': response, 'error': None}}
303 | if avvo_ids:
304 | for aid in avvo_ids:
305 | avvo_resp = self.get_from_avvo_api(aid)
306 | dob_estimate = self.parse_avvo_api_response(avvo_resp, name)
307 | if dob_estimate:
308 | info['info']['avvo_id'] = aid
309 | return dob_estimate, info
310 | else:
311 | continue
312 | info['info']['error'] = 'no avvo ids yielded a dob estimate'
313 | info['info']['avvo_ids_attempted'] = avvo_ids
314 | return None, info
315 | else:
316 | info['info']['error'] = 'no links found.. check response'
317 | return None, info
318 |
319 | def parse_avvo_api_response(self, r, name):
320 | # r: response as dict
321 | if r['lawyers']:
322 | lawyer = r['lawyers'][0]
323 | licensed_since = lawyer['licensed_since']
324 | if licensed_since:
325 | lawyer_name = '{} {} {}'.format(lawyer['firstname'], lawyer['middlename'], lawyer['lastname'])
326 | if self.name_distance(name, lawyer_name) <= self.distance_threshold:
327 | return licensed_since - self.graduation_dob_offset
328 |
329 | def search_google_avvo(self, query):
330 | r = requests.get('https://www.googleapis.com/customsearch/v1/siterestrict?key={}&cx={}&num=3&q="{}"'.format(
331 | self.google_api_key, self.avvo_cse_id, query))
332 | r = json.loads(r.content)
333 | if 'items' in r:
334 | links = [l['link'] for l in r['items']]
335 | return links, r
336 | else:
337 | return None, r
338 |
339 | def get_avvo_ids_google(self, name):
340 | links, response = self.search_google_avvo(name)
341 | if links:
342 | avvo_ids = [self.get_avvo_id_from_link(l) for l in links]
343 | avvo_ids = [i for i in avvo_ids if i]
344 | return avvo_ids, response
345 | else:
346 | return None, response
347 |
348 | @staticmethod
349 | def get_avvo_id_from_link(link):
350 | if link.startswith('https://www.avvo.com/attorneys/'):
351 | page_path = os.path.splitext(link.split('/')[-1])[0]
352 | avvo_id = page_path.split('-')[-1]
353 | if avvo_id.isnumeric():
354 | return avvo_id
355 |
356 | @staticmethod
357 | def get_from_avvo_api(avvo_id):
358 | headers = {'Authorization': 'Bearer {}'.format(AVVO_API_ACCESS_TOKEN)}
359 | url = 'https://api.avvo.com/api/4/lawyers/{}.json'.format(avvo_id)
360 | r = requests.get(url, headers=headers)
361 | return json.loads(r.content)
362 |
363 | @classmethod
364 | def name_distance(cls, string1, string2, wrong_initial_penalty=5):
365 | '''
366 | levenshtein distance accommodating for initials
367 | First and last initials must match
368 | TODO: allow for hyphenated names with last name partial match
369 | '''
370 | name1 = string1.lower().translate(str.maketrans('', '', string.punctuation)).lower()
371 | name2 = string2.lower().translate(str.maketrans('', '', string.punctuation)).lower()
372 | base_dist = lev_dist(name1, name2)
373 |
374 | if base_dist == 0:
375 | return 0
376 |
377 | if '-' in string1.split()[-1] or '-' in string2.split()[-1]:
378 | s1_perms = cls.hyphenation_perm(string1)
379 | s2_perms = cls.hyphenation_perm(string2)
380 | dists = []
381 | for s1 in s1_perms:
382 | for s2 in s2_perms:
383 | dists.append(cls.name_distance(s1, s2))
384 | return min(dists)
385 |
386 | name1_split = name1.split()
387 | name2_split = name2.split()
388 |
389 | if name1_split[0] == name2_split[0] and name1_split[-1] == name2_split[-1]:
390 | if len(name1_split) == 2 or len(name2_split) == 2:
391 | return lev_dist(' '.join([name1_split[0], name1_split[-1]]),
392 | ' '.join([name2_split[0], name2_split[-1]]))
393 |
394 | newname1 = ' '.join([n[0] if (1 <= i < len(name1_split) - 1) else n for i, n in enumerate(name1_split)])
395 | newname2 = ' '.join([n[0] if (1 <= i < len(name2_split) - 1) else n for i, n in enumerate(name2_split)])
396 | return lev_dist(newname1, newname2)
397 | else:
398 | return base_dist + wrong_initial_penalty
399 |
400 | @staticmethod
401 | def hyphenation_perm(name):
402 | splitup = name.split()
403 | lastname = splitup[-1]
404 | if '-' in lastname:
405 | lname_candidates = [' '.join(splitup[:-1] + [l]) for l in lastname.split('-')]
406 | return lname_candidates
407 | else:
408 | return [name]
409 |
410 |
411 | if __name__ == "__main__":
412 | assert AVVO_API_ACCESS_TOKEN, "AVVO_API_ACCESS_TOKEN in local_info.py needs to be filled out"
413 | assert AVVO_CSE_ID, "AVVO_CSE_ID in local_info.py needs to be filled out"
414 | assert GOOGLE_API_KEY, "GOOGLE_API_KEY in local_info.py needs to be filled out"
415 |
416 | args = parse_args()
417 | base_outfolder = args.base_outfolder
418 | assert os.path.isdir(base_outfolder)
419 |
420 | pickle_path = os.path.join(base_outfolder, 'dob.p')
421 | info_pickle_path = os.path.join(base_outfolder, 'dob_info.p')
422 |
423 | speaker_id_path = os.path.join(base_outfolder, 'speaker_ids.json')
424 | assert os.path.isfile(speaker_id_path), "Can't find speaker_ids.json"
425 |
426 | speaker_ids = json.load(open(speaker_id_path, encoding='utf-8'),
427 | object_pairs_hook=OrderedDict)
428 |
429 | parser = LawyerDOBParser()
430 | parser.get_scotus_clerks()
431 |
432 | if args.overwrite or not os.path.isfile(pickle_path):
433 | dobs = OrderedDict({})
434 | infos = OrderedDict({})
435 | speakers_to_scrape = sorted(speaker_ids.keys())
436 | else:
437 | infos = pickle.load(open(info_pickle_path, 'rb'))
438 | dobs = pickle.load(open(pickle_path, 'rb'))
439 | if args.skip_attempted:
440 | speakers_to_scrape = set(speaker_ids.keys()) - set(dobs.keys())
441 | else:
442 | speakers_to_scrape = set(speaker_ids.keys()) - set([s for s in dobs if dobs[s]])
443 |
444 | if speakers_to_scrape:
445 | speakers_to_scrape = sorted(list(speakers_to_scrape))
446 |
447 | for s in tqdm(speakers_to_scrape):
448 | query_name = speaker_ids[s]['name']
449 | parsed_dob, info = parser.parse_name(query_name)
450 | dobs[s] = parsed_dob
451 | infos[s] = info
452 | pickle.dump(dobs, open(pickle_path, 'wb'))
453 | pickle.dump(infos, open(info_pickle_path, 'wb'))
454 |
455 | num_dob_speakers = sum([1 for s in dobs if dobs[s]])
456 |
457 | print('Found DoB for {} out of {} speakers'.format(num_dob_speakers, len(speaker_ids)))
458 | print('Done!')
459 |
460 |
--------------------------------------------------------------------------------
/scotus_data_prep/step3_prepdata.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import pickle
6 | import sys
7 | import shutil
8 | from collections import OrderedDict
9 | from copy import copy
10 | from glob import glob
11 |
12 | import numpy as np
13 | from tqdm import tqdm
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description='Prep the data for verification, diarization and feature extraction')
18 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder')
19 | parser.add_argument('--subsegment-length', type=float, default=1.5, help='Length of diarization subsegments (default: 1.5s)')
20 | parser.add_argument('--subsegment-shift', type=float, default=0.75, help='Subsegment shift duration (default: 0.75s)')
21 | args = parser.parse_args()
22 | return args
23 |
24 | def write_json(outfile, outdict):
25 | with open(outfile, 'w', encoding='utf-8') as wp:
26 | json.dump(outdict, wp)
27 |
28 | def assign_newrecnames(caselist):
29 | """
30 | Converts recording names to a Kaldi friendly standard format of YEAR-XYZ
31 |
32 | input: recording json list
33 | output: (original to newrecname mapping, newrecname to original mapping)
34 | """
35 | recid_orec_mapping = OrderedDict({})
36 |
37 | for r in caselist:
38 | rec_name = os.path.splitext(os.path.basename(r))[0]
39 | year = rec_name[:4]
40 | assert len(year) == 4, 'Year length was not 4, something is wrong'
41 | index = 0
42 | new_recid = "{0}-{1:0=3d}".format(str(year), index)
43 | while True:
44 | if new_recid in recid_orec_mapping:
45 | index += 1
46 | new_recid = "{0}-{1:0=3d}".format(str(year), index)
47 | else:
48 | recid_orec_mapping[new_recid] = rec_name
49 | break
50 |
51 | orec_recid_mapping = OrderedDict({recid_orec_mapping[k]: k for k in recid_orec_mapping})
52 | return orec_recid_mapping, recid_orec_mapping
53 |
54 |
55 | def make_wavscp(base_folder, caselist, orec_recid_mapping):
56 | """
57 | Make wav.scp with new recnames
58 |
59 | inputs: (outfile, list of case jsons, original to new recording mapping)
60 | outputs: None [file written to outfile]
61 | """
62 | wavlines = []
63 | wavscp = os.path.join(base_folder, 'wav.scp')
64 | for r in caselist:
65 | orec_name = os.path.splitext(os.path.basename(r))[0]
66 | newrec_id = orec_recid_mapping[orec_name]
67 | mp3file = os.path.join(os.path.abspath(base_folder), 'audio/{}.mp3'.format(orec_name))
68 | assert os.path.isfile(mp3file), "Couldn't find {}".format(mp3file)
69 | wavline = '{} ffmpeg -v 8 -i {} -f wav -ar 16000 -acodec pcm_s16le -|\n'.format(newrec_id, mp3file)
70 | wavlines.append(wavline)
71 |
72 | with open(wavscp, 'w+') as wp:
73 | for line in wavlines:
74 | wp.write(line)
75 |
76 |
77 | def process_utts_dob(utts, dobs):
78 | '''
79 | Removes utterances from speakers with unknown dob
80 | '''
81 | new_utts = []
82 | for u in utts:
83 | if u['speaker_id'] in dobs:
84 | if dobs[u['speaker_id']]:
85 | new_utts.append(u)
86 | assert new_utts, "No more utts remain: this should not occur, check DoB pickle"
87 | return new_utts
88 |
89 |
90 | def join_up_utterances(utt_list,
91 | cutoff_dur=10000,
92 | soft_min_length=0.0):
93 | new_utt_list = [utt_list[0].copy()]
94 | for i, utt in enumerate(utt_list):
95 | if i == 0:
96 | continue
97 | if utt['start'] == new_utt_list[-1]['stop'] and utt['speaker_id'] == new_utt_list[-1]['speaker_id']:
98 | if new_utt_list[-1]['stop'] - new_utt_list[-1]['start'] < cutoff_dur or utt_list[i]['stop'] - utt_list[i][
99 | 'start'] < soft_min_length:
100 | new_utt_list[-1]['stop'] = utt['stop']
101 | new_utt_list[-1]['text'] += ' {}'.format(utt['text'])
102 |
103 | else:
104 | new_utt_list.append(utt_list[i].copy())
105 | return new_utt_list
106 |
107 |
108 | def split_up_single_utterance(utt,
109 | target_utt_length=10.0,
110 | min_utt_length=4.0):
111 | duration = utt['stop'] - utt['start']
112 | if duration < min_utt_length + target_utt_length:
113 | return [utt]
114 | else:
115 | remaining_duration = copy(duration)
116 | new_utt_list = []
117 | iterations = 0
118 | while remaining_duration >= min_utt_length + target_utt_length:
119 | new_utt = OrderedDict({})
120 | new_utt['start'] = utt['start'] + (iterations * target_utt_length)
121 | new_utt['stop'] = new_utt['start'] + target_utt_length
122 | new_utt['text'] = utt['text']
123 | new_utt['speaker_id'] = utt['speaker_id']
124 |
125 | new_utt_list.append(new_utt)
126 |
127 | remaining_duration -= target_utt_length
128 | iterations += 1
129 |
130 | new_utt = OrderedDict({})
131 | new_utt['start'] = new_utt_list[-1]['stop']
132 | new_utt['stop'] = utt['stop']
133 | new_utt['text'] = utt['text']
134 | new_utt['speaker_id'] = utt['speaker_id']
135 | new_utt_list.append(new_utt)
136 |
137 | return new_utt_list
138 |
139 |
140 | def split_up_long_utterances(utt_list,
141 | target_utt_length=10.0,
142 | min_utt_length=4.0):
143 | new_utt_list = []
144 | for utt in utt_list:
145 | splitup = split_up_single_utterance(utt,
146 | target_utt_length=target_utt_length,
147 | min_utt_length=min_utt_length)
148 | new_utt_list.extend(splitup)
149 | return new_utt_list
150 |
151 |
152 | def make_segments(utts, recname, min_utt_len=2.0):
153 | """
154 | Make kaldi friendly segments of the format recname-VWXYZ
155 |
156 | Input: uttlist imported from transcription json, recname
157 |
158 | """
159 | seglines = []
160 | speakers = []
161 | utt_ids = []
162 |
163 | i = 0
164 | for utt in utts:
165 | utt_id = '{0}-{1:0=5d}'.format(recname, i)
166 | start = float(utt['start'])
167 | stop = float(utt['stop'])
168 |
169 | if float(start) + min_utt_len > float(stop):
170 | # Discard too short snippets
171 | continue
172 |
173 | speaker_id = utt['speaker_id']
174 | if not speaker_id.strip():
175 | # Discard empty speaker_ids
176 | continue
177 |
178 | line = '{} {} {} {}\n'.format(utt_id, recname, start, stop)
179 |
180 | seglines.append(line)
181 | speakers.append(speaker_id)
182 | utt_ids.append(utt_id)
183 | i += 1
184 |
185 | return seglines, utt_ids, speakers
186 |
187 |
188 | def prep_utts(utts):
189 | # Converts start and stop to float
190 | for u in utts:
191 | u['start'] = float(u['start'])
192 | u['stop'] = float(u['stop'])
193 | return utts
194 |
195 |
196 | def calculate_speaker_ages(speakers, dobs, rec_date):
197 | """
198 | Calculate the age in days of a list of speakers based on recording date and DoB
199 |
200 | inputs: (list of speakers, DoB dictionary of datetimes, recording datetime)
201 | output: list of speaker ages
202 | """
203 | set_speakers = set(speakers)
204 | age_dict = {}
205 | for s in set_speakers:
206 | dob = dobs[s]
207 | delta_days = abs((rec_date - dob).days)
208 | age_dict[s] = delta_days
209 | speaker_ages = [age_dict[s] for s in speakers]
210 | return speaker_ages
211 |
212 |
213 | def make_verification_dataset(base_outfolder, caselist, orec_recid_mapping, dobs):
214 | """
215 | Makes a verification/training dataset: Long utterances split up
216 | """
217 | veri_data_path = os.path.join(base_outfolder, 'veri_data')
218 | os.makedirs(veri_data_path, exist_ok=True)
219 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp')
220 | shutil.copy(wavscp_path, veri_data_path)
221 |
222 | all_seglines = []
223 | all_uttids = []
224 | all_speakers = []
225 | all_recs = []
226 | all_ages = []
227 |
228 | for case in tqdm(caselist):
229 | rec_name = os.path.splitext(os.path.basename(case))[0]
230 | new_recid = orec_recid_mapping[rec_name]
231 | js = json.load(open(case, encoding='utf-8'), object_pairs_hook=OrderedDict)
232 |
233 | utts = js['utts']
234 | utts = process_utts_dob(utts, dobs)
235 | utts = prep_utts(utts)
236 | utts = join_up_utterances(utts)
237 | utts = split_up_long_utterances(utts, target_utt_length=10.0, min_utt_length=4.0)
238 |
239 | seglines, utt_ids, speakers = make_segments(utts, new_recid)
240 |
241 | case_date = datetime.datetime.strptime(js['case_date'], '%Y/%m/%d')
242 | utt_ages = calculate_speaker_ages(speakers, dobs, case_date)
243 |
244 | all_recs.extend([new_recid for _ in utt_ids])
245 | all_seglines.extend(seglines)
246 | all_uttids.extend(utt_ids)
247 | all_speakers.extend(speakers)
248 | all_ages.extend(utt_ages)
249 |
250 | with open(os.path.join(veri_data_path, 'segments'), 'w+') as fp:
251 | for line in all_seglines:
252 | fp.write(line)
253 |
254 | with open(os.path.join(veri_data_path, 'real_utt2spk'), 'w+') as fp:
255 | for u, s in zip(all_uttids, all_speakers):
256 | line = '{} {}\n'.format(u, s)
257 | fp.write(line)
258 |
259 | with open(os.path.join(veri_data_path, 'utt2age'), 'w+') as fp:
260 | for u, a in zip(all_uttids, all_ages):
261 | line = '{} {}\n'.format(u, a)
262 | fp.write(line)
263 |
264 | with open(os.path.join(veri_data_path, 'utt2spk'), 'w+') as fp:
265 | for u, r in zip(all_uttids, all_recs):
266 | line = '{} {}\n'.format(u, r)
267 | fp.write(line)
268 |
269 | utt2spk_to_spk2utt(os.path.join(veri_data_path, 'utt2spk'),
270 | outfile=os.path.join(veri_data_path, 'spk2utt'))
271 |
272 | utt2spk_to_spk2utt(os.path.join(veri_data_path, 'real_utt2spk'),
273 | outfile=os.path.join(veri_data_path, 'real_spk2utt'))
274 |
275 |
276 |
277 | def split_up_single_utterance_subsegments(utt,
278 | target_utt_length=1.5,
279 | min_utt_length=1.4,
280 | shift=0.75):
281 | """
282 | Split up a single utterance into subsegments based on input variables
283 | """
284 | duration = utt['stop'] - utt['start']
285 | if duration < min_utt_length + target_utt_length - shift:
286 | return [utt]
287 | else:
288 | new_utt_list = []
289 | current_start = copy(utt['start'])
290 | while current_start <= utt['stop'] - min_utt_length:
291 | new_utt = OrderedDict({})
292 | new_utt['start'] = current_start
293 | new_utt['stop'] = new_utt['start'] + target_utt_length
294 | new_utt['byte_start'] = utt['byte_start'] # todo fix
295 | new_utt['byte_stop'] = utt['byte_stop'] # todo fix
296 | new_utt['text'] = 'n/a'
297 | new_utt['speaker_id'] = utt['speaker_id']
298 |
299 | new_utt_list.append(new_utt)
300 |
301 | current_start += shift
302 |
303 | new_utt = OrderedDict({})
304 | new_utt['start'] = current_start
305 | new_utt['stop'] = utt['stop']
306 | new_utt['byte_start'] = utt['byte_start']
307 | new_utt['byte_stop'] = utt['byte_stop']
308 | new_utt['text'] = 'n/a'
309 | new_utt['speaker_id'] = utt['speaker_id']
310 | new_utt_list.append(new_utt)
311 |
312 | return new_utt_list
313 |
314 |
315 | def split_up_uttlist_subsegments(utt_list,
316 | target_utt_length=1.5,
317 | min_utt_length=1.4,
318 | shift=0.75):
319 | """
320 | Split up a list of utterances into subsegments based in input variables
321 | """
322 | new_utt_list = []
323 | for utt in utt_list:
324 | splitup = split_up_single_utterance_subsegments(utt,
325 | target_utt_length=target_utt_length,
326 | min_utt_length=min_utt_length,
327 | shift=shift)
328 | new_utt_list.extend(splitup)
329 | return new_utt_list
330 |
331 | def rttm_lines_from_uttlist(utts, rec_id):
332 | utts = join_up_utterances(utts)
333 | rttm_lines = []
334 |
335 | rttm_line = 'SPEAKER {} 0 {:.2f} {:.2f} {} \n'
336 |
337 | for u in utts:
338 | spk = u['speaker_id']
339 | start = float(u['start'])
340 | stop = float(u['stop'])
341 | offset = stop - start
342 | if offset < 0.0:
343 | continue
344 | rttm_lines.append(rttm_line.format(rec_id, start, offset, spk))
345 |
346 | return rttm_lines
347 |
348 |
349 |
350 | def make_diarization_dataset(base_outfolder, caselist, orec_recid_mapping, dobs,
351 | target_utt_length=1.5, min_utt_length=1.4, shift=0.75):
352 | """
353 | Makes a diarization dataset, utterances split into overlapping segments
354 | Makes rttms
355 | """
356 | diar_data_path = os.path.join(base_outfolder, 'diar_data')
357 | os.makedirs(diar_data_path, exist_ok=True)
358 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp')
359 | shutil.copy(wavscp_path, diar_data_path)
360 |
361 | all_seglines = []
362 | all_uttids = []
363 | all_speakers = []
364 | all_recs = []
365 |
366 | all_rttm_lines = []
367 |
368 | for case in tqdm(caselist):
369 | rec_name = os.path.splitext(os.path.basename(case))[0]
370 | new_recid = orec_recid_mapping[rec_name]
371 | js = json.load(open(case, encoding='utf-8'), object_pairs_hook=OrderedDict)
372 |
373 | utts = js['utts']
374 | utts = prep_utts(utts)
375 | utts = join_up_utterances(utts, cutoff_dur=1e20)
376 |
377 | rec_rttm_lines = rttm_lines_from_uttlist(utts, new_recid)
378 | all_rttm_lines.extend(rec_rttm_lines)
379 |
380 | utts = split_up_uttlist_subsegments(utts,
381 | target_utt_length=1.5,
382 | min_utt_length=1.4,
383 | shift=0.75)
384 |
385 | seglines, utt_ids, speakers = make_segments(utts, new_recid, min_utt_len=min_utt_length)
386 |
387 | all_recs.extend([new_recid for _ in utt_ids])
388 | all_seglines.extend(seglines)
389 | all_uttids.extend(utt_ids)
390 | all_speakers.extend(speakers)
391 |
392 | with open(os.path.join(diar_data_path, 'segments'), 'w+') as fp:
393 | for line in all_seglines:
394 | fp.write(line)
395 |
396 | with open(os.path.join(diar_data_path, 'ref.rttm'), 'w+') as fp:
397 | for line in all_rttm_lines:
398 | fp.write(line)
399 |
400 | with open(os.path.join(diar_data_path, 'real_utt2spk'), 'w+') as fp:
401 | for u, s in zip(all_uttids, all_speakers):
402 | line = '{} {}\n'.format(u, s)
403 | fp.write(line)
404 |
405 | with open(os.path.join(diar_data_path, 'utt2spk'), 'w+') as fp:
406 | for u, r in zip(all_uttids, all_recs):
407 | line = '{} {}\n'.format(u, r)
408 | fp.write(line)
409 |
410 | utt2spk_to_spk2utt(os.path.join(diar_data_path, 'utt2spk'),
411 | outfile=os.path.join(diar_data_path, 'spk2utt'))
412 |
413 | utt2spk_to_spk2utt(os.path.join(diar_data_path, 'real_utt2spk'),
414 | outfile=os.path.join(diar_data_path, 'real_spk2utt'))
415 |
416 |
417 | def utt2spk_to_spk2utt(utt2spk_path, outfile=None):
418 | utts = []
419 | spks = []
420 | with open(utt2spk_path) as fp:
421 | for line in fp:
422 | splitup = line.strip().split(' ')
423 | assert len(splitup) == 2, 'Got more or less columns that was expected: (Got {}, expected 2)'.format(
424 | len(splitup))
425 | utts.append(splitup[0])
426 | spks.append(splitup[1])
427 | set_spks = sorted(list(set(spks)))
428 | spk2utt_dict = OrderedDict({k: [] for k in set_spks})
429 | for u, s in zip(utts, spks):
430 | spk2utt_dict[s].append(u)
431 |
432 | if outfile:
433 | with open(outfile, 'w+') as fp:
434 | for spk in spk2utt_dict:
435 | line = '{} {}\n'.format(spk, ' '.join(spk2utt_dict[spk]))
436 | fp.write(line)
437 | return spk2utt_dict
438 |
439 |
440 |
441 | if __name__ == "__main__":
442 | args = parse_args()
443 | base_outfolder = args.base_outfolder
444 | assert os.path.isdir(base_outfolder), 'Outfolder does not exist'
445 | wavscp_path = os.path.join(args.base_outfolder, 'wav.scp')
446 | assert os.path.isfile(wavscp_path), "Can't find {}".format(wavscp_path)
447 |
448 | dobs_pkl_path = os.path.join(args.base_outfolder, 'dob.p')
449 | assert os.path.isfile(dobs_pkl_path), "Couldn't find {}".format(dobs_pkl_path)
450 | dobs = pickle.load(open(dobs_pkl_path, 'rb'))
451 |
452 | audio_folder = os.path.join(base_outfolder, 'audio')
453 | transcript_folder = os.path.join(base_outfolder, 'transcripts')
454 |
455 | caselist = glob(os.path.join(transcript_folder, '*.json'))
456 | caselist = sorted(caselist)
457 |
458 | print('Assigning new recording names...')
459 | orec_recid_mapping, recid_orec_mapping = assign_newrecnames(caselist)
460 |
461 | write_json(os.path.join(base_outfolder, 'orec_recid_mapping.json'), orec_recid_mapping)
462 | write_json(os.path.join(base_outfolder, 'recid_orec_mapping.json'), recid_orec_mapping)
463 |
464 | print('Making wav.scp in {}'.format(os.path.join(base_outfolder, 'wav.scp')))
465 | make_wavscp(base_outfolder, caselist, orec_recid_mapping)
466 |
467 | print('Making base verification/training dataset...')
468 | make_verification_dataset(base_outfolder, caselist, orec_recid_mapping, dobs)
469 |
470 | print('Making diarization dataset...')
471 | make_diarization_dataset(base_outfolder, caselist, orec_recid_mapping, dobs,
472 | target_utt_length=args.subsegment_length, min_utt_length=args.subsegment_length-0.1,
473 | shift=args.subsegment_shift)
474 |
475 |
--------------------------------------------------------------------------------
/scotus_data_prep/step4_extract_feats.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | step=0
4 | nj=20
5 | base_outfolder=/PATH/TO/BASE_OUTFOLDER
6 | scotus_veri_dir=$base_outfolder/veri_data
7 | scotus_diar_dir=$base_outfolder/diar_data
8 |
9 | if [ $step -le 0 ]; then
10 | #make the feats
11 | utils/fix_data_dir.sh $scotus_veri_dir
12 | steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj $nj \
13 | --cmd run.pl $scotus_veri_dir
14 |
15 | utils/fix_data_dir.sh $scotus_veri_dir
16 | sid/compute_vad_decision.sh --nj $nj --cmd run.pl $scotus_veri_dir
17 | utils/fix_data_dir.sh $scotus_veri_dir
18 | fi
19 |
20 | if [ $step -le 1 ]; then
21 | local/nnet3/xvector/prepare_feats_for_egs.sh --nj $nj --cmd run.pl \
22 | $scotus_veri_dir ${scotus_veri_dir}_nosil $scotus_veri_dir/nosil_feats
23 | utils/fix_data_dir.sh ${scotus_veri_dir}_nosil
24 | fi
25 |
26 | if [ $step -le 2 ]; then
27 | #make the feats
28 | utils/fix_data_dir.sh $scotus_diar_dir
29 | steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj $nj \
30 | --cmd run.pl $scotus_diar_dir
31 |
32 | utils/fix_data_dir.sh $scotus_diar_dir
33 | sid/compute_vad_decision.sh --nj $nj --cmd run.pl $scotus_diar_dir
34 | utils/fix_data_dir.sh $scotus_diar_dir
35 | fi
36 |
37 | if [ $step -le 3 ]; then
38 | local/nnet3/xvector/prepare_feats_for_egs.sh --nj $nj --cmd run.pl \
39 | $scotus_diar_dir ${scotus_diar_dir}_nosil $scotus_diar_dir/nosil_feats
40 | utils/fix_data_dir.sh ${scotus_diar_dir}_nosil
41 | fi
--------------------------------------------------------------------------------
/scotus_data_prep/step5_trim_split_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 | import pickle
6 | import random
7 | import shutil
8 | import sys
9 | from collections import OrderedDict
10 | from itertools import combinations
11 |
12 | import numpy as np
13 | from scipy.special import comb
14 | from tqdm import tqdm
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(description='Prep the data for verification, diarization and feature extraction')
19 | parser.add_argument('--base-outfolder', type=str, help='Location of the base outfolder')
20 | parser.add_argument('--train-proportion', type=float, default=0.8, help='Train proportion (default: 0.8)')
21 | parser.add_argument('--pos-per-spk', type=int, default=10, help='Positive trials per speaker')
22 | args = parser.parse_args()
23 | return args
24 |
25 |
26 | def load_n_col(file):
27 | data = []
28 | with open(file) as fp:
29 | for line in fp:
30 | data.append(line.strip().split(' '))
31 | columns = list(zip(*data))
32 | columns = [list(i) for i in columns]
33 | return columns
34 |
35 |
36 | def write_lines(lines, file):
37 | with open(file, 'w+') as fp:
38 | for line in lines:
39 | fp.write(line)
40 |
41 |
42 | def fix_data_dir(data_dir):
43 | """
44 | Files: real_utt2spk, real_spk2utt, utt2age, wav.scp, utt2spk, spk2utt, segments
45 | Cleans and fixes all the right files to agree with kaldi data processing
46 | """
47 | backup_data_dir = os.path.join(data_dir, '.mybackup')
48 | os.makedirs(backup_data_dir, exist_ok=True)
49 |
50 | files = glob.glob(os.path.join(data_dir, '*'))
51 | files = [f for f in files if os.path.isfile(f)]
52 | _ = [shutil.copy(f, backup_data_dir) for f in files]
53 |
54 | utt2spk_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'real_utt2spk')))})
55 | utt2fpath_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'feats.scp')))})
56 |
57 | utt2age = os.path.isfile(os.path.join(data_dir, 'utt2age'))
58 | if utt2age:
59 | utt2age_dict = OrderedDict({k: v for k, v in zip(*load_n_col(os.path.join(data_dir, 'utt2age')))})
60 |
61 | complete_utts = set(utt2fpath_dict.keys()).intersection(set(utt2spk_dict.keys()))
62 | complete_utts = sorted(list(complete_utts))
63 |
64 | print('Reducing real utt2spk ({}) to {} utts...'.format(len(utt2spk_dict),
65 | len(complete_utts)))
66 |
67 | blank_utts = os.path.join(data_dir, 'utts')
68 | with open(blank_utts, 'w+') as fp:
69 | for u in complete_utts:
70 | fp.write('{}\n'.format(u))
71 |
72 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'real_utt2spk'),
73 | os.path.join(data_dir, 'utt2spk')))
74 |
75 | os.rename(os.path.join(data_dir, 'segments'), os.path.join(data_dir, 'segments_old'))
76 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'segments_old'),
77 | os.path.join(data_dir, 'segments')))
78 |
79 | if utt2age:
80 | os.rename(os.path.join(data_dir, 'utt2age'), os.path.join(data_dir, 'utt2age_old'))
81 | os.system('./filter_scp.pl {} {} > {}'.format(blank_utts, os.path.join(data_dir, 'utt2age_old'),
82 | os.path.join(data_dir, 'utt2age')))
83 |
84 |
85 | set_spks = sorted(list(set(utt2spk_dict.values())))
86 | spk2utt_dict = OrderedDict({k: [] for k in set_spks})
87 | for u in complete_utts:
88 | spk2utt_dict[utt2spk_dict[u]].append(u)
89 |
90 | with open(os.path.join(data_dir, 'spk2utt'), 'w+') as fp:
91 | for s in spk2utt_dict:
92 | if spk2utt_dict[s]:
93 | line = '{} {}\n'.format(s, ' '.join(spk2utt_dict[s]))
94 | fp.write(line)
95 |
96 |
97 | def load_one_tomany(file):
98 | one = []
99 | many = []
100 | with open(file) as fp:
101 | for line in fp:
102 | line = line.strip().split(' ', 1)
103 | one.append(line[0])
104 | m = line[1].split(' ')
105 | many.append(m)
106 | return one, many
107 |
108 |
109 | def utt2spk_to_spk2utt(utt2spk_path, outfile=None):
110 | utts = []
111 | spks = []
112 | with open(utt2spk_path) as fp:
113 | for line in fp:
114 | splitup = line.strip().split(' ')
115 | assert len(splitup) == 2, 'Got more or less columns that was expected: (Got {}, expected 2)'.format(
116 | len(splitup))
117 | utts.append(splitup[0])
118 | spks.append(splitup[1])
119 | set_spks = sorted(list(set(spks)))
120 | spk2utt_dict = OrderedDict({k: [] for k in set_spks})
121 | for u, s in zip(utts, spks):
122 | spk2utt_dict[s].append(u)
123 |
124 | if outfile:
125 | with open(outfile, 'w+') as fp:
126 | for spk in spk2utt_dict:
127 | line = '{} {}\n'.format(spk, ' '.join(spk2utt_dict[spk]))
128 | fp.write(line)
129 | return spk2utt_dict
130 |
131 |
132 | def split_recordings(data_dir, train_proportion=0.8):
133 | """
134 | Split the recordings based on train proportion
135 |
136 | returns train_recordings, test_recordings
137 | """
138 | np.random.seed(1234)
139 | random.seed(1234)
140 | segments_path = os.path.join(data_dir, 'segments')
141 | assert os.path.isfile(segments_path), "Couldn't find {}".format(segments_path)
142 |
143 | utts, urecs, _, _ = load_n_col(segments_path)
144 |
145 | setrecs = sorted(list(set(urecs)))
146 | num_train_recs = int(np.floor(train_proportion * len(setrecs)))
147 |
148 | train_recs = np.random.choice(setrecs, size=num_train_recs, replace=False)
149 | test_recs = [r for r in setrecs if r not in train_recs]
150 | return train_recs, test_recs
151 |
152 |
153 | def split_data_dir(data_dir, train_recs, test_recs):
154 | """
155 | Split recordings and files into train and test subfolders based on train_recs, test_recs
156 |
157 | Filters feats, utt2spk, spk2utt, segments, (rttm, utt2age)
158 | """
159 | segments_path = os.path.join(data_dir, 'segments')
160 | assert os.path.isfile(segments_path), "Couldn't find {}".format(segments_path)
161 |
162 | train_dir = os.path.join(data_dir, 'train')
163 | test_dir = os.path.join(data_dir, 'test')
164 |
165 | os.makedirs(train_dir, exist_ok=True)
166 | os.makedirs(test_dir, exist_ok=True)
167 |
168 | utts, urecs, _, _ = load_n_col(segments_path)
169 | utt2rec_dict = OrderedDict({k:v for k,v in zip(utts, urecs)})
170 |
171 | train_utts = [u for u in utts if utt2rec_dict[u] in train_recs]
172 | test_utts = [u for u in utts if utt2rec_dict[u] in test_recs]
173 |
174 | tr_u = os.path.join(train_dir, 'utts')
175 | with open(tr_u, 'w+') as fp:
176 | for u in train_utts:
177 | fp.write('{}\n'.format(u))
178 |
179 | te_u = os.path.join(test_dir, 'utts')
180 | with open(te_u, 'w+') as fp:
181 | for u in test_utts:
182 | fp.write('{}\n'.format(u))
183 |
184 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'utt2spk'),
185 | os.path.join(train_dir, 'utt2spk')))
186 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'utt2spk'),
187 | os.path.join(test_dir, 'utt2spk')))
188 |
189 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'feats.scp'),
190 | os.path.join(train_dir, 'feats.scp')))
191 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'feats.scp'),
192 | os.path.join(test_dir, 'feats.scp')))
193 |
194 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'segments'),
195 | os.path.join(train_dir, 'segments')))
196 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'segments'),
197 | os.path.join(test_dir, 'segments')))
198 |
199 | utt2spk_to_spk2utt(os.path.join(train_dir, 'utt2spk'), outfile=os.path.join(train_dir, 'spk2utt'))
200 | utt2spk_to_spk2utt(os.path.join(test_dir, 'utt2spk'), outfile=os.path.join(test_dir, 'spk2utt'))
201 |
202 | if os.path.isfile(os.path.join(data_dir, 'utt2age')):
203 | os.system('./filter_scp.pl {} {} > {}'.format(tr_u, os.path.join(data_dir, 'utt2age'),
204 | os.path.join(train_dir, 'utt2age')))
205 | os.system('./filter_scp.pl {} {} > {}'.format(te_u, os.path.join(data_dir, 'utt2age'),
206 | os.path.join(test_dir, 'utt2age')))
207 |
208 | rttm_path = os.path.join(data_dir, 'ref.rttm')
209 | if os.path.isfile(rttm_path):
210 | with open(rttm_path) as fp:
211 | rttm_lines = [line for line in fp]
212 |
213 | rttm_recs = [line.split()[1].strip() for line in rttm_lines]
214 |
215 | tr_rttmlines = [l for l, r in zip(rttm_lines, rttm_recs) if r in train_recs]
216 | te_rttmlines = [l for l, r in zip(rttm_lines, rttm_recs) if r in test_recs]
217 |
218 | write_lines(tr_rttmlines, os.path.join(train_dir, 'ref.rttm'))
219 | write_lines(te_rttmlines, os.path.join(test_dir, 'ref.rttm'))
220 |
221 |
222 | def nonoverlapped_utts(veri_data_dir):
223 | '''
224 | Retrieve only the utterances that belong to speakers not seen in the training set
225 | '''
226 | train_dir = os.path.join(veri_data_dir, 'train')
227 | test_dir = os.path.join(veri_data_dir, 'test')
228 |
229 | train_utts, train_spkrs = load_n_col(os.path.join(train_dir, 'utt2spk'))
230 | test_utts, test_spkrs = load_n_col(os.path.join(test_dir, 'utt2spk'))
231 |
232 | set_tr_spkrs = set(train_spkrs)
233 | set_te_spkrs = set(test_spkrs)
234 |
235 | non_overlapped_speakers = set_te_spkrs - set_tr_spkrs
236 | assert len(non_overlapped_speakers) >= 10, "Something has gone wrong if less than 10 speakers are left"
237 |
238 | valid_utts, valid_spkrs = [], []
239 | for u, s in zip(test_utts, test_spkrs):
240 | if s in non_overlapped_speakers:
241 | valid_utts.append(u)
242 | valid_spkrs.append(s)
243 |
244 | return valid_utts, valid_spkrs
245 |
246 |
247 | def generate_veri_pairs(utts, spkrs, pos_per_spk=15):
248 | # Randomly creates pairs for a verification list
249 | # pos_per_spk determines the number of same-speaker pairs, which is always paired with an equal number of negatives
250 | np.random.seed(1234)
251 | random.seed(1234)
252 | setspkrs = sorted(list(set(spkrs)))
253 | spk2utt_dict = OrderedDict({s: [] for s in setspkrs})
254 | for utt, s in zip(utts, spkrs):
255 | spk2utt_dict[s].append(utt)
256 |
257 | u0 = []
258 | u1 = []
259 | labs = []
260 |
261 | for s in setspkrs:
262 | random.shuffle(spk2utt_dict[s])
263 |
264 | for s in setspkrs:
265 | positives = [spk2utt_dict[s].pop() for _ in range(pos_per_spk)]
266 | num_neg_trials = int(comb(pos_per_spk, 2))
267 | negatives = [spk2utt_dict[np.random.choice(list(set(setspkrs) - set([s])))].pop() for _ in
268 | range(num_neg_trials)]
269 | for a, b in combinations(positives, 2):
270 | u0.append(a)
271 | u1.append(b)
272 | labs.append(1)
273 | for a, b in zip(np.random.choice(positives, size=num_neg_trials, replace=True), negatives):
274 | u0.append(a)
275 | u1.append(b)
276 | labs.append(0)
277 |
278 | return u0, u1, labs
279 |
280 |
281 | if __name__ == "__main__":
282 | args = parse_args()
283 |
284 | veri_data_dir = os.path.join(args.base_outfolder, 'veri_data_nosil')
285 | assert os.path.isdir(veri_data_dir), "Couldn't find {}".format(veri_data_dir)
286 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/segments'), veri_data_dir)
287 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/utt2age'), veri_data_dir)
288 | shutil.copy(os.path.join(args.base_outfolder, 'veri_data/real_utt2spk'), veri_data_dir)
289 |
290 | print('Fixing data dir: {}'.format(veri_data_dir))
291 | fix_data_dir(veri_data_dir)
292 | train_recs, test_recs = split_recordings(veri_data_dir, train_proportion=args.train_proportion)
293 |
294 | print('Making recording split...')
295 | with open(os.path.join(args.base_outfolder, 'recording_split.json'), 'w+', encoding='utf-8') as fp:
296 | recdict = {'train': list(train_recs), 'test': list(test_recs)}
297 | json.dump(recdict, fp)
298 |
299 | print('Splitting verification data...')
300 | split_data_dir(veri_data_dir, train_recs, test_recs)
301 |
302 | print('Making verification pairs for test portion of verification data...')
303 | valid_utts, valid_spkrs = nonoverlapped_utts(veri_data_dir)
304 | u0, u1, labs = generate_veri_pairs(valid_utts, valid_spkrs, pos_per_spk=args.pos_per_spk)
305 | veri_lines = ['{} {} {}\n'.format(l, a, b) for l, a, b in zip(labs, u0, u1)]
306 | write_lines(veri_lines, os.path.join(veri_data_dir, 'test/veri_pairs'))
307 |
308 | print('Now fixing diarization data...')
309 | diar_data_dir = os.path.join(args.base_outfolder, 'diar_data_nosil')
310 | assert os.path.isdir(diar_data_dir), "Couldn't find {}".format(diar_data_dir)
311 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/real_utt2spk'), diar_data_dir)
312 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/segments'), diar_data_dir)
313 | shutil.copy(os.path.join(args.base_outfolder, 'diar_data/ref.rttm'), diar_data_dir)
314 |
315 | fix_data_dir(diar_data_dir)
316 |
317 | print('Splitting diarization data...')
318 | split_data_dir(diar_data_dir, train_recs, test_recs)
319 |
320 |
321 | print('Done!!')
322 |
323 |
324 |
325 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import configparser
3 | import glob
4 | import json
5 | import os
6 | import pickle
7 | import random
8 | import shutil
9 | import time
10 | from collections import OrderedDict
11 | from pprint import pprint
12 | from inference import test, test_all_factors, test_enrollment_models, test_nosil
13 |
14 | import numpy as np
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import uvloop
20 | from data_io import SpeakerDataset, SpeakerModelTestDataset, SpeakerTestDataset
21 | from models.classifiers import (AdaCos, AMSMLoss, ArcFace, L2SoftMax, SoftMax,
22 | SphereFace, XVecHead, XVecHeadUncertain, GradientReversal)
23 | from models.criteria import (DisturbLabelLoss, LabelSmoothingLoss,
24 | TwoNeighbourSmoothingLoss, MultiTaskUncertaintyLossKendall,
25 | MultiTaskUncertaintyLossLiebel)
26 | from models.extractors import ETDNN, FTDNN, XTDNN
27 | from torch.utils.data import DataLoader
28 | from torch.utils.tensorboard import SummaryWriter
29 | from tqdm import tqdm
30 | from utils import SpeakerRecognitionMetrics, schedule_lr
31 |
32 |
33 | def get_lr(optimizer):
34 | for param_group in optimizer.param_groups:
35 | return param_group['lr']
36 |
37 |
38 | def parse_args():
39 | parser = argparse.ArgumentParser(description='Train SV model')
40 | parser.add_argument('--cfg', type=str, default='./configs/example_speaker.cfg')
41 | parser.add_argument('--transfer-learning', action='store_true', default=False,
42 | help='Start from g_start.pt in exp folder')
43 | parser.add_argument('--resume-checkpoint', type=int, default=0)
44 | args = parser.parse_args()
45 | assert os.path.isfile(args.cfg)
46 | args._start_time = time.ctime()
47 | return args
48 |
49 |
50 | def parse_config(args):
51 | assert os.path.isfile(args.cfg)
52 | config = configparser.ConfigParser()
53 | config.read(args.cfg)
54 |
55 | args.train_data = config['Datasets'].get('train')
56 | assert args.train_data
57 | args.test_data = config['Datasets'].get('test')
58 |
59 | args.model_type = config['Model'].get('model_type', fallback='XTDNN')
60 | assert args.model_type in ['XTDNN', 'ETDNN', 'FTDNN']
61 |
62 | args.classifier_heads = config['Model'].get('classifier_heads').split(',')
63 | assert len(args.classifier_heads) <= 3, 'Three options available'
64 | assert len(args.classifier_heads) == len(set(args.classifier_heads))
65 | for clf in args.classifier_heads:
66 | assert clf in ['speaker', 'nationality', 'gender', 'age', 'age_regression', 'rec']
67 |
68 | args.classifier_types = config['Optim'].get('classifier_types').split(',')
69 | assert len(args.classifier_heads) == len(args.classifier_types)
70 |
71 | args.classifier_loss_weighting_type = config['Optim'].get('classifier_loss_weighting_type', fallback='none')
72 | assert args.classifier_loss_weighting_type in ['none', 'uncertainty_kendall', 'uncertainty_liebel', 'dwa']
73 |
74 | args.dwa_temperature = config['Optim'].getfloat('dwa_temperature', fallback=2.)
75 |
76 | args.classifier_loss_weights = np.array(json.loads(config['Optim'].get('classifier_loss_weights'))).astype(float)
77 | assert len(args.classifier_heads) == len(args.classifier_loss_weights)
78 |
79 | args.classifier_lr_mults = np.array(json.loads(config['Optim'].get('classifier_lr_mults'))).astype(float)
80 | assert len(args.classifier_heads) == len(args.classifier_lr_mults)
81 |
82 | # assert clf_type in ['l2softmax', 'adm', 'adacos', 'xvec', 'arcface', 'sphereface', 'softmax']
83 |
84 | args.classifier_smooth_types = config['Optim'].get('classifier_smooth_types').split(',')
85 | assert len(args.classifier_smooth_types) == len(args.classifier_heads)
86 | args.classifier_smooth_types = [s.strip() for s in args.classifier_smooth_types]
87 |
88 | args.label_smooth_type = config['Optim'].get('label_smooth_type', fallback='None')
89 | assert args.label_smooth_type in ['None', 'disturb', 'uniform']
90 | args.label_smooth_prob = config['Optim'].getfloat('label_smooth_prob', fallback=0.1)
91 |
92 | args.input_dim = config['Hyperparams'].getint('input_dim', fallback=30)
93 | args.embedding_dim = config['Hyperparams'].getint('embedding_dim', fallback=512)
94 | args.lr = config['Hyperparams'].getfloat('lr', fallback=0.2)
95 | args.batch_size = config['Hyperparams'].getint('batch_size', fallback=400)
96 | args.max_seq_len = config['Hyperparams'].getint('max_seq_len', fallback=400)
97 | args.no_cuda = config['Hyperparams'].getboolean('no_cuda', fallback=False)
98 | args.seed = config['Hyperparams'].getint('seed', fallback=123)
99 | args.num_iterations = config['Hyperparams'].getint('num_iterations', fallback=50000)
100 | args.momentum = config['Hyperparams'].getfloat('momentum', fallback=0.9)
101 | args.scheduler_steps = np.array(json.loads(config.get('Hyperparams', 'scheduler_steps'))).astype(int)
102 | args.scheduler_lambda = config['Hyperparams'].getfloat('scheduler_lambda', fallback=0.5)
103 | args.multi_gpu = config['Hyperparams'].getboolean('multi_gpu', fallback=False)
104 | args.classifier_lr_mult = config['Hyperparams'].getfloat('classifier_lr_mult', fallback=1.)
105 | args.dropout = config['Hyperparams'].getboolean('dropout', fallback=True)
106 |
107 | args.model_dir = config['Outputs']['model_dir']
108 | if not hasattr(args, 'basefolder'):
109 | args.basefolder = config['Outputs'].get('basefolder', fallback=None)
110 | args.log_file = os.path.join(args.model_dir, 'train.log')
111 | args.checkpoint_interval = config['Outputs'].getint('checkpoint_interval')
112 | args.results_pkl = os.path.join(args.model_dir, 'results.p')
113 |
114 | args.num_age_bins = config['Misc'].getint('num_age_bins', fallback=10)
115 | args.age_label_smoothing = config['Misc'].getboolean('age_label_smoothing', fallback=False)
116 | return args
117 |
118 |
119 | def train(ds_train):
120 | use_cuda = not args.no_cuda and torch.cuda.is_available()
121 | print('=' * 30)
122 | print('USE_CUDA SET TO: {}'.format(use_cuda))
123 | print('CUDA AVAILABLE?: {}'.format(torch.cuda.is_available()))
124 | print('=' * 30)
125 | device = torch.device("cuda" if use_cuda else "cpu")
126 |
127 | writer = SummaryWriter(comment=os.path.basename(args.cfg))
128 |
129 | if args.model_type == 'XTDNN':
130 | generator = XTDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim)
131 | if args.model_type == 'ETDNN':
132 | generator = ETDNN(features_per_frame=args.input_dim, embed_features=args.embedding_dim)
133 | if args.model_type == 'FTDNN':
134 | generator = FTDNN(in_dim=args.input_dim, embedding_dim=args.embedding_dim)
135 |
136 | generator.train()
137 | generator = generator.to(device)
138 |
139 | model_dict = {'generator': {'model': generator, 'lr_mult': 1., 'loss_weight': None}}
140 | clf_head_dict = {k: {'model': None, 'lr_mult': lr_mult, 'loss_weight': loss_weight} for k, lr_mult, loss_weight in
141 | zip(args.classifier_heads, args.classifier_lr_mults, args.classifier_loss_weights)}
142 |
143 | num_cls_per_task = [ds_train.num_classes[t] for t in args.classifier_heads]
144 |
145 | for clf_target, clf_type, num_classes, clf_smooth_type in zip(args.classifier_heads,
146 | args.classifier_types,
147 | num_cls_per_task,
148 | args.classifier_smooth_types):
149 | if clf_type == 'adm':
150 | clf = AMSMLoss(args.embedding_dim, num_classes)
151 | elif clf_type == 'adacos':
152 | clf = AdaCos(args.embedding_dim, num_classes)
153 | elif clf_type == 'l2softmax':
154 | clf = L2SoftMax(args.embedding_dim, num_classes)
155 | elif clf_type == 'softmax':
156 | clf = SoftMax(args.embedding_dim, num_classes)
157 | elif clf_type == 'xvec':
158 | clf = XVecHead(args.embedding_dim, num_classes)
159 | elif clf_type == 'xvec_regression':
160 | clf = XVecHead(args.embedding_dim, 1)
161 | elif clf_type == 'xvec_uncertain':
162 | clf = XVecHeadUncertain(args.embedding_dim, num_classes)
163 | elif clf_type == 'arcface':
164 | clf = ArcFace(args.embedding_dim, num_classes)
165 | elif clf_type == 'sphereface':
166 | clf = SphereFace(args.embedding_dim, num_classes)
167 | else:
168 | assert None, 'Classifier type {} not found'.format(clf_type)
169 |
170 | if clf_head_dict[clf_target]['loss_weight'] >= 0.0:
171 | clf_head_dict[clf_target]['model'] = clf.train().to(device)
172 | else:
173 | # GRL for negative loss weight
174 | abs_lw = np.abs(clf_head_dict[clf_target]['loss_weight'])
175 | clf_head_dict[clf_target]['model'] = nn.Sequential(
176 | GradientReversal(lambda_=abs_lw),
177 | clf
178 | ).train().to(device)
179 | clf_head_dict[clf_target]['loss_weight'] = 1.0 # this is lambda_ in the GRL
180 |
181 | if clf_smooth_type == 'none':
182 | if clf_target.endswith('regression'):
183 | clf_smooth = nn.SmoothL1Loss()
184 | else:
185 | clf_smooth = nn.CrossEntropyLoss()
186 | elif clf_smooth_type == 'twoneighbour':
187 | clf_smooth = TwoNeighbourSmoothingLoss(smoothing=args.label_smooth_prob)
188 | elif clf_smooth_type == 'uniform':
189 | clf_smooth = LabelSmoothingLoss(smoothing=args.label_smooth_prob)
190 | elif clf_smooth_type == 'disturb':
191 | clf_smooth = DisturbLabelLoss(device, disturb_prob=args.label_smooth_prob)
192 | else:
193 | assert None, 'Smooth type not found: {}'.format(clf_smooth_type)
194 |
195 | clf_head_dict[clf_target]['criterion'] = clf_smooth
196 |
197 | model_dict.update(clf_head_dict)
198 |
199 | if args.classifier_loss_weighting_type == 'uncertainty_kendall':
200 | model_dict['loss_aggregator'] = {
201 | 'model': MultiTaskUncertaintyLossKendall(len(args.classifier_heads)).to(device),
202 | 'lr_mult': 1.,
203 | 'loss_weight': None
204 | }
205 | if args.classifier_loss_weighting_type == 'uncertainty_liebel':
206 | model_dict['loss_aggregator'] = {
207 | 'model': MultiTaskUncertaintyLossLiebel(len(args.classifier_heads)).to(device),
208 | 'lr_mult': 1.,
209 | 'loss_weight': None
210 | }
211 |
212 | if args.resume_checkpoint != 0:
213 | model_str = os.path.join(args.model_dir, '{}_{}.pt')
214 | for m in model_dict:
215 | model_dict[m]['model'].load_state_dict(torch.load(model_str.format(m, args.resume_checkpoint)))
216 |
217 | optimizer = torch.optim.SGD(
218 | [{'params': model_dict[m]['model'].parameters(), 'lr': args.lr * model_dict[m]['lr_mult']} for m in model_dict],
219 | momentum=args.momentum)
220 |
221 |
222 | iterations = 0
223 |
224 | total_loss = 0
225 | running_loss = [np.nan for _ in range(500)]
226 |
227 | non_spk_clf_heads = [a for a in args.classifier_heads if a != 'speaker']
228 |
229 | best_test_eer = (-1, 1.0)
230 | best_test_dcf = (-1, 1.0)
231 | best_acc = {k: (-1, 0.0) for k in non_spk_clf_heads}
232 |
233 | if os.path.isfile(args.results_pkl) and args.resume_checkpoint != 0:
234 | rpkl = pickle.load(open(args.results_pkl, "rb"))
235 | keylist = list(rpkl.keys())
236 |
237 | if args.test_data:
238 | test_eers = [(rpkl[key]['test_eer'], key) for i, key in enumerate(rpkl)]
239 | best_teer = min(test_eers)
240 | best_test_eer = (best_teer[1], best_teer[0])
241 |
242 | test_dcfs = [(rpkl[key]['test_dcf'], key) for i, key in enumerate(rpkl)]
243 | besttest_dcf = min(test_dcfs)
244 | best_test_dcf = (besttest_dcf[1], besttest_dcf[0])
245 |
246 | else:
247 | rpkl = OrderedDict({})
248 |
249 | if args.multi_gpu:
250 | dpp_generator = nn.DataParallel(generator).to(device)
251 |
252 | data_generator = ds_train.get_batches(batch_size=args.batch_size, max_seq_len=args.max_seq_len)
253 |
254 | if args.model_type == 'FTDNN':
255 | drop_indexes = np.linspace(0, 1, args.num_iterations)
256 | drop_sch = ([0, 0.5, 1], [0, 0.5, 0])
257 | drop_schedule = np.interp(drop_indexes, drop_sch[0], drop_sch[1])
258 |
259 | for iterations in range(1, args.num_iterations + 1):
260 | if iterations > args.num_iterations:
261 | break
262 | if iterations in args.scheduler_steps:
263 | schedule_lr(optimizer, factor=args.scheduler_lambda)
264 | if iterations <= args.resume_checkpoint:
265 | print('Skipping iteration {}'.format(iterations), file=open(args.log_file, "a"))
266 | continue
267 |
268 | if args.model_type == 'FTDNN':
269 | if args.dropout:
270 | generator.set_dropout_alpha(drop_schedule[iterations - 1])
271 |
272 | feats, labels = next(data_generator)
273 | feats = feats.to(device)
274 |
275 | if args.multi_gpu:
276 | embeds = dpp_generator(feats)
277 | else:
278 | embeds = generator(feats)
279 |
280 | total_loss = 0
281 | losses = []
282 |
283 | loss_tensors = []
284 |
285 | for m in args.classifier_heads:
286 | lab = labels[m].to(device)
287 | if m == 'rec':
288 | preds = model_dict[m]['model'](embeds)
289 | else:
290 | preds = model_dict[m]['model'](embeds, lab)
291 | loss = model_dict[m]['criterion'](preds, lab)
292 | if args.classifier_loss_weighting_type == 'none':
293 | total_loss += loss * model_dict[m]['loss_weight']
294 | else:
295 | loss_tensors.append(loss)
296 | losses.append(round(loss.item(), 4))
297 |
298 | if args.classifier_loss_weighting_type.startswith('uncertainty'):
299 | loss_tensors = torch.FloatTensor(loss_tensors).to(device)
300 | total_loss = model_dict['loss_aggregator']['model'](loss_tensors)
301 |
302 | if args.classifier_loss_weighting_type == 'dwa':
303 | loss_tensors = loss_tensors
304 | if iterations < 4:
305 | loss_t_1 = np.ones(len(loss_tensors))
306 | for l in loss_tensors:
307 | total_loss += l
308 | else:
309 | dwa_w = loss_t_1/loss_t_2
310 | K = len(loss_tensors)
311 | per_task_weight = torch.FloatTensor(dwa_w/args.dwa_temperature) #lambda_k
312 | per_task_weight = torch.nn.functional.softmax(per_task_weight, dim=0) * K
313 | per_task_weight = per_task_weight.numpy()
314 | for l, w in zip(loss_tensors, per_task_weight):
315 | total_loss += l * w
316 |
317 | loss_t_2 = loss_t_1.copy()
318 | loss_t_1 = torch.FloatTensor(loss_tensors).detach().cpu().numpy()
319 |
320 |
321 | optimizer.zero_grad()
322 | total_loss.backward()
323 | optimizer.step()
324 |
325 | if args.model_type == 'FTDNN':
326 | generator.step_ftdnn_layers()
327 |
328 | running_loss.pop(0)
329 | running_loss.append(total_loss.item())
330 | rmean_loss = np.nanmean(np.array(running_loss))
331 |
332 | if iterations % 10 == 0:
333 | msg = "{}: {}: [{}/{}] \t C-Loss:{:.4f}, AvgLoss:{:.4f}, losses: {}, lr: {}, bs: {}".format(
334 | args.model_dir,
335 | time.ctime(),
336 | iterations,
337 | args.num_iterations,
338 | total_loss.item(),
339 | rmean_loss,
340 | losses,
341 | get_lr(optimizer),
342 | len(feats))
343 | print(msg)
344 | print(msg, file=open(args.log_file, "a"))
345 |
346 | writer.add_scalar('combined loss', total_loss.item(), iterations)
347 | writer.add_scalar('Avg loss', rmean_loss, iterations)
348 |
349 | if iterations % args.checkpoint_interval == 0:
350 | for m in model_dict:
351 | model_dict[m]['model'].eval().cpu()
352 | cp_filename = "{}_{}.pt".format(m, iterations)
353 | cp_model_path = os.path.join(args.model_dir, cp_filename)
354 | torch.save(model_dict[m]['model'].state_dict(), cp_model_path)
355 | model_dict[m]['model'].to(device).train()
356 |
357 | if args.test_data:
358 | rpkl, best_test_eer, best_test_dcf = eval_step(model_dict, device, ds_test, iterations, rpkl, writer,
359 | best_test_eer, best_test_dcf, best_acc)
360 |
361 | # ---- Final model saving -----
362 | for m in model_dict:
363 | model_dict[m]['model'].eval().cpu()
364 | cp_filename = "final_{}_{}.pt".format(m, iterations)
365 | cp_model_path = os.path.join(args.model_dir, cp_filename)
366 | torch.save(model_dict[m]['model'].state_dict(), cp_model_path)
367 |
368 |
369 | def eval_step(model_dict, device, ds_test, iterations, rpkl, writer, best_test_eer, best_test_dcf, best_acc):
370 | rpkl[iterations] = {}
371 | print('Evaluating on test/validation for {} iterations'.format(iterations))
372 | if args.test_data:
373 | test_eer, test_dcf, acc_dict = test_all_factors(model_dict, ds_test, device)
374 | print(acc_dict)
375 | for att in acc_dict:
376 | writer.add_scalar(att, acc_dict[att], iterations)
377 | if acc_dict[att] > best_acc[att][1]:
378 | best_acc[att] = (iterations, acc_dict[att])
379 | print('{} accuracy on Test Set: {}'.format(att, acc_dict[att]))
380 | print('{} accuracy on Test Set: {}'.format(att, acc_dict[att]), file=open(args.log_file, "a"))
381 | print('Best test {} acc: {}'.format(att, best_acc[att]))
382 | print('Best test {} acc: {}'.format(att, best_acc[att]), file=open(args.log_file, "a"))
383 | rpkl[iterations][att] = acc_dict[att]
384 |
385 | print('EER on Test Set: {} ({})'.format(test_eer, args.test_data))
386 | print('EER on Test Set: {} ({})'.format(test_eer, args.test_data), file=open(args.log_file, "a"))
387 | writer.add_scalar('test_eer', test_eer, iterations)
388 | if test_eer < best_test_eer[1]:
389 | best_test_eer = (iterations, test_eer)
390 | print('Best test EER: {}'.format(best_test_eer))
391 | print('Best test EER: {}'.format(best_test_eer), file=open(args.log_file, "a"))
392 | rpkl[iterations]['test_eer'] = test_eer
393 |
394 | print('minDCF on Test Set: {} ({})'.format(test_dcf, args.test_data))
395 | print('minDCF on Test Set: {} ({})'.format(test_dcf, args.test_data), file=open(args.log_file, "a"))
396 | writer.add_scalar('test_dcf', test_dcf, iterations)
397 | if test_dcf < best_test_dcf[1]:
398 | best_test_dcf = (iterations, test_dcf)
399 | print('Best test minDCF: {}'.format(best_test_dcf))
400 | print('Best test minDCF: {}'.format(best_test_dcf), file=open(args.log_file, "a"))
401 | rpkl[iterations]['test_dcf'] = test_dcf
402 |
403 | pickle.dump(rpkl, open(args.results_pkl, "wb"))
404 | return rpkl, best_test_eer, best_test_dcf
405 |
406 |
407 | if __name__ == "__main__":
408 | args = parse_args()
409 | args = parse_config(args)
410 | os.makedirs(args.model_dir, exist_ok=True)
411 | if args.resume_checkpoint == 0:
412 | shutil.copy(args.cfg, os.path.join(args.model_dir, 'experiment_settings.cfg'))
413 | else:
414 | shutil.copy(args.cfg, os.path.join(args.model_dir, 'experiment_settings_resume.cfg'))
415 | if os.path.exists(args.log_file):
416 | os.remove(args.log_file)
417 | pprint(vars(args))
418 | torch.manual_seed(args.seed)
419 | np.random.seed(seed=args.seed)
420 | random.seed(args.seed)
421 | uvloop.install()
422 | ds_train = SpeakerDataset(args.train_data, num_age_bins=args.num_age_bins)
423 | class_enc_dict = ds_train.get_class_encs()
424 | if args.test_data:
425 | ds_test = SpeakerDataset(args.test_data, test_mode=True,
426 | class_enc_dict=class_enc_dict, num_age_bins=args.num_age_bins)
427 | train(ds_train)
428 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from operator import itemgetter
2 | from scipy.interpolate import interp1d
3 | from scipy.optimize import brentq
4 | from sklearn.metrics import pairwise_distances, roc_curve, accuracy_score
5 | from sklearn.metrics.pairwise import paired_distances
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 |
12 | def mtd(stuff, device):
13 | if isinstance(stuff, torch.Tensor):
14 | return stuff.to(device)
15 | else:
16 | return [mtd(s, device) for s in stuff]
17 |
18 | class SpeakerRecognitionMetrics:
19 | '''
20 | This doesn't need to be a class [remnant of old structuring].
21 | To be reworked
22 | '''
23 |
24 | def __init__(self, distance_measure=None):
25 | if not distance_measure:
26 | distance_measure = 'cosine'
27 | self.distance_measure = distance_measure
28 |
29 | def get_labels_scores(self, vectors, labels):
30 | labels = labels[:, np.newaxis]
31 | pair_labels = pairwise_distances(labels, metric='hamming').astype(int).flatten()
32 | pair_scores = pairwise_distances(vectors, metric=self.distance_measure).flatten()
33 | return pair_labels, pair_scores
34 |
35 | def get_roc(self, vectors, labels):
36 | pair_labels, pair_scores = self.get_labels_scores(vectors, labels)
37 | fpr, tpr, threshold = roc_curve(pair_labels, pair_scores, pos_label=1, drop_intermediate=False)
38 | # fnr = 1. - tpr
39 | return fpr, tpr, threshold
40 |
41 | def get_eer(self, vectors, labels):
42 | fpr, tpr, _ = self.get_roc(vectors, labels)
43 | # fnr = 1 - self.tpr
44 | # eer = self.fpr[np.nanargmin(np.absolute((fnr - self.fpr)))]
45 | eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
46 | return eer
47 |
48 | def eer_from_pairs(self, pair_labels, pair_scores):
49 | self.fpr, self.tpr, self.thresholds = roc_curve(pair_labels, pair_scores, pos_label=1, drop_intermediate=False)
50 | fnr = 1 - self.tpr
51 | eer = self.fpr[np.nanargmin(np.absolute((fnr - self.fpr)))]
52 | return eer
53 |
54 | def eer_from_ers(self, fpr, tpr):
55 | fnr = 1 - tpr
56 | eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
57 | return eer
58 |
59 | def scores_from_pairs(self, vecs0, vecs1):
60 | return paired_distances(vecs0, vecs1, metric=self.distance_measure)
61 |
62 | def compute_min_dcf(self, fpr, tpr, thresholds, p_target=0.01, c_miss=10, c_fa=1):
63 | #adapted from compute_min_dcf.py in kaldi sid
64 | # thresholds, fpr, tpr = list(zip(*sorted(zip(thresholds, fpr, tpr))))
65 | incr_score_indices = np.argsort(thresholds, kind="mergesort")
66 | thresholds = thresholds[incr_score_indices]
67 | fpr = fpr[incr_score_indices]
68 | tpr = tpr[incr_score_indices]
69 |
70 | fnr = 1. - tpr
71 | min_c_det = float("inf")
72 | for i in range(0, len(fnr)):
73 | c_det = c_miss * fnr[i] * p_target + c_fa * fpr[i] * (1 - p_target)
74 | if c_det < min_c_det:
75 | min_c_det = c_det
76 |
77 | c_def = min(c_miss * p_target, c_fa * (1 - p_target))
78 | min_dcf = min_c_det / c_def
79 | return min_dcf
80 |
81 | def compute_eer(self, fnr, fpr):
82 | """ computes the equal error rate (EER) given FNR and FPR values calculated
83 | for a range of operating points on the DET curve
84 | """
85 |
86 | diff_pm_fa = fnr - fpr
87 | x1 = np.flatnonzero(diff_pm_fa >= 0)[0]
88 | x2 = np.flatnonzero(diff_pm_fa < 0)[-1]
89 | a = (fnr[x1] - fpr[x1]) / (fpr[x2] - fpr[x1] - (fnr[x2] - fnr[x1]))
90 | return fnr[x1] + a * (fnr[x2] - fnr[x1])
91 |
92 | def compute_pmiss_pfa(self, scores, labels):
93 | """ computes false positive rate (FPR) and false negative rate (FNR)
94 | given trial scores and their labels. A weights option is also provided
95 | to equalize the counts over score partitions (if there is such
96 | partitioning).
97 | """
98 |
99 | sorted_ndx = np.argsort(scores)
100 | labels = labels[sorted_ndx]
101 |
102 | tgt = (labels == 1).astype('f8')
103 | imp = (labels == 0).astype('f8')
104 |
105 | fnr = np.cumsum(tgt) / np.sum(tgt)
106 | fpr = 1 - np.cumsum(imp) / np.sum(imp)
107 | return fnr, fpr
108 |
109 |
110 | def compute_min_cost(self, scores, labels, p_target=0.01):
111 | fnr, fpr = self.compute_pmiss_pfa(scores, labels)
112 | eer = self.compute_eer(fnr, fpr)
113 | min_c = self.compute_c_norm(fnr, fpr, p_target)
114 | return eer, min_c
115 |
116 | def compute_c_norm(self, fnr, fpr, p_target, c_miss=10, c_fa=1):
117 | """ computes normalized minimum detection cost function (DCF) given
118 | the costs for false accepts and false rejects as well as a priori
119 | probability for target speakers
120 | """
121 |
122 | dcf = c_miss * fnr * p_target + c_fa * fpr * (1 - p_target)
123 | c_det = np.min(dcf)
124 | c_def = min(c_miss * p_target, c_fa * (1 - p_target))
125 |
126 | return c_det/c_def
127 |
128 |
129 |
130 | def warm_up_lr(batch, num_batch_warm_up, init_lr, optimizer):
131 | for params in optimizer.param_groups:
132 | params['lr'] = batch * init_lr / num_batch_warm_up
133 |
134 | def schedule_lr(optimizer, factor=0.1):
135 | for params in optimizer.param_groups:
136 | params['lr'] *= factor
137 | print(optimizer)
138 |
139 | def set_lr(optimizer, lr):
140 | for params in optimizer.param_groups:
141 | params['lr'] = lr
142 | print(optimizer)
143 |
144 |
145 |
--------------------------------------------------------------------------------
/voxceleb_data_prep/data/nationality_to_country.tsv:
--------------------------------------------------------------------------------
1 | Nationality Country
2 | Abkhazia Abkhazia
3 | Abkhaz Abkhazia
4 | Abkhazian Abkhazia
5 | Afghanistan Afghanistan
6 | Afghan Afghanistan
7 | Åland Islands Finland
8 | Åland Island Finland
9 | Albania Albania
10 | Albanian Albania
11 | Algeria Algeria
12 | Algerian Algeria
13 | American Samoa American Samoa
14 | American Samoan American Samoa
15 | Andorra Andorra
16 | Andorran Andorra
17 | Angola Angola
18 | Angolan Angola
19 | Anguilla Anguilla
20 | Anguillan Anguilla
21 | Antigua and Barbuda Antigua and Barbuda
22 | Antiguan Antigua and Barbuda
23 | Barbudan Antigua and Barbuda
24 | Argentina Argentina
25 | Argentine Argentina
26 | Argentinian Argentina
27 | Armenia Armenia
28 | Armenian Armenia
29 | Aruba Aruba
30 | Aruban Aruba
31 | Australia Australia
32 | Australian Australia
33 | Austria Austria
34 | Austrian Austria
35 | Azerbaijan Azerbaijan
36 | Azerbaijani Azerbaijan
37 | Azeri Azerbaijan
38 | Bahamas Bahamas
39 | Bahamian Bahamas
40 | Bahrain Bahrain
41 | Bahraini Bahrain
42 | Bangladesh Bangladesh
43 | Bangladeshi Bangladesh
44 | Barbados Barbados
45 | Barbadian Barbados
46 | Belarus Belarus
47 | Belarusian Belarus
48 | Belgium Belgium
49 | Belgian Belgium
50 | Belize Belize
51 | Belizean Belize
52 | Benin Benin
53 | Beninese Benin
54 | Beninois Benin
55 | Bermuda Bermuda
56 | Bermudian Bermuda
57 | Bermudan Bermuda
58 | Bhutan Bhutan
59 | Bhutanese Bhutan
60 | Bolivia Bolivia
61 | Bolivian Bolivia
62 | Bonaire Curaçao
63 | Bosnia and Herzegovina Bosnia and Herzegovina
64 | Bosnian Bosnia and Herzegovina
65 | Herzegovinian Bosnia and Herzegovina
66 | Botswana Botswana
67 | Motswana Botswana
68 | Botswanan Botswana
69 | Batswana Botswana
70 | Bouvet Island Norway
71 | Bouvet Islander Norway
72 | Brazil Brazil
73 | Brazilian Brazil
74 | British Indian Ocean Territory United Kingdom
75 | BIOT United Kingdom
76 | Brunei Brunei
77 | Bruneian Brunei
78 | Bulgaria Bulgaria
79 | Bulgarian Bulgaria
80 | Burkina Faso Burkina Faso
81 | Burkinabé Burkina Faso
82 | Burkinabè/Burkinabé Burkina Faso
83 | Burma Myanmar
84 | Burmese Myanmar
85 | Bamar Myanmar
86 | Burundi Burundi
87 | Burundian Burundi
88 | Barundi Burundi
89 | Cabo Verde Cape Verde
90 | Cabo Verdean Cape Verde
91 | Cambodia Cambodia
92 | Cambodian Cambodia
93 | Cameroon Cameroon
94 | Cameroonian Cameroon
95 | Canada Canada
96 | Canadian Canada
97 | Cayman Islands Cayman Islands
98 | Caymanian Cayman Islands
99 | Central African Republic Central African Republic
100 | Central African Central African Republic
101 | Chad Chad
102 | Chadian Chad
103 | Chile Chile
104 | Chilean Chile
105 | China China
106 | Chinese China
107 | Taiwanese Taiwan
108 | Formosan Taiwan
109 | Christmas Island Christmas Island
110 | Christmas Islanders Christmas Island
111 | Cocos (Keeling) Islands Cocos Islands
112 | Cocos Island Cocos Islands
113 | Colombia Colombia
114 | Colombian Colombia
115 | Comoros Comoros
116 | Comoran Comoros
117 | Comorian Comoros
118 | Democratic Republic of the Congo Democratic Republic of the Congo
119 | Congolese Congo
120 | Congo Congo
121 | Republic of the Congo Congo
122 | Cook Islands Cook Islands
123 | Cook Island Cook Islands
124 | Costa Rica Costa Rica
125 | Costa Rican Costa Rica
126 | Croatia Croatia
127 | Croatian Croatia
128 | Croats Croatia
129 | Cuba Cuba
130 | Cuban Cuba
131 | Curaçao Curaçao
132 | Curaçaoan Curaçao
133 | Cyprus Cyprus
134 | Cypriot Cyprus
135 | Czech Republic Czech Republic
136 | Czech Czech Republic
137 | Denmark Denmark
138 | Danish Denmark
139 | Djibouti Djibouti
140 | Djiboutian Djibouti
141 | Dominica Dominica
142 | Dominican Dominican Republic
143 | Dominicans Dominican Republic
144 | Dominican Republic Dominican Republic
145 | East Timor East Timor
146 | Timorese East Timor
147 | Ecuador Ecuador
148 | Ecuadorian Ecuador
149 | Egypt Egypt
150 | Egyptian Egypt
151 | El Salvador El Salvador
152 | Salvadoran El Salvador
153 | Salvadorean El Salvador
154 | England United Kingdom
155 | English United Kingdom
156 | Equatorial Guinea Equatorial Guinea
157 | Equatorial Guinean Equatorial Guinea
158 | Equatoguinean Equatorial Guinea
159 | Eritrea Eritrea
160 | Eritrean Eritrea
161 | Estonia Estonia
162 | Estonian Estonia
163 | Eswatini (Swaziland) Swaziland
164 | Swazi Swaziland
165 | Swati Swaziland
166 | Ethiopia Ethiopia
167 | Ethiopian Ethiopia
168 | Habesha Ethiopia
169 | Falkland Islands Falkland Islands
170 | Falkland Island Falkland Islands
171 | Faroe Islands Faroe Islands
172 | Faroese Faroe Islands
173 | Fiji Fiji
174 | Fijian Fiji
175 | Finland Finland
176 | Finnish Finland
177 | France France
178 | French Guiana French Guiana
179 | French Guianese French Guiana
180 | French Polynesia French Polynesia
181 | French Polynesian French Polynesia
182 | French Southern Territories France
183 | French France
184 | Gabon Gabon
185 | Gabonese Gabon
186 | Gabonaise Gabon
187 | Gambia Gambia
188 | Gambian Gambia
189 | Gambians Gambia
190 | Georgia Georgia
191 | Georgian Georgia
192 | Germany Germany
193 | German Germany
194 | Ghana Ghana
195 | Ghanaian Ghana
196 | Gibraltar Gibraltar
197 | Gibraltarian Gibraltar
198 | Greece Greece
199 | Greek Greece
200 | Greenland Greenland
201 | Greenlandic Greenland
202 | Grenada Grenada
203 | Grenadian Grenada
204 | Guadeloupe Dominica
205 | Guadeloupians Dominica
206 | Guadeloupeans Dominica
207 | Guam Guam
208 | Guamanian Guam
209 | Guatemala Guatemala
210 | Guatemalan Guatemala
211 | Guernsey Guernsey
212 | Channel Island Guernsey
213 | Channel Islanders Guernsey
214 | Guinea Guinea
215 | Guinean Guinea
216 | Guinea-Bissau Guinea-Bissau
217 | Bissau-Guinean Guinea-Bissau
218 | Guyana Guyana
219 | Guyanese Guyana
220 | Haiti Haiti
221 | Haitian Haiti
222 | Honduras Honduras
223 | Honduran Honduras
224 | Hondurans Honduras
225 | Hong Kong Hong Kong
226 | Cantonese China
227 | Hong Konger Hong Kong
228 | Hongkongers Hong Kong
229 | Hong Kongese Hong Kong
230 | Hungary Hungary
231 | Hungarian Hungary
232 | Magyar Hungary
233 | Iceland Iceland
234 | Icelandic Iceland
235 | Icelander Iceland
236 | India India
237 | Indian India
238 | Indonesia Indonesia
239 | Indonesian Indonesia
240 | Iran Iran
241 | Iranian Iran
242 | Persian Iran
243 | Iraq Iraq
244 | Iraqi Iraq
245 | Ireland Ireland
246 | Irish Ireland
247 | Isle of Man Isle of Man
248 | Manx Isle of Man
249 | Israel Israel
250 | Israeli Israel
251 | Italy Italy
252 | Italian Italy
253 | Ivory Coast Ivory Coast
254 | Ivorian Ivory Coast
255 | Jamaica Jamaica
256 | Jamaican Jamaica
257 | Jan Mayen Norway
258 | Jan Mayen Norway
259 | Jan Mayen residents Norway
260 | Japan Japan
261 | Japanese Japan
262 | Jersey Jersey
263 | Channel Island Jersey
264 | Channel Islanders Jersey
265 | Jordan Jordan
266 | Jordanian Jordan
267 | Kazakhstan Kazakhstan
268 | Kazakhstani Kazakhstan
269 | Kazakh Kazakhstan
270 | Kenya Kenya
271 | Kenyan Kenya
272 | Kiribati Kiribati
273 | I-Kiribati Kiribati
274 | Korean South Korea
275 | North Korea North Korea
276 | South Korea South Korea
277 | North Korean North Korea
278 | South Korean South Korea
279 | Kosovo Kosovo
280 | Kosovar Kosovo
281 | Kosovan Kosovo
282 | Kuwait Kuwait
283 | Kuwaiti Kuwait
284 | Kyrgyzstan Kyrgyzstan
285 | Kyrgyzstani Kyrgyzstan
286 | Kyrgyz Kyrgyzstan
287 | Kirgiz Kyrgyzstan
288 | Kirghiz Kyrgyzstan
289 | Lao People's Democratic Republic Laos
290 | Lao Laos
291 | Laotian Laos
292 | Laos Laos
293 | Latvia Latvia
294 | Latvian Latvia
295 | Lettish Latvia
296 | Lebanon Lebanon
297 | Lebanese Lebanon
298 | Lesotho Lesotho
299 | Basotho Lesotho
300 | Liberia Liberia
301 | Liberian Liberia
302 | Libya Libya
303 | Libyan Libya
304 | Liechtenstein Liechtenstein
305 | Liechtensteiner Liechtenstein
306 | Lithuania Lithuania
307 | Lithuanian Lithuania
308 | Luxembourg Luxembourg
309 | Luxembourgish Luxembourg
310 | Luxembourger Luxembourg
311 | Macau Macau
312 | Macanese Macau
313 | Macedonia North Macedonia
314 | Macedonian North Macedonia
315 | Madagascar Madagascar
316 | Malagasy Madagascar
317 | Malawi Malawi
318 | Malawian Malawi
319 | Malaysia Malaysia
320 | Malaysian Malaysia
321 | Maldives Maldives
322 | Maldivian Maldives
323 | Mali Mali
324 | Malian Mali
325 | Malinese Mali
326 | Malta Malta
327 | Maltese Malta
328 | Marshall Islands Marshall Islands
329 | Marshallese Marshall Islands
330 | Martinique Saint Lucia
331 | Martiniquais Saint Lucia
332 | Martinican Saint Lucia
333 | Martiniquaises Saint Lucia
334 | Mauritania Mauritania
335 | Mauritanian Mauritania
336 | Mauritius Mauritius
337 | Mauritian Mauritius
338 | Mayotte Comoros
339 | Mahoran Comoros
340 | Mexico Mexico
341 | Mexican Mexico
342 | Micronesia Federated States of Micronesia
343 | Micronesian Federated States of Micronesia
344 | Moldova Moldova
345 | Moldovan Moldova
346 | Monaco Monaco
347 | Monégasque Monaco
348 | Monacan Monaco
349 | Mongolia Mongolia
350 | Mongolian Mongolia
351 | Mongol Mongolia
352 | Montenegro Montenegro
353 | Montenegrin Montenegro
354 | Montserrat Montserrat
355 | Montserratian Montserrat
356 | Morocco Morocco
357 | Moroccan Morocco
358 | Mozambique Mozambique
359 | Mozambican Mozambique
360 | Myanmar Myanmar
361 | Burmese Myanmar
362 | Bamar Myanmar
363 | Namibia Namibia
364 | Namibian Namibia
365 | Nauru Nauru
366 | Nauruan Nauru
367 | Nepal Nepal
368 | Nepali Nepal
369 | Nepalese Nepal
370 | Netherlands Netherlands
371 | Dutch Netherlands
372 | Netherlandic Netherlands
373 | New Caledonia New Caledonia
374 | New Caledonian New Caledonia
375 | New Zealand New Zealand
376 | New Zealandic New Zealand
377 | New Zealander New Zealand
378 | Nicaragua Nicaragua
379 | Nicaraguan Nicaragua
380 | Niger Niger
381 | Nigerien Niger
382 | Nigeria Nigeria
383 | Nigerian Nigeria
384 | Niue Niue
385 | Niuean Niue
386 | Norfolk Island Norfolk Island
387 | Norfolk Islanders Norfolk Island
388 | Northern Ireland United Kingdom
389 | Northern Irish United Kingdom
390 | Northern Mariana Islands Northern Mariana Islands
391 | Northern Marianan Northern Mariana Islands
392 | Norway Norway
393 | Norwegian Norway
394 | Oman Oman
395 | Omani Oman
396 | Pakistan Pakistan
397 | Pakistani Pakistan
398 | Palau Palau
399 | Palauan Palau
400 | Palestine Palestine
401 | Palestinian Palestine
402 | Panama Panama
403 | Panamanian Panama
404 | Papua New Guinea Papua New Guinea
405 | Papua New Guinean Papua New Guinea
406 | Paraguay Paraguay
407 | Paraguayan Paraguay
408 | Peru Peru
409 | Peruvian Peru
410 | Philippines Philippines
411 | Philippine Philippines
412 | Filipino Philippines
413 | Filipina Philippines
414 | Pitcairn Islands Pitcairn Islands
415 | Pitcairn Island Pitcairn Islands
416 | Pitcairn Islander Pitcairn Islands
417 | Poland Poland
418 | Polish Poland
419 | Portugal Portugal
420 | Portuguese Portugal
421 | Puerto Rico Puerto Rico
422 | Puerto Rican Puerto Rico
423 | Qatar Qatar
424 | Qatari Qatar
425 | Réunion Mauritius
426 | Réunionese Mauritius
427 | Réunionnais Mauritius
428 | Romania Romania
429 | Romanian Romania
430 | Russia Russia
431 | Russian Russia
432 | Rwanda Rwanda
433 | Rwandan Rwanda
434 | Banyarwanda Rwanda
435 | Saba Saint Kitts and Nevis
436 | Saba Dutch Saint Kitts and Nevis
437 | Saint Barthélemy Saint Barthélemy
438 | Barthélemois Saint Barthélemy
439 | Barthélemois/Barthélemoises Saint Barthélemy
440 | Saint Kitts and Nevis Saint Kitts and Nevis
441 | Kittitian Saint Kitts and Nevis
442 | Nevisian Saint Kitts and Nevis
443 | Saint Lucia Saint Lucia
444 | Saint Lucian Saint Lucia
445 | Saint Martin Saint Martin
446 | Saint-Martinoise Saint Martin
447 | Saint-Martinois/Saint-Martinoises Saint Martin
448 | Saint Pierre and Miquelon Saint Pierre and Miquelon
449 | Saint-Pierrais Saint Pierre and Miquelon
450 | Miquelonnais Saint Pierre and Miquelon
451 | Saint-Pierraises Saint Pierre and Miquelon
452 | Miquelonnaises Saint Pierre and Miquelon
453 | Saint Vincent and the Grenadines Saint Vincent and the Grenadines
454 | Saint Vincentian Saint Vincent and the Grenadines
455 | Vincentian Saint Vincent and the Grenadines
456 | Samoa Samoa
457 | Samoan Samoa
458 | San Marino San Marino
459 | Sammarinese San Marino
460 | São Tomé and Príncipe São Tomé and Príncipe
461 | São Toméan São Tomé and Príncipe
462 | Saudi Arabia Saudi Arabia
463 | Saudi Saudi Arabia
464 | Saudi Arabian Saudi Arabia
465 | Scotland United Kingdom
466 | Scottish United Kingdom
467 | Senegal Senegal
468 | Senegalese Senegal
469 | Serbia Serbia
470 | Serbian Serbia
471 | Seychelles Seychelles
472 | Seychellois Seychelles
473 | Seychellois/Seychelloises Seychelles
474 | Sierra Leone Sierra Leone
475 | Sierra Leonean Sierra Leone
476 | Singapore Singapore
477 | Singaporean Singapore
478 | Sint Eustatius Saint Kitts and Nevis
479 | Statian Saint Kitts and Nevis
480 | Sint Maarten Sint Maarten
481 | Sint Maartener Sint Maarten
482 | Slovakia Slovakia
483 | Slovak Slovakia
484 | Slovakian Slovakia
485 | Slovenia Slovenia
486 | Slovenian Slovenia
487 | Slovene Slovenia
488 | Solomon Islands Solomon Islands
489 | Solomon Island Solomon Islands
490 | Solomon Islander Solomon Islands
491 | Somalia Somalia
492 | Somali Somalia
493 | Somaliland Somalia
494 | Somalilander Somalia
495 | South Africa South Africa
496 | South Africa South Africa
497 | South Georgia and the South Sandwich Islands United Kingdom
498 | South Georgia Island United Kingdom
499 | South Sandwich Island United Kingdom
500 | South Ossetia South Ossetia
501 | South Ossetian South Ossetia
502 | South Sudan South Sudan
503 | South Sudanese South Sudan
504 | Spain Spain
505 | Spanish Spain
506 | Spaniard Spain
507 | Sri Lanka Sri Lanka
508 | Sri Lankan Sri Lanka
509 | Sudan Sudan
510 | Sudanese Sudan
511 | Suriname Suriname
512 | Surinamese Suriname
513 | Surinamer Suriname
514 | Svalbard Norway
515 | Svalbard resident Norway
516 | Swaziland Eswatini
517 | Swazi Eswatini
518 | Swati Eswatini
519 | Swazis Eswatini
520 | Sweden Sweden
521 | Swedish Sweden
522 | Switzerland Switzerland
523 | Swiss Switzerland
524 | Syria Syria
525 | Syrian Syria
526 | Tajikistan Tajikistan
527 | Tajikistani Tajikistan
528 | Tajiks Tajikistan
529 | Tanzania Tanzania
530 | Tanzanian Tanzania
531 | Thailand Thailand
532 | Thai Thailand
533 | Timor-Leste East Timor
534 | Timorese East Timor
535 | Togo Togo
536 | Togolese Togo
537 | Tokelau Tokelau
538 | Tokelauan Tokelau
539 | Tokelauans Tokelau
540 | Tonga Tonga
541 | Tongan Tonga
542 | Trinidad and Tobago Trinidad and Tobago
543 | Trinidadian Trinidad and Tobago
544 | Tobagonian Trinidad and Tobago
545 | Tunisia Tunisia
546 | Tunisian Tunisia
547 | Turkey Turkey
548 | Turkish Turkey
549 | Turkmenistan Turkmenistan
550 | Turkmen Turkmenistan
551 | Turks and Caicos Islands Turks and Caicos Islands
552 | Turks and Caicos Island Turks and Caicos Islands
553 | Turks and Caicos Islander Turks and Caicos Islands
554 | Tuvalu Tuvalu
555 | Tuvaluan Tuvalu
556 | Uganda Uganda
557 | Ugandan Uganda
558 | Ukraine Ukraine
559 | Ukrainian Ukraine
560 | United Arab Emirates United Arab Emirates
561 | Emirati United Arab Emirates
562 | Emirian United Arab Emirates
563 | Emiri United Arab Emirates
564 | Great Britain United Kingdom
565 | United Kingdom United Kingdom
566 | British United Kingdom
567 | UK United Kingdom
568 | U.K. United Kingdom
569 | U.K United Kingdom
570 | United States of America United States
571 | United States United States
572 | U.S. United States
573 | American United States
574 | U.S.A United States
575 | U.S.A. United States
576 | USA United States
577 | Uruguay Uruguay
578 | Uruguayan Uruguay
579 | Uzbekistan Uzbekistan
580 | Uzbekistani Uzbekistan
581 | Uzbek Uzbekistan
582 | Vanuatu Vanuatu
583 | Vanuatuan Vanuatu
584 | Ni-Vanuatu Vanuatu
585 | Vatican City Vatican City
586 | Vatican Vatican City
587 | Vatican citizen Vatican City
588 | Vatican City State Vatican City
589 | Venezuelan Venezuela
590 | Venezuela Venezuela
591 | Vietnam Vietnam
592 | Vietnamese Vietnam
593 | British Virgin Islands British Virgin Islands
594 | British Virgin Island British Virgin Islands
595 | British Virgin Islanders British Virgin Islands
596 | Virgin Islands U.S. Virgin Islands
597 | U.S. Virgin Island U.S. Virgin Islands
598 | U.S. Virgin Islanders U.S. Virgin Islands
599 | Wales United Kingdom
600 | Welsh United Kingdom
601 | Wallis and Futuna Wallis and Futuna
602 | Wallisian Wallis and Futuna
603 | Futunan Wallis and Futuna
604 | Western Sahara Western Sahara
605 | Sahrawi Western Sahara
606 | Sahrawian Western Sahara
607 | Sahraouian Western Sahara
608 | Yemen Yemen
609 | Yemeni Yemen
610 | Zambia Zambia
611 | Zambian Zambia
612 | Zimbabwe Zimbabwe
613 | Zimbabwean Zimbabwe
614 |
--------------------------------------------------------------------------------
/voxceleb_data_prep/data/us_states.csv:
--------------------------------------------------------------------------------
1 | States
2 | Alabama
3 | Alaska
4 | Arizona
5 | Arkansas
6 | California
7 | Colorado
8 | Connecticut
9 | Delaware
10 | Florida
11 | Georgia
12 | Hawaii
13 | Idaho
14 | Illinois
15 | Indiana
16 | Iowa
17 | Kansas
18 | Kentucky
19 | Louisiana
20 | Maine
21 | Maryland
22 | Massachusetts
23 | Michigan
24 | Minnesota
25 | Mississippi
26 | Missouri
27 | Montana
28 | Nebraska
29 | Nevada
30 | New Hampshire
31 | New Jersey
32 | New Mexico
33 | New York
34 | North Carolina
35 | North Dakota
36 | Ohio
37 | Oklahoma
38 | Oregon
39 | Pennsylvania
40 | Rhode Island
41 | South Carolina
42 | South Dakota
43 | Tennessee
44 | Texas
45 | Utah
46 | Vermont
47 | Virginia
48 | Washington
49 | West Virginia
50 | Wisconsin
51 | Wyoming
52 |
--------------------------------------------------------------------------------
/voxceleb_data_prep/scrape_nationalities.py:
--------------------------------------------------------------------------------
1 | from collections import Counter, OrderedDict
2 | from pprint import pprint
3 |
4 | import pandas as pd
5 | import spacy
6 | import wikipedia
7 | import wptools
8 | from tqdm import tqdm
9 | from wikipedia import DisambiguationError
10 |
11 |
12 | def wiki_infobox(text):
13 | try:
14 | page = wptools.page(text, silent=True).get_parse()
15 | infobox = page.data['infobox']
16 | except:
17 | infobox = {}
18 | return infobox
19 |
20 |
21 | if __name__ == "__main__":
22 | # Meta information of vox1, same as original filename
23 | vox1 = pd.read_csv('./data/vox1_meta.csv', delimiter='\t')
24 |
25 | # Meta information of vox2, same as original filename
26 | vox2 = pd.read_csv('./data/vox2_meta.csv')
27 |
28 | # Meta information of vggface2, original filename meta/identity_meta.csv
29 | vgg2 = pd.read_csv('./data/vggface2_meta.csv', quotechar='"', skipinitialspace=True)
30 |
31 | us_states = set(pd.read_csv('./data/us_states.csv')['States'].str.lower().values)
32 |
33 | vgg_id_to_name = {k:v.strip() for k,v in zip(vgg2['Class_ID'].values, vgg2['Name'].values)}
34 | vox2_ids_dict = {k:v.strip() for k,v in zip(vox2['VoxCeleb2 ID '].values, vox2['VGGFace2 ID '])}
35 | vox2_id_to_name = {k:vgg_id_to_name[vox2_ids_dict[k]] for k in vox2_ids_dict}
36 | vox2_name_to_id = {k:v for v, k in vox2_id_to_name.items()}
37 |
38 | natcountry = pd.read_csv('./data/nationality_to_country.tsv', delimiter='\t')
39 | country_set = set(natcountry.Country.values)
40 | country_nat_dict = {k.lower():[] for k in country_set}
41 |
42 | for c, n in zip(natcountry.Country.values, natcountry.Nationality.values):
43 | country_nat_dict[c.lower()].append(n.lower())
44 |
45 | nat_country_dict = {n.lower():c.lower() for c, n in zip(natcountry.Country.values, natcountry.Nationality.values)}
46 |
47 | # Reorder country nat_dict based on demographics of vox1
48 | # Most populous first
49 | # Then by length of country name
50 | # This is to make sure the country names which are substrings are not checked first
51 | common_nats = vox1.Nationality.str.lower().value_counts().keys()
52 |
53 | common_nats_keys = []
54 | for c in common_nats:
55 | common_nats_keys.append(nat_country_dict[c])
56 |
57 | ordered_country_nat_dict = OrderedDict({})
58 | for c in common_nats_keys:
59 | ordered_country_nat_dict[c] = country_nat_dict[c]
60 |
61 | countries_by_len = sorted(country_nat_dict.keys(), key=len, reverse=True)
62 |
63 | for c in countries_by_len:
64 | if c not in ordered_country_nat_dict:
65 | ordered_country_nat_dict[c] = country_nat_dict[c]
66 |
67 | vox2_names = list(vox2_id_to_name.values())
68 | vox2_nationalities = OrderedDict({k:[] for k in vox2_names})
69 |
70 | nlp = spacy.load("en_core_web_sm")
71 |
72 | for i, name in enumerate(tqdm(vox2_nationalities)):
73 | if vox2_nationalities[name]:
74 | continue
75 |
76 | # Get the wikipedia page and summary text
77 | qname = ' '.join(name.split('_'))
78 | try:
79 | text = wikipedia.summary(qname)
80 | except:
81 | search = wikipedia.search(qname, results=3)
82 | if len(search) == 0:
83 | print(name)
84 | continue
85 | else:
86 | index = 0
87 | while True:
88 | if index >= len(search):
89 | qname = 'nan'
90 | text = ''
91 | break
92 | try:
93 | qname = search[index]
94 | text = wikipedia.summary(qname, auto_suggest=False)
95 | break
96 | except DisambiguationError:
97 | index += 1
98 |
99 | if qname == 'nan':
100 | #Couldn't find a good wikipedia page
101 | print(name)
102 | continue
103 |
104 | #Try the infobox first
105 | try:
106 | person_infobox = wiki_infobox(qname)
107 | if 'birth_place' in person_infobox:
108 | place = person_infobox['birth_place'].replace('[', '').replace(']', '').lower()
109 |
110 | for s in us_states:
111 | if s in place:
112 | vox2_nationalities[name] = ['united states']
113 | break
114 |
115 | if vox2_nationalities[name]:
116 | continue
117 |
118 | for c in ordered_country_nat_dict:
119 | if c in place:
120 | vox2_nationalities[name] = [c]
121 | break
122 |
123 | if vox2_nationalities[name]:
124 | continue
125 |
126 | place_infobox = wiki_infobox(place)
127 | if 'subdivision_name' in place_infobox:
128 | subd_country = place_infobox['subdivision_name'].lower()
129 | if subd_country in country_nat_dict:
130 | vox2_nationalities[name] = [subd_country]
131 | continue
132 | except:
133 | pass
134 |
135 |
136 | #Otherwise try the summary text
137 | doc = nlp(text)
138 | all_stopwords = nlp.Defaults.stop_words
139 |
140 | nat_candidates = []
141 | for j, tok in enumerate(doc):
142 | if tok.text.lower() in nat_country_dict:
143 | if doc[j-1].text.lower() == 'new':
144 | continue
145 | if tok.text.lower() == 'american' and doc[j+1].text.lower() == 'football':
146 | continue
147 | nat_candidates.append(nat_country_dict[tok.text.lower()])
148 | if doc[j+1].text.lower() == '-':
149 | if doc[j+2].text.lower() == 'born':
150 | c = doc[j+3].text.lower()
151 | if c in nat_country_dict:
152 | nat_candidates = [nat_country_dict[c]]
153 | break
154 | vox2_nationalities[name] = nat_candidates
155 |
156 | vox2_nats_final = OrderedDict({})
157 | for name in vox2_nationalities:
158 | # Take the most frequent nat in the list
159 | # In case of ties, takes the first occuring one
160 | if vox2_nationalities[name]:
161 | best_nat = Counter(vox2_nationalities[name]).most_common(1)[0][0]
162 | vox2_nats_final[name] = best_nat
163 | else:
164 | vox2_nats_final[name] = 'unk'
165 |
166 |
167 | with open('./spk2nat', 'w') as fp:
168 | for name in vox2_nats_final:
169 | id = vox2_name_to_id[name]
170 | nat = '_'.join(vox2_nats_final[name].split())
171 | line = '{} {}\n'.format(id, nat)
172 | fp.write(line)
173 |
174 |
--------------------------------------------------------------------------------