├── .gitattributes ├── .gitignore ├── .gitmodules ├── README.md ├── configs └── baseline2.conf ├── dataloaders └── backend_fusion.py ├── embeddings ├── asv_embd_dev.pk ├── asv_embd_eval.pk ├── asv_embd_trn.pk ├── cm_embd_dev.pk ├── cm_embd_eval.pk ├── cm_embd_trn.pk ├── spk_model_dev.pk └── spk_model_eval.pk ├── main.py ├── metrics.py ├── models └── baseline2.py ├── pdfs ├── 2022_SASV_evaluation_plan_v0.1.pdf └── 2022_SASV_evaluation_plan_v0.2.pdf ├── protocols ├── ASVspoof2019.LA.asv.dev.gi.trl.txt ├── ASVspoof2019.LA.asv.eval.gi.trl.txt ├── ASVspoof2019.LA.cm.dev.trl.txt ├── ASVspoof2019.LA.cm.eval.trl.txt └── ASVspoof2019.LA.cm.train.trn.txt ├── requirements.txt ├── save_embeddings.py ├── schedulers.py ├── spk_meta ├── spk_meta_dev.pk ├── spk_meta_eval.pk └── spk_meta_trn.pk ├── systems └── baseline2.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | embeddings/cm_embd_trn.pk filter=lfs diff=lfs merge=lfs -text 2 | embeddings/spk_model_dev.pk filter=lfs diff=lfs merge=lfs -text 3 | embeddings/spk_model_eval.pk filter=lfs diff=lfs merge=lfs -text 4 | embeddings/asv_embd_dev.pk filter=lfs diff=lfs merge=lfs -text 5 | embeddings/asv_embd_eval.pk filter=lfs diff=lfs merge=lfs -text 6 | embeddings/asv_embd_trn.pk filter=lfs diff=lfs merge=lfs -text 7 | embeddings/cm_embd_dev.pk filter=lfs diff=lfs merge=lfs -text 8 | embeddings/cm_embd_eval.pk filter=lfs diff=lfs merge=lfs -text 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | *pycache* 3 | *.zip 4 | *tmp* 5 | LA* 6 | exp_result 7 | 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ECAPATDNN"] 2 | path = ECAPATDNN 3 | url = https://github.com/TaoRuijie/ECAPATDNN 4 | [submodule "aasist"] 5 | path = aasist 6 | url = https://github.com/clovaai/aasist 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | This repository contains several materials that supplements the Spoofing-Aware Speaker Verification (SASV) Challenge 2022 including: 3 | - calculating metrics; 4 | - extracting speaker/spoofing embeddings from pre-trained models; 5 | - training/evaluating Baseline2 in the evaluation plan. 6 | 7 | More information can be found in the [webpage](https://sasv-challenge.github.io) and the [evaluation plan](pdfs/2022_SASV_evaluation_plan_v0.2.pdf) 8 | 9 | ### Prerequisites 10 | #### Load ECAPA-TDNN & AASIST repositories 11 | ``` 12 | git submodule init 13 | git submodule update 14 | ``` 15 | 16 | #### Install requirements 17 | ``` 18 | pip install -r requirements.txt 19 | ``` 20 | ### Data preparation 21 | The ASVspoof2019 LA dataset [1] can be downloaded using the scipt in AASIST [2] repository 22 | ``` 23 | python ./aasist/download_dataset.py 24 | ``` 25 | 26 | ### Speaker & spoofing embedding extraction 27 | Speaker embeddings and spoofing embeddings can be extracted using below script. 28 | Extracted embeddings will be saved in `./embeddings`. 29 | - Speaker embeddings are extracted using the ECAPA-TDNN [3]. 30 | - Implmented by https://github.com/TaoRuijie/ECAPATDNN 31 | - Spoofing embeddings are extracted using the AASIST [2]. 32 | - We also prepared extracted embeddings. 33 | - To use prepared emebddings, git-lfs is required. Please refer to [https://git-lfs.github.com](https://git-lfs.github.com) for further instruction. After installing git-lfs use following command to download the embeddings. 34 | ``` 35 | git-lfs install 36 | git-lfs pull 37 | ``` 38 | 39 | 40 | ``` 41 | python save_embeddings.py 42 | ``` 43 | 44 | ## Baseline 2 Training 45 | Run below script to train Baseline2 in the evaluation plan. 46 | - It will reproduce **Baseline2** described in the Evaluation plan. 47 | ``` 48 | python main.py --config ./configs/baseline2.conf 49 | ``` 50 | 51 | ## Developing own models 52 | - Currently adding... 53 | 54 | ### Adding custom DNN architecture 55 | 1. create new file under `./models/`. 56 | 2. create a new configuration file under `./configs` 57 | 3. in the new configuration, modify `model_arch` and add required arguments in 58 | `model_config`. 59 | 4. run `python main.py --config {USER_CONFIG_FILE}` 60 | ### Using only metrics 61 | Use `get_all_EERs` in `metrics.py` to calculate all three EERs. 62 | - prediction scores and keys should be passed on using 63 | - `protocols/ASVspoof2019.LA.asv.dev.gi.trl.txt` or 64 | - `protocols/ASVspoof2019.LA.asv.eval.gi.trl.txt` 65 | 66 | ## References 67 | [1] ASVspoof 2019: A large-scale public database of synthesized, converted and replayed speech 68 | ```bibtex 69 | @article{wang2020asvspoof, 70 | title={ASVspoof 2019: A large-scale public database of synthesized, converted and replayed speech}, 71 | author={Wang, Xin and Yamagishi, Junichi and Todisco, Massimiliano and Delgado, H{\'e}ctor and Nautsch, Andreas and Evans, Nicholas and Sahidullah, Md and Vestman, Ville and Kinnunen, Tomi and Lee, Kong Aik and others}, 72 | journal={Computer Speech \& Language}, 73 | volume={64}, 74 | pages={101114}, 75 | year={2020}, 76 | publisher={Elsevier} 77 | } 78 | ``` 79 | [2] AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks 80 | ```bibtex 81 | @inproceedings{Jung2022AASIST, 82 | author={Jung, Jee-weon and Heo, Hee-Soo and Tak, Hemlata and Shim, Hye-jin and Chung, Joon Son and Lee, Bong-Jin and Yu, Ha-Jin and Evans, Nicholas}, 83 | booktitle={Proc. ICASSP}, 84 | title={AASIST: Audio Anti-Spoofing using Integrated Spectro-Temporal Graph Attention Networks}, 85 | year={2022} 86 | ``` 87 | [3] ECAPA-TDNN: Emphasized Channel Attention, propagation and aggregation in TDNN based speaker verification 88 | ```bibtex 89 | @inproceedings{desplanques2020ecapa, 90 | title={{ECAPA-TDNN: Emphasized Channel Attention, propagation and aggregation in TDNN based speaker verification}}, 91 | author={Desplanques, Brecht and Thienpondt, Jenthe and Demuynck, Kris}, 92 | booktitle={Proc. Interspeech 2020}, 93 | pages={3830--3834}, 94 | year={2020} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /configs/baseline2.conf: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 24, 3 | "dataloader": "backend_fusion", 4 | "dirs": { 5 | "spk_meta": "spk_meta/", 6 | "embedding": "embeddings/", 7 | "sasv_dev_trial": "protocols/ASVspoof2019.LA.asv.dev.gi.trl.txt", 8 | "sasv_eval_trial": "protocols/ASVspoof2019.LA.asv.eval.gi.trl.txt", 9 | "cm_trn_list": "protocols/ASVspoof2019.LA.cm.train.trn.txt", 10 | "cm_dev_list": "protocols/ASVspoof2019.LA.cm.dev.trl.txt", 11 | "cm_eval_list": "protocols/ASVspoof2019.LA.cm.eval.trl.txt", 12 | }, 13 | "epoch": 10, 14 | "fast_dev_run": false, 15 | "loader": { 16 | "n_workers": 6 17 | }, 18 | "loss": "cce", 19 | "loss_weight": [0.1, 0.9], 20 | "model_arch": "baseline2", 21 | "model_config": { 22 | "code_dim": 544, 23 | "dnn_l_nodes": [256, 128, 64] 24 | }, 25 | "ngpus": 1, 26 | "optimizer": "adam", 27 | "optim": { 28 | "lr": 0.0001, 29 | "scheduler": "keras", 30 | "wd": 0.001 31 | }, 32 | "progbar_refresh": 10, 33 | "pl_system": "baseline2", 34 | "save_top_k": 3, 35 | "seed": 1234, 36 | "val_interval_epoch": 1 37 | } -------------------------------------------------------------------------------- /dataloaders/backend_fusion.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, List 3 | 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class SASV_Trainset(Dataset): 8 | def __init__(self, cm_embd, asv_embd, spk_meta): 9 | self.cm_embd = cm_embd 10 | self.asv_embd = asv_embd 11 | self.spk_meta = spk_meta 12 | 13 | def __len__(self): 14 | return len(self.cm_embd.keys()) 15 | 16 | def __getitem__(self, index): 17 | 18 | ans_type = random.randint(0, 1) 19 | if ans_type == 1: # target 20 | spk = random.choice(list(self.spk_meta.keys())) 21 | enr, tst = random.sample(self.spk_meta[spk]["bonafide"], 2) 22 | 23 | elif ans_type == 0: # nontarget 24 | nontarget_type = random.randint(1, 2) 25 | 26 | if nontarget_type == 1: # zero-effort nontarget 27 | spk, ze_spk = random.sample(self.spk_meta.keys(), 2) 28 | enr = random.choice(self.spk_meta[spk]["bonafide"]) 29 | tst = random.choice(self.spk_meta[ze_spk]["bonafide"]) 30 | 31 | if nontarget_type == 2: # spoof nontarget 32 | spk = random.choice(list(self.spk_meta.keys())) 33 | if len(self.spk_meta[spk]["spoof"]) == 0: 34 | while True: 35 | spk = random.choice(list(self.spk_meta.keys())) 36 | if len(self.spk_meta[spk]["spoof"]) != 0: 37 | break 38 | enr = random.choice(self.spk_meta[spk]["bonafide"]) 39 | tst = random.choice(self.spk_meta[spk]["spoof"]) 40 | 41 | return self.asv_embd[enr], self.asv_embd[tst], self.cm_embd[tst], ans_type 42 | 43 | 44 | class SASV_DevEvalset(Dataset): 45 | def __init__(self, utt_list, spk_model, asv_embd, cm_embd): 46 | self.utt_list = utt_list 47 | self.spk_model = spk_model 48 | self.asv_embd = asv_embd 49 | self.cm_embd = cm_embd 50 | 51 | def __len__(self): 52 | return len(self.utt_list) 53 | 54 | def __getitem__(self, index): 55 | line = self.utt_list[index] 56 | spkmd, key, _, ans = line.strip().split(" ") 57 | 58 | return self.spk_model[spkmd], self.asv_embd[key], self.cm_embd[key], ans 59 | 60 | 61 | def get_trnset( 62 | cm_embd_trn: Dict, asv_embd_trn: Dict, spk_meta_trn: Dict 63 | ) -> SASV_DevEvalset: 64 | return SASV_Trainset( 65 | cm_embd=cm_embd_trn, asv_embd=asv_embd_trn, spk_meta=spk_meta_trn 66 | ) 67 | 68 | 69 | def get_dev_evalset( 70 | utt_list: List, cm_embd: Dict, asv_embd: Dict, spk_model: Dict 71 | ) -> SASV_DevEvalset: 72 | return SASV_DevEvalset( 73 | utt_list=utt_list, cm_embd=cm_embd, asv_embd=asv_embd, spk_model=spk_model 74 | ) 75 | -------------------------------------------------------------------------------- /embeddings/asv_embd_dev.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b0037eca30f868a1ca76802af5f83b18152f97075883c1f7f9e23e7dc7cf0945 3 | size 20275627 4 | -------------------------------------------------------------------------------- /embeddings/asv_embd_eval.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:166f1fc2206fb772cad646fe46ae10035688f0e29a627905a364ce3dbc1cd2ab 3 | size 58137566 4 | -------------------------------------------------------------------------------- /embeddings/asv_embd_trn.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fd0d83299e64dcf9238e56d6d07951ead149544f3d73e5650e5e25f72bc4e924 3 | size 20713068 4 | -------------------------------------------------------------------------------- /embeddings/cm_embd_dev.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c2ab34390033d03209cfdc503c05729f6b8db2373fad3c7cafe5b47281c59558 3 | size 17095163 4 | -------------------------------------------------------------------------------- /embeddings/cm_embd_eval.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:09404cb93b89a981eb38db78eafb2881416a7448ef1dd55bbe98859b8fb00ed8 3 | size 49017997 4 | -------------------------------------------------------------------------------- /embeddings/cm_embd_trn.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8bce9eb413b4bec340778b5c56da8c5f68d68eabdce2da22b789229014d9d0d8 3 | size 17463987 4 | -------------------------------------------------------------------------------- /embeddings/spk_model_dev.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c14e4a4cf713d5662d08f618a835dad0b628b2a27bb94838d2488e69c063bec0 3 | size 8231 4 | -------------------------------------------------------------------------------- /embeddings/spk_model_eval.pk: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6d84fed3550f1bdd1ed28c4a566b9c37d0677e4ac124a2f66ddda189a3f7096b 3 | size 39049 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from importlib import import_module 5 | from pathlib import Path 6 | from shutil import copy 7 | 8 | import pytorch_lightning as pl 9 | from omegaconf import OmegaConf 10 | 11 | from utils import * 12 | 13 | warnings.filterwarnings("ignore", category=FutureWarning) 14 | 15 | 16 | def main(args): 17 | # load configurations and set seed 18 | config = OmegaConf.load(args.config) 19 | output_dir = Path(args.output_dir) 20 | pl.seed_everything(config.seed, workers=True) 21 | 22 | # generate speaker-utterance meta information 23 | if not ( 24 | os.path.exists(config.dirs.spk_meta + "spk_meta_trn.pk") 25 | and os.path.exists(config.dirs.spk_meta + "spk_meta_dev.pk") 26 | and os.path.exists(config.dirs.spk_meta + "spk_meta_eval.pk") 27 | ): 28 | generate_spk_meta(config) 29 | 30 | # configure paths 31 | model_tag = os.path.splitext(os.path.basename(args.config))[0] 32 | model_tag = output_dir / model_tag 33 | model_save_path = model_tag / "weights" 34 | model_save_path.mkdir(parents=True, exist_ok=True) 35 | copy(args.config, model_tag / "config.conf") 36 | 37 | _system = import_module("systems.{}".format(config.pl_system)) 38 | _system = getattr(_system, "System") 39 | system = _system(config) 40 | 41 | # Configure logging and callbacks 42 | logger = [ 43 | pl.loggers.TensorBoardLogger(save_dir=model_tag, version=1, name="tsbd_logs"), 44 | pl.loggers.csv_logs.CSVLogger( 45 | save_dir=model_tag, 46 | version=1, 47 | name="csv_logs", 48 | flush_logs_every_n_steps=config.progbar_refresh * 100, 49 | ), 50 | ] 51 | 52 | callbacks = [ 53 | pl.callbacks.ModelSummary(max_depth=3), 54 | pl.callbacks.LearningRateMonitor(logging_interval="step"), 55 | pl.callbacks.ModelCheckpoint( 56 | dirpath=model_save_path, 57 | filename="{epoch}-{sasv_eer_dev:.5f}", 58 | monitor="sasv_eer_dev", 59 | mode="min", 60 | every_n_epochs=config.val_interval_epoch, 61 | save_top_k=config.save_top_k, 62 | ), 63 | ] 64 | 65 | # Train / Evaluate 66 | gpus = find_gpus(config.ngpus, min_req_mem=config.min_req_mem) 67 | if gpus == -1: 68 | raise ValueError("Required GPUs are not available") 69 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 70 | trainer = pl.Trainer( 71 | accelerator="gpu", 72 | callbacks=callbacks, 73 | check_val_every_n_epoch=1, 74 | devices=config.ngpus, 75 | fast_dev_run=config.fast_dev_run, 76 | gradient_clip_val=config.gradient_clip 77 | if config.gradient_clip is not None 78 | else 0, 79 | limit_train_batches=1.0, 80 | limit_val_batches=1.0, 81 | logger=logger, 82 | max_epochs=config.epoch, 83 | num_sanity_val_steps=0, 84 | progress_bar_refresh_rate=config.progbar_refresh, # 0 to disable 85 | reload_dataloaders_every_n_epochs=config.loader.reload_every_n_epoch 86 | if config.loader.reload_every_n_epoch is not None 87 | else config.epoch, 88 | strategy="ddp", 89 | sync_batchnorm=True, 90 | val_check_interval=1.0, # 0.25 validates 4 times every epoch 91 | ) 92 | 93 | trainer.fit(system) 94 | trainer.test(ckpt_path="best") 95 | 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser(description="SASVC2022 Baseline framework.") 99 | parser.add_argument( 100 | "--config", 101 | dest="config", 102 | type=str, 103 | help="configuration file", 104 | required=True, 105 | default="configs/Baseline2.conf", 106 | ) 107 | parser.add_argument( 108 | "--output_dir", 109 | dest="output_dir", 110 | type=str, 111 | help="output directory for results", 112 | default="./exp_result", 113 | ) 114 | main(parser.parse_args()) 115 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import numpy 4 | import torch 5 | from scipy.interpolate import interp1d 6 | from scipy.optimize import brentq 7 | from sklearn.metrics import roc_curve 8 | 9 | 10 | def get_all_EERs( 11 | preds: Union[torch.Tensor, List, numpy.ndarray], keys: List 12 | ) -> List[float]: 13 | """ 14 | Calculate all three EERs used in the SASV Challenge 2022. 15 | preds and keys should be pre-calculated using dev or eval protocol in 16 | either 'protocols/ASVspoof2019.LA.asv.dev.gi.trl.txt' or 17 | 'protocols/ASVspoof2019.LA.asv.eval.gi.trl.txt' 18 | 19 | :param preds: list of scores in tensor 20 | :param keys: list of keys where each element should be one of 21 | ['target', 'nontarget', 'spoof'] 22 | """ 23 | sasv_labels, sv_labels, spf_labels = [], [], [] 24 | sv_preds, spf_preds = [], [] 25 | 26 | for pred, key in zip(preds, keys): 27 | if key == "target": 28 | sasv_labels.append(1) 29 | sv_labels.append(1) 30 | spf_labels.append(1) 31 | sv_preds.append(pred) 32 | spf_preds.append(pred) 33 | 34 | elif key == "nontarget": 35 | sasv_labels.append(0) 36 | sv_labels.append(0) 37 | sv_preds.append(pred) 38 | 39 | elif key == "spoof": 40 | sasv_labels.append(0) 41 | spf_labels.append(0) 42 | spf_preds.append(pred) 43 | else: 44 | raise ValueError( 45 | f"should be one of 'target', 'nontarget', 'spoof', got:{key}" 46 | ) 47 | 48 | fpr, tpr, _ = roc_curve(sasv_labels, preds, pos_label=1) 49 | sasv_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) 50 | 51 | fpr, tpr, _ = roc_curve(sv_labels, sv_preds, pos_label=1) 52 | sv_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) 53 | 54 | fpr, tpr, _ = roc_curve(spf_labels, spf_preds, pos_label=1) 55 | spf_eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) 56 | 57 | return sasv_eer, sv_eer, spf_eer 58 | -------------------------------------------------------------------------------- /models/baseline2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | class Model(torch.nn.Module): 5 | def __init__(self, model_config): 6 | 7 | super().__init__() 8 | self.enh_DNN = self._make_layers(model_config['code_dim'], model_config['dnn_l_nodes']) 9 | self.fc_out = torch.nn.Linear(model_config['dnn_l_nodes'][-1], 2, bias = False) 10 | 11 | def forward(self, embd_asv_enr, embd_asv_tst, embd_cm): 12 | 13 | asv_enr = torch.squeeze(embd_asv_enr, 1) # shape: (bs, 192) 14 | asv_tst = torch.squeeze(embd_asv_tst, 1) # shape: (bs, 192) 15 | cm_tst = torch.squeeze(embd_cm, 1) # shape: (bs, 160) 16 | 17 | x = self.enh_DNN(torch.cat([asv_enr, asv_tst, cm_tst], dim = 1)) # shape: (bs, 32) 18 | x = self.fc_out(x) # (bs, 2) 19 | 20 | return x 21 | 22 | def _make_layers(self, in_dim, l_nodes): 23 | l_fc = [] 24 | for idx in range(len(l_nodes)): 25 | if idx == 0: 26 | l_fc.append(torch.nn.Linear(in_features = in_dim, 27 | out_features = l_nodes[idx])) 28 | else: 29 | l_fc.append(torch.nn.Linear(in_features = l_nodes[idx-1], 30 | out_features = l_nodes[idx])) 31 | l_fc.append(torch.nn.LeakyReLU(negative_slope = 0.3)) 32 | return torch.nn.Sequential(*l_fc) -------------------------------------------------------------------------------- /pdfs/2022_SASV_evaluation_plan_v0.1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sasv-challenge/SASVC2022_Baseline/1545f2bbc5860504cd24003e864fa06adcfdd26e/pdfs/2022_SASV_evaluation_plan_v0.1.pdf -------------------------------------------------------------------------------- /pdfs/2022_SASV_evaluation_plan_v0.2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sasv-challenge/SASVC2022_Baseline/1545f2bbc5860504cd24003e864fa06adcfdd26e/pdfs/2022_SASV_evaluation_plan_v0.2.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf==2.0.6 2 | pytorch_lightning 3 | torch 4 | torchvision 5 | torchcontrib 6 | numpy 7 | scipy 8 | scikit-learn 9 | pathlib 10 | soundfile 11 | -------------------------------------------------------------------------------- /save_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pickle as pk 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from aasist.data_utils import Dataset_ASVspoof2019_devNeval 13 | from aasist.models.AASIST import Model as AASISTModel 14 | from ECAPATDNN.model import ECAPA_TDNN 15 | from utils import load_parameters 16 | 17 | # list of dataset partitions 18 | SET_PARTITION = ["trn", "dev", "eval"] 19 | 20 | # list of countermeasure(CM) protocols 21 | SET_CM_PROTOCOL = { 22 | "trn": "protocols/ASVspoof2019.LA.cm.train.trn.txt", 23 | "dev": "protocols/ASVspoof2019.LA.cm.dev.trl.txt", 24 | "eval": "protocols/ASVspoof2019.LA.cm.eval.trl.txt", 25 | } 26 | 27 | # directories of each dataset partition 28 | SET_DIR = { 29 | "trn": "./LA/ASVspoof2019_LA_train/", 30 | "dev": "./LA/ASVspoof2019_LA_dev/", 31 | "eval": "./LA/ASVspoof2019_LA_eval/", 32 | } 33 | 34 | # enrolment data list for speaker model calculation 35 | # each speaker model comprises multiple enrolment utterances 36 | SET_TRN = { 37 | "dev": [ 38 | "./LA/ASVspoof2019_LA_asv_protocols/ASVspoof2019.LA.asv.dev.female.trn.txt", 39 | "./LA/ASVspoof2019_LA_asv_protocols/ASVspoof2019.LA.asv.dev.male.trn.txt", 40 | ], 41 | "eval": [ 42 | "./LA/ASVspoof2019_LA_asv_protocols/ASVspoof2019.LA.asv.eval.female.trn.txt", 43 | "./LA/ASVspoof2019_LA_asv_protocols/ASVspoof2019.LA.asv.eval.male.trn.txt", 44 | ], 45 | } 46 | 47 | 48 | def save_embeddings( 49 | set_name, cm_embd_ext, asv_embd_ext, device 50 | ): 51 | meta_lines = open(SET_CM_PROTOCOL[set_name], "r").readlines() 52 | utt2spk = {} 53 | utt_list = [] 54 | for line in meta_lines: 55 | tmp = line.strip().split(" ") 56 | 57 | spk = tmp[0] 58 | utt = tmp[1] 59 | 60 | if utt in utt2spk: 61 | print("Duplicated utt error", utt) 62 | 63 | utt2spk[utt] = spk 64 | utt_list.append(utt) 65 | 66 | base_dir = SET_DIR[set_name] 67 | dataset = Dataset_ASVspoof2019_devNeval(utt_list, Path(base_dir)) 68 | loader = DataLoader( 69 | dataset, batch_size=30, shuffle=False, drop_last=False, pin_memory=True 70 | ) 71 | 72 | cm_emb_dic = {} 73 | asv_emb_dic = {} 74 | 75 | print("Getting embeddings from set %s..." % (set_name)) 76 | 77 | for batch_x, key in tqdm(loader): 78 | batch_x = batch_x.to(device) 79 | with torch.no_grad(): 80 | batch_cm_emb, _ = cm_embd_ext(batch_x) 81 | batch_cm_emb = batch_cm_emb.detach().cpu().numpy() 82 | batch_asv_emb = asv_embd_ext(batch_x, aug=False).detach().cpu().numpy() 83 | 84 | for k, cm_emb, asv_emb in zip(key, batch_cm_emb, batch_asv_emb): 85 | cm_emb_dic[k] = cm_emb 86 | asv_emb_dic[k] = asv_emb 87 | 88 | os.makedirs("embeddings", exist_ok=True) 89 | with open( "embeddings/cm_embd_%s.pk" % (set_name), "wb") as f: 90 | pk.dump(cm_emb_dic, f) 91 | with open("embeddings/asv_embd_%s.pk" % (set_name), "wb") as f: 92 | pk.dump(asv_emb_dic, f) 93 | 94 | 95 | def save_models(set_name, asv_embd_ext, device): 96 | utt2spk = {} 97 | utt_list = [] 98 | 99 | for trn in SET_TRN[set_name]: 100 | meta_lines = open(trn, "r").readlines() 101 | 102 | for line in meta_lines: 103 | tmp = line.strip().split(" ") 104 | 105 | spk = tmp[0] 106 | utts = tmp[1].split(",") 107 | 108 | for utt in utts: 109 | if utt in utt2spk: 110 | print("Duplicated utt error", utt) 111 | 112 | utt2spk[utt] = spk 113 | utt_list.append(utt) 114 | 115 | base_dir = SET_DIR[set_name] 116 | dataset = Dataset_ASVspoof2019_devNeval(utt_list, Path(base_dir)) 117 | loader = DataLoader( 118 | dataset, batch_size=30, shuffle=False, drop_last=False, pin_memory=True 119 | ) 120 | asv_emb_dic = {} 121 | 122 | print("Getting embedgins from set %s..." % (set_name)) 123 | 124 | for batch_x, key in tqdm(loader): 125 | batch_x = batch_x.to(device) 126 | with torch.no_grad(): 127 | batch_asv_emb = asv_embd_ext(batch_x, aug=False).detach().cpu().numpy() 128 | 129 | for k, asv_emb in zip(key, batch_asv_emb): 130 | utt = k 131 | spk = utt2spk[utt] 132 | 133 | if spk not in asv_emb_dic: 134 | asv_emb_dic[spk] = [] 135 | 136 | asv_emb_dic[spk].append(asv_emb) 137 | 138 | for spk in asv_emb_dic: 139 | asv_emb_dic[spk] = np.mean(asv_emb_dic[spk], axis=0) 140 | 141 | with open("embeddings/spk_model.pk_%s" % (set_name), "wb") as f: 142 | pk.dump(asv_emb_dic, f) 143 | 144 | 145 | def get_args(): 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument( 148 | "-aasist_config", type=str, default="./aasist/config/AASIST.conf" 149 | ) 150 | parser.add_argument( 151 | "-aasist_weight", type=str, default="./aasist/models/weights/AASIST.pth" 152 | ) 153 | parser.add_argument( 154 | "-ecapa_weight", type=str, default="./ECAPATDNN/exps/pretrain.model" 155 | ) 156 | 157 | return parser.parse_args() 158 | 159 | 160 | def main(): 161 | args = get_args() 162 | 163 | device = "cuda" if torch.cuda.is_available() else "cpu" 164 | print("Device: {}".format(device)) 165 | 166 | with open(args.aasist_config, "r") as f_json: 167 | config = json.loads(f_json.read()) 168 | 169 | model_config = config["model_config"] 170 | cm_embd_ext = AASISTModel(model_config).to(device) 171 | load_parameters(cm_embd_ext.state_dict(), args.aasist_weight) 172 | cm_embd_ext.to(device) 173 | cm_embd_ext.eval() 174 | 175 | asv_embd_ext = ECAPA_TDNN(C=1024) 176 | load_parameters(asv_embd_ext.state_dict(), args.ecapa_weight) 177 | asv_embd_ext.to(device) 178 | asv_embd_ext.eval() 179 | 180 | for set_name in SET_PARTITION: 181 | save_embeddings( 182 | set_name, 183 | cm_embd_ext, 184 | asv_embd_ext, 185 | device, 186 | ) 187 | if set_name == "trn": 188 | continue 189 | save_models(set_name, asv_embd_ext, device) 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /schedulers.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-instance-attributes 2 | #! /usr/bin/python 3 | # -*- encoding: utf-8 -*- 4 | """ 5 | Original source: 6 | https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/edit/master/cosine_annealing_warmup/scheduler.py 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | from torch.optim.lr_scheduler import _LRScheduler 13 | 14 | 15 | class CosineAnnealingWarmupRestarts(_LRScheduler): 16 | """ 17 | optimizer (Optimizer): Wrapped optimizer. 18 | first_cycle_steps (int): First cycle step size. 19 | cycle_mult(float): Cycle steps magnification. Default: -1. 20 | max_lr(float): First cycle's max learning rate. Default: 0.1. 21 | min_lr(float): Min learning rate. Default: 0.001. 22 | warmup_steps(int): Linear warmup step size. Default: 0. 23 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 24 | last_epoch (int): The index of last epoch. Default: -1. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer: torch.optim.Optimizer, 30 | first_cycle_steps: int, 31 | cycle_mult: float = 1.0, 32 | max_lr: float = 0.1, 33 | min_lr: float = 0.001, 34 | warmup_steps: int = 0, 35 | gamma: float = 1.0, 36 | last_epoch: int = -1, 37 | ): 38 | assert warmup_steps < first_cycle_steps 39 | 40 | self.first_cycle_steps = first_cycle_steps # first cycle step size 41 | self.cycle_mult = cycle_mult # cycle steps magnification 42 | self.base_max_lr = max_lr # first max learning rate 43 | self.max_lr = max_lr # max learning rate in the current cycle 44 | self.min_lr = min_lr # min learning rate 45 | self.warmup_steps = warmup_steps # warmup step size 46 | self.gamma = gamma # decrease rate of max learning rate by cycle 47 | 48 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 49 | self.cycle = 0 # cycle count 50 | self.step_in_cycle = last_epoch # step size of the current cycle 51 | 52 | super().__init__(optimizer, last_epoch) 53 | 54 | # set learning rate min_lr 55 | self.init_lr() 56 | 57 | def init_lr(self): 58 | self.base_lrs = [] 59 | for param_group in self.optimizer.param_groups: 60 | param_group["lr"] = self.min_lr 61 | self.base_lrs.append(self.min_lr) 62 | 63 | def get_lr(self): 64 | if self.step_in_cycle == -1: 65 | return self.base_lrs 66 | elif self.step_in_cycle < self.warmup_steps: 67 | return [ 68 | (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps 69 | + base_lr 70 | for base_lr in self.base_lrs 71 | ] 72 | else: 73 | return [ 74 | base_lr 75 | + (self.max_lr - base_lr) 76 | * ( 77 | 1 78 | + math.cos( 79 | math.pi 80 | * (self.step_in_cycle - self.warmup_steps) 81 | / (self.cur_cycle_steps - self.warmup_steps) 82 | ) 83 | ) 84 | / 2 85 | for base_lr in self.base_lrs 86 | ] 87 | 88 | def step(self, epoch=None): 89 | if epoch is None: 90 | epoch = self.last_epoch + 1 91 | self.step_in_cycle = self.step_in_cycle + 1 92 | if self.step_in_cycle >= self.cur_cycle_steps: 93 | self.cycle += 1 94 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 95 | self.cur_cycle_steps = ( 96 | int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) 97 | + self.warmup_steps 98 | ) 99 | else: 100 | if epoch >= self.first_cycle_steps: 101 | if self.cycle_mult == 1.0: 102 | self.step_in_cycle = epoch % self.first_cycle_steps 103 | self.cycle = epoch // self.first_cycle_steps 104 | else: 105 | n = int( 106 | math.log( 107 | ( 108 | epoch / self.first_cycle_steps * (self.cycle_mult - 1) 109 | + 1 110 | ), 111 | self.cycle_mult, 112 | ) 113 | ) 114 | self.cycle = n 115 | self.step_in_cycle = epoch - int( 116 | self.first_cycle_steps 117 | * (self.cycle_mult ** n - 1) 118 | / (self.cycle_mult - 1) 119 | ) 120 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( 121 | n 122 | ) 123 | else: 124 | self.cur_cycle_steps = self.first_cycle_steps 125 | self.step_in_cycle = epoch 126 | 127 | self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) 128 | self.last_epoch = math.floor(epoch) 129 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 130 | param_group["lr"] = lr 131 | -------------------------------------------------------------------------------- /spk_meta/spk_meta_dev.pk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sasv-challenge/SASVC2022_Baseline/1545f2bbc5860504cd24003e864fa06adcfdd26e/spk_meta/spk_meta_dev.pk -------------------------------------------------------------------------------- /spk_meta/spk_meta_eval.pk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sasv-challenge/SASVC2022_Baseline/1545f2bbc5860504cd24003e864fa06adcfdd26e/spk_meta/spk_meta_eval.pk -------------------------------------------------------------------------------- /spk_meta/spk_meta_trn.pk: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sasv-challenge/SASVC2022_Baseline/1545f2bbc5860504cd24003e864fa06adcfdd26e/spk_meta/spk_meta_trn.pk -------------------------------------------------------------------------------- /systems/baseline2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle as pk 4 | from importlib import import_module 5 | from typing import Any 6 | 7 | import omegaconf 8 | import pytorch_lightning as pl 9 | import schedulers as lr_schedulers 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from metrics import get_all_EERs 14 | from utils import keras_decay 15 | 16 | 17 | class System(pl.LightningModule): 18 | def __init__( 19 | self, config: omegaconf.dictconfig.DictConfig, *args: Any, **kwargs: Any 20 | ) -> None: 21 | super().__init__(*args, **kwargs) 22 | self.config = config 23 | _model = import_module("models.{}".format(config.model_arch)) 24 | _model = getattr(_model, "Model") 25 | self.model = _model(config.model_config) 26 | self.configure_loss() 27 | self.save_hyperparameters() 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | out = self.model(x) 31 | 32 | return out 33 | 34 | def training_step(self, batch, batch_idx, dataloader_idx=-1): 35 | embd_asv_enrol, embd_asv_test, embd_cm_test, label = batch 36 | pred = self.model(embd_asv_enrol, embd_asv_test, embd_cm_test) 37 | loss = self.loss(pred, label) 38 | self.log( 39 | "trn_loss", 40 | loss, 41 | on_step=True, 42 | on_epoch=True, 43 | prog_bar=True, 44 | logger=True, 45 | ) 46 | 47 | return loss 48 | 49 | def validation_step(self, batch, batch_idx, dataloader_idx=-1): 50 | embd_asv_enrol, embd_asv_test, embd_cm_test, key = batch 51 | pred = self.model(embd_asv_enrol, embd_asv_test, embd_cm_test) 52 | pred = torch.softmax(pred, dim=-1) 53 | 54 | return {"pred": pred, "key": key} 55 | 56 | def validation_epoch_end(self, outputs): 57 | log_dict = {} 58 | preds, keys = [], [] 59 | for output in outputs: 60 | preds.append(output["pred"]) 61 | keys.extend(list(output["key"])) 62 | 63 | preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy() 64 | sasv_eer, sv_eer, spf_eer = get_all_EERs(preds=preds, keys=keys) 65 | 66 | log_dict["sasv_eer_dev"] = sasv_eer 67 | log_dict["sv_eer_dev"] = sv_eer 68 | log_dict["spf_eer_dev"] = spf_eer 69 | 70 | self.log_dict(log_dict) 71 | 72 | def test_step(self, batch, batch_idx, dataloader_idx=-1): 73 | res_dict = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx) 74 | return res_dict 75 | 76 | def test_epoch_end(self, outputs): 77 | log_dict = {} 78 | preds, keys = [], [] 79 | for output in outputs: 80 | preds.append(output["pred"]) 81 | keys.extend(list(output["key"])) 82 | 83 | preds = torch.cat(preds, dim=0)[:, 1].detach().cpu().numpy() 84 | sasv_eer, sv_eer, spf_eer = get_all_EERs(preds=preds, keys=keys) 85 | 86 | log_dict["sasv_eer_eval"] = sasv_eer 87 | log_dict["sv_eer_eval"] = sv_eer 88 | log_dict["spf_eer_eval"] = spf_eer 89 | 90 | self.log_dict(log_dict) 91 | 92 | 93 | def configure_optimizers(self): 94 | if self.config.optimizer.lower() == "adam": 95 | optimizer = torch.optim.Adam( 96 | params=self.parameters(), 97 | lr=self.config.optim.lr, 98 | weight_decay=self.config.optim.wd, 99 | ) 100 | elif self.config.optimizer.lowe() == "sgd": 101 | optimizer = torch.optim.SGD( 102 | params=self.parameters(), 103 | lr=self.config.optim.lr, 104 | momentum=self.config.optim.momentum, 105 | weight_decay=self.config.optim.wd, 106 | ) 107 | else: 108 | raise NotImplementedError("....") 109 | 110 | if self.config.optim.scheduler.lower() == "sgdr_cos_anl": 111 | assert ( 112 | self.config.optim.n_epoch_per_cycle is not None 113 | and self.config.optim.min_lr is not None 114 | and self.config.optim.warmup_steps is not None 115 | and self.config.optim.lr_mult_after_cycle is not None 116 | ) 117 | lr_scheduler = lr_schedulers.CosineAnnealingWarmupRestarts( 118 | optimizer, 119 | first_cycle_steps=len(self.train_dataloader()) 120 | // self.config.ngpus 121 | * self.config.optim.n_epoch_per_cycle, 122 | cycle_mult=1.0, 123 | max_lr=self.config.optim.lr, 124 | min_lr=self.config.optim.min_lr, 125 | warmup_steps=self.config.optim.warmup_steps, 126 | gamma=self.config.optim.lr_mult_after_cycle, 127 | ) 128 | return { 129 | "optimizer": optimizer, 130 | "lr_scheduler": { 131 | "scheduler": lr_scheduler, 132 | "interval": "step", 133 | "frequency": 1, 134 | }, 135 | } 136 | 137 | elif self.config.optim.scheduler.lower() == "reduce_on_plateau": 138 | assert ( 139 | self.config.optim.lr is not None 140 | and self.config.optim.min_lr is not None 141 | and self.config.optim.factor is not None 142 | and self.config.optim.patience is not None 143 | ) 144 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 145 | optimizer, 146 | factor=self.config.optim.factor, 147 | patience=self.config.optim.patience, 148 | min_lr=self.config.optim.min_lr, 149 | verbose=True, 150 | ) 151 | return { 152 | "optimizer": optimizer, 153 | "lr_scheduler": { 154 | "scheduler": lr_scheduler, 155 | "interval": "epoch", 156 | "frequency": 1, 157 | "strict": True, 158 | "monitor": "dev_sasv_eer", 159 | }, 160 | } 161 | elif self.config.optim.scheduler.lower() == "keras": 162 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 163 | optimizer, lr_lambda=lambda step: keras_decay(step) 164 | ) 165 | return { 166 | "optimizer": optimizer, 167 | "lr_scheduler": { 168 | "scheduler": lr_scheduler, 169 | "interval": "step", 170 | "frequency": 1, 171 | "strict": True, 172 | }, 173 | } 174 | 175 | else: 176 | raise NotImplementedError(".....") 177 | 178 | def setup(self, stage=None): 179 | """ 180 | configures dataloaders. 181 | 182 | Args: 183 | stage: one among ["fit", "validate", "test", "predict"] 184 | """ 185 | self.load_meta_information() 186 | self.load_embeddings() 187 | 188 | if stage == "fit" or stage is None: 189 | module = import_module("dataloaders." + self.config.dataloader) 190 | self.ds_func_trn = getattr(module, "get_trnset") 191 | self.ds_func_dev = getattr(module, "get_dev_evalset") 192 | elif stage == "validate": 193 | module = import_module("dataloaders." + self.config.dataloader) 194 | self.ds_func_dev = getattr(module, "get_dev_evalset") 195 | elif stage == "test": 196 | module = import_module("dataloaders." + self.config.dataloader) 197 | self.ds_func_eval = getattr(module, "get_dev_evalset") 198 | else: 199 | raise NotImplementedError(".....") 200 | 201 | def train_dataloader(self): 202 | self.train_ds = self.ds_func_trn(self.cm_embd_trn, self.asv_embd_trn, self.spk_meta_trn) 203 | return DataLoader( 204 | self.train_ds, 205 | batch_size=self.config.batch_size, 206 | shuffle=True, 207 | drop_last=True, 208 | num_workers=self.config.loader.n_workers, 209 | ) 210 | 211 | def val_dataloader(self): 212 | with open(self.config.dirs.sasv_dev_trial, "r") as f: 213 | sasv_dev_trial = f.readlines() 214 | self.dev_ds = self.ds_func_dev( 215 | sasv_dev_trial, self.cm_embd_dev, self.asv_embd_dev, self.spk_model_dev) 216 | return DataLoader( 217 | self.dev_ds, 218 | batch_size=self.config.batch_size, 219 | shuffle=False, 220 | drop_last=False, 221 | num_workers=self.config.loader.n_workers, 222 | ) 223 | 224 | def test_dataloader(self): 225 | with open(self.config.dirs.sasv_eval_trial, "r") as f: 226 | sasv_eval_trial = f.readlines() 227 | self.eval_ds = self.ds_func_eval( 228 | sasv_eval_trial, self.cm_embd_eval, self.asv_embd_eval, self.spk_model_eval) 229 | return DataLoader( 230 | self.eval_ds, 231 | batch_size=self.config.batch_size, 232 | shuffle=False, 233 | drop_last=False, 234 | num_workers=self.config.loader.n_workers, 235 | ) 236 | 237 | def configure_loss(self): 238 | if self.config.loss.lower() == "bce": 239 | self.loss = F.binary_cross_entropy_with_logits 240 | if self.config.loss.lower() == "cce": 241 | self.loss = torch.nn.CrossEntropyLoss( 242 | weight=torch.FloatTensor(self.config.loss_weight) 243 | ) 244 | else: 245 | raise NotImplementedError("!") 246 | 247 | def load_meta_information(self): 248 | with open(self.config.dirs.spk_meta + "spk_meta_trn.pk", "rb") as f: 249 | self.spk_meta_trn = pk.load(f) 250 | with open(self.config.dirs.spk_meta + "spk_meta_dev.pk", "rb") as f: 251 | self.spk_meta_dev = pk.load(f) 252 | with open(self.config.dirs.spk_meta + "spk_meta_eval.pk", "rb") as f: 253 | self.spk_meta_eval = pk.load(f) 254 | 255 | def load_embeddings(self): 256 | # load saved countermeasures(CM) related preparations 257 | with open(self.config.dirs.embedding + "cm_embd_trn.pk", "rb") as f: 258 | self.cm_embd_trn = pk.load(f) 259 | with open(self.config.dirs.embedding + "cm_embd_dev.pk", "rb") as f: 260 | self.cm_embd_dev = pk.load(f) 261 | with open(self.config.dirs.embedding + "cm_embd_eval.pk", "rb") as f: 262 | self.cm_embd_eval = pk.load(f) 263 | 264 | # load saved automatic speaker verification(ASV) related preparations 265 | with open(self.config.dirs.embedding + "asv_embd_trn.pk", "rb") as f: 266 | self.asv_embd_trn = pk.load(f) 267 | with open(self.config.dirs.embedding + "asv_embd_dev.pk", "rb") as f: 268 | self.asv_embd_dev = pk.load(f) 269 | with open(self.config.dirs.embedding + "asv_embd_eval.pk", "rb") as f: 270 | self.asv_embd_eval = pk.load(f) 271 | 272 | # load speaker models for development and evaluation sets 273 | with open(self.config.dirs.embedding + "spk_model_dev.pk", "rb") as f: 274 | self.spk_model_dev = pk.load(f) 275 | with open(self.config.dirs.embedding + "spk_model_eval.pk", "rb") as f: 276 | self.spk_model_eval = pk.load(f) 277 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle as pk 3 | import random 4 | import sys 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | def str_to_bool(val): 13 | """Convert a string representation of truth to true (1) or false (0). 14 | Copied from the python implementation distutils.utils.strtobool 15 | 16 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 17 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 18 | 'val' is anything else. 19 | >>> str_to_bool('YES') 20 | 1 21 | >>> str_to_bool('FALSE') 22 | 0 23 | """ 24 | val = val.lower() 25 | if val in ("y", "yes", "t", "true", "on", "1"): 26 | return True 27 | if val in ("n", "no", "f", "false", "off", "0"): 28 | return False 29 | raise ValueError("invalid truth value {}".format(val)) 30 | 31 | 32 | def cosine_annealing(step, total_steps, lr_max, lr_min): 33 | """Cosine Annealing for learning rate decay scheduler""" 34 | return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi)) 35 | 36 | 37 | def keras_decay(step, decay=0.0001): 38 | """Learning rate decay in Keras-style""" 39 | return 1.0 / (1.0 + decay * step) 40 | 41 | 42 | def set_seed(args): 43 | """ 44 | set initial seed for reproduction 45 | """ 46 | 47 | random.seed(args.seed) 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | if torch.cuda.is_available(): 51 | torch.cuda.manual_seed_all(args.seed) 52 | torch.backends.cudnn.deterministic = args.cudnn_deterministic_toggle 53 | torch.backends.cudnn.benchmark = args.cudnn_benchmark_toggle 54 | 55 | 56 | def set_init_weights(m): 57 | if isinstance(m, nn.Linear): 58 | torch.nn.init.xavier_uniform_(m.weight) 59 | try: 60 | m.bias.data.fill_(0.0001) 61 | except: 62 | pass 63 | elif isinstance(m, nn.BatchNorm1d): 64 | pass 65 | else: 66 | try: 67 | torch.nn.init.kaiming_normal_(m.weight, a=0.01) 68 | except: 69 | pass 70 | 71 | def load_parameters(trg_state, path): 72 | loaded_state = torch.load(path, map_location=lambda storage, loc: storage) 73 | for name, param in loaded_state.items(): 74 | origname = name 75 | if name not in trg_state: 76 | name = name.replace("module.", "") 77 | name = name.replace("speaker_encoder.", "") 78 | if name not in trg_state: 79 | print("%s is not in the model."%origname) 80 | continue 81 | if trg_state[name].size() != loaded_state[origname].size(): 82 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, trg_state[name].size(), loaded_state[origname].size())) 83 | continue 84 | trg_state[name].copy_(param) 85 | 86 | 87 | def find_gpus(nums=4, min_req_mem=None) -> str: 88 | """ 89 | Allocates 'nums' GPUs that have the most free memory. 90 | Original source: 91 | https://discuss.pytorch.org/t/it-there-anyway-to-let-program-select-free-gpu-automatically/17560/10 92 | 93 | :param nums: number of GPUs to find 94 | :param min_req_mem: required GPU memory (in MB) 95 | :return: string of GPU indices separated with comma 96 | """ 97 | 98 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp_free_gpus') 99 | with open('tmp_free_gpus', 'r', encoding="utf-8") as lines_txt: 100 | frees = lines_txt.readlines() 101 | idx_freememory_pair = [ (idx,int(x.split()[2])) 102 | for idx,x in enumerate(frees) ] 103 | idx_freememory_pair.sort(key=lambda my_tuple:my_tuple[1],reverse=True) 104 | using_gpus = [str(idx_memory_pair[0]) 105 | for idx_memory_pair in idx_freememory_pair[:nums] ] 106 | 107 | # return error signal if minimum required memory is given and 108 | # at least one GPU does not have sufficient memory 109 | if min_req_mem is not None and \ 110 | int(idx_freememory_pair[nums][1]) < min_req_mem: 111 | 112 | return -1 113 | 114 | using_gpus = ','.join(using_gpus) 115 | print('using GPU idx: #', using_gpus) 116 | return using_gpus 117 | 118 | 119 | def get_spkdic(cm_meta: str) -> Dict: 120 | l_cm_meta = open(cm_meta, "r").readlines() 121 | 122 | d_spk = {} 123 | # dictionary of speakers 124 | # d_spk : { 125 | # 'spk_id1':{ 126 | # 'bonafide': [utt1, utt2], 127 | # 'spoof': [utt5] 128 | # }, 129 | # 'spk_id2':{ 130 | # 'bonafide': [utt3, utt4, utt8], 131 | # 'spoof': [utt6, utt7] 132 | # } ... 133 | # } 134 | 135 | for line in l_cm_meta: 136 | spk, filename, _, _, ans = line.strip().split(" ") 137 | if spk not in d_spk: 138 | d_spk[spk] = {} 139 | d_spk[spk]["bonafide"] = [] 140 | d_spk[spk]["spoof"] = [] 141 | 142 | if ans == "bonafide": 143 | d_spk[spk]["bonafide"].append(filename) 144 | elif ans == "spoof": 145 | d_spk[spk]["spoof"].append(filename) 146 | 147 | return d_spk 148 | 149 | 150 | def generate_spk_meta(config) -> None: 151 | d_spk_train = get_spkdic(config.dirs.cm_trn_list) 152 | d_spk_dev = get_spkdic(config.dirs.cm_dev_list) 153 | d_spk_eval = get_spkdic(config.dirs.cm_eval_list) 154 | os.makedirs(config.dirs.spk_meta, exist_ok=True) 155 | 156 | # save speaker dictionaries 157 | with open(config.dirs.spk_meta + "spk_meta_trn.pk", "wb") as f: 158 | pk.dump(d_spk_train, f) 159 | with open(config.dirs.spk_meta + "spk_meta_dev.pk", "wb") as f: 160 | pk.dump(d_spk_dev, f) 161 | with open(config.dirs.spk_meta + "spk_meta_eval.pk", "wb") as f: 162 | pk.dump(d_spk_eval, f) 163 | --------------------------------------------------------------------------------