├── data └── toysets │ ├── test.txt │ ├── score.txt │ ├── 0_seg.wav │ ├── 1_seg.wav │ ├── 2_seg.wav │ ├── 3_seg.wav │ ├── sample.txt │ └── trail.txt ├── .gitignore ├── requirement.txt ├── infer.sh ├── toyset_infer.sh ├── z_cul_eer.sh ├── run.sh ├── toyset_run.sh ├── utils ├── tools │ ├── tools.py │ ├── evaluation.py │ └── cul_eer.py ├── wrapper │ ├── optim_wrapper.py │ ├── schedule_wrapper.py │ └── loss_wrapper.py ├── ideas │ ├── MoEF │ │ ├── w2v2_moe_fz24_aasist.py │ │ ├── aasist.py │ │ └── moef.py │ └── reweight_learner.py ├── loadData │ ├── RawBoost.py │ ├── toyset_dm.py │ ├── asvspoof_data_DA_still_process.py │ └── asvspoof_data_DA.py └── arg_parse.py ├── LICENSE ├── models ├── wav2vec │ ├── l5_aasist_step_stable.py │ └── aasist.py ├── tl_model.py ├── LEMAAS │ └── lemaas_v6_1.py └── tl_model_postft_loss.py ├── README.md └── main.py /data/toysets/test.txt: -------------------------------------------------------------------------------- 1 | 0_seg 2 | 1_seg 3 | 2_seg 4 | 3_seg 5 | -------------------------------------------------------------------------------- /data/toysets/score.txt: -------------------------------------------------------------------------------- 1 | A01 spoof -11.60762 2 | - bonafide 4.84932 3 | A01 spoof -5.262832 4 | - bonafide 4.84932 -------------------------------------------------------------------------------- /data/toysets/0_seg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john852517791/pytorch_lightning_FAD/HEAD/data/toysets/0_seg.wav -------------------------------------------------------------------------------- /data/toysets/1_seg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john852517791/pytorch_lightning_FAD/HEAD/data/toysets/1_seg.wav -------------------------------------------------------------------------------- /data/toysets/2_seg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john852517791/pytorch_lightning_FAD/HEAD/data/toysets/2_seg.wav -------------------------------------------------------------------------------- /data/toysets/3_seg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john852517791/pytorch_lightning_FAD/HEAD/data/toysets/3_seg.wav -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | a_train_log/* 2 | b_gpu_log/* 3 | __pycache__ 4 | */__pycache__ 5 | */*/__pycache__ 6 | */*/*/__pycache__ 7 | ref 8 | -------------------------------------------------------------------------------- /data/toysets/sample.txt: -------------------------------------------------------------------------------- 1 | LA_0077 0_seg - A01 spoof 2 | LA_0069 1_seg - - bonafide 3 | LA_0077 2_seg - A01 spoof 4 | LA_0069 3_seg - - bonafide 5 | -------------------------------------------------------------------------------- /data/toysets/trail.txt: -------------------------------------------------------------------------------- 1 | - 0_seg - - - spoof - eval 2 | - 1_seg - - - bonafide - eval 3 | - 2_seg - - - spoof - eval 4 | - 3_seg - - - bonafide - eval 5 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | tensorboard==2.12.0 3 | soundfile 4 | pillow 5 | pytorch_model_summary 6 | timm 7 | matplotlib 8 | torchvision 9 | torchaudio 10 | scipy 11 | pandas 12 | lightning 13 | pyyaml 14 | transformers 15 | -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | # --module-model models.pure.stable_aasist 2 | ckpt=$1 3 | gpu=$2 4 | 5 | 6 | # --module-model models.rawformer.stable_base 7 | line=" 8 | nohup python main.py --inference 9 | --trained_model ${ckpt} 10 | --batch_size 100 11 | --gpuid ${gpu} 12 | > ${ckpt}/z_infer.log 13 | &" 14 | # --truncate 96000 15 | 16 | # --colour 2 17 | 18 | 19 | 20 | echo ${line} 21 | eval ${line} -------------------------------------------------------------------------------- /toyset_infer.sh: -------------------------------------------------------------------------------- 1 | # --module-model models.pure.stable_aasist 2 | ckpt="a_train_log/aasist/version_0" 3 | gpu=0 4 | 5 | 6 | # --module-model models.rawformer.stable_base 7 | # nohup 8 | line=" 9 | python main.py --inference 10 | --trained_model ${ckpt} 11 | --batch_size 1 12 | --gpuid ${gpu} 13 | > ${ckpt}/z_infer.log 14 | " 15 | # & 16 | # --truncate 96000 17 | 18 | # --colour 2 19 | 20 | 21 | 22 | echo ${line} 23 | eval ${line} -------------------------------------------------------------------------------- /z_cul_eer.sh: -------------------------------------------------------------------------------- 1 | scoreFilepath=$1 2 | eval_name=$2 3 | 4 | line="python utils/tools/cul_eer.py --pos 2 --scoreFile $scoreFilepath/infer_19.log > $scoreFilepath/eer_19" 5 | # line="python utils/tools/cul_eer19.py --scoreFile $scoreFilepath/infer/infer_19.log " 6 | echo $line 7 | eval $line 8 | line="python utils/tools/cul_eer21.py --scoreFile $scoreFilepath/infer_LA21.log > $scoreFilepath/eer_21" 9 | echo $line 10 | eval $line 11 | line="python utils/tools/cul_eer21df.py --scoreFile $scoreFilepath/infer_DF21.log > $scoreFilepath/eer_df21" 12 | echo $line 13 | eval $line 14 | line="python utils/tools/cul_itw.py --scoreFile $scoreFilepath/infer_ITW.log > $scoreFilepath/eer_itw" 15 | echo $line 16 | eval $line -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # model=$1 2 | # tlmodel=$2 3 | # dmmodule=$3 4 | train_log_dir=$1 5 | lr=$2 6 | gpu=$3 7 | # loss=$5 8 | 9 | 10 | model=$(echo "$model" | sed 's/\//./g') 11 | model=$(echo "$model" | sed 's/\.py//g') 12 | 13 | 14 | if [ -z "$loss" ] || [ "$loss" = "-" ]; then 15 | loss="CE" 16 | fi 17 | 18 | # CUDA_VISIBLE_DEVICES=${gpu} 19 | line=" 20 | nohup python main.py 21 | --seed 1234 22 | --module_model models.aasist.AASIST 23 | --tl_model models.tl_model 24 | --data_module utils.loadData.asvspoof_data_DA 25 | --savedir ${train_log_dir} 26 | --optim_lr ${lr} 27 | --gpuid ${gpu} 28 | --batch_size 24 29 | --epochs 100 30 | --no_best_epochs 100 31 | --optim adam 32 | --weight_decay 0.0001 33 | --loss WCE 34 | --scheduler cosAnneal 35 | --truncate 64600 36 | > b_gpu_log/test_${gpu}.log 37 | &" 38 | # --usingDA 39 | # --da_prob 0.7 40 | echo ${line} 41 | eval ${line} -------------------------------------------------------------------------------- /toyset_run.sh: -------------------------------------------------------------------------------- 1 | # model=$1 2 | # tlmodel=$2 3 | # dmmodule=$3 4 | train_log_dir=a_train_log/aasist 5 | lr=0.01 6 | gpu=0 7 | # loss=$5 8 | 9 | 10 | model=$(echo "$model" | sed 's/\//./g') 11 | model=$(echo "$model" | sed 's/\.py//g') 12 | 13 | 14 | if [ -z "$loss" ] || [ "$loss" = "-" ]; then 15 | loss="CE" 16 | fi 17 | 18 | # CUDA_VISIBLE_DEVICES=${gpu} 19 | # nohup 20 | line=" 21 | python main.py 22 | --seed 1234 23 | --module_model models.aasist.AASIST 24 | --tl_model models.tl_model 25 | --data_module utils.loadData.toyset_dm 26 | --savedir ${train_log_dir} 27 | --optim_lr ${lr} 28 | --gpuid ${gpu} 29 | --batch_size 2 30 | --epochs 3 31 | --no_best_epochs 2 32 | --optim adam 33 | --weight_decay 0.0001 34 | --loss WCE 35 | --scheduler cosAnneal 36 | --truncate 64600 37 | > b_gpu_log/test_${gpu}.log 38 | " 39 | # & 40 | # --usingDA 41 | # --da_prob 0.7 42 | echo ${line} 43 | eval ${line} -------------------------------------------------------------------------------- /utils/tools/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def norm(X_pad): 4 | mean_x = X_pad.mean() 5 | var_x = X_pad.var() 6 | return np.array([(x - mean_x) / np.sqrt(var_x + 1e-7) for x in X_pad]) 7 | 8 | 9 | def pad(x, max_len=64600): 10 | x_len = x.shape[0] 11 | if x_len >= max_len: 12 | return x[:max_len] 13 | # need to pad 14 | num_repeats = int(max_len / x_len) + 1 15 | padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] 16 | return padded_x 17 | 18 | 19 | def pad_random(x: np.ndarray, max_len: int = 64600): 20 | x_len = x.shape[0] 21 | # if duration is already long enough 22 | if x_len > max_len: 23 | stt = np.random.randint(x_len - max_len) 24 | return x[stt:stt + max_len] 25 | 26 | # if too short 27 | num_repeats = int(max_len / x_len) + 1 28 | padded_x = np.tile(x, (num_repeats))[:max_len] 29 | return padded_x -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Zhiyong Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/wrapper/optim_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class opti_conf(): 5 | def __init__(self): 6 | self.optim = "adam" 7 | 8 | self.optim_lr = 0.0001 9 | self.weight_decay = 0.5 10 | # for sgd 11 | self.momentum = 0.9 12 | 13 | class optimizer_wrap(): 14 | def __init__(self,cfg:opti_conf,model): 15 | super(optimizer_wrap).__init__() 16 | self.cfg = cfg 17 | self.model = model 18 | 19 | 20 | def get_optim(self): 21 | optim = None 22 | 23 | if self.cfg.optim == "adam": 24 | optim = torch.optim.Adam( 25 | self.model.parameters(), 26 | lr=self.cfg.optim_lr, 27 | weight_decay=self.cfg.weight_decay, 28 | ) 29 | elif self.cfg.optim == "adamw": 30 | optim = torch.optim.AdamW( 31 | self.model.parameters(), 32 | lr=self.cfg.optim_lr, 33 | weight_decay=self.cfg.weight_decay 34 | ) 35 | elif self.cfg.optim == "sgd": 36 | optim = torch.optim.SGD( 37 | self.model.parameters(), 38 | lr=self.cfg.optim_lr, 39 | momentum = self.cfg.momentum 40 | ) 41 | 42 | else: 43 | raise Exception(f"no optim named {self.cfg.optim}") 44 | 45 | 46 | return optim 47 | 48 | -------------------------------------------------------------------------------- /models/wav2vec/l5_aasist_step_stable.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | from transformers import Wav2Vec2Model,AutoConfig 4 | import torch.nn as nn 5 | import torch 6 | from models.wav2vec.aasist import W2VAASIST 7 | 8 | 9 | 10 | 11 | class this_arg(): 12 | stage = 1 13 | # if stage == 1, means that the pretrained model is freezed 14 | # if stage == 2, means that training will finetune the pretrained model 15 | 16 | class Model(nn.Module): 17 | def __init__(self, args = this_arg()): 18 | super().__init__() 19 | pretrain_cfg = AutoConfig.from_pretrained("datasets/pretrained_model/facebook/wav2vec2-xls-r-300m/config.json") 20 | pretrain_cfg.num_hidden_layers = 6 21 | self.pretrain_model = Wav2Vec2Model.from_pretrained( 22 | "datasets/pretrained_model/facebook/wav2vec2-xls-r-300m", 23 | config=pretrain_cfg) 24 | self.classifier = W2VAASIST() 25 | # self.freeze_parameters() 26 | # self.register_buffer('pre_features', torch.zeros(args.batch_size,160)) 27 | self.register_buffer('pre_features', torch.zeros(32,160)) 28 | # self.register_buffer('pre_weight1', torch.ones(args.batch_size, 1)) 29 | self.register_buffer('pre_weight1', torch.ones(32, 1)) 30 | 31 | def freeze_parameters(self): 32 | print("freeze") 33 | for param in self.pretrain_model.parameters(): 34 | # print(param.requires_grad) 35 | param.requires_grad = False 36 | 37 | def unfreeze_parameters(self): 38 | print("unfreeze") 39 | for param in self.pretrain_model.parameters(): 40 | # print(param.requires_grad) 41 | param.requires_grad = True 42 | 43 | def forward(self, x): 44 | with torch.no_grad(): 45 | x = self.pretrain_model( 46 | x, 47 | output_hidden_states = True 48 | ).hidden_states[5] 49 | pred , hidden_state = self.classifier(x) 50 | 51 | return pred, hidden_state 52 | 53 | 54 | 55 | 56 | if __name__ == "__main__": 57 | md = Model() 58 | # md.freeze_parameters() 59 | # md.unfreeze_parameters() 60 | op, hd = md( torch.randn((8,64600))) 61 | print(op.shape) -------------------------------------------------------------------------------- /utils/ideas/MoEF/w2v2_moe_fz24_aasist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | from transformers import Wav2Vec2Model,AutoConfig 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | from utils.ideas.MoEF.aasist import W2VAASIST 8 | from utils.ideas.MoLE import MoELocal, MoE24fusion 9 | 10 | 11 | class this_arg(): 12 | stage = 1 13 | moe_topk = 2 14 | # experts per feature 15 | moe_experts = 4 16 | moe_exp_hid = 128 17 | # if stage == 1, means that the pretrained model is freezed 18 | # if stage == 2, means that training will finetune the pretrained model 19 | 20 | class Model(nn.Module): 21 | def __init__(self, args = this_arg()): 22 | super().__init__() 23 | self.pretrain_model = Wav2Vec2Model.from_pretrained( 24 | "datasets/pretrained_model/facebook/wav2vec2-xls-r-300m") 25 | self.classifier = W2VAASIST() 26 | self.moe_l = MoE24fusion( 27 | ds_inputsize=1024, 28 | input_size=1024, 29 | output_size=1024, 30 | num_experts=24*args.moe_experts, 31 | hidden_size=args.moe_exp_hid, 32 | noisy_gating=True, 33 | k = args.moe_topk, 34 | trainingmode=True 35 | ) 36 | # for param in self.pretrain_model.parameters(): 37 | # # print(param.requires_grad) 38 | # param.requires_grad = False 39 | 40 | def forward(self, x,train = False): 41 | with torch.no_grad(): 42 | # if True: 43 | x = self.pretrain_model( 44 | x, 45 | output_hidden_states = True, 46 | output_attentions = True 47 | ) 48 | bs,t,sp = x[0].shape 49 | hidden_ones = [] 50 | for i in range(24): 51 | hidden_ones.append(x.hidden_states[i].view(bs*t, sp)) 52 | 53 | fusion_x = self.moe_l( 54 | x.last_hidden_state.view(bs*t, sp), 55 | hidden_ones, 56 | training = train 57 | ) 58 | 59 | pred , hidden_state = self.classifier(fusion_x[0].view(bs,t, sp)) 60 | return pred , (x.hidden_states,x.attentions), fusion_x[1] 61 | 62 | 63 | if __name__ == "__main__": 64 | md = Model() 65 | # md.freeze_parameters() 66 | # md.unfreeze_parameters() 67 | # op, hd,_ = md( torch.randn((2,64600))) 68 | # print(op.shape) 69 | print(sum(p.numel() for p in md.pretrain_model.parameters() if p.requires_grad)/1000000) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_lightning_FAD 2 | 3 | This is a general framework for fake audio detection using pytorch lightning 4 | the dataset used here is asvspoof2019 5 | 6 | # env 7 | 8 | python 3.9 9 | 10 | ``` 11 | pip install -r requirement.txt 12 | ``` 13 | 14 | # run sample 15 | 16 | ## toyset 17 | By default, use the first GPU. If you need to change it, modify it yourself. 18 | 19 | ``` 20 | bash toyset_run.sh 21 | 22 | bash toyset_infer.sh 23 | ``` 24 | 25 | 26 | ## whole datasets 27 | first thing first, change the dir in the "utils/loadData/asvspoof_data_DA.py" 28 | 29 | the run this line 30 | 31 | ``` 32 | bash run.sh a_train_log/aasist 0.01 6 33 | ``` 34 | 35 | # usage 36 | 37 | ## 1. data module 38 | 39 | if you want to use anthor data input format, please reference file 'utils/loadData/asvspoof_data_DA.py' to write the datamodule. 40 | 41 | if you won't change anything in the "models/tl_model.py", please **make sure that the train set return three elements (tensor, label, filename), and the dev/test set return two elements (tensor, filename)** 42 | 43 | then change the **"--data_module"** config when you run the "run.sh" 44 | 45 | ## 2. model 46 | 47 | if you want to use another model architecture, add it in to the folder "models". 48 | 49 | if you won't change anything in the "models/tl_model.py", please **make sure the model you create return at least two elements** (prediction and hidden state) and change the model class name to "Model" 50 | 51 | then change the **"--module_model"** config when you run the "run.sh" 52 | 53 | ## 3. tl_model 54 | 55 | if you want to modify something in the train/eval/test/inference stage (like modification about the loss culculation), create a new file and reference file "models/tl_model.py" 56 | 57 | then change the **"--tl_model"** config when you run the "run.sh" 58 | 59 | # Generalized Fake Audio Detection via Deep Stable Learning [arxiv](https://arxiv.org/pdf/2406.03237) 60 | reweight leaner is in utils/ideas/ 61 | check usage in the tl_model_file (models/tl_model_postft_loss.py) and model file (models/wav2vec/l5_aasist_step_stable.py) 62 | and follow the usage of this framework mentioned above 63 | 64 | [This method offers only slight performance improvements for small models after fine-tuning many many hyperparameters, so to be honest, if you wanna have a try, do not expect it will work at the very first run using random hyperparameters if you are not lucky enough.] 65 | 66 | # Mixture of Experts Fusion for Fake Audio Detection Using Frozen wav2vec 2.0 67 | model file is utils/ideas/MoEF/w2v2_moe_fz24_aasist.py 68 | and the MoEF module is in utils/ideas/MoEF/moef.py 69 | 70 | ✨ checkout the branch icassp 71 | -------------------------------------------------------------------------------- /utils/wrapper/schedule_wrapper.py: -------------------------------------------------------------------------------- 1 | from torch import optim, nn, utils, Tensor 2 | import torch,transformers 3 | import numpy as np 4 | 5 | 6 | class schdule_conf(): 7 | def __init__(self): 8 | self.scheduler = "cosWarmup" 9 | self.epochs = 100 10 | # for cosWarmup 11 | self.num_warmup_steps = 5 12 | # self.num_training_steps = self.epochs - self.num_warmup_steps 13 | # for cosanneal 14 | self.total_step = 1057 # (25380//24) * 100 15 | # for step 16 | self.step_size = 5 17 | self.gamma = 0.1 18 | 19 | self.optim_lr = 1 20 | 21 | 22 | class scheduler_wrap(): 23 | """ Wrapper over different types of learning rate Scheduler 24 | 25 | """ 26 | def __init__(self, optimizer, args:schdule_conf): 27 | self.optimizer = optimizer 28 | self.args = args 29 | 30 | def get_scheduler(self): 31 | 32 | # other config or none 33 | scheduler = None 34 | 35 | if self.args.scheduler == "cosWarmup": 36 | scheduler = transformers.get_cosine_schedule_with_warmup( 37 | optimizer = self.optimizer, 38 | num_warmup_steps=self.args.num_warmup_steps, 39 | num_training_steps = self.args.epochs 40 | ) 41 | elif self.args.scheduler == "cosAnneal": 42 | scheduler = torch.optim.lr_scheduler.LambdaLR( 43 | self.optimizer, 44 | lr_lambda=lambda step: cosine_annealing( 45 | step, 46 | 105700, 47 | 1, # since lr_lambda computes multiplicative factor 48 | 0.000005 / 0.0001)) 49 | elif self.args.scheduler == "normal_cosAnneal": 50 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 51 | self.optimizer, 52 | T_max=self.args.epochs, 53 | # eta_min=0.000005 / 0.0001 54 | eta_min=0.05*self.args.optim_lr 55 | ) 56 | elif self.args.scheduler == "step": 57 | scheduler = optim.lr_scheduler.StepLR( 58 | self.optimizer, 59 | step_size= self.args.step_size, 60 | gamma=self.args.gamma 61 | ) 62 | else: 63 | print(f"no scheduler is used") 64 | 65 | 66 | return scheduler 67 | 68 | 69 | def cosine_annealing(step, total_steps, lr_max, lr_min): 70 | """Cosine Annealing for learning rate decay scheduler""" 71 | return lr_min + (lr_max - 72 | lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi)) 73 | 74 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import lightning as L 2 | from utils.arg_parse import f_args_parsed,set_random_seed 3 | import importlib 4 | import os,yaml,shutil 5 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint,LearningRateMonitor 6 | from lightning.pytorch import loggers as pl_loggers 7 | from datetime import datetime 8 | # arguments initialization 9 | args = f_args_parsed() 10 | 11 | ### temporal config 12 | # 13 | # args.stage = 1 14 | # 15 | # ### 16 | 17 | 18 | # config gpu 19 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuid 20 | 21 | # random seed initialization and gpu seed 22 | set_random_seed(args.seed, args) 23 | 24 | # config the base model containing train eval test and inference funtion 25 | tl_model = importlib.import_module(args.tl_model) 26 | 27 | # config the data module containing the train set, dev set and test set 28 | dm_module = importlib.import_module(args.data_module) 29 | asvspoof_dm = dm_module.asvspoof_dataModule(args=args) 30 | 31 | if True: 32 | # ⭐train 33 | if not args.inference: 34 | # import model.py 35 | prj_model = importlib.import_module(args.module_model) 36 | 37 | # model 38 | model = prj_model.Model(args) 39 | 40 | # init model, including loss func and optim 41 | customed_model_wrapper = tl_model.base_model( 42 | model=model, 43 | args=args 44 | ) 45 | 46 | # config logdir 47 | tb_logger = pl_loggers.TensorBoardLogger(args.savedir,name="") 48 | 49 | # model initialization 50 | trainer = L.Trainer( 51 | max_epochs=args.epochs, 52 | strategy='ddp_find_unused_parameters_true', 53 | log_every_n_steps = 1, 54 | callbacks=[ 55 | # dev损失无下降就提前停止 56 | EarlyStopping('dev_eer',patience=args.no_best_epochs,mode="min",verbose=True,log_rank_zero_only=True), 57 | # 模型按照最低val_loss来保存 58 | ModelCheckpoint(monitor='dev_eer', 59 | save_top_k=1, 60 | save_weights_only=True,mode="min",filename='best_model-{epoch:02d}-{dev_eer:.4f}'), 61 | LearningRateMonitor(logging_interval='epoch',log_weight_decay=True), 62 | ], 63 | check_val_every_n_epoch=1, 64 | logger=tb_logger, 65 | enable_progress_bar=False 66 | ) 67 | trainer.fit( 68 | model=customed_model_wrapper, 69 | datamodule=asvspoof_dm 70 | ) 71 | 72 | # # test 19 default 73 | trainer.test( 74 | model=customed_model_wrapper, 75 | datamodule=asvspoof_dm 76 | ) 77 | else: 78 | checkpointpath=args.trained_model 79 | # checkpointpath=trainer.log_dir 80 | args.savedir = checkpointpath 81 | 82 | # gain model 83 | ymlconf = os.path.join(checkpointpath,"hparams.yaml") 84 | with open(ymlconf,"r") as f_yaml: 85 | parser1 = yaml.safe_load(f_yaml) 86 | infer_m = importlib.import_module(parser1["module_model"]) 87 | test_dm_module = importlib.import_module(parser1["data_module"]) 88 | test_asvspoof_dm = test_dm_module.asvspoof_dataModule(args=args) 89 | 90 | infer_model = infer_m.Model(args) 91 | 92 | print(parser1) 93 | 94 | # print(args.savedir) 95 | ckpt_files = [file for file in os.listdir(checkpointpath+"/checkpoints/") if file.endswith(".ckpt")] 96 | # customed_model=model_wrapper.base_model(model=model) 97 | customed_model=tl_model.base_model.load_from_checkpoint( 98 | checkpoint_path=os.path.join(f"{checkpointpath}/checkpoints/",ckpt_files[0]), 99 | model=infer_model, 100 | args = args, 101 | strict=False) 102 | inferer = L.Trainer(logger=pl_loggers.TensorBoardLogger(args.savedir,name="")) 103 | 104 | # la19 105 | inferer.test( 106 | model=customed_model, 107 | datamodule=test_asvspoof_dm 108 | ) 109 | # la21 110 | inferer.predict( 111 | model=customed_model, 112 | datamodule=test_asvspoof_dm 113 | ) 114 | # df21 115 | inferer.model.args.testset = "DF21" 116 | test_asvspoof_dm = test_dm_module.asvspoof_dataModule(args=args) 117 | inferer.predict( 118 | model=customed_model, 119 | datamodule=test_asvspoof_dm 120 | ) 121 | 122 | # ITW 123 | inferer.model.args.testset = "ITW" 124 | test_asvspoof_dm = test_dm_module.asvspoof_dataModule(args=args) 125 | inferer.predict( 126 | model=customed_model, 127 | datamodule=test_asvspoof_dm 128 | ) 129 | 130 | 131 | # change the version_0 to infer, and delete useless files 132 | current_time = datetime.now() 133 | time_str = current_time.strftime("%Y_%m_%d_%H_%M_%S") 134 | inferfolder = os.path.join(checkpointpath,f"infer_{time_str}") 135 | if not os.path.exists(inferfolder): 136 | os.makedirs(inferfolder) 137 | folder_a = os.path.join(checkpointpath,"version_0") 138 | for filename in os.listdir(folder_a): 139 | if filename.endswith('.log'): 140 | original_path = os.path.join(folder_a, filename) 141 | destination_path = os.path.join(inferfolder, filename) 142 | shutil.move(original_path, destination_path) 143 | shutil.rmtree(folder_a) 144 | 145 | # print(args) -------------------------------------------------------------------------------- /utils/ideas/reweight_learner.py: -------------------------------------------------------------------------------- 1 | # code from https://github.com/xxgege/StableNet 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | # coding:utf-8 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import math 13 | 14 | 15 | 16 | def sd(x): 17 | return np.std(x, axis=0, ddof=1) 18 | 19 | 20 | def sd_gpu(x): 21 | return torch.std(x, dim=0) 22 | 23 | 24 | def normalize_gpu(x): 25 | x = F.normalize(x, p=1, dim=1) 26 | return x 27 | 28 | 29 | def normalize(x): 30 | mean = np.mean(x, axis=0) 31 | std = sd(x) 32 | std[std == 0] = 1 33 | x = (x - mean) / std 34 | return x 35 | 36 | 37 | def random_fourier_features_gpu(x, w=None, b=None, num_f=None, sum=True, sigma=None, seed=None): 38 | if num_f is None: 39 | num_f = 1 40 | n = x.size(0) 41 | r = x.size(1) 42 | x = x.view(n, r, 1) 43 | c = x.size(2) 44 | if sigma is None or sigma == 0: 45 | sigma = 1 46 | if w is None: 47 | w = 1 / sigma * (torch.randn(size=(num_f, c))) 48 | b = 2 * np.pi * torch.rand(size=(r, num_f)) 49 | b = b.repeat((n, 1, 1)) 50 | 51 | Z = torch.sqrt(torch.tensor(2.0 / num_f).cuda()) 52 | 53 | mid = torch.matmul(x.cuda(), w.t().cuda()) 54 | 55 | mid = mid + b.cuda() 56 | mid -= mid.min(dim=1, keepdim=True)[0] 57 | mid /= mid.max(dim=1, keepdim=True)[0].cuda() 58 | mid *= np.pi / 2.0 59 | 60 | if sum: 61 | Z = Z * (torch.cos(mid).cuda() + torch.sin(mid).cuda()) 62 | else: 63 | Z = Z * torch.cat((torch.cos(mid).cuda(), torch.sin(mid).cuda()), dim=-1) 64 | 65 | return Z 66 | 67 | 68 | def lossc(inputs, target, weight): 69 | loss = nn.NLLLoss(reduce=False) 70 | return loss(inputs, target).view(1, -1).mm(weight).view(1) 71 | 72 | 73 | def cov(x, w=None): 74 | if w is None: 75 | n = x.shape[0] 76 | cov = torch.matmul(x.t(), x) / n 77 | e = torch.mean(x, dim=0).view(-1, 1) 78 | res = cov - torch.matmul(e, e.t()) 79 | else: 80 | w = w.view(-1, 1) 81 | cov = torch.matmul((w * x).t(), x) 82 | e = torch.sum(w * x, dim=0).view(-1, 1) 83 | res = cov - torch.matmul(e, e.t()) 84 | 85 | return res 86 | 87 | 88 | def lossb_expect(cfeaturec, weight, num_f, sum=True): 89 | cfeaturecs = random_fourier_features_gpu(cfeaturec, num_f=num_f, sum=sum).cuda() 90 | loss = Variable(torch.FloatTensor([0]).cuda()) 91 | weight = weight.cuda() 92 | for i in range(cfeaturecs.size()[-1]): 93 | cfeaturec = cfeaturecs[:, :, i] 94 | 95 | cov1 = cov(cfeaturec, weight) 96 | cov_matrix = cov1 * cov1 97 | loss += torch.sum(cov_matrix) - torch.trace(cov_matrix) 98 | 99 | return loss 100 | 101 | 102 | def lr_setter(optimizer, epoch, args, bl=False): 103 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 104 | 105 | lr = args.lr 106 | if bl: 107 | lr = args.lrbl * (0.1 ** (epoch // (args.epochb * 0.5))) 108 | else: 109 | if args.cos: 110 | lr *= ((0.01 + math.cos(0.5 * (math.pi * epoch / args.epochs))) / 1.01) 111 | else: 112 | if epoch >= args.epochs_decay[0]: 113 | lr *= 0.1 114 | if epoch >= args.epochs_decay[1]: 115 | lr *= 0.1 116 | for param_group in optimizer.param_groups: 117 | param_group['lr'] = lr 118 | 119 | class args(): 120 | def __init__(self ): 121 | self.lrbl = 0.9 122 | self.epochb = 20 123 | self.num_f=20 124 | self.sum=True 125 | self.decay_pow=2 126 | self.lambda_decay_rate=1 127 | self.min_lambda_times=0.01 128 | self.lambdap = 70.0 129 | self.lambda_decay_epoch=5 130 | self.first_step_cons=0.9 131 | self.presave_ratio=0.9 # 0.9 132 | self.lr = 0.01 133 | self.epochs_decay=[24, 30] 134 | self.optim = "sgd" 135 | 136 | def weight_learner(cfeatures, pre_features, pre_weight1, args=args(), global_epoch=0, iter=0): 137 | softmax = nn.Softmax(0) 138 | weight = Variable(torch.ones(cfeatures.size()[0], 1).cuda()) 139 | weight.requires_grad = True 140 | cfeaturec = Variable(torch.FloatTensor(cfeatures.size()).cuda()) 141 | cfeaturec.data.copy_(cfeatures.data) 142 | all_feature = torch.cat([cfeaturec, pre_features.detach()], dim=0) 143 | # optimizerbl = torch.optim.SGD([weight], lr=args.lrbl, momentum=0.9) 144 | if args.optim == "adamw": 145 | optimizerbl = torch.optim.AdamW([weight],lr=args.lrbl) 146 | # print(args.optim) 147 | elif args.optim == "adam": 148 | # print(args.optim) 149 | optimizerbl = torch.optim.Adam([weight],lr=args.lrbl) 150 | elif args.optim == "sgd": 151 | # print(args.optim) 152 | optimizerbl = torch.optim.SGD([weight], lr=args.lrbl, momentum=0.9) 153 | 154 | 155 | for epoch in range(args.epochb): 156 | lr_setter(optimizerbl, epoch, args, bl=True) 157 | # 上个batch的preweight的作用是什么 158 | all_weight = torch.cat((weight, pre_weight1.detach()), dim=0) 159 | optimizerbl.zero_grad() 160 | 161 | lossb = lossb_expect(all_feature, softmax(all_weight), args.num_f, args.sum) 162 | lossp = softmax(weight).pow(args.decay_pow).sum() 163 | lambdap = args.lambdap * max((args.lambda_decay_rate ** (global_epoch // args.lambda_decay_epoch)), 164 | args.min_lambda_times) 165 | lossg = lossb / lambdap + lossp 166 | if global_epoch == 0: 167 | lossg = lossg * args.first_step_cons 168 | 169 | lossg.backward(retain_graph=True) 170 | optimizerbl.step() 171 | 172 | if global_epoch == 0 and iter < 10: 173 | pre_features = (pre_features * iter + cfeatures) / (iter + 1) 174 | pre_weight1 = (pre_weight1 * iter + weight) / (iter + 1) 175 | 176 | elif cfeatures.size()[0] < pre_features.size()[0]: 177 | pre_features[:cfeatures.size()[0]] = pre_features[:cfeatures.size()[0]] * args.presave_ratio + cfeatures * ( 178 | 1 - args.presave_ratio) 179 | pre_weight1[:cfeatures.size()[0]] = pre_weight1[:cfeatures.size()[0]] * args.presave_ratio + weight * ( 180 | 1 - args.presave_ratio) 181 | 182 | else: 183 | pre_features = pre_features * args.presave_ratio + cfeatures * (1 - args.presave_ratio) 184 | pre_weight1 = pre_weight1 * args.presave_ratio + weight * (1 - args.presave_ratio) 185 | 186 | softmax_weight = softmax(weight) 187 | 188 | return softmax_weight, pre_features, pre_weight1 189 | 190 | if __name__ == '__main__': 191 | pass 192 | -------------------------------------------------------------------------------- /models/tl_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import lightning as L 3 | import torch 4 | import logging,os 5 | from utils.wrapper import loss_wrapper, optim_wrapper,schedule_wrapper 6 | from utils.tools import cul_eer 7 | import os 8 | 9 | 10 | class base_model(L.LightningModule): 11 | def __init__(self, 12 | model, 13 | args, 14 | ) -> None: 15 | super().__init__() 16 | self.args = args 17 | self.model = model 18 | 19 | self.save_hyperparameters(self.args) 20 | 21 | self.model_optimizer = optim_wrapper.optimizer_wrap(self.args, self.model).get_optim() 22 | self.LRScheduler = schedule_wrapper.scheduler_wrap(self.model_optimizer,self.args).get_scheduler() 23 | # for loss 24 | self.args.model = model 25 | self.args.samloss_optim = self.model_optimizer 26 | self.loss_criterion,self.loss_optimizer,self.minimizor = loss_wrapper.loss_wrap(self.args).get_loss() 27 | 28 | 29 | self.logging_test = None 30 | self.logging_predict = None 31 | 32 | def forward(self,x): 33 | return self.model(x) 34 | 35 | def training_step(self, batch, batch_idx): 36 | 37 | # batch[0] -- tensor 38 | # batch[1] -- label 39 | # batch[2] -- filename 40 | 41 | 42 | # model output, better return 2 elements, prediction and any other thing 43 | output = self.forward(batch[0]) 44 | batch_loss = self.loss_criterion(output[0], batch[1]) 45 | 46 | batch_loss = batch_loss.mean() 47 | self.log_dict({ 48 | "loss": batch_loss, 49 | },on_step=True, 50 | on_epoch=True,prog_bar=True, logger=True, 51 | # prevent from saving wrong ckp based on the eval_loss from different gpus 52 | sync_dist=True, 53 | ) 54 | return batch_loss 55 | 56 | def validation_step(self,batch): 57 | # batch[0] -- tensor 58 | # batch[1] -- filename 59 | 60 | # model output 61 | output = self.forward(batch[0]) 62 | 63 | softmax_pred = torch.nn.functional.softmax(output[0],dim=1) 64 | 65 | # log the prediction for cul eer 66 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'a') as file: 67 | for i in range(len(softmax_pred)): 68 | file.write(f"{batch[1][i]} {str(softmax_pred.cpu().numpy()[i][1])}\n") 69 | 70 | # batch_loss = self.loss_criterion(data_predict, data_label).mean() 71 | # # Logging to TensorBoard (if installed) by default 72 | # self.log("val_loss", batch_loss, batch_size=len(data_in),sync_dist=True) 73 | 74 | def on_validation_epoch_end(self) -> None: 75 | # culculate the dev eer 76 | dev_eer = 0. 77 | dev_tdcf = 0. 78 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'r') as file: 79 | lines = file.readlines() 80 | 81 | if len(lines) > 10000: 82 | dev_eer, dev_tdcf = cul_eer.eerandtdcf( 83 | os.path.join(self.logger.log_dir,"dev.log"), 84 | "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt", 85 | "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.dev.gi.trl.scores.txt" 86 | ) 87 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'w') as file: 88 | pass 89 | self.log_dict({ 90 | "dev_eer": (dev_eer), 91 | "dev_tdcf": dev_tdcf, 92 | },on_step=False, 93 | on_epoch=True,prog_bar=False, logger=True, 94 | # prevent from saving wrong ckp based on the eval_loss from different gpus 95 | sync_dist=True, 96 | ) 97 | 98 | def on_test_start(self): 99 | # logging.basicConfig(filename=os.path.join(self.logger.log_dir,f"infer_test.log"),level=logging.INFO,format="") 100 | self.logging_test = logging.getLogger("logging_test") 101 | self.logging_test.setLevel(logging.INFO) 102 | hdl=logging.FileHandler(os.path.join(self.logger.log_dir,f"infer_19.log")) 103 | hdl.setFormatter("") 104 | self.logging_test.addHandler(hdl) 105 | 106 | def test_step(self, batch,) -> Any: 107 | # batch[0] -- tensor 108 | # batch[1] -- filename 109 | 110 | # model output 111 | output = self.forward(batch[0]) 112 | 113 | data_predict = torch.nn.functional.softmax(output[0],dim=1) 114 | 115 | for i in range(len(batch[1])): 116 | self.logging_test.info(f"{batch[1][i]} {str(data_predict.cpu().numpy()[i][0])} {str(data_predict.cpu().numpy()[i][1])}") 117 | # return data_info[0],data_predict.cpu().numpy() 118 | return {'loss': 0, 'y_pred': data_predict} 119 | 120 | def on_predict_start(self): 121 | # logging.basicConfig(filename=os.path.join(self.args.savedir,f"infer_predict.log"),level=logging.INFO,format="") 122 | self.logging_predict = logging.getLogger(f"logging_predict_{self.args.testset}") 123 | self.logging_predict.setLevel(logging.INFO) 124 | hdlx = logging.FileHandler(os.path.join(self.logger.log_dir,f"infer_{self.args.testset}.log")) 125 | hdlx.setFormatter("") 126 | self.logging_predict.addHandler(hdlx) 127 | 128 | def predict_step(self, batch, batch_idx): 129 | # batch[0] -- tensor 130 | # batch[1] -- filename 131 | 132 | # model output 133 | output = self.forward(batch[0]) 134 | 135 | data_predict = torch.nn.functional.softmax(output[0],dim=1) 136 | 137 | # self.logging_predict.info(f"{data_info[0]} {str(data_predict.cpu().numpy()[0][1])} {str(data_predict.cpu().numpy()[0][0])}") 138 | for i in range(len(batch[1])): 139 | self.logging_predict.info(f"{batch[1][i]} {str(data_predict.cpu().numpy()[i][1])}") 140 | # return data_info[0],data_predict.cpu().numpy() 141 | return 142 | 143 | def configure_optimizers(self): 144 | configure = None 145 | if self.LRScheduler is not None: 146 | configure = { 147 | "optimizer":self.model_optimizer, 148 | 'lr_scheduler': self.LRScheduler, 149 | 'monitor': 'dev_eer' 150 | } 151 | else: 152 | configure = { 153 | "optimizer":self.model_optimizer, 154 | } 155 | 156 | return configure -------------------------------------------------------------------------------- /utils/loadData/RawBoost.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from scipy import signal 6 | import copy 7 | 8 | ''' 9 | Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. 10 | RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. 11 | In Proc. ICASSP 2022, pp:6382--6386. 12 | ''' 13 | 14 | def randRange(x1, x2, integer): 15 | y = np.random.uniform(low=x1, high=x2, size=(1,)) 16 | if integer: 17 | y = int(y) 18 | return y 19 | 20 | def normWav(x,always): 21 | if always: 22 | x = x/np.amax(abs(x)) 23 | elif np.amax(abs(x)) > 1: 24 | x = x/np.amax(abs(x)) 25 | return x 26 | 27 | 28 | def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): 29 | b = 1 30 | for i in range(0, nBands): 31 | fc = randRange(minF,maxF,0); 32 | bw = randRange(minBW,maxBW,0); 33 | c = randRange(minCoeff,maxCoeff,1); 34 | 35 | if c/2 == int(c/2): 36 | c = c + 1 37 | f1 = fc - bw/2 38 | f2 = fc + bw/2 39 | if f1 <= 0: 40 | f1 = 1/1000 41 | if f2 >= fs/2: 42 | f2 = fs/2-1/1000 43 | b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b) 44 | 45 | G = randRange(minG,maxG,0); 46 | _, h = signal.freqz(b, 1, fs=fs) 47 | b = pow(10, G/20)*b/np.amax(abs(h)) 48 | return b 49 | 50 | 51 | def filterFIR(x,b): 52 | N = b.shape[0] + 1 53 | xpad = np.pad(x, (0, N), 'constant') 54 | y = signal.lfilter(b, 1, xpad) 55 | y = y[int(N/2):int(y.shape[0]-N/2)] 56 | return y 57 | 58 | # Linear and non-linear convolutive noise 59 | def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs): 60 | y = [0] * x.shape[0] 61 | for i in range(0, N_f): 62 | if i == 1: 63 | minG = minG-minBiasLinNonLin; 64 | maxG = maxG-maxBiasLinNonLin; 65 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) 66 | y = y + filterFIR(np.power(x, (i+1)), b) 67 | y = y - np.mean(y) 68 | y = normWav(y,0) 69 | return y 70 | 71 | 72 | # Impulsive signal dependent noise 73 | def ISD_additive_noise(x, P, g_sd): 74 | beta = randRange(0, P, 0) 75 | 76 | y = copy.deepcopy(x) 77 | x_len = x.shape[0] 78 | n = int(x_len*(beta/100)) 79 | p = np.random.permutation(x_len)[:n] 80 | f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1)) 81 | r = g_sd * x[p] * f_r 82 | y[p] = x[p] + r 83 | y = normWav(y,0) 84 | return y 85 | 86 | 87 | # Stationary signal independent noise 88 | 89 | def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): 90 | noise = np.random.normal(0, 1, x.shape[0]) 91 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) 92 | noise = filterFIR(noise, b) 93 | noise = normWav(noise,1) 94 | SNR = randRange(SNRmin, SNRmax, 0) 95 | noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR) 96 | x = x + noise 97 | return x 98 | 99 | 100 | 101 | 102 | 103 | #--------------RawBoost data augmentation algorithms---------------------------## 104 | 105 | def process_Rawboost_feature(feature, sr,args,algo): 106 | 107 | # Data process by Convolutive noise (1st algo) 108 | if algo==1: 109 | 110 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 111 | 112 | # Data process by Impulsive noise (2nd algo) 113 | elif algo==2: 114 | 115 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 116 | 117 | # Data process by coloured additive noise (3rd algo) 118 | elif algo==3: 119 | 120 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 121 | 122 | # Data process by all 3 algo. together in series (1+2+3) 123 | elif algo==4: 124 | 125 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 126 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 127 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 128 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF, 129 | args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 130 | 131 | # Data process by 1st two algo. together in series (1+2) 132 | elif algo==5: 133 | 134 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 135 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 136 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 137 | 138 | 139 | # Data process by 1st and 3rd algo. together in series (1+3) 140 | elif algo==6: 141 | 142 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 143 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 144 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 145 | 146 | # Data process by 2nd and 3rd algo. together in series (2+3) 147 | elif algo==7: 148 | 149 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 150 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 151 | 152 | # Data process by 1st two algo. together in Parallel (1||2) 153 | elif algo==8: 154 | 155 | feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 156 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 157 | feature2=ISD_additive_noise(feature, args.P, args.g_sd) 158 | 159 | feature_para=feature1+feature2 160 | feature=normWav(feature_para,0) #normalized resultant waveform 161 | 162 | # original data without Rawboost processing 163 | else: 164 | 165 | feature=feature 166 | 167 | return feature 168 | 169 | -------------------------------------------------------------------------------- /models/LEMAAS/lemaas_v6_1.py: -------------------------------------------------------------------------------- 1 | """ 2 | lemaas 3 | """ 4 | import numpy as np 5 | import torch,math 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from timm.models.layers import DropPath 9 | 10 | def nearest_odd_number(c): 11 | # 计算表达式的值 12 | value = 0.5 * math.log2(c) + 0.5 13 | # 取整并转换为最接近的奇数 14 | nearest_odd = round(value) 15 | if nearest_odd % 2 == 0: 16 | # 如果最接近的奇数是偶数,则向上取整到下一个奇数 17 | nearest_odd += 1 18 | return nearest_odd 19 | 20 | 21 | class eca_layer(nn.Module): 22 | def __init__(self, channel): 23 | super(eca_layer, self).__init__() 24 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 25 | self.k_size = nearest_odd_number(channel) 26 | self.conv = nn.Conv1d(channel, channel, kernel_size=self.k_size,padding=(self.k_size-1)//2, bias=True, groups=channel) 27 | self.sigmoid = nn.Sigmoid() 28 | 29 | 30 | def forward(self, x): 31 | # b, c, _, _ = x.size() 32 | y = self.avg_pool(x) 33 | y = self.conv(y) 34 | y = self.sigmoid(y) 35 | x = x * y.expand_as(x) 36 | return x 37 | 38 | class modified_Block(nn.Module): 39 | r""" ConvNeXt Block. There are two equivalent implementations: 40 | planes: dimension of input tensor 41 | """ 42 | def __init__(self,planes, stride=1, scales=4, groups=1, norm_layer=True,drop_path = 0): 43 | super().__init__() 44 | if planes % scales != 0: #输出通道数为4的倍数 45 | raise ValueError('Planes must be divisible by scales') 46 | if norm_layer: #BN层 47 | norm_layer = nn.BatchNorm1d 48 | 49 | self.scales = scales 50 | self.stride = stride 51 | self.relu = nn.ReLU() 52 | #3*3的卷积层,一共有3个卷积层和3个BN层 53 | self.res2net_conv1n = nn.ModuleList([nn.Conv1d(planes // scales, planes // scales, 54 | kernel_size=3, stride=1, padding=1, groups=groups) for _ in range(scales-1)]) 55 | self.res2net_bn = nn.ModuleList([norm_layer(planes // scales) for _ in range(scales-1)]) 56 | 57 | self.pwconv1 = nn.Linear(planes, 4 * planes) # pointwise/1x1 convs, implemented with linear layers 58 | self.act = nn.SELU() 59 | self.pwconv2 = nn.Linear(4 * planes, planes) 60 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 61 | 62 | self.cam1 = eca_layer(planes) 63 | 64 | def forward(self, x): 65 | input = x 66 | 67 | # scales个(1x3)的残差分层架构 68 | xs = torch.chunk(x, self.scales, 1) #将x分割成scales块 69 | ys = [] 70 | for s in range(self.scales): 71 | if s == 0: 72 | ys.append(xs[s]) 73 | elif s == 1: 74 | ys.append(self.relu(self.res2net_bn[s-1](self.res2net_conv1n[s-1](xs[s])))) 75 | else: 76 | ys.append(self.relu(self.res2net_bn[s-1](self.res2net_conv1n[s-1](xs[s] + ys[-1])))) 77 | out = torch.cat(ys, 1) 78 | 79 | out = out.permute(0, 2, 1) # (N, C, W) -> (N, W, C) 80 | # linear layer in in*4 81 | x = self.pwconv1(out) 82 | # selu 83 | x = self.act(x) 84 | # linear layer in in 85 | x = self.pwconv2(x) 86 | # channel attention module 87 | x = x.permute(0, 2,1) # (N, W, C) -> (N, C, W) 88 | x = self.cam1(x) 89 | 90 | x = input + self.drop_path(x) 91 | return x 92 | 93 | class Model(nn.Module): 94 | def __init__(self, args=None): 95 | super().__init__() 96 | op1 = 16 97 | op2 = 32 98 | op3 = 64 99 | op4 = 128 100 | self.conv1 = nn.Conv1d(1, op1, kernel_size=(7,), stride=(1,),padding=3) 101 | self.bn1_1 = nn.BatchNorm1d(op1) 102 | self.bn1_2 = nn.BatchNorm1d(op1) 103 | self.block1_1 = modified_Block(op1,) 104 | 105 | self.conv2 = nn.Conv1d(op1, op2, kernel_size=(7,), stride=(1,),padding=3) 106 | self.bn2 = nn.BatchNorm1d(op2) 107 | self.block2_1 = modified_Block(op2,) 108 | self.block2_2 = modified_Block(op2,) 109 | 110 | self.conv3 = nn.Conv1d(op2, op3, kernel_size=(7,), stride=(1,),padding=3) 111 | self.bn3 = nn.BatchNorm1d(op3) 112 | self.block3_1 = modified_Block(op3,) 113 | self.block3_2 = modified_Block(op3,) 114 | self.block3_3 = modified_Block(op3,) 115 | 116 | self.conv4 = nn.Conv1d(op3, op4, kernel_size=(7,), stride=(1,),padding=3) 117 | # self.conv4 = nn.Conv1d(op3, op4, kernel_size=(7,), stride=(1,),padding=3) 118 | self.bn4 = nn.BatchNorm1d(op4) 119 | self.block4 = modified_Block(op4,) 120 | 121 | self.maxpool = nn.MaxPool1d(kernel_size=9) 122 | self.flatten = nn.Flatten() 123 | self.selu_act = nn.SELU() 124 | 125 | self.linear1 = nn.Linear(1792,64) 126 | self.linear2 = nn.Linear(64,16) 127 | self.linear3 = nn.Linear(16,2) 128 | self.softmax = nn.Softmax(1) 129 | 130 | def forward(self, x): 131 | x = x.unsqueeze(1) 132 | # print(x.shape) 133 | x = self.conv1(x) 134 | # print(x.shape) 135 | x = self.bn1_1(x) 136 | x = self.block1_1(x) 137 | x = self.bn1_2(x) 138 | x = self.conv2(x) 139 | x = self.maxpool(x) 140 | # print(x.shape) 141 | x = self.block2_1(x) 142 | x = self.block2_2(x) 143 | x = self.bn2(x) 144 | x = self.conv3(x) 145 | x = self.maxpool(x) 146 | # print(x.shape) 147 | x = self.block3_1(x) 148 | x = self.block3_2(x) 149 | x = self.block3_3(x) 150 | x = self.bn3(x) 151 | x = self.conv4(x) 152 | x = self.maxpool(x) 153 | # print(x.shape) 154 | x = self.block4(x) 155 | x = self.bn4(x) 156 | x = self.selu_act(x) 157 | x = self.maxpool(x) 158 | x = self.flatten(x) 159 | 160 | x = self.linear1(x) 161 | x = self.selu_act(x) 162 | x = self.linear2(x) 163 | x = self.selu_act(x) 164 | x = self.linear3(x) 165 | x = self.softmax(x) 166 | 167 | return x,None 168 | 169 | 170 | 171 | 172 | if __name__ == "__main__": 173 | md = Model() 174 | # print(summary(md, torch.randn((8,64600)), show_input=False)) 175 | op,res = md( torch.randn((4,96000))) 176 | print(op.shape) 177 | # # print(res.shape) 178 | print(sum(i.numel() for i in md.parameters() if i.requires_grad)/1000) # 0.97M 179 | 180 | # test = eca_layer(16) 181 | # z = test(torch.randn((8,16,95994))) 182 | # print(z.shape) 183 | # test = eca_layer(32) 184 | # z = test(torch.randn((8,32))) 185 | # print(z.shape) 186 | # test = eca_layer(64) 187 | # z = test(torch.randn((8,64))) 188 | # print(z.shape) 189 | # test = eca_layer(128) 190 | # z = test(torch.randn((8,128))) 191 | # print(z.shape) 192 | 193 | # mblock = modified_Block(16) 194 | # print(mblock(torch.randn((32, 16, 95994))).shape) 195 | # print(sum(i.numel() for i in mblock.parameters() if i.requires_grad)/1000) # 0.97M 196 | -------------------------------------------------------------------------------- /models/tl_model_postft_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import lightning as L 3 | import torch 4 | import logging,os 5 | from utils.wrapper import loss_wrapper, optim_wrapper,schedule_wrapper 6 | from utils.tools import cul_eer 7 | # from models.wav2vec.l5_aasist_step import Model 8 | from utils.ideas.reweight_learner import weight_learner 9 | from utils.ideas.reweight_learner import args as stable_arg 10 | 11 | 12 | 13 | class base_model(L.LightningModule): 14 | def __init__(self, 15 | model, 16 | args, 17 | ) -> None: 18 | super().__init__() 19 | self.args = args 20 | self.model = model 21 | self.stable_conf = stable_arg() 22 | self.args.stable_conf = str(vars(self.stable_conf)) 23 | self.save_hyperparameters(self.args) 24 | 25 | self.model_optimizer = optim_wrapper.optimizer_wrap(self.args, self.model).get_optim() 26 | self.LRScheduler = schedule_wrapper.scheduler_wrap(self.model_optimizer,self.args).get_scheduler() 27 | # for loss 28 | self.args.model = model 29 | self.args.samloss_optim = self.model_optimizer 30 | self.loss_criterion,self.loss_optimizer,self.minimizor = loss_wrapper.loss_wrap(self.args).get_loss() 31 | 32 | 33 | self.logging_test = None 34 | self.logging_predict = None 35 | 36 | 37 | def forward(self,x): 38 | return self.model(x) 39 | 40 | # def on_train_epoch_start(self): 41 | # if self.args.start_ft != 0: 42 | # print(self.current_epoch) 43 | # if self.current_epoch >= self.args.start_ft: 44 | # self.model.unfreeze_parameters() 45 | # else: 46 | # self.model.freeze_parameters() 47 | # self.model.unfreeze_parameters() 48 | 49 | 50 | def training_step(self, batch, batch_idx): 51 | 52 | # batch[0] -- tensor 53 | # batch[1] -- label 54 | # batch[2] -- filename 55 | 56 | 57 | # model output, better return 2 elements, prediction and any other thing 58 | output = self.forward(batch[0]) 59 | batch_loss = self.loss_criterion(output[0], batch[1]) 60 | 61 | # stable 62 | pre_features = self.model.pre_features 63 | pre_weight1 = self.model.pre_weight1 64 | 65 | loss_weight, pre_features, pre_weight1 = weight_learner( 66 | output[1], 67 | pre_features, 68 | pre_weight1, 69 | args=self.stable_conf, 70 | global_epoch = self.current_epoch, iter = batch_idx) 71 | self.model.pre_features.data.copy_(pre_features) 72 | self.model.pre_weight1.data.copy_(pre_weight1) 73 | batch_loss = batch_loss.view(1, -1).mm(loss_weight).view(1) 74 | 75 | 76 | 77 | batch_loss = batch_loss.mean() 78 | self.log_dict({ 79 | "loss": batch_loss, 80 | },on_step=True, 81 | on_epoch=True,prog_bar=True, logger=True, 82 | # prevent from saving wrong ckp based on the eval_loss from different gpus 83 | sync_dist=True, 84 | ) 85 | return batch_loss 86 | 87 | def validation_step(self,batch): 88 | # batch[0] -- tensor 89 | # batch[1] -- label 90 | # batch[2] -- filename 91 | 92 | # model output 93 | output = self.forward(batch[0]) 94 | 95 | softmax_pred = torch.nn.functional.softmax(output[0],dim=1) 96 | 97 | # log the prediction for cul eer 98 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'a') as file: 99 | for i in range(len(softmax_pred)): 100 | file.write(f"{batch[2][i]} {str(softmax_pred.cpu().numpy()[i][1])}\n") 101 | 102 | # batch_loss = self.loss_criterion(data_predict, data_label).mean() 103 | # # Logging to TensorBoard (if installed) by default 104 | # self.log("val_loss", batch_loss, batch_size=len(data_in),sync_dist=True) 105 | 106 | def on_validation_epoch_end(self) -> None: 107 | # culculate the dev eer 108 | dev_eer = 0. 109 | dev_tdcf = 0. 110 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'r') as file: 111 | lines = file.readlines() 112 | 113 | if len(lines) > 10000: 114 | if "singfake" in self.args.data_module: 115 | dev_eer = cul_eer.eeronly( 116 | os.path.join(self.logger.log_dir,"dev.log"), 117 | "/data8/wangzhiyong/project/fakeAudioDetection/pytorch_lightning_FAD/datasets/sing_fsd/dataset/label/dev.txt", 118 | ) 119 | else: 120 | dev_eer, dev_tdcf = cul_eer.eerandtdcf( 121 | os.path.join(self.logger.log_dir,"dev.log"), 122 | "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt", 123 | "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.dev.gi.trl.scores.txt" 124 | ) 125 | with open(os.path.join(self.logger.log_dir,"dev.log"), 'w') as file: 126 | pass 127 | self.log_dict({ 128 | "dev_eer": (dev_eer), 129 | "dev_tdcf": dev_tdcf, 130 | },on_step=False, 131 | on_epoch=True,prog_bar=False, logger=True, 132 | # prevent from saving wrong ckp based on the eval_loss from different gpus 133 | sync_dist=True, 134 | ) 135 | 136 | def on_test_start(self): 137 | # logging.basicConfig(filename=os.path.join(self.logger.log_dir,f"infer_test.log"),level=logging.INFO,format="") 138 | self.logging_test = logging.getLogger("logging_test") 139 | self.logging_test.setLevel(logging.INFO) 140 | hdl=logging.FileHandler(os.path.join(self.logger.log_dir,f"infer_19.log")) 141 | hdl.setFormatter("") 142 | self.logging_test.addHandler(hdl) 143 | 144 | def test_step(self, batch,) -> Any: 145 | # batch[0] -- tensor 146 | # batch[1] -- filename 147 | 148 | # model output 149 | output = self.forward(batch[0]) 150 | 151 | data_predict = torch.nn.functional.softmax(output[0],dim=1) 152 | 153 | for i in range(len(batch[1])): 154 | self.logging_test.info(f"{batch[1][i]} {str(data_predict.cpu().numpy()[i][0])} {str(data_predict.cpu().numpy()[i][1])}") 155 | # return data_info[0],data_predict.cpu().numpy() 156 | return {'loss': 0, 'y_pred': data_predict} 157 | 158 | def on_predict_start(self): 159 | # logging.basicConfig(filename=os.path.join(self.args.savedir,f"infer_predict.log"),level=logging.INFO,format="") 160 | self.logging_predict = logging.getLogger(f"logging_predict_{self.args.testset}") 161 | self.logging_predict.setLevel(logging.INFO) 162 | hdlx = logging.FileHandler(os.path.join(self.logger.log_dir,f"infer_{self.args.testset}.log")) 163 | hdlx.setFormatter("") 164 | self.logging_predict.addHandler(hdlx) 165 | 166 | def predict_step(self, batch, batch_idx): 167 | # batch[0] -- tensor 168 | # batch[1] -- filename 169 | 170 | # model output 171 | output = self.forward(batch[0]) 172 | 173 | data_predict = torch.nn.functional.softmax(output[0],dim=1) 174 | 175 | # self.logging_predict.info(f"{data_info[0]} {str(data_predict.cpu().numpy()[0][1])} {str(data_predict.cpu().numpy()[0][0])}") 176 | for i in range(len(batch[1])): 177 | self.logging_predict.info(f"{batch[1][i]} {str(data_predict.cpu().numpy()[i][1])}") 178 | # return data_info[0],data_predict.cpu().numpy() 179 | return 180 | 181 | def configure_optimizers(self): 182 | configure = None 183 | if self.LRScheduler is not None: 184 | configure = { 185 | "optimizer":self.model_optimizer, 186 | 'lr_scheduler': self.LRScheduler, 187 | 'monitor': 'dev_eer' 188 | } 189 | else: 190 | configure = { 191 | "optimizer":self.model_optimizer, 192 | } 193 | 194 | return configure -------------------------------------------------------------------------------- /utils/wrapper/loss_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd.function import Function 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from collections import defaultdict 7 | from typing import Tuple 8 | 9 | from argparse import Namespace 10 | 11 | 12 | 13 | class loss_config(): 14 | def __init__(self): 15 | self.loss = "WCE" 16 | self.reduce = 1 17 | # for WCE 18 | # self.loss_weight = torch.FloatTensor([0.1, 0.9]) 19 | self.loss_lr = 0.01 20 | # for SAM ASAM 21 | self.samloss_optim = None 22 | self.model = None 23 | self.rho = 0.5 24 | self.eta = 0.0 25 | 26 | class loss_wrap(): 27 | def __init__(self,cfg:loss_config): 28 | super(loss_wrap).__init__() 29 | self.cfg = cfg 30 | if cfg.reduce == 1: 31 | self.reduce = "mean" 32 | else: 33 | self.reduce = "none" 34 | 35 | 36 | def get_loss(self): 37 | final_Loss = None 38 | loss_optim = None 39 | minimizor = None 40 | if self.cfg.loss == "CE": 41 | final_Loss = nn.CrossEntropyLoss(reduction=self.reduce) 42 | elif self.cfg.loss == "FOCAL": 43 | final_Loss = FocalLoss(gamma=2 , alpha=2580/22800) 44 | elif self.cfg.loss == "WCE": 45 | final_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 0.9]), reduction=self.reduce) 46 | elif self.cfg.loss == "WCEsf": 47 | final_Loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.15, 0.85]), reduction=self.reduce) 48 | elif self.cfg.loss == "AM": 49 | final_Loss=AMSoftmax() 50 | loss_optim = torch.optim.SGD(final_Loss.parameters(), lr=self.cfg.loss_lr) # 0.01 51 | elif self.cfg.loss == "OC": 52 | final_Loss=OCSoftmax() 53 | loss_optim = torch.optim.SGD(final_Loss.parameters(), lr=self.cfg.loss_lr) #0.0003 54 | 55 | elif self.cfg.loss == "SAM": 56 | minimizor = SAM(self.cfg.samloss_optim, self.cfg.model, rho=self.cfg.rho, eta=self.cfg.eta) 57 | 58 | elif self.cfg.loss == "ASAM": 59 | minimizor = ASAM(self.cfg.samloss_optim, self.cfg.model, rho=self.cfg.rho, eta=self.cfg.eta) 60 | 61 | else: 62 | raise Exception(f"no loss named {self.cfg.loss}") 63 | 64 | return final_Loss, loss_optim, minimizor 65 | 66 | 67 | 68 | 69 | 70 | class FocalLoss(nn.Module): 71 | def __init__(self, gamma=0, alpha=None, size_average=True): 72 | super(FocalLoss, self).__init__() 73 | self.gamma = gamma 74 | self.alpha = alpha 75 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 76 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 77 | self.size_average = size_average 78 | 79 | def forward(self, input, target): 80 | if input.dim()>2: 81 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 82 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 83 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 84 | target = target.view(-1,1) 85 | 86 | logpt = F.log_softmax(input) 87 | logpt = logpt.gather(1,target) 88 | logpt = logpt.view(-1) 89 | pt = Variable(logpt.data.exp()) 90 | 91 | if self.alpha is not None: 92 | if self.alpha.type()!=input.data.type(): 93 | self.alpha = self.alpha.type_as(input.data) 94 | at = self.alpha.gather(0,target.data.view(-1)) 95 | logpt = logpt * Variable(at) 96 | 97 | loss = -1 * (1-pt)**self.gamma * logpt 98 | if self.size_average: return loss.mean() 99 | else: return loss.sum() 100 | 101 | 102 | class OCSoftmax(nn.Module): 103 | def __init__(self, feat_dim=2, r_real=0.9, r_fake=0.5, alpha=20.0,reduce = True): 104 | super(OCSoftmax, self).__init__() 105 | self.feat_dim = feat_dim 106 | self.r_real = r_real 107 | self.r_fake = r_fake 108 | self.alpha = alpha 109 | self.center = nn.Parameter(torch.randn(1, self.feat_dim)) 110 | self.reduce = reduce 111 | nn.init.kaiming_uniform_(self.center, 0.25) 112 | self.softplus = nn.Softplus() 113 | 114 | def forward(self, x, labels): 115 | """ 116 | Args: 117 | x: feature matrix with shape (batch_size, feat_dim). 118 | labels: ground truth labels with shape (batch_size). 119 | """ 120 | w = F.normalize(self.center, p=2, dim=1) 121 | x = F.normalize(x, p=2, dim=1) 122 | 123 | scores = x @ w.transpose(0,1) 124 | output_scores = scores.clone() 125 | 126 | scores[labels == 0] = self.r_real - scores[labels == 0] 127 | scores[labels == 1] = scores[labels == 1] - self.r_fake 128 | 129 | if self.reduce: 130 | loss = self.softplus(self.alpha * scores).mean() 131 | else: 132 | loss = self.softplus(self.alpha * scores) 133 | # print(output_scores.squeeze(1).shape) 134 | # return loss, -output_scores.squeeze(1) 135 | return loss 136 | 137 | class AMSoftmax(nn.Module): 138 | def __init__(self, num_classes=2, enc_dim=2, s=20, m=0.9): 139 | super(AMSoftmax, self).__init__() 140 | self.enc_dim = enc_dim 141 | self.num_classes = num_classes 142 | self.s = s 143 | self.m = m 144 | self.centers = nn.Parameter(torch.randn(num_classes, enc_dim)) 145 | 146 | def forward(self, feat, label): 147 | batch_size = feat.shape[0] 148 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 149 | nfeat = torch.div(feat, norms) 150 | 151 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 152 | ncenters = torch.div(self.centers, norms_c) 153 | logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) 154 | 155 | y_onehot = torch.FloatTensor(batch_size, self.num_classes) 156 | y_onehot.zero_() 157 | y_onehot = Variable(y_onehot).cuda() 158 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 159 | margin_logits = self.s * (logits - y_onehot) 160 | # print(margin_logits.shape) 161 | 162 | # return logits, margin_logits 163 | return logits 164 | 165 | 166 | 167 | class ASAM: 168 | def __init__(self, optimizer, model, rho=0.5, eta=0.01): 169 | self.optimizer = optimizer 170 | self.model = model 171 | self.rho = rho 172 | self.eta = eta 173 | self.state = defaultdict(dict) 174 | 175 | @torch.no_grad() 176 | def ascent_step(self): 177 | wgrads = [] 178 | for n, p in self.model.named_parameters(): 179 | if p.grad is None: 180 | continue 181 | t_w = self.state[p].get("eps") 182 | if t_w is None: 183 | t_w = torch.clone(p).detach() 184 | self.state[p]["eps"] = t_w 185 | if 'weight' in n: 186 | t_w[...] = p[...] 187 | t_w.abs_().add_(self.eta) 188 | p.grad.mul_(t_w) 189 | wgrads.append(torch.norm(p.grad, p=2)) 190 | wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 191 | for n, p in self.model.named_parameters(): 192 | if p.grad is None: 193 | continue 194 | t_w = self.state[p].get("eps") 195 | if 'weight' in n: 196 | p.grad.mul_(t_w) 197 | eps = t_w 198 | eps[...] = p.grad[...] 199 | eps.mul_(self.rho / wgrad_norm) 200 | p.add_(eps) 201 | self.optimizer.zero_grad() 202 | 203 | @torch.no_grad() 204 | def descent_step(self): 205 | for n, p in self.model.named_parameters(): 206 | if p.grad is None: 207 | continue 208 | p.sub_(self.state[p]["eps"]) 209 | self.optimizer.step() 210 | self.optimizer.zero_grad() 211 | 212 | 213 | class SAM(ASAM): 214 | @torch.no_grad() 215 | def ascent_step(self): 216 | grads = [] 217 | for n, p in self.model.named_parameters(): 218 | if p.grad is None: 219 | continue 220 | grads.append(torch.norm(p.grad, p=2)) 221 | grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16 222 | for n, p in self.model.named_parameters(): 223 | if p.grad is None: 224 | continue 225 | eps = self.state[p].get("eps") 226 | if eps is None: 227 | eps = torch.clone(p).detach() 228 | self.state[p]["eps"] = eps 229 | eps[...] = p.grad[...] 230 | eps.mul_(self.rho / grad_norm) 231 | p.add_(eps) 232 | self.optimizer.zero_grad() 233 | 234 | 235 | if __name__ == "__main__": 236 | 237 | cfg = loss_config() 238 | cfg.loss = "wzy" 239 | wrap = loss_wrap(cfg) 240 | wrap.get_loss() -------------------------------------------------------------------------------- /utils/arg_parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | """ 4 | startup_config 5 | 6 | Startup configuration utilities 7 | 8 | """ 9 | import os 10 | import sys 11 | import torch 12 | import importlib 13 | import random 14 | import numpy as np 15 | 16 | __author__ = "Xin Wang" 17 | __email__ = "wangxin@nii.ac.jp" 18 | __copyright__ = "Copyright 2020, Xin Wang" 19 | 20 | 21 | def set_random_seed(random_seed, args=None): 22 | """ set_random_seed(random_seed, args=None) 23 | 24 | Set the random_seed for numpy, python, and cudnn 25 | 26 | input 27 | ----- 28 | random_seed: integer random seed 29 | args: argue parser 30 | """ 31 | 32 | # initialization 33 | torch.manual_seed(random_seed) 34 | random.seed(random_seed) 35 | np.random.seed(random_seed) 36 | os.environ['PYTHONHASHSEED'] = str(random_seed) 37 | 38 | #For torch.backends.cudnn.deterministic 39 | #Note: this default configuration may result in RuntimeError 40 | #see https://pytorch.org/docs/stable/notes/randomness.html 41 | if torch.cuda.is_available(): 42 | torch.cuda.manual_seed_all(random_seed) 43 | return 44 | 45 | 46 | 47 | 48 | 49 | 50 | """ Arg_parse""" 51 | # using the common args,like batchsize, epoch... 52 | def f_args_parsed(argument_input = None): 53 | 54 | parser = argparse.ArgumentParser( 55 | description='General argument parse' 56 | ) 57 | mes="" 58 | # random seed 59 | parser.add_argument('--seed', type=int, default=1234, help='random seed (default: 1234)') 60 | # for DDP 61 | parser.add_argument('--gpuid', type=str, default="0") 62 | 63 | # ⭐⭐⭐ training model filename and its data config filename 64 | # 'module of model definition (default model, model.py will be loaded)' 65 | parser.add_argument('--module_model', type=str, default="model") 66 | # module of torch lightning model train step file 67 | parser.add_argument('--tl_model', type=str, default="models.tl_model") 68 | # datamodule python file 69 | parser.add_argument('--data_module', type=str, default="utils.loadData.asvspoof_data_DA") 70 | 71 | # pretrained model config(hugging face config) 72 | # parser.add_argument('--pretrained-model-config', type=str, default="") 73 | ###### 74 | # ⭐⭐Training settings 75 | # inference or train 76 | parser.add_argument('--inference', action='store_true', default=False, help=mes) 77 | # 'batch size for training/inference (default: 8)' 78 | parser.add_argument('--batch_size', type=int, default=8) 79 | # 'number of epochs to train (default: 50)' 80 | parser.add_argument('--epochs', type=int, default=100) 81 | # 'number of no-best epochs for early stopping (default: 5)' 82 | parser.add_argument('--no_best_epochs', type=int, default=5) 83 | 84 | 85 | ###### 86 | # ⭐options to save model 87 | # checkpoint dir 88 | parser.add_argument('--savedir', type=str, default="./a_train_log", help='save model to this direcotry (default ./)') 89 | 90 | ####### 91 | # ⭐options to load model 92 | # 'a trained model for inference or resume training ' 93 | parser.add_argument('--trained_model', type=str, default="", help=mes + "(default: '')") 94 | # infer dataset 95 | parser.add_argument('--testset', type=str, default="LA21", help=mes + "(default: 'LA21, DF21, ITW')") 96 | parser.add_argument('--truncate', type=int, default=64600) 97 | 98 | 99 | # ⭐for loss selection 100 | # (default is CE, WCE, AM, OC, SAM, ASAM; other str will be defaultly set to CE) 101 | parser.add_argument('--loss', type=str, default="WCE") 102 | # 1 for reduction, 0 for no reduce 103 | parser.add_argument('--reduce', type=int, default=0) 104 | parser.add_argument('--loss_lr', type=float, default=0.01) 105 | parser.add_argument('--rho', type=float, default=0.5) 106 | parser.add_argument('--eta', type=str, default=0.0) 107 | 108 | 109 | 110 | # # ⭐optimizer setting 111 | # for optimizer selection, (adam, adamw, sgd ) 112 | parser.add_argument('--optim', type=str, default="adam") 113 | # # learning rate 114 | parser.add_argument('--optim_lr', type=float, default=0.0001,help='learning rate (default: 0.0001)') 115 | # # weight_decay / l2 penalty 116 | parser.add_argument('--weight_decay', type=float, default=0.0001) 117 | # for SGD 118 | parser.add_argument('--momentum', type=float, default=0.9) 119 | 120 | 121 | # (cosWarmup, cosAnneal, step) 122 | parser.add_argument('--scheduler', type=str, default="") 123 | # warm up settings,uppper stage, default 3 124 | parser.add_argument('--num_warmup_steps', type=int, default=3) 125 | # for cosAnneal, num of train samples // batchsize * epochs 126 | parser.add_argument('--total_step', type=int, default=1057) 127 | # scheduler 128 | parser.add_argument('--step_size', type=int, default=5) 129 | parser.add_argument('--gamma', type=float, default=0.1) 130 | 131 | # applying data augmentation 132 | parser.add_argument('--usingDA', action='store_true', default=False) 133 | parser.add_argument('--da_prob', type=float, default=2) 134 | 135 | 136 | args_main = parser.parse_args() 137 | 138 | if not args_main.usingDA: 139 | return args_main 140 | else: 141 | ##===================================================Rawboost data augmentation ======================================================================# 142 | 143 | parser.add_argument('--algo', type=int, default=5, 144 | help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \ 145 | 5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .[default=0]') 146 | 147 | # LnL_convolutive_noise parameters 148 | parser.add_argument('--nBands', type=int, default=5, 149 | help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]') 150 | parser.add_argument('--minF', type=int, default=20, 151 | help='minimum centre frequency [Hz] of notch filter.[default=20] ') 152 | parser.add_argument('--maxF', type=int, default=8000, 153 | help='maximum centre frequency [Hz] (= max_len: 227 | return x[:max_len] 228 | # need to pad 229 | num_repeats = int(max_len / x_len) + 1 230 | padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] 231 | return padded_x 232 | 233 | 234 | def pad_random(x: np.ndarray, max_len: int = 64600): 235 | x_len = x.shape[0] 236 | # if duration is already long enough 237 | if x_len >= max_len: 238 | stt = np.random.randint(x_len - max_len) 239 | return x[stt:stt + max_len] 240 | 241 | # if too short 242 | num_repeats = int(max_len / x_len) + 1 243 | padded_x = np.tile(x, (num_repeats))[:max_len] 244 | return padded_x 245 | 246 | 247 | class Dataset_ASVspoof2019_train(Dataset): 248 | def __init__(self, list_IDs, labels, base_dir,args,cut = 64600): 249 | """self.list_IDs : list of strings (each string: utt key), 250 | self.labels : dictionary (key: utt key, value: label integer)""" 251 | self.list_IDs = list_IDs 252 | self.labels = labels 253 | self.base_dir = base_dir 254 | self.args = args 255 | self.cut = cut # take ~4 sec audio (64600 samples) 256 | 257 | def __len__(self): 258 | return len(self.list_IDs) 259 | 260 | def __getitem__(self, index): 261 | key = self.list_IDs[index] 262 | X, fs = sf.read(os.path.join(self.base_dir,f"{key}.wav")) 263 | 264 | if self.args.usingDA and (np.random.rand() < self.args.da_prob): 265 | X=process_Rawboost_feature(X,fs,self.args,self.args.algo) 266 | if self.cut == 0: 267 | X_pad = X 268 | else: 269 | X_pad = pad_random(X, self.cut) 270 | x_inp = Tensor(X_pad) 271 | y = self.labels[key] 272 | # 1. tensor 2.label 3.filename 273 | return x_inp, y, key 274 | 275 | 276 | class Dataset_ASVspoof2019_devNeval(Dataset): 277 | def __init__(self, list_IDs, base_dir,args=None,cut = 64600): 278 | """self.list_IDs : list of strings (each string: utt key), 279 | """ 280 | self.list_IDs = list_IDs 281 | self.base_dir = base_dir 282 | self.cut = cut # take ~4 sec audio (64600 samples) 283 | self.args = args 284 | 285 | def __len__(self): 286 | return len(self.list_IDs) 287 | 288 | def __getitem__(self, index): 289 | key = self.list_IDs[index] 290 | X, fs = sf.read(os.path.join(self.base_dir,f"{key}.wav")) 291 | if self.args.usingDA and ("ASVspoof2019_LA_dev" in self.base_dir): 292 | X=process_Rawboost_feature(X,fs,self.args,self.args.algo) 293 | if self.cut == 0: 294 | X_pad = X 295 | else: 296 | X_pad = pad(X, self.cut) 297 | x_inp = Tensor(X_pad) 298 | # 1.tensor 2.filename 299 | return x_inp, key 300 | 301 | 302 | class Dataset_ASVspoof2019_evaltest(Dataset): 303 | def __init__(self, list_IDs, base_dir,args=None,cut = 64600): 304 | """self.list_IDs : list of strings (each string: utt key), 305 | """ 306 | self.list_IDs = list_IDs 307 | self.base_dir = base_dir 308 | self.cut = cut # take ~4 sec audio (64600 samples) 309 | self.args = args 310 | 311 | def __len__(self): 312 | return len(self.list_IDs) 313 | 314 | def __getitem__(self, index): 315 | key = self.list_IDs[index] 316 | X, fs = sf.read(os.path.join(self.base_dir,f"{key}.wav")) 317 | if self.cut == 0: 318 | X_pad = X 319 | else: 320 | X_pad = pad(X, self.cut) 321 | x_inp = Tensor(X_pad) 322 | return x_inp, key 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | -------------------------------------------------------------------------------- /utils/loadData/asvspoof_data_DA_still_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import soundfile as sf 3 | import torch,os 4 | from torch import Tensor 5 | from torch.utils.data import Dataset,DataLoader,DistributedSampler 6 | from .RawBoost import process_Rawboost_feature 7 | import lightning as L 8 | from transformers import Wav2Vec2FeatureExtractor 9 | from utils.tools.tools import pad,pad_random 10 | 11 | class asvspoof_dataModule(L.LightningDataModule): 12 | def __init__(self,args): 13 | super().__init__() 14 | self.args = args 15 | 16 | # TODO: change the dir to your own data dir 17 | # label file 18 | self.protocols_path = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/" 19 | self.train_protocols_file = self.protocols_path + "ASVspoof2019.LA.cm.train.trl.txt" 20 | self.dev_protocols_file = self.protocols_path + "ASVspoof2019.LA.cm.dev.trl.txt" 21 | # flac file dir 22 | self.dataset_base_path="/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/" 23 | self.train_set=self.dataset_base_path+"ASVspoof2019_LA_train/" 24 | self.dev_set=self.dataset_base_path+"ASVspoof2019_LA_dev/" 25 | # test set 26 | self.eval_protocols_file_19 = self.protocols_path + "ASVspoof2019.LA.cm.eval.trl.txt" 27 | self.eval_set_19 = self.dataset_base_path+"ASVspoof2019_LA_eval/" 28 | self.eval_protocols_file_21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/ASVspoof2021_LA_eval/eval_file/ASVspoof2021.LA.cm.eval.trl.txt" 29 | self.eval_set_21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/ASVspoof2021_LA_eval/" 30 | 31 | 32 | self.LA21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/eval_file/ASVspoof2021.LA.cm.eval.trl.txt" 33 | self.LA21FLAC = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/" 34 | self.LA21TRIAL = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/eval_file/CM_trial_metadata.txt" 35 | 36 | self.DF21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/ASVspoof2021.DF.cm.eval.trl.txt" 37 | self.DF21FLAC = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/" 38 | self.DF21TRIAL = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/trial_metadata.txt" 39 | 40 | self.ITWTXT = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/release_in_the_wild/label.txt" 41 | self.ITWDIR = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/release_in_the_wild/wav" 42 | 43 | self.truncate = args.truncate 44 | self.predict = args.testset # LA21, DF21, ITW 45 | 46 | def setup(self, stage: str): 47 | # Assign train/val datasets for use in dataloaders 48 | if stage == "fit": 49 | d_label_trn,file_train = genSpoof_list( 50 | dir_meta=self.train_protocols_file, 51 | is_train=True, 52 | is_eval=False 53 | ) 54 | 55 | self.asvspoof19_trn_set = Dataset_ASVspoof2019_train( 56 | list_IDs=file_train, 57 | labels=d_label_trn, 58 | base_dir=self.train_set, 59 | cut=self.truncate, 60 | args= self.args, 61 | ) 62 | 63 | 64 | label_dev, file_dev = genSpoof_list( 65 | dir_meta=self.dev_protocols_file, 66 | is_train=True, 67 | is_eval=False) 68 | 69 | self.asvspoof19_val_set = Dataset_ASVspoof2019_train( 70 | list_IDs=file_dev, 71 | labels=label_dev, 72 | base_dir=self.dev_set, 73 | cut=self.truncate, 74 | args= self.args, 75 | ) 76 | 77 | # Assign test dataset for use in dataloader(s) 78 | if stage == "test": 79 | file_eval = genSpoof_list( 80 | dir_meta=self.eval_protocols_file_19, 81 | is_train=False, 82 | is_eval=True 83 | ) 84 | self.asvspoof19_test_set = Dataset_ASVspoof2019_evaltest( 85 | list_IDs=file_eval, 86 | base_dir=self.eval_set_19, 87 | cut=self.truncate 88 | ) 89 | 90 | if stage == "predict": 91 | if self.predict == "LA21": 92 | file_list=[] 93 | with open(self.LA21, 'r') as f: 94 | l_meta = f.readlines() 95 | for line in l_meta: 96 | key= line.strip() 97 | file_list.append(key) 98 | print(f"no.{(len(file_list))} of eval trials") 99 | self.predict_set = Dataset_ASVspoof2019_evaltest( 100 | list_IDs=file_list, 101 | base_dir=self.LA21FLAC, 102 | cut=self.truncate) 103 | 104 | elif self.predict == "DF21": 105 | file_list=[] 106 | with open(self.DF21, 'r') as f: 107 | l_meta = f.readlines() 108 | for line in l_meta: 109 | key= line.strip() 110 | file_list.append(key) 111 | print(f"no.{(len(file_list))} of eval trials") 112 | self.predict_set = Dataset_ASVspoof2019_evaltest( 113 | list_IDs=file_list, 114 | base_dir=self.DF21FLAC, 115 | cut=self.truncate) 116 | 117 | elif self.predict == "ITW": 118 | file_list=[] 119 | # 打开文件 120 | with open(self.ITWTXT, 'r') as file: 121 | lines = file.readlines() 122 | for line in lines: 123 | columns = line.split() 124 | file_list.append(columns[1]) 125 | self.predict_set = dataset_itw( 126 | list_IDs=file_list, 127 | base_dir=self.ITWDIR, 128 | cut=self.truncate) 129 | 130 | 131 | 132 | 133 | 134 | def train_dataloader(self): 135 | return DataLoader(self.asvspoof19_trn_set, batch_size=self.args.batch_size, shuffle=True,drop_last = True,num_workers=8) 136 | 137 | def val_dataloader(self): 138 | return DataLoader(self.asvspoof19_val_set, batch_size=self.args.batch_size, shuffle=False,drop_last = False,num_workers=8) 139 | 140 | def test_dataloader(self): 141 | datald = DataLoader( 142 | self.asvspoof19_test_set,batch_size=self.args.batch_size, 143 | shuffle=False,num_workers=8 144 | ) 145 | if "," in self.args.gpuid: 146 | datald = DataLoader( 147 | self.asvspoof19_test_set,batch_size=self.args.batch_size, 148 | shuffle=False,num_workers=8, 149 | sampler=DistributedSampler(self.asvspoof19_test_set) 150 | ) 151 | return datald 152 | 153 | def predict_dataloader(self): 154 | predict_loader = DataLoader( 155 | self.predict_set, 156 | batch_size= self.args.batch_size, 157 | shuffle=False, 158 | drop_last=False, 159 | pin_memory=True, 160 | num_workers=8) 161 | if "," in self.args.gpuid: 162 | predict_loader = DataLoader( 163 | self.predict_set, 164 | batch_size= self.args.batch_size, 165 | shuffle=False, 166 | drop_last=False, 167 | pin_memory=True, 168 | sampler=DistributedSampler(self.predict_set), 169 | num_workers=8 170 | ) 171 | return predict_loader 172 | 173 | 174 | def norm(X_pad): 175 | mean_x = X_pad.mean() 176 | var_x = X_pad.var() 177 | return np.array([(x - mean_x) / np.sqrt(var_x + 1e-7) for x in X_pad]) 178 | 179 | 180 | class dataset_itw(Dataset): 181 | def __init__(self, list_IDs, base_dir,cut = 64600): 182 | self.list_IDs = list_IDs 183 | self.base_dir = base_dir 184 | self.cut = cut # take ~4 sec audio (64600 samples) 185 | 186 | def __len__(self): 187 | return len(self.list_IDs) 188 | 189 | def __getitem__(self, index): 190 | key = self.list_IDs[index] 191 | X, _ = sf.read(os.path.join(self.base_dir,f"{key}.wav")) 192 | if self.cut == 0: 193 | X_pad = X 194 | else: 195 | X_pad = pad(X, self.cut) 196 | X_pad = norm(X_pad) 197 | x_inp = Tensor(X_pad) 198 | return x_inp, key 199 | 200 | 201 | 202 | def genSpoof_list(dir_meta, is_train=False, is_eval=False): 203 | 204 | d_meta = {} 205 | file_list = [] 206 | with open(dir_meta, "r") as f: 207 | l_meta = f.readlines() 208 | 209 | if is_train: 210 | for line in l_meta: 211 | _, key, _, _, label = line.strip().split(" ") 212 | file_list.append(key) 213 | d_meta[key] = 1 if label == "bonafide" else 0 214 | return d_meta, file_list 215 | 216 | elif is_eval: 217 | for line in l_meta: 218 | _, key, _, _, _ = line.strip().split(" ") 219 | #key = line.strip() 220 | file_list.append(key) 221 | return file_list 222 | else: 223 | for line in l_meta: 224 | _, key, _, _, label = line.strip().split(" ") 225 | file_list.append(key) 226 | d_meta[key] = 1 if label == "bonafide" else 0 227 | return d_meta, file_list 228 | 229 | 230 | class Dataset_ASVspoof2019_train(Dataset): 231 | def __init__(self, list_IDs, labels, base_dir,args,cut = 64600): 232 | """self.list_IDs : list of strings (each string: utt key), 233 | self.labels : dictionary (key: utt key, value: label integer)""" 234 | self.list_IDs = list_IDs 235 | self.labels = labels 236 | self.base_dir = base_dir 237 | self.args = args 238 | self.cut = cut # take ~4 sec audio (64600 samples) 239 | 240 | def __len__(self): 241 | return len(self.list_IDs) 242 | 243 | def __getitem__(self, index): 244 | key = self.list_IDs[index] 245 | X, fs = sf.read(os.path.join(self.base_dir , f"flac/{key}.flac")) 246 | if self.args.usingDA and (np.random.rand() < self.args.da_prob): 247 | X=process_Rawboost_feature(X,fs,self.args,self.args.algo) 248 | if self.cut == 0: 249 | X_pad = X 250 | else: 251 | X_pad = pad(X, self.cut) 252 | X_pad = norm(X_pad) 253 | x_inp = Tensor(X_pad) 254 | y = self.labels[key] 255 | # 1. tensor 2.label 3.filename 256 | return x_inp, y, key 257 | 258 | 259 | 260 | class Dataset_ASVspoof2019_evaltest(Dataset): 261 | def __init__(self, list_IDs, base_dir,args=None,cut = 64600): 262 | """self.list_IDs : list of strings (each string: utt key), 263 | """ 264 | self.list_IDs = list_IDs 265 | self.base_dir = base_dir 266 | self.cut = cut # take ~4 sec audio (64600 samples) 267 | self.args = args 268 | 269 | def __len__(self): 270 | return len(self.list_IDs) 271 | 272 | def __getitem__(self, index): 273 | key = self.list_IDs[index] 274 | X, fs = sf.read(os.path.join(self.base_dir,f"flac/{key}.flac")) 275 | if self.cut == 0: 276 | X_pad = X 277 | else: 278 | X_pad = pad(X, self.cut) 279 | X_pad = norm(X_pad) 280 | x_inp = Tensor(X_pad) 281 | return x_inp, key 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /utils/loadData/asvspoof_data_DA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import soundfile as sf 3 | import torch,os 4 | from torch import Tensor 5 | from torch.utils.data import Dataset,DataLoader,DistributedSampler 6 | from .RawBoost import process_Rawboost_feature 7 | import lightning as L 8 | 9 | class asvspoof_dataModule(L.LightningDataModule): 10 | def __init__(self,args): 11 | super().__init__() 12 | self.args = args 13 | 14 | # TODO: change the dir to your own data dir 15 | # label file 16 | self.protocols_path = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/" 17 | self.train_protocols_file = self.protocols_path + "ASVspoof2019.LA.cm.train.trl.txt" 18 | self.dev_protocols_file = self.protocols_path + "ASVspoof2019.LA.cm.dev.trl.txt" 19 | # flac file dir 20 | self.dataset_base_path="/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/" 21 | self.train_set=self.dataset_base_path+"ASVspoof2019_LA_train/" 22 | self.dev_set=self.dataset_base_path+"ASVspoof2019_LA_dev/" 23 | # test set 24 | self.eval_protocols_file_19 = self.protocols_path + "ASVspoof2019.LA.cm.eval.trl.txt" 25 | self.eval_set_19 = self.dataset_base_path+"ASVspoof2019_LA_eval/" 26 | self.eval_protocols_file_21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/ASVspoof2021_LA_eval/eval_file/ASVspoof2021.LA.cm.eval.trl.txt" 27 | self.eval_set_21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/ASVspoof2021_LA_eval/" 28 | 29 | 30 | self.LA21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/eval_file/ASVspoof2021.LA.cm.eval.trl.txt" 31 | self.LA21FLAC = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/" 32 | self.LA21TRIAL = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_LA_eval/eval_file/CM_trial_metadata.txt" 33 | 34 | self.DF21 = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/ASVspoof2021.DF.cm.eval.trl.txt" 35 | self.DF21FLAC = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/" 36 | self.DF21TRIAL = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/ASVspoof2021_DF_eval/trial_metadata.txt" 37 | 38 | self.ITWTXT = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/release_in_the_wild/label.txt" 39 | self.ITWDIR = "/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/reference/fad/aasist/datasets/release_in_the_wild/wav" 40 | 41 | 42 | 43 | self.truncate = args.truncate 44 | self.predict = args.testset # LA21, DF21, ITW 45 | 46 | def setup(self, stage: str): 47 | # Assign train/val datasets for use in dataloaders 48 | if stage == "fit": 49 | d_label_trn,file_train = genSpoof_list( 50 | dir_meta=self.train_protocols_file, 51 | is_train=True, 52 | is_eval=False 53 | ) 54 | 55 | self.asvspoof19_trn_set = Dataset_ASVspoof2019_train( 56 | list_IDs=file_train, 57 | labels=d_label_trn, 58 | base_dir=self.train_set, 59 | cut=self.truncate, 60 | args= self.args 61 | ) 62 | 63 | _, file_dev = genSpoof_list( 64 | dir_meta=self.dev_protocols_file, 65 | is_train=False, 66 | is_eval=False) 67 | 68 | self.asvspoof19_val_set = Dataset_ASVspoof2019_devNeval( 69 | list_IDs=file_dev, 70 | base_dir=self.dev_set, 71 | args= self.args, 72 | cut=self.truncate 73 | ) 74 | 75 | # Assign test dataset for use in dataloader(s) 76 | if stage == "test": 77 | file_eval = genSpoof_list( 78 | dir_meta=self.eval_protocols_file_19, 79 | is_train=False, 80 | is_eval=True 81 | ) 82 | self.asvspoof19_test_set = Dataset_ASVspoof2019_evaltest( 83 | list_IDs=file_eval, 84 | base_dir=self.eval_set_19, 85 | cut=self.truncate 86 | ) 87 | 88 | if stage == "predict": 89 | if self.predict == "LA21": 90 | file_list=[] 91 | with open(self.LA21, 'r') as f: 92 | l_meta = f.readlines() 93 | for line in l_meta: 94 | key= line.strip() 95 | file_list.append(key) 96 | print(f"no.{(len(file_list))} of eval trials") 97 | self.predict_set = Dataset_ASVspoof2019_evaltest( 98 | list_IDs=file_list, 99 | base_dir=self.LA21FLAC, 100 | cut=self.truncate) 101 | 102 | elif self.predict == "DF21": 103 | file_list=[] 104 | with open(self.DF21, 'r') as f: 105 | l_meta = f.readlines() 106 | for line in l_meta: 107 | key= line.strip() 108 | file_list.append(key) 109 | print(f"no.{(len(file_list))} of eval trials") 110 | self.predict_set = Dataset_ASVspoof2019_evaltest( 111 | list_IDs=file_list, 112 | base_dir=self.DF21FLAC, 113 | cut=self.truncate) 114 | 115 | elif self.predict == "ITW": 116 | file_list=[] 117 | # 打开文件 118 | with open(self.ITWTXT, 'r') as file: 119 | lines = file.readlines() 120 | for line in lines: 121 | columns = line.split() 122 | file_list.append(columns[1]) 123 | self.predict_set = dataset_itw( 124 | list_IDs=file_list, 125 | base_dir=self.ITWDIR, 126 | cut=self.truncate) 127 | 128 | 129 | 130 | 131 | 132 | def train_dataloader(self): 133 | return DataLoader(self.asvspoof19_trn_set, batch_size=self.args.batch_size, shuffle=True,drop_last = True,num_workers=4) 134 | 135 | def val_dataloader(self): 136 | return DataLoader(self.asvspoof19_val_set, batch_size=self.args.batch_size, shuffle=False,drop_last = False,num_workers=4) 137 | 138 | def test_dataloader(self): 139 | datald = DataLoader( 140 | self.asvspoof19_test_set,batch_size=self.args.batch_size, 141 | shuffle=False,num_workers=4 142 | ) 143 | if "," in self.args.gpuid: 144 | datald = DataLoader( 145 | self.asvspoof19_test_set,batch_size=self.args.batch_size, 146 | shuffle=False,num_workers=4, 147 | sampler=DistributedSampler(self.asvspoof19_test_set) 148 | ) 149 | return datald 150 | 151 | def predict_dataloader(self): 152 | predict_loader = DataLoader( 153 | self.predict_set, 154 | batch_size= self.args.batch_size, 155 | shuffle=False, 156 | drop_last=False, 157 | pin_memory=True, 158 | num_workers=4) 159 | if "," in self.args.gpuid: 160 | predict_loader = DataLoader( 161 | self.predict_set, 162 | batch_size= self.args.batch_size, 163 | shuffle=False, 164 | drop_last=False, 165 | pin_memory=True, 166 | sampler=DistributedSampler(self.predict_set), 167 | num_workers=4 168 | ) 169 | return predict_loader 170 | 171 | 172 | 173 | 174 | 175 | class dataset_itw(Dataset): 176 | def __init__(self, list_IDs, base_dir,cut = 64600): 177 | self.list_IDs = list_IDs 178 | self.base_dir = base_dir 179 | self.cut = cut # take ~4 sec audio (64600 samples) 180 | 181 | def __len__(self): 182 | return len(self.list_IDs) 183 | 184 | def __getitem__(self, index): 185 | key = self.list_IDs[index] 186 | X, _ = sf.read(os.path.join(self.base_dir,f"{key}.wav")) 187 | if self.cut == 0: 188 | X_pad = X 189 | else: 190 | X_pad = pad(X, self.cut) 191 | x_inp = Tensor(X_pad) 192 | return x_inp, key 193 | 194 | 195 | 196 | def genSpoof_list(dir_meta, is_train=False, is_eval=False): 197 | 198 | d_meta = {} 199 | file_list = [] 200 | with open(dir_meta, "r") as f: 201 | l_meta = f.readlines() 202 | 203 | if is_train: 204 | for line in l_meta: 205 | _, key, _, _, label = line.strip().split(" ") 206 | file_list.append(key) 207 | d_meta[key] = 1 if label == "bonafide" else 0 208 | return d_meta, file_list 209 | 210 | elif is_eval: 211 | for line in l_meta: 212 | _, key, _, _, _ = line.strip().split(" ") 213 | #key = line.strip() 214 | file_list.append(key) 215 | return file_list 216 | else: 217 | for line in l_meta: 218 | _, key, _, _, label = line.strip().split(" ") 219 | file_list.append(key) 220 | d_meta[key] = 1 if label == "bonafide" else 0 221 | return d_meta, file_list 222 | 223 | 224 | def pad(x, max_len=64600): 225 | x_len = x.shape[0] 226 | if x_len >= max_len: 227 | return x[:max_len] 228 | # need to pad 229 | num_repeats = int(max_len / x_len) + 1 230 | padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] 231 | return padded_x 232 | 233 | 234 | def pad_random(x: np.ndarray, max_len: int = 64600): 235 | x_len = x.shape[0] 236 | # if duration is already long enough 237 | if x_len >= max_len: 238 | stt = np.random.randint(x_len - max_len) 239 | return x[stt:stt + max_len] 240 | 241 | # if too short 242 | num_repeats = int(max_len / x_len) + 1 243 | padded_x = np.tile(x, (num_repeats))[:max_len] 244 | return padded_x 245 | 246 | 247 | class Dataset_ASVspoof2019_train(Dataset): 248 | def __init__(self, list_IDs, labels, base_dir,args,cut = 64600): 249 | """self.list_IDs : list of strings (each string: utt key), 250 | self.labels : dictionary (key: utt key, value: label integer)""" 251 | self.list_IDs = list_IDs 252 | self.labels = labels 253 | self.base_dir = base_dir 254 | self.args = args 255 | self.cut = cut # take ~4 sec audio (64600 samples) 256 | 257 | def __len__(self): 258 | return len(self.list_IDs) 259 | 260 | def __getitem__(self, index): 261 | key = self.list_IDs[index] 262 | X, fs = sf.read(os.path.join(self.base_dir , f"flac/{key}.flac")) 263 | if self.args.usingDA and (np.random.rand() < self.args.da_prob): 264 | X=process_Rawboost_feature(X,fs,self.args,self.args.algo) 265 | if self.cut == 0: 266 | X_pad = X 267 | else: 268 | X_pad = pad_random(X, self.cut) 269 | x_inp = Tensor(X_pad) 270 | y = self.labels[key] 271 | # 1. tensor 2.label 3.filename 272 | return x_inp, y, key 273 | 274 | 275 | class Dataset_ASVspoof2019_devNeval(Dataset): 276 | def __init__(self, list_IDs, base_dir,args=None,cut = 64600): 277 | """self.list_IDs : list of strings (each string: utt key), 278 | """ 279 | self.list_IDs = list_IDs 280 | self.base_dir = base_dir 281 | self.cut = cut # take ~4 sec audio (64600 samples) 282 | self.args = args 283 | 284 | def __len__(self): 285 | return len(self.list_IDs) 286 | 287 | def __getitem__(self, index): 288 | key = self.list_IDs[index] 289 | X, fs = sf.read(os.path.join(self.base_dir,f"flac/{key}.flac")) 290 | if self.args.usingDA and ("ASVspoof2019_LA_dev" in self.base_dir): 291 | X=process_Rawboost_feature(X,fs,self.args,self.args.algo) 292 | if self.cut == 0: 293 | X_pad = X 294 | else: 295 | X_pad = pad(X, self.cut) 296 | x_inp = Tensor(X_pad) 297 | # 1.tensor 2.filename 298 | return x_inp, key 299 | 300 | 301 | class Dataset_ASVspoof2019_evaltest(Dataset): 302 | def __init__(self, list_IDs, base_dir,args=None,cut = 64600): 303 | """self.list_IDs : list of strings (each string: utt key), 304 | """ 305 | self.list_IDs = list_IDs 306 | self.base_dir = base_dir 307 | self.cut = cut # take ~4 sec audio (64600 samples) 308 | self.args = args 309 | 310 | def __len__(self): 311 | return len(self.list_IDs) 312 | 313 | def __getitem__(self, index): 314 | key = self.list_IDs[index] 315 | X, fs = sf.read(os.path.join(self.base_dir,f"flac/{key}.flac")) 316 | if self.cut == 0: 317 | X_pad = X 318 | else: 319 | X_pad = pad(X, self.cut) 320 | x_inp = Tensor(X_pad) 321 | return x_inp, key 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /utils/tools/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import numpy as np 5 | 6 | 7 | def calculate_tDCF_EER(cm_scores_file, 8 | asv_score_file, 9 | output_file, 10 | printout=True): 11 | # Replace CM scores with your own scores or provide score file as the 12 | # first argument. 13 | # cm_scores_file = 'score_cm.txt' 14 | # Replace ASV scores with organizers' scores or provide score file as 15 | # the second argument. 16 | # asv_score_file = 'ASVspoof2019.LA.asv.eval.gi.trl.scores.txt' 17 | 18 | # Fix tandem detection cost function (t-DCF) parameters 19 | Pspoof = 0.05 20 | cost_model = { 21 | 'Pspoof': Pspoof, # Prior probability of a spoofing attack 22 | 'Ptar': (1 - Pspoof) * 0.99, # Prior probability of target speaker 23 | 'Pnon': (1 - Pspoof) * 0.01, # Prior probability of nontarget speaker 24 | 'Cmiss': 1, # Cost of ASV system falsely rejecting target speaker 25 | 'Cfa': 10, # Cost of ASV system falsely accepting nontarget speaker 26 | 'Cmiss_asv': 1, # Cost of ASV system falsely rejecting target speaker 27 | 'Cfa_asv': 28 | 10, # Cost of ASV system falsely accepting nontarget speaker 29 | 'Cmiss_cm': 1, # Cost of CM system falsely rejecting target speaker 30 | 'Cfa_cm': 10, # Cost of CM system falsely accepting spoof 31 | } 32 | 33 | # Load organizers' ASV scores 34 | asv_data = np.genfromtxt(asv_score_file, dtype=str) 35 | # asv_sources = asv_data[:, 0] 36 | asv_keys = asv_data[:, 1] 37 | asv_scores = asv_data[:, 2].astype(float) 38 | 39 | # Load CM scores 40 | cm_data = np.genfromtxt(cm_scores_file, dtype=str) 41 | # cm_utt_id = cm_data[:, 0] 42 | cm_sources = cm_data[:, 1] 43 | cm_keys = cm_data[:, 2] 44 | cm_scores = cm_data[:, 4].astype(float) 45 | 46 | # Extract target, nontarget, and spoof scores from the ASV scores 47 | tar_asv = asv_scores[asv_keys == 'target'] 48 | non_asv = asv_scores[asv_keys == 'nontarget'] 49 | spoof_asv = asv_scores[asv_keys == 'spoof'] 50 | 51 | # Extract bona fide (real human) and spoof scores from the CM scores 52 | bona_cm = cm_scores[cm_keys == 'bonafide'] 53 | spoof_cm = cm_scores[cm_keys == 'spoof'] 54 | 55 | # EERs of the standalone systems and fix ASV operating point to 56 | # EER threshold 57 | eer_asv, asv_threshold = compute_eer(tar_asv, non_asv) 58 | eer_cm = compute_eer(bona_cm, spoof_cm)[0] 59 | 60 | attack_types = [f'A{_id:02d}' for _id in range(7, 20)] 61 | if printout: 62 | spoof_cm_breakdown = { 63 | attack_type: cm_scores[cm_sources == attack_type] 64 | for attack_type in attack_types 65 | } 66 | 67 | eer_cm_breakdown = { 68 | attack_type: compute_eer(bona_cm, 69 | spoof_cm_breakdown[attack_type])[0] 70 | for attack_type in attack_types 71 | } 72 | 73 | [Pfa_asv, Pmiss_asv, 74 | Pmiss_spoof_asv] = obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, 75 | asv_threshold) 76 | 77 | # Compute t-DCF 78 | tDCF_curve, CM_thresholds = compute_tDCF(bona_cm, 79 | spoof_cm, 80 | Pfa_asv, 81 | Pmiss_asv, 82 | Pmiss_spoof_asv, 83 | cost_model, 84 | print_cost=False) 85 | 86 | # Minimum t-DCF 87 | min_tDCF_index = np.argmin(tDCF_curve) 88 | min_tDCF = tDCF_curve[min_tDCF_index] 89 | 90 | if printout: 91 | with open(output_file, "w") as f_res: 92 | f_res.write('\nCM SYSTEM\n') 93 | f_res.write('\tEER\t\t= {:8.9f} % ' 94 | '(Equal error rate for countermeasure)\n'.format( 95 | eer_cm * 100)) 96 | 97 | f_res.write('\nTANDEM\n') 98 | f_res.write('\tmin-tDCF\t\t= {:8.9f}\n'.format(min_tDCF)) 99 | 100 | f_res.write('\nBREAKDOWN CM SYSTEM\n') 101 | for attack_type in attack_types: 102 | _eer = eer_cm_breakdown[attack_type] * 100 103 | f_res.write( 104 | f'\tEER {attack_type}\t\t= {_eer:8.9f} % (Equal error rate for {attack_type}\n' 105 | ) 106 | os.system(f"cat {output_file}") 107 | 108 | return eer_cm * 100, min_tDCF 109 | 110 | 111 | def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold): 112 | 113 | # False alarm and miss rates for ASV 114 | Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size 115 | Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size 116 | 117 | # Rate of rejecting spoofs in ASV 118 | if spoof_asv.size == 0: 119 | Pmiss_spoof_asv = None 120 | else: 121 | Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size 122 | 123 | return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv 124 | 125 | 126 | def compute_det_curve(target_scores, nontarget_scores): 127 | 128 | n_scores = target_scores.size + nontarget_scores.size 129 | all_scores = np.concatenate((target_scores, nontarget_scores)) 130 | labels = np.concatenate( 131 | (np.ones(target_scores.size), np.zeros(nontarget_scores.size))) 132 | 133 | # Sort labels based on scores 134 | indices = np.argsort(all_scores, kind='mergesort') 135 | labels = labels[indices] 136 | 137 | # Compute false rejection and false acceptance rates 138 | tar_trial_sums = np.cumsum(labels) 139 | nontarget_trial_sums = nontarget_scores.size - \ 140 | (np.arange(1, n_scores + 1) - tar_trial_sums) 141 | 142 | # false rejection rates 143 | frr = np.concatenate( 144 | (np.atleast_1d(0), tar_trial_sums / target_scores.size)) 145 | far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / 146 | nontarget_scores.size)) # false acceptance rates 147 | # Thresholds are the sorted scores 148 | thresholds = np.concatenate( 149 | (np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) 150 | 151 | return frr, far, thresholds 152 | 153 | 154 | def compute_eer(target_scores, nontarget_scores): 155 | """ Returns equal error rate (EER) and the corresponding threshold. """ 156 | frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) 157 | abs_diffs = np.abs(frr - far) 158 | min_index = np.argmin(abs_diffs) 159 | eer = np.mean((frr[min_index], far[min_index])) 160 | return eer, thresholds[min_index] 161 | 162 | 163 | def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, 164 | Pmiss_spoof_asv, cost_model, print_cost): 165 | """ 166 | Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. 167 | In brief, t-DCF returns a detection cost of a cascaded system of this form, 168 | 169 | Speech waveform -> [CM] -> [ASV] -> decision 170 | 171 | where CM stands for countermeasure and ASV for automatic speaker 172 | verification. The CM is therefore used as a 'gate' to decided whether or 173 | not the input speech sample should be passed onwards to the ASV system. 174 | Generally, both CM and ASV can do detection errors. Not all those errors 175 | are necessarily equally cost, and not all types of users are necessarily 176 | equally likely. The tandem t-DCF gives a principled with to compare 177 | different spoofing countermeasures under a detection cost function 178 | framework that takes that information into account. 179 | 180 | INPUTS: 181 | 182 | bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human) 183 | detection scores obtained by executing a spoofing 184 | countermeasure (CM) on some positive evaluation trials. 185 | trial represents a bona fide case. 186 | spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack) 187 | detection scores obtained by executing a spoofing 188 | CM on some negative evaluation trials. 189 | Pfa_asv False alarm (false acceptance) rate of the ASV 190 | system that is evaluated in tandem with the CM. 191 | Assumed to be in fractions, not percentages. 192 | Pmiss_asv Miss (false rejection) rate of the ASV system that 193 | is evaluated in tandem with the spoofing CM. 194 | Assumed to be in fractions, not percentages. 195 | Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that 196 | is evaluated in tandem with the spoofing CM. That 197 | is, the fraction of spoof samples that were 198 | rejected by the ASV system. 199 | cost_model A struct that contains the parameters of t-DCF, 200 | with the following fields. 201 | 202 | Ptar Prior probability of target speaker. 203 | Pnon Prior probability of nontarget speaker (zero-effort impostor) 204 | Psoof Prior probability of spoofing attack. 205 | Cmiss_asv Cost of ASV falsely rejecting target. 206 | Cfa_asv Cost of ASV falsely accepting nontarget. 207 | Cmiss_cm Cost of CM falsely rejecting target. 208 | Cfa_cm Cost of CM falsely accepting spoof. 209 | 210 | print_cost Print a summary of the cost parameters and the 211 | implied t-DCF cost function? 212 | 213 | OUTPUTS: 214 | 215 | tDCF_norm Normalized t-DCF curve across the different CM 216 | system operating points; see [2] for more details. 217 | Normalized t-DCF > 1 indicates a useless 218 | countermeasure (as the tandem system would do 219 | better without it). min(tDCF_norm) will be the 220 | minimum t-DCF used in ASVspoof 2019 [2]. 221 | CM_thresholds Vector of same size as tDCF_norm corresponding to 222 | the CM threshold (operating point). 223 | 224 | NOTE: 225 | o In relative terms, higher detection scores values are assumed to 226 | indicate stronger support for the bona fide hypothesis. 227 | o You should provide real-valued soft scores, NOT hard decisions. The 228 | recommendation is that the scores are log-likelihood ratios (LLRs) 229 | from a bonafide-vs-spoof hypothesis based on some statistical model. 230 | This, however, is NOT required. The scores can have arbitrary range 231 | and scaling. 232 | o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. 233 | 234 | References: 235 | 236 | [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco, 237 | M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection 238 | Cost Function for the Tandem Assessment of Spoofing Countermeasures 239 | and Automatic Speaker Verification", Proc. Odyssey 2018: the 240 | Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne, 241 | France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf) 242 | 243 | [2] ASVspoof 2019 challenge evaluation plan 244 | TODO: 245 | """ 246 | 247 | # Sanity check of cost parameters 248 | if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \ 249 | cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0: 250 | print('WARNING: Usually the cost values should be positive!') 251 | 252 | if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ 253 | np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: 254 | sys.exit( 255 | 'ERROR: Your prior probabilities should be positive and sum up to one.' 256 | ) 257 | 258 | # Unless we evaluate worst-case model, we need to have some spoof tests against asv 259 | if Pmiss_spoof_asv is None: 260 | sys.exit( 261 | 'ERROR: you should provide miss rate of spoof tests against your ASV system.' 262 | ) 263 | 264 | # Sanity check of scores 265 | combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) 266 | if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): 267 | sys.exit('ERROR: Your scores contain nan or inf.') 268 | 269 | # Sanity check that inputs are scores and not decisions 270 | n_uniq = np.unique(combined_scores).size 271 | if n_uniq < 3: 272 | sys.exit( 273 | 'ERROR: You should provide soft CM scores - not binary decisions') 274 | 275 | # Obtain miss and false alarm rates of CM 276 | Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve( 277 | bonafide_score_cm, spoof_score_cm) 278 | 279 | # Constants - see ASVspoof 2019 evaluation plan 280 | C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \ 281 | cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv 282 | C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv) 283 | 284 | # Sanity check of the weights 285 | if C1 < 0 or C2 < 0: 286 | sys.exit( 287 | 'You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?' 288 | ) 289 | 290 | # Obtain t-DCF curve for all thresholds 291 | tDCF = C1 * Pmiss_cm + C2 * Pfa_cm 292 | 293 | # Normalized t-DCF 294 | tDCF_norm = tDCF / np.minimum(C1, C2) 295 | 296 | # Everything should be fine if reaching here. 297 | if print_cost: 298 | 299 | print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format( 300 | bonafide_score_cm.size, spoof_score_cm.size)) 301 | print('t-DCF MODEL') 302 | print(' Ptar = {:8.5f} (Prior probability of target user)'. 303 | format(cost_model['Ptar'])) 304 | print( 305 | ' Pnon = {:8.5f} (Prior probability of nontarget user)'. 306 | format(cost_model['Pnon'])) 307 | print( 308 | ' Pspoof = {:8.5f} (Prior probability of spoofing attack)'. 309 | format(cost_model['Pspoof'])) 310 | print( 311 | ' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)' 312 | .format(cost_model['Cfa_asv'])) 313 | print( 314 | ' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)' 315 | .format(cost_model['Cmiss_asv'])) 316 | print( 317 | ' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)' 318 | .format(cost_model['Cfa_cm'])) 319 | print( 320 | ' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)' 321 | .format(cost_model['Cmiss_cm'])) 322 | print( 323 | '\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)' 324 | ) 325 | 326 | if C2 == np.minimum(C1, C2): 327 | print( 328 | ' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format( 329 | C1 / C2)) 330 | else: 331 | print( 332 | ' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format( 333 | C2 / C1)) 334 | 335 | return tDCF_norm, CM_thresholds 336 | -------------------------------------------------------------------------------- /models/wav2vec/aasist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Function 6 | import os 7 | from torch import Tensor 8 | import numpy as np 9 | from torch.utils import data 10 | from collections import OrderedDict 11 | from torch.nn.parameter import Parameter 12 | from pytorch_model_summary import summary 13 | import math 14 | from typing import Union 15 | 16 | 17 | class GraphAttentionLayer(nn.Module): 18 | def __init__(self, in_dim, out_dim, **kwargs): 19 | super().__init__() 20 | 21 | # attention map 22 | self.att_proj = nn.Linear(in_dim, out_dim) 23 | self.att_weight = self._init_new_params(out_dim, 1) 24 | 25 | # project 26 | self.proj_with_att = nn.Linear(in_dim, out_dim) 27 | self.proj_without_att = nn.Linear(in_dim, out_dim) 28 | 29 | # batch norm 30 | self.bn = nn.BatchNorm1d(out_dim) 31 | 32 | # dropout for inputs 33 | self.input_drop = nn.Dropout(p=0.2) 34 | 35 | # activate 36 | self.act = nn.SELU(inplace=True) 37 | 38 | # temperature 39 | self.temp = 1. 40 | if "temperature" in kwargs: 41 | self.temp = kwargs["temperature"] 42 | 43 | def forward(self, x): 44 | ''' 45 | x :(#bs, #node, #dim) 46 | ''' 47 | # apply input dropout 48 | x = self.input_drop(x) 49 | 50 | # derive attention map 51 | att_map = self._derive_att_map(x) 52 | 53 | # projection 54 | x = self._project(x, att_map) 55 | 56 | # apply batch norm 57 | x = self._apply_BN(x) 58 | x = self.act(x) 59 | return x 60 | 61 | def _pairwise_mul_nodes(self, x): 62 | ''' 63 | Calculates pairwise multiplication of nodes. 64 | - for attention map 65 | x :(#bs, #node, #dim) 66 | out_shape :(#bs, #node, #node, #dim) 67 | ''' 68 | 69 | nb_nodes = x.size(1) 70 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 71 | x_mirror = x.transpose(1, 2) 72 | 73 | return x * x_mirror 74 | 75 | def _derive_att_map(self, x): 76 | ''' 77 | x :(#bs, #node, #dim) 78 | out_shape :(#bs, #node, #node, 1) 79 | ''' 80 | att_map = self._pairwise_mul_nodes(x) 81 | # size: (#bs, #node, #node, #dim_out) 82 | att_map = torch.tanh(self.att_proj(att_map)) 83 | # size: (#bs, #node, #node, 1) 84 | att_map = torch.matmul(att_map, self.att_weight) 85 | 86 | # apply temperature 87 | att_map = att_map / self.temp 88 | 89 | att_map = F.softmax(att_map, dim=-2) 90 | 91 | return att_map 92 | 93 | def _project(self, x, att_map): 94 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 95 | x2 = self.proj_without_att(x) 96 | 97 | return x1 + x2 98 | 99 | def _apply_BN(self, x): 100 | org_size = x.size() 101 | x = x.view(-1, org_size[-1]) 102 | x = self.bn(x) 103 | x = x.view(org_size) 104 | 105 | return x 106 | 107 | def _init_new_params(self, *size): 108 | out = nn.Parameter(torch.FloatTensor(*size)) 109 | nn.init.xavier_normal_(out) 110 | return out 111 | 112 | 113 | class HtrgGraphAttentionLayer(nn.Module): 114 | def __init__(self, in_dim, out_dim, **kwargs): 115 | super().__init__() 116 | 117 | self.proj_type1 = nn.Linear(in_dim, in_dim) 118 | self.proj_type2 = nn.Linear(in_dim, in_dim) 119 | 120 | # attention map 121 | self.att_proj = nn.Linear(in_dim, out_dim) 122 | self.att_projM = nn.Linear(in_dim, out_dim) 123 | 124 | self.att_weight11 = self._init_new_params(out_dim, 1) 125 | self.att_weight22 = self._init_new_params(out_dim, 1) 126 | self.att_weight12 = self._init_new_params(out_dim, 1) 127 | self.att_weightM = self._init_new_params(out_dim, 1) 128 | 129 | # project 130 | self.proj_with_att = nn.Linear(in_dim, out_dim) 131 | self.proj_without_att = nn.Linear(in_dim, out_dim) 132 | 133 | self.proj_with_attM = nn.Linear(in_dim, out_dim) 134 | self.proj_without_attM = nn.Linear(in_dim, out_dim) 135 | 136 | # batch norm 137 | self.bn = nn.BatchNorm1d(out_dim) 138 | 139 | # dropout for inputs 140 | self.input_drop = nn.Dropout(p=0.2) 141 | 142 | # activate 143 | self.act = nn.SELU(inplace=True) 144 | 145 | # temperature 146 | self.temp = 1. 147 | if "temperature" in kwargs: 148 | self.temp = kwargs["temperature"] 149 | 150 | def forward(self, x1, x2, master=None): 151 | ''' 152 | x1 :(#bs, #node, #dim) 153 | x2 :(#bs, #node, #dim) 154 | ''' 155 | # print('x1',x1.shape) 156 | # print('x2',x2.shape) 157 | num_type1 = x1.size(1) 158 | num_type2 = x2.size(1) 159 | # print('num_type1',num_type1) 160 | # print('num_type2',num_type2) 161 | x1 = self.proj_type1(x1) 162 | # print('proj_type1',x1.shape) 163 | x2 = self.proj_type2(x2) 164 | # print('proj_type2',x2.shape) 165 | x = torch.cat([x1, x2], dim=1) 166 | # print('Concat x1 and x2',x.shape) 167 | 168 | if master is None: 169 | master = torch.mean(x, dim=1, keepdim=True) 170 | # print('master',master.shape) 171 | # apply input dropout 172 | x = self.input_drop(x) 173 | 174 | # derive attention map 175 | att_map = self._derive_att_map(x, num_type1, num_type2) 176 | # print('master',master.shape) 177 | # directional edge for master node 178 | master = self._update_master(x, master) 179 | # print('master',master.shape) 180 | # projection 181 | x = self._project(x, att_map) 182 | # print('proj x',x.shape) 183 | # apply batch norm 184 | x = self._apply_BN(x) 185 | x = self.act(x) 186 | 187 | x1 = x.narrow(1, 0, num_type1) 188 | # print('x1',x1.shape) 189 | x2 = x.narrow(1, num_type1, num_type2) 190 | # print('x2',x2.shape) 191 | return x1, x2, master 192 | 193 | def _update_master(self, x, master): 194 | 195 | att_map = self._derive_att_map_master(x, master) 196 | master = self._project_master(x, master, att_map) 197 | 198 | return master 199 | 200 | def _pairwise_mul_nodes(self, x): 201 | ''' 202 | Calculates pairwise multiplication of nodes. 203 | - for attention map 204 | x :(#bs, #node, #dim) 205 | out_shape :(#bs, #node, #node, #dim) 206 | ''' 207 | 208 | nb_nodes = x.size(1) 209 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 210 | x_mirror = x.transpose(1, 2) 211 | 212 | return x * x_mirror 213 | 214 | def _derive_att_map_master(self, x, master): 215 | ''' 216 | x :(#bs, #node, #dim) 217 | out_shape :(#bs, #node, #node, 1) 218 | ''' 219 | att_map = x * master 220 | att_map = torch.tanh(self.att_projM(att_map)) 221 | 222 | att_map = torch.matmul(att_map, self.att_weightM) 223 | 224 | # apply temperature 225 | att_map = att_map / self.temp 226 | 227 | att_map = F.softmax(att_map, dim=-2) 228 | 229 | return att_map 230 | 231 | def _derive_att_map(self, x, num_type1, num_type2): 232 | ''' 233 | x :(#bs, #node, #dim) 234 | out_shape :(#bs, #node, #node, 1) 235 | ''' 236 | att_map = self._pairwise_mul_nodes(x) 237 | # size: (#bs, #node, #node, #dim_out) 238 | att_map = torch.tanh(self.att_proj(att_map)) 239 | # size: (#bs, #node, #node, 1) 240 | 241 | att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) 242 | 243 | att_board[:, :num_type1, :num_type1, :] = torch.matmul( 244 | att_map[:, :num_type1, :num_type1, :], self.att_weight11) 245 | att_board[:, num_type1:, num_type1:, :] = torch.matmul( 246 | att_map[:, num_type1:, num_type1:, :], self.att_weight22) 247 | att_board[:, :num_type1, num_type1:, :] = torch.matmul( 248 | att_map[:, :num_type1, num_type1:, :], self.att_weight12) 249 | att_board[:, num_type1:, :num_type1, :] = torch.matmul( 250 | att_map[:, num_type1:, :num_type1, :], self.att_weight12) 251 | 252 | att_map = att_board 253 | 254 | # apply temperature 255 | att_map = att_map / self.temp 256 | 257 | att_map = F.softmax(att_map, dim=-2) 258 | 259 | return att_map 260 | 261 | def _project(self, x, att_map): 262 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 263 | x2 = self.proj_without_att(x) 264 | 265 | return x1 + x2 266 | 267 | def _project_master(self, x, master, att_map): 268 | 269 | x1 = self.proj_with_attM(torch.matmul( 270 | att_map.squeeze(-1).unsqueeze(1), x)) 271 | x2 = self.proj_without_attM(master) 272 | 273 | return x1 + x2 274 | 275 | def _apply_BN(self, x): 276 | org_size = x.size() 277 | x = x.view(-1, org_size[-1]) 278 | x = self.bn(x) 279 | x = x.view(org_size) 280 | 281 | return x 282 | 283 | def _init_new_params(self, *size): 284 | out = nn.Parameter(torch.FloatTensor(*size)) 285 | nn.init.xavier_normal_(out) 286 | return out 287 | 288 | 289 | class GraphPool(nn.Module): 290 | def __init__(self, k: float, in_dim: int, p: Union[float, int]): 291 | super().__init__() 292 | self.k = k 293 | self.sigmoid = nn.Sigmoid() 294 | self.proj = nn.Linear(in_dim, 1) 295 | self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() 296 | self.in_dim = in_dim 297 | 298 | def forward(self, h): 299 | Z = self.drop(h) 300 | weights = self.proj(Z) 301 | scores = self.sigmoid(weights) 302 | new_h = self.top_k_graph(scores, h, self.k) 303 | 304 | return new_h 305 | 306 | def top_k_graph(self, scores, h, k): 307 | """ 308 | args 309 | ===== 310 | scores: attention-based weights (#bs, #node, 1) 311 | h: graph data (#bs, #node, #dim) 312 | k: ratio of remaining nodes, (float) 313 | returns 314 | ===== 315 | h: graph pool applied data (#bs, #node', #dim) 316 | """ 317 | _, n_nodes, n_feat = h.size() 318 | n_nodes = max(int(n_nodes * k), 1) 319 | _, idx = torch.topk(scores, n_nodes, dim=1) 320 | idx = idx.expand(-1, -1, n_feat) 321 | 322 | h = h * scores 323 | h = torch.gather(h, 1, idx) 324 | 325 | return h 326 | 327 | 328 | class Residual_block(nn.Module): 329 | def __init__(self, nb_filts, first=False): 330 | super().__init__() 331 | self.first = first 332 | 333 | if not self.first: 334 | self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) 335 | self.conv1 = nn.Conv2d(in_channels=nb_filts[0], 336 | out_channels=nb_filts[1], 337 | kernel_size=(2, 3), 338 | padding=(1, 1), 339 | stride=1) 340 | self.selu = nn.SELU(inplace=True) 341 | 342 | self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) 343 | self.conv2 = nn.Conv2d(in_channels=nb_filts[1], 344 | out_channels=nb_filts[1], 345 | kernel_size=(2, 3), 346 | padding=(0, 1), 347 | stride=1) 348 | 349 | if nb_filts[0] != nb_filts[1]: 350 | self.downsample = True 351 | self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], 352 | out_channels=nb_filts[1], 353 | padding=(0, 1), 354 | kernel_size=(1, 3), 355 | stride=1) 356 | 357 | else: 358 | self.downsample = False 359 | 360 | def forward(self, x): 361 | identity = x 362 | if not self.first: 363 | out = self.bn1(x) 364 | out = self.selu(out) 365 | else: 366 | out = x 367 | 368 | # print('out',out.shape) 369 | out = self.conv1(x) 370 | 371 | # print('aft conv1 out',out.shape) 372 | out = self.bn2(out) 373 | out = self.selu(out) 374 | # print('out',out.shape) 375 | out = self.conv2(out) 376 | # print('conv2 out',out.shape) 377 | 378 | if self.downsample: 379 | identity = self.conv_downsample(identity) 380 | 381 | out += identity 382 | # out = self.mp(out) 383 | return out 384 | 385 | 386 | class W2VAASIST(nn.Module): 387 | def __init__(self): 388 | super().__init__() 389 | 390 | # AASIST parameters 391 | filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] 392 | gat_dims = [64, 32] 393 | pool_ratios = [0.5, 0.5, 0.5, 0.5] 394 | temperatures = [2.0, 2.0, 100.0, 100.0] 395 | 396 | #### 397 | # create network wav2vec 2.0 398 | #### 399 | 400 | self.first_bn = nn.BatchNorm2d(num_features=1) 401 | self.first_bn1 = nn.BatchNorm2d(num_features=64) 402 | self.drop = nn.Dropout(0.5, inplace=True) 403 | self.drop_way = nn.Dropout(0.2, inplace=True) 404 | self.selu = nn.SELU(inplace=True) 405 | 406 | # RawNet2 encoder 407 | self.encoder = nn.Sequential( 408 | nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), 409 | nn.Sequential(Residual_block(nb_filts=filts[2])), 410 | nn.Sequential(Residual_block(nb_filts=filts[3])), 411 | nn.Sequential(Residual_block(nb_filts=filts[4])), 412 | nn.Sequential(Residual_block(nb_filts=filts[4])), 413 | nn.Sequential(Residual_block(nb_filts=filts[4]))) 414 | 415 | 416 | self.LL = nn.Linear(1024, 128) 417 | 418 | self.attention = nn.Sequential( 419 | nn.Conv2d(64, 128, kernel_size=(1, 1)), 420 | nn.SELU(inplace=True), 421 | nn.BatchNorm2d(128), 422 | nn.Conv2d(128, 64, kernel_size=(1, 1)), 423 | 424 | ) 425 | # position encoding 426 | self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) 427 | 428 | self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 429 | self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 430 | 431 | # Graph module 432 | self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], 433 | gat_dims[0], 434 | temperature=temperatures[0]) 435 | self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], 436 | gat_dims[0], 437 | temperature=temperatures[1]) 438 | # HS-GAL layer 439 | self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( 440 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 441 | self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( 442 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 443 | self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( 444 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 445 | self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( 446 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 447 | 448 | # Graph pooling layers 449 | self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) 450 | self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) 451 | self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 452 | self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 453 | 454 | self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 455 | self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 456 | 457 | self.out_layer = nn.Linear(5 * gat_dims[1], 2) 458 | 459 | def forward(self, x): 460 | # -------pre-trained Wav2vec model fine tunning ------------------------## 461 | x = self.LL(x) 462 | x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number) 463 | x = x.unsqueeze(dim=1) # add channel 464 | # print(x.shape) 465 | x = F.max_pool2d(x, (3, 3)) 466 | x = self.first_bn(x) 467 | x = self.selu(x) 468 | 469 | # RawNet2-based encoder 470 | x = self.encoder(x) 471 | x = self.first_bn1(x) 472 | x = self.selu(x) 473 | 474 | # print(x.shape) 475 | w = self.attention(x) 476 | # print(w.shape) 477 | 478 | # ------------SA for spectral feature-------------# 479 | w1 = F.softmax(w, dim=-1) 480 | m = torch.sum(x * w1, dim=-1) 481 | e_S = m.transpose(1, 2) + self.pos_S 482 | 483 | # graph module layer 484 | gat_S = self.GAT_layer_S(e_S) 485 | out_S = self.pool_S(gat_S) # (#bs, #node, #dim) 486 | 487 | # ------------SA for temporal feature-------------# 488 | w2 = F.softmax(w, dim=-2) 489 | m1 = torch.sum(x * w2, dim=-2) 490 | 491 | e_T = m1.transpose(1, 2) 492 | 493 | # graph module layer 494 | gat_T = self.GAT_layer_T(e_T) 495 | out_T = self.pool_T(gat_T) 496 | 497 | # learnable master node 498 | master1 = self.master1.expand(x.size(0), -1, -1) 499 | master2 = self.master2.expand(x.size(0), -1, -1) 500 | 501 | # inference 1 502 | out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( 503 | out_T, out_S, master=self.master1) 504 | 505 | out_S1 = self.pool_hS1(out_S1) 506 | out_T1 = self.pool_hT1(out_T1) 507 | 508 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( 509 | out_T1, out_S1, master=master1) 510 | out_T1 = out_T1 + out_T_aug 511 | out_S1 = out_S1 + out_S_aug 512 | master1 = master1 + master_aug 513 | 514 | # inference 2 515 | out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( 516 | out_T, out_S, master=self.master2) 517 | out_S2 = self.pool_hS2(out_S2) 518 | out_T2 = self.pool_hT2(out_T2) 519 | 520 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( 521 | out_T2, out_S2, master=master2) 522 | out_T2 = out_T2 + out_T_aug 523 | out_S2 = out_S2 + out_S_aug 524 | master2 = master2 + master_aug 525 | 526 | out_T1 = self.drop_way(out_T1) 527 | out_T2 = self.drop_way(out_T2) 528 | out_S1 = self.drop_way(out_S1) 529 | out_S2 = self.drop_way(out_S2) 530 | master1 = self.drop_way(master1) 531 | master2 = self.drop_way(master2) 532 | 533 | out_T = torch.max(out_T1, out_T2) 534 | out_S = torch.max(out_S1, out_S2) 535 | master = torch.max(master1, master2) 536 | 537 | # Readout operation 538 | T_max, _ = torch.max(torch.abs(out_T), dim=1) 539 | T_avg = torch.mean(out_T, dim=1) 540 | 541 | S_max, _ = torch.max(torch.abs(out_S), dim=1) 542 | S_avg = torch.mean(out_S, dim=1) 543 | 544 | last_hidden = torch.cat( 545 | [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) 546 | 547 | last_hidden = self.drop(last_hidden) 548 | output = self.out_layer(last_hidden) 549 | 550 | return output, last_hidden 551 | 552 | 553 | if __name__ == "__main__": 554 | # os.environ["CUDA_VISIBLE_DEVICES"] = "5" 555 | print(summary(W2VAASIST(), torch.randn((16,201,1024)), show_input=False)) 556 | # model = W2VAASIST() 557 | # op,hid = model(torch.randn((16,201,1024))) 558 | # print(op.shape) 559 | # print(hid.shape) 560 | -------------------------------------------------------------------------------- /utils/ideas/MoEF/aasist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Function 6 | import os 7 | from torch import Tensor 8 | import numpy as np 9 | from torch.utils import data 10 | from collections import OrderedDict 11 | from torch.nn.parameter import Parameter 12 | from pytorch_model_summary import summary 13 | import math 14 | from typing import Union 15 | 16 | 17 | class GraphAttentionLayer(nn.Module): 18 | def __init__(self, in_dim, out_dim, **kwargs): 19 | super().__init__() 20 | 21 | # attention map 22 | self.att_proj = nn.Linear(in_dim, out_dim) 23 | self.att_weight = self._init_new_params(out_dim, 1) 24 | 25 | # project 26 | self.proj_with_att = nn.Linear(in_dim, out_dim) 27 | self.proj_without_att = nn.Linear(in_dim, out_dim) 28 | 29 | # batch norm 30 | self.bn = nn.BatchNorm1d(out_dim) 31 | 32 | # dropout for inputs 33 | self.input_drop = nn.Dropout(p=0.2) 34 | 35 | # activate 36 | self.act = nn.SELU(inplace=True) 37 | 38 | # temperature 39 | self.temp = 1. 40 | if "temperature" in kwargs: 41 | self.temp = kwargs["temperature"] 42 | 43 | def forward(self, x): 44 | ''' 45 | x :(#bs, #node, #dim) 46 | ''' 47 | # apply input dropout 48 | x = self.input_drop(x) 49 | 50 | # derive attention map 51 | att_map = self._derive_att_map(x) 52 | 53 | # projection 54 | x = self._project(x, att_map) 55 | 56 | # apply batch norm 57 | x = self._apply_BN(x) 58 | x = self.act(x) 59 | return x 60 | 61 | def _pairwise_mul_nodes(self, x): 62 | ''' 63 | Calculates pairwise multiplication of nodes. 64 | - for attention map 65 | x :(#bs, #node, #dim) 66 | out_shape :(#bs, #node, #node, #dim) 67 | ''' 68 | 69 | nb_nodes = x.size(1) 70 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 71 | x_mirror = x.transpose(1, 2) 72 | 73 | return x * x_mirror 74 | 75 | def _derive_att_map(self, x): 76 | ''' 77 | x :(#bs, #node, #dim) 78 | out_shape :(#bs, #node, #node, 1) 79 | ''' 80 | att_map = self._pairwise_mul_nodes(x) 81 | # size: (#bs, #node, #node, #dim_out) 82 | att_map = torch.tanh(self.att_proj(att_map)) 83 | # size: (#bs, #node, #node, 1) 84 | att_map = torch.matmul(att_map, self.att_weight) 85 | 86 | # apply temperature 87 | att_map = att_map / self.temp 88 | 89 | att_map = F.softmax(att_map, dim=-2) 90 | 91 | return att_map 92 | 93 | def _project(self, x, att_map): 94 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 95 | x2 = self.proj_without_att(x) 96 | 97 | return x1 + x2 98 | 99 | def _apply_BN(self, x): 100 | org_size = x.size() 101 | x = x.view(-1, org_size[-1]) 102 | x = self.bn(x) 103 | x = x.view(org_size) 104 | 105 | return x 106 | 107 | def _init_new_params(self, *size): 108 | out = nn.Parameter(torch.FloatTensor(*size)) 109 | nn.init.xavier_normal_(out) 110 | return out 111 | 112 | 113 | class HtrgGraphAttentionLayer(nn.Module): 114 | def __init__(self, in_dim, out_dim, **kwargs): 115 | super().__init__() 116 | 117 | self.proj_type1 = nn.Linear(in_dim, in_dim) 118 | self.proj_type2 = nn.Linear(in_dim, in_dim) 119 | 120 | # attention map 121 | self.att_proj = nn.Linear(in_dim, out_dim) 122 | self.att_projM = nn.Linear(in_dim, out_dim) 123 | 124 | self.att_weight11 = self._init_new_params(out_dim, 1) 125 | self.att_weight22 = self._init_new_params(out_dim, 1) 126 | self.att_weight12 = self._init_new_params(out_dim, 1) 127 | self.att_weightM = self._init_new_params(out_dim, 1) 128 | 129 | # project 130 | self.proj_with_att = nn.Linear(in_dim, out_dim) 131 | self.proj_without_att = nn.Linear(in_dim, out_dim) 132 | 133 | self.proj_with_attM = nn.Linear(in_dim, out_dim) 134 | self.proj_without_attM = nn.Linear(in_dim, out_dim) 135 | 136 | # batch norm 137 | self.bn = nn.BatchNorm1d(out_dim) 138 | 139 | # dropout for inputs 140 | self.input_drop = nn.Dropout(p=0.2) 141 | 142 | # activate 143 | self.act = nn.SELU(inplace=True) 144 | 145 | # temperature 146 | self.temp = 1. 147 | if "temperature" in kwargs: 148 | self.temp = kwargs["temperature"] 149 | 150 | def forward(self, x1, x2, master=None): 151 | ''' 152 | x1 :(#bs, #node, #dim) 153 | x2 :(#bs, #node, #dim) 154 | ''' 155 | # print('x1',x1.shape) 156 | # print('x2',x2.shape) 157 | num_type1 = x1.size(1) 158 | num_type2 = x2.size(1) 159 | # print('num_type1',num_type1) 160 | # print('num_type2',num_type2) 161 | x1 = self.proj_type1(x1) 162 | # print('proj_type1',x1.shape) 163 | x2 = self.proj_type2(x2) 164 | # print('proj_type2',x2.shape) 165 | x = torch.cat([x1, x2], dim=1) 166 | # print('Concat x1 and x2',x.shape) 167 | 168 | if master is None: 169 | master = torch.mean(x, dim=1, keepdim=True) 170 | # print('master',master.shape) 171 | # apply input dropout 172 | x = self.input_drop(x) 173 | 174 | # derive attention map 175 | att_map = self._derive_att_map(x, num_type1, num_type2) 176 | # print('master',master.shape) 177 | # directional edge for master node 178 | master = self._update_master(x, master) 179 | # print('master',master.shape) 180 | # projection 181 | x = self._project(x, att_map) 182 | # print('proj x',x.shape) 183 | # apply batch norm 184 | x = self._apply_BN(x) 185 | x = self.act(x) 186 | 187 | x1 = x.narrow(1, 0, num_type1) 188 | # print('x1',x1.shape) 189 | x2 = x.narrow(1, num_type1, num_type2) 190 | # print('x2',x2.shape) 191 | return x1, x2, master 192 | 193 | def _update_master(self, x, master): 194 | 195 | att_map = self._derive_att_map_master(x, master) 196 | master = self._project_master(x, master, att_map) 197 | 198 | return master 199 | 200 | def _pairwise_mul_nodes(self, x): 201 | ''' 202 | Calculates pairwise multiplication of nodes. 203 | - for attention map 204 | x :(#bs, #node, #dim) 205 | out_shape :(#bs, #node, #node, #dim) 206 | ''' 207 | 208 | nb_nodes = x.size(1) 209 | x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) 210 | x_mirror = x.transpose(1, 2) 211 | 212 | return x * x_mirror 213 | 214 | def _derive_att_map_master(self, x, master): 215 | ''' 216 | x :(#bs, #node, #dim) 217 | out_shape :(#bs, #node, #node, 1) 218 | ''' 219 | att_map = x * master 220 | att_map = torch.tanh(self.att_projM(att_map)) 221 | 222 | att_map = torch.matmul(att_map, self.att_weightM) 223 | 224 | # apply temperature 225 | att_map = att_map / self.temp 226 | 227 | att_map = F.softmax(att_map, dim=-2) 228 | 229 | return att_map 230 | 231 | def _derive_att_map(self, x, num_type1, num_type2): 232 | ''' 233 | x :(#bs, #node, #dim) 234 | out_shape :(#bs, #node, #node, 1) 235 | ''' 236 | att_map = self._pairwise_mul_nodes(x) 237 | # size: (#bs, #node, #node, #dim_out) 238 | att_map = torch.tanh(self.att_proj(att_map)) 239 | # size: (#bs, #node, #node, 1) 240 | 241 | att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) 242 | 243 | att_board[:, :num_type1, :num_type1, :] = torch.matmul( 244 | att_map[:, :num_type1, :num_type1, :], self.att_weight11) 245 | att_board[:, num_type1:, num_type1:, :] = torch.matmul( 246 | att_map[:, num_type1:, num_type1:, :], self.att_weight22) 247 | att_board[:, :num_type1, num_type1:, :] = torch.matmul( 248 | att_map[:, :num_type1, num_type1:, :], self.att_weight12) 249 | att_board[:, num_type1:, :num_type1, :] = torch.matmul( 250 | att_map[:, num_type1:, :num_type1, :], self.att_weight12) 251 | 252 | att_map = att_board 253 | 254 | # apply temperature 255 | att_map = att_map / self.temp 256 | 257 | att_map = F.softmax(att_map, dim=-2) 258 | 259 | return att_map 260 | 261 | def _project(self, x, att_map): 262 | x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) 263 | x2 = self.proj_without_att(x) 264 | 265 | return x1 + x2 266 | 267 | def _project_master(self, x, master, att_map): 268 | 269 | x1 = self.proj_with_attM(torch.matmul( 270 | att_map.squeeze(-1).unsqueeze(1), x)) 271 | x2 = self.proj_without_attM(master) 272 | 273 | return x1 + x2 274 | 275 | def _apply_BN(self, x): 276 | org_size = x.size() 277 | x = x.view(-1, org_size[-1]) 278 | x = self.bn(x) 279 | x = x.view(org_size) 280 | 281 | return x 282 | 283 | def _init_new_params(self, *size): 284 | out = nn.Parameter(torch.FloatTensor(*size)) 285 | nn.init.xavier_normal_(out) 286 | return out 287 | 288 | 289 | class GraphPool(nn.Module): 290 | def __init__(self, k: float, in_dim: int, p: Union[float, int]): 291 | super().__init__() 292 | self.k = k 293 | self.sigmoid = nn.Sigmoid() 294 | self.proj = nn.Linear(in_dim, 1) 295 | self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() 296 | self.in_dim = in_dim 297 | 298 | def forward(self, h): 299 | Z = self.drop(h) 300 | weights = self.proj(Z) 301 | scores = self.sigmoid(weights) 302 | new_h = self.top_k_graph(scores, h, self.k) 303 | 304 | return new_h 305 | 306 | def top_k_graph(self, scores, h, k): 307 | """ 308 | args 309 | ===== 310 | scores: attention-based weights (#bs, #node, 1) 311 | h: graph data (#bs, #node, #dim) 312 | k: ratio of remaining nodes, (float) 313 | returns 314 | ===== 315 | h: graph pool applied data (#bs, #node', #dim) 316 | """ 317 | _, n_nodes, n_feat = h.size() 318 | n_nodes = max(int(n_nodes * k), 1) 319 | _, idx = torch.topk(scores, n_nodes, dim=1) 320 | idx = idx.expand(-1, -1, n_feat) 321 | 322 | h = h * scores 323 | h = torch.gather(h, 1, idx) 324 | 325 | return h 326 | 327 | 328 | class Residual_block(nn.Module): 329 | def __init__(self, nb_filts, first=False): 330 | super().__init__() 331 | self.first = first 332 | 333 | if not self.first: 334 | self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) 335 | self.conv1 = nn.Conv2d(in_channels=nb_filts[0], 336 | out_channels=nb_filts[1], 337 | kernel_size=(2, 3), 338 | padding=(1, 1), 339 | stride=1) 340 | self.selu = nn.SELU(inplace=True) 341 | 342 | self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) 343 | self.conv2 = nn.Conv2d(in_channels=nb_filts[1], 344 | out_channels=nb_filts[1], 345 | kernel_size=(2, 3), 346 | padding=(0, 1), 347 | stride=1) 348 | 349 | if nb_filts[0] != nb_filts[1]: 350 | self.downsample = True 351 | self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], 352 | out_channels=nb_filts[1], 353 | padding=(0, 1), 354 | kernel_size=(1, 3), 355 | stride=1) 356 | 357 | else: 358 | self.downsample = False 359 | 360 | def forward(self, x): 361 | identity = x 362 | if not self.first: 363 | out = self.bn1(x) 364 | out = self.selu(out) 365 | else: 366 | out = x 367 | 368 | # print('out',out.shape) 369 | out = self.conv1(x) 370 | 371 | # print('aft conv1 out',out.shape) 372 | out = self.bn2(out) 373 | out = self.selu(out) 374 | # print('out',out.shape) 375 | out = self.conv2(out) 376 | # print('conv2 out',out.shape) 377 | 378 | if self.downsample: 379 | identity = self.conv_downsample(identity) 380 | 381 | out += identity 382 | # out = self.mp(out) 383 | return out 384 | 385 | 386 | class W2VAASIST(nn.Module): 387 | def __init__(self): 388 | super().__init__() 389 | 390 | # AASIST parameters 391 | filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] 392 | gat_dims = [64, 32] 393 | pool_ratios = [0.5, 0.5, 0.5, 0.5] 394 | temperatures = [2.0, 2.0, 100.0, 100.0] 395 | 396 | #### 397 | # create network wav2vec 2.0 398 | #### 399 | 400 | self.first_bn = nn.BatchNorm2d(num_features=1) 401 | self.first_bn1 = nn.BatchNorm2d(num_features=64) 402 | self.drop = nn.Dropout(0.5, inplace=True) 403 | self.drop_way = nn.Dropout(0.2, inplace=True) 404 | self.selu = nn.SELU(inplace=True) 405 | 406 | # RawNet2 encoder 407 | self.encoder = nn.Sequential( 408 | nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), 409 | nn.Sequential(Residual_block(nb_filts=filts[2])), 410 | nn.Sequential(Residual_block(nb_filts=filts[3])), 411 | nn.Sequential(Residual_block(nb_filts=filts[4])), 412 | nn.Sequential(Residual_block(nb_filts=filts[4])), 413 | nn.Sequential(Residual_block(nb_filts=filts[4]))) 414 | 415 | 416 | self.LL = nn.Linear(1024, 128) 417 | 418 | self.attention = nn.Sequential( 419 | nn.Conv2d(64, 128, kernel_size=(1, 1)), 420 | nn.SELU(inplace=True), 421 | nn.BatchNorm2d(128), 422 | nn.Conv2d(128, 64, kernel_size=(1, 1)), 423 | 424 | ) 425 | # position encoding 426 | self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) 427 | 428 | self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 429 | self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 430 | 431 | # Graph module 432 | self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], 433 | gat_dims[0], 434 | temperature=temperatures[0]) 435 | self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], 436 | gat_dims[0], 437 | temperature=temperatures[1]) 438 | # HS-GAL layer 439 | self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( 440 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 441 | self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( 442 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 443 | self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( 444 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 445 | self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( 446 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 447 | 448 | # Graph pooling layers 449 | self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) 450 | self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) 451 | self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 452 | self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 453 | 454 | self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 455 | self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 456 | 457 | self.out_layer = nn.Linear(5 * gat_dims[1], 2) 458 | 459 | def forward(self, x): 460 | # -------pre-trained Wav2vec model fine tunning ------------------------## 461 | x = self.LL(x) 462 | x = x.transpose(1, 2) # (bs,feat_out_dim,frame_number) 463 | x = x.unsqueeze(dim=1) # add channel 464 | # print(x.shape) 465 | x = F.max_pool2d(x, (3, 3)) 466 | x = self.first_bn(x) 467 | x = self.selu(x) 468 | 469 | # RawNet2-based encoder 470 | x = self.encoder(x) 471 | x = self.first_bn1(x) 472 | x = self.selu(x) 473 | 474 | # print(x.shape) 475 | w = self.attention(x) 476 | # print(w.shape) 477 | 478 | # ------------SA for spectral feature-------------# 479 | w1 = F.softmax(w, dim=-1) 480 | m = torch.sum(x * w1, dim=-1) 481 | e_S = m.transpose(1, 2) + self.pos_S 482 | 483 | # graph module layer 484 | gat_S = self.GAT_layer_S(e_S) 485 | out_S = self.pool_S(gat_S) # (#bs, #node, #dim) 486 | 487 | # ------------SA for temporal feature-------------# 488 | w2 = F.softmax(w, dim=-2) 489 | m1 = torch.sum(x * w2, dim=-2) 490 | 491 | e_T = m1.transpose(1, 2) 492 | 493 | # graph module layer 494 | gat_T = self.GAT_layer_T(e_T) 495 | out_T = self.pool_T(gat_T) 496 | 497 | # learnable master node 498 | master1 = self.master1.expand(x.size(0), -1, -1) 499 | master2 = self.master2.expand(x.size(0), -1, -1) 500 | 501 | # inference 1 502 | out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( 503 | out_T, out_S, master=self.master1) 504 | 505 | out_S1 = self.pool_hS1(out_S1) 506 | out_T1 = self.pool_hT1(out_T1) 507 | 508 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( 509 | out_T1, out_S1, master=master1) 510 | out_T1 = out_T1 + out_T_aug 511 | out_S1 = out_S1 + out_S_aug 512 | master1 = master1 + master_aug 513 | 514 | # inference 2 515 | out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( 516 | out_T, out_S, master=self.master2) 517 | out_S2 = self.pool_hS2(out_S2) 518 | out_T2 = self.pool_hT2(out_T2) 519 | 520 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( 521 | out_T2, out_S2, master=master2) 522 | out_T2 = out_T2 + out_T_aug 523 | out_S2 = out_S2 + out_S_aug 524 | master2 = master2 + master_aug 525 | 526 | out_T1 = self.drop_way(out_T1) 527 | out_T2 = self.drop_way(out_T2) 528 | out_S1 = self.drop_way(out_S1) 529 | out_S2 = self.drop_way(out_S2) 530 | master1 = self.drop_way(master1) 531 | master2 = self.drop_way(master2) 532 | 533 | out_T = torch.max(out_T1, out_T2) 534 | out_S = torch.max(out_S1, out_S2) 535 | master = torch.max(master1, master2) 536 | 537 | # Readout operation 538 | T_max, _ = torch.max(torch.abs(out_T), dim=1) 539 | T_avg = torch.mean(out_T, dim=1) 540 | 541 | S_max, _ = torch.max(torch.abs(out_S), dim=1) 542 | S_avg = torch.mean(out_S, dim=1) 543 | 544 | last_hidden = torch.cat( 545 | [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) 546 | 547 | last_hidden = self.drop(last_hidden) 548 | output = self.out_layer(last_hidden) 549 | 550 | return output, last_hidden 551 | 552 | 553 | if __name__ == "__main__": 554 | # os.environ["CUDA_VISIBLE_DEVICES"] = "5" 555 | print(summary(W2VAASIST(), torch.randn((16,201,1024)), show_input=False)) 556 | # model = W2VAASIST() 557 | # op,hid = model(torch.randn((16,201,1024))) 558 | # print(op.shape) 559 | # print(hid.shape) 560 | 561 | print(sum(p.numel() for p in W2VAASIST().parameters() if p.requires_grad)/1000000) 562 | 563 | 564 | -------------------------------------------------------------------------------- /utils/ideas/MoEF/moef.py: -------------------------------------------------------------------------------- 1 | # Sparsely-Gated Mixture-of-Experts Layers. 2 | # See "Outrageously Large Neural Networks" 3 | # https://arxiv.org/abs/1701.06538 4 | # 5 | # Author: David Rau 6 | # 7 | # The code is based on the TensorFlow implementation: 8 | # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.distributions.normal import Normal 14 | import numpy as np 15 | 16 | 17 | class SparseDispatcher(object): 18 | """Helper for implementing a mixture of experts. 19 | The purpose of this class is to create input minibatches for the 20 | experts and to combine the results of the experts to form a unified 21 | output tensor. 22 | There are two functions: 23 | dispatch - take an input Tensor and create input Tensors for each expert. 24 | combine - take output Tensors from each expert and form a combined output 25 | Tensor. Outputs from different experts for the same batch element are 26 | summed together, weighted by the provided "gates". 27 | The class is initialized with a "gates" Tensor, which specifies which 28 | batch elements go to which experts, and the weights to use when combining 29 | the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. 30 | The inputs and outputs are all two-dimensional [batch, depth]. 31 | Caller is responsible for collapsing additional dimensions prior to 32 | calling this class and reshaping the output to the original shape. 33 | See common_layers.reshape_like(). 34 | Example use: 35 | gates: a float32 `Tensor` with shape `[batch_size, num_experts]` 36 | inputs: a float32 `Tensor` with shape `[batch_size, input_size]` 37 | experts: a list of length `num_experts` containing sub-networks. 38 | dispatcher = SparseDispatcher(num_experts, gates) 39 | expert_inputs = dispatcher.dispatch(inputs) 40 | expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] 41 | outputs = dispatcher.combine(expert_outputs) 42 | The preceding code sets the output for a particular example b to: 43 | output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) 44 | This class takes advantage of sparsity in the gate matrix by including in the 45 | `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. 46 | """ 47 | 48 | def __init__(self, num_experts, gates): 49 | """Create a SparseDispatcher.""" 50 | 51 | self._gates = gates 52 | self._num_experts = num_experts 53 | # sort experts 54 | sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) 55 | # drop indices 56 | _, self._expert_index = sorted_experts.split(1, dim=1) 57 | # get according batch index for each expert 58 | self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] 59 | # calculate num samples that each expert gets 60 | self._part_sizes = (gates > 0).sum(0).tolist() 61 | # expand gates to match with self._batch_index 62 | gates_exp = gates[self._batch_index.flatten()] 63 | self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) 64 | 65 | def dispatch(self, inp): 66 | """Create one input Tensor for each expert. 67 | The `Tensor` for a expert `i` contains the slices of `inp` corresponding 68 | to the batch elements `b` where `gates[b, i] > 0`. 69 | Args: 70 | inp: a `Tensor` of shape "[batch_size, ]` 71 | Returns: 72 | a list of `num_experts` `Tensor`s with shapes 73 | `[expert_batch_size_i, ]`. 74 | """ 75 | 76 | # assigns samples to experts whose gate is nonzero 77 | 78 | # expand according to batch index so we can just split by _part_sizes 79 | inp_exp = inp[self._batch_index].squeeze(1) 80 | return torch.split(inp_exp, self._part_sizes, dim=0) 81 | 82 | def combine(self, expert_out, expert_out_tir, expert_c, expert_d, multiply_by_gates=True): 83 | """Sum together the expert output, weighted by the gates. 84 | The slice corresponding to a particular batch element `b` is computed 85 | as the sum over all experts `i` of the expert output, weighted by the 86 | corresponding gate values. If `multiply_by_gates` is set to False, the 87 | gate values are ignored. 88 | Args: 89 | expert_out: a list of `num_experts` `Tensor`s, each with shape 90 | `[expert_batch_size_i, ]`. 91 | multiply_by_gates: a boolean 92 | Returns: 93 | a `Tensor` with shape `[batch_size, ]`. 94 | """ 95 | # apply exp to expert outputs, so we are not longer in log space 96 | stitched_rgb = torch.cat(expert_out, 0).exp() 97 | stitched_tir = torch.cat(expert_out_tir, 0).exp() 98 | expert_c = torch.cat(expert_c, 0).exp() 99 | expert_d = torch.cat(expert_d, 0).exp() 100 | stitched = torch.cat((stitched_rgb, stitched_tir,expert_c,expert_d), 0) 101 | 102 | if multiply_by_gates: 103 | # stitched = stitched.mul(self._nonzero_gates.unsqueeze(1).unsqueeze(1)) 104 | stitched = self._nonzero_gates.mul(stitched) 105 | # zeros = torch.zeros((self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(2), expert_out[-1].size(3)), requires_grad=True, device=stitched.device) 106 | zeros = torch.zeros((self._gates.size(0), expert_out[-1].size(1)), requires_grad=True, device=stitched.device) 107 | 108 | # combine samples that have been processed by the same k experts 109 | combined = zeros.index_add(0, self._batch_index, stitched.float()) 110 | # add eps to all zero values in order to avoid nans when going back to log space 111 | combined[combined == 0] = np.finfo(float).eps 112 | # back to log space 113 | return combined.log() 114 | 115 | def combine_tuple(self, expert_op24, multiply_by_gates=True): 116 | """Sum together the expert output, weighted by the gates. 117 | The slice corresponding to a particular batch element `b` is computed 118 | as the sum over all experts `i` of the expert output, weighted by the 119 | corresponding gate values. If `multiply_by_gates` is set to False, the 120 | gate values are ignored. 121 | Args: 122 | expert_out: a list of `num_experts` `Tensor`s, each with shape 123 | `[expert_batch_size_i, ]`. 124 | multiply_by_gates: a boolean 125 | Returns: 126 | a `Tensor` with shape `[batch_size, ]`. 127 | """ 128 | # apply exp to expert outputs, so we are not longer in log space 129 | stitched_all = [] 130 | for i in range(24): 131 | stitched_all.append(torch.cat(expert_op24[i], 0).exp()) 132 | stitched = torch.cat(stitched_all, 0) 133 | 134 | if multiply_by_gates: 135 | # stitched = stitched.mul(self._nonzero_gates.unsqueeze(1).unsqueeze(1)) 136 | stitched = self._nonzero_gates.mul(stitched) 137 | # zeros = torch.zeros((self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(2), expert_out[-1].size(3)), requires_grad=True, device=stitched.device) 138 | zeros = torch.zeros((self._gates.size(0), expert_op24[0][-1].size(1)), requires_grad=True, device=stitched.device) 139 | 140 | # combine samples that have been processed by the same k experts 141 | combined = zeros.index_add(0, self._batch_index, stitched.float()) 142 | # add eps to all zero values in order to avoid nans when going back to log space 143 | combined[combined == 0] = np.finfo(float).eps 144 | # back to log space 145 | return combined.log() 146 | 147 | def expert_to_gates(self): 148 | """Gate values corresponding to the examples in the per-expert `Tensor`s. 149 | Returns: 150 | a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` 151 | and shapes `[expert_batch_size_i]` 152 | """ 153 | # split nonzero gates for each expert 154 | return torch.split(self._nonzero_gates, self._part_sizes, dim=0) 155 | 156 | class ConvLayer(nn.Module): 157 | def __init__(self, input_size, output_size, hidden_size, kernel_size=3, stride=1, 158 | padding=1): 159 | super(ConvLayer, self).__init__() 160 | self.conv1 = nn.Linear(input_size, hidden_size) 161 | self.conv2 = nn.Linear(hidden_size, output_size) 162 | # self.conv1 = nn.Conv2d(input_size, hidden_size, kernel_size, stride, padding) 163 | # self.conv2 = nn.Conv2d(hidden_size, output_size, kernel_size, stride, padding) 164 | self.relu = nn.ReLU() 165 | 166 | def forward(self, x): 167 | out = self.conv1(x) 168 | out = self.relu(out) 169 | out = self.conv2(out) 170 | return out 171 | 172 | class MoELocal(nn.Module): 173 | 174 | """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. 175 | Args: 176 | input_size: integer - size of the input 177 | output_size: integer - size of the input 178 | num_experts: an integer - number of experts 179 | hidden_size: an integer - hidden size of the experts 180 | noisy_gating: a boolean 181 | k: an integer - how many experts to use for each batch element 182 | """ 183 | 184 | def __init__(self, ds_inputsize, input_size, output_size, num_experts, hidden_size, noisy_gating=True, k=4, trainingmode=True): 185 | super(MoELocal, self).__init__() 186 | self.noisy_gating = noisy_gating 187 | self.num_experts = num_experts 188 | self.output_size = output_size 189 | self.input_size = input_size 190 | self.hidden_size = hidden_size 191 | self.training = trainingmode 192 | self.k = k 193 | # instantiate experts 194 | self.experts = nn.ModuleList([ConvLayer(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)]) 195 | self.w_gate = nn.Parameter(torch.zeros(ds_inputsize, num_experts), requires_grad=True) 196 | self.w_noise = nn.Parameter(torch.zeros(ds_inputsize, num_experts), requires_grad=True) 197 | 198 | self.softplus = nn.Softplus() 199 | self.softmax = nn.Softmax(1) 200 | self.register_buffer("mean", torch.tensor([0.0])) 201 | self.register_buffer("std", torch.tensor([1.0])) 202 | assert(self.k <= self.num_experts) 203 | 204 | def cv_squared(self, x): 205 | """The squared coefficient of variation of a sample. 206 | Useful as a loss to encourage a positive distribution to be more uniform. 207 | Epsilons added for numerical stability. 208 | Returns 0 for an empty Tensor. 209 | Args: 210 | x: a `Tensor`. 211 | Returns: 212 | a `Scalar`. 213 | """ 214 | eps = 1e-10 215 | # if only num_experts = 1 216 | 217 | if x.shape[0] == 1: 218 | return torch.tensor([0], device=x.device, dtype=x.dtype) 219 | return x.float().var() / (x.float().mean()**2 + eps) 220 | 221 | def _gates_to_load(self, gates): 222 | """Compute the true load per expert, given the gates. 223 | The load is the number of examples for which the corresponding gate is >0. 224 | Args: 225 | gates: a `Tensor` of shape [batch_size, n] 226 | Returns: 227 | a float32 `Tensor` of shape [n] 228 | """ 229 | return (gates > 0).sum(0) 230 | 231 | def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): 232 | """Helper function to NoisyTopKGating. 233 | Computes the probability that value is in top k, given different random noise. 234 | This gives us a way of backpropagating from a loss that balances the number 235 | of times each expert is in the top k experts per example. 236 | In the case of no noise, pass in None for noise_stddev, and the result will 237 | not be differentiable. 238 | Args: 239 | clean_values: a `Tensor` of shape [batch, n]. 240 | noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus 241 | normally distributed noise with standard deviation noise_stddev. 242 | noise_stddev: a `Tensor` of shape [batch, n], or None 243 | noisy_top_values: a `Tensor` of shape [batch, m]. 244 | "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 245 | Returns: 246 | a `Tensor` of shape [batch, n]. 247 | """ 248 | batch = clean_values.size(0) 249 | m = noisy_top_values.size(1) 250 | top_values_flat = noisy_top_values.flatten() 251 | 252 | threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k 253 | threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) 254 | is_in = torch.gt(noisy_values, threshold_if_in) 255 | threshold_positions_if_out = threshold_positions_if_in - 1 256 | threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) 257 | # is each value currently in the top k. 258 | normal = Normal(self.mean, self.std) 259 | prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) 260 | prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) 261 | prob = torch.where(is_in, prob_if_in, prob_if_out) 262 | return prob 263 | 264 | def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): 265 | """Noisy top-k gating. 266 | See paper: https://arxiv.org/abs/1701.06538. 267 | Args: 268 | x: input Tensor with shape [batch_size, input_size] 269 | train: a boolean - we only add noise at training time. 270 | noise_epsilon: a float 271 | Returns: 272 | gates: a Tensor with shape [batch_size, num_experts] 273 | load: a Tensor with shape [num_experts] 274 | """ 275 | clean_logits = x @ self.w_gate 276 | if self.noisy_gating and train: 277 | raw_noise_stddev = x @ self.w_noise 278 | noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) 279 | noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) 280 | logits = noisy_logits 281 | else: 282 | logits = clean_logits 283 | # calculate topk + 1 that will be needed for the noisy gates 284 | top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) 285 | top_k_logits = top_logits[:, :self.k] 286 | top_k_indices = top_indices[:, :self.k] 287 | top_k_gates = self.softmax(top_k_logits) 288 | 289 | zeros = torch.zeros_like(logits, requires_grad=True) 290 | gates = zeros.scatter(1, top_k_indices, top_k_gates) 291 | 292 | if self.noisy_gating and self.k < self.num_experts and train: 293 | load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) 294 | else: 295 | load = self._gates_to_load(gates) 296 | return gates, load 297 | 298 | def forward(self, x_ds, x_a, x_b, x_c, x_d, training = False): 299 | """Args: 300 | x: tensor shape [batch_size, input_size] # x.shape torch.Size([6, 32, 512, 640]) 301 | train: a boolean scalar. # x_ds.shape torch.Size([6, 163840]) 302 | loss_coef: a scalar - multiplier on load-balancing losses 303 | 304 | Returns: 305 | y: a tensor with shape [batch_size, output_size]. 306 | extra_training_loss: a scalar. This should be added into the overall 307 | training loss of the model. The backpropagation of this loss 308 | encourages all experts to be approximately equally used across a batch. 309 | """ 310 | gates, load = self.noisy_top_k_gating(x_ds, training) 311 | # calculate importance loss 312 | importance = gates.sum(0) 313 | 314 | loss = self.cv_squared(importance) + self.cv_squared(load) 315 | # loss *= loss_coef 316 | 317 | dispatcher = SparseDispatcher(self.num_experts, gates) 318 | 319 | expert_inputs_a = dispatcher.dispatch(x_a) 320 | expert_inputs_b = dispatcher.dispatch(x_b) 321 | expert_inputs_c = dispatcher.dispatch(x_c) 322 | expert_inputs_d = dispatcher.dispatch(x_d) 323 | 324 | gates = dispatcher.expert_to_gates() 325 | expert_outputs_a = [self.experts[i](expert_inputs_a[i]) for i in range(self.num_experts//4)] 326 | expert_outputs_b = [self.experts[i](expert_inputs_b[i]) for i in range(self.num_experts//4, self.num_experts//4 *2)] 327 | expert_outputs_c = [self.experts[i](expert_inputs_c[i]) for i in range(self.num_experts//4 * 2, self.num_experts//4 *3)] 328 | expert_outputs_d = [self.experts[i](expert_inputs_d[i]) for i in range(self.num_experts//4 * 3, self.num_experts)] 329 | 330 | y = dispatcher.combine(expert_outputs_a, expert_outputs_b,expert_outputs_c,expert_outputs_d) 331 | return y, loss 332 | 333 | 334 | 335 | class MoE24fusion(MoELocal): 336 | def __init__(self, ds_inputsize, input_size, output_size, num_experts, hidden_size, noisy_gating=True, k=4, trainingmode=True): 337 | super(MoE24fusion, self).__init__(ds_inputsize, input_size, output_size, num_experts, hidden_size, noisy_gating=noisy_gating, k=k, trainingmode=trainingmode) 338 | def forward(self, x_ds, x_tuple24, training = False): 339 | """Args: 340 | x: tensor shape [batch_size, input_size] # x.shape torch.Size([6, 32, 512, 640]) 341 | train: a boolean scalar. # x_ds.shape torch.Size([6, 163840]) 342 | loss_coef: a scalar - multiplier on load-balancing losses 343 | 344 | Returns: 345 | y: a tensor with shape [batch_size, output_size]. 346 | extra_training_loss: a scalar. This should be added into the overall 347 | training loss of the model. The backpropagation of this loss 348 | encourages all experts to be approximately equally used across a batch. 349 | """ 350 | gates, load = self.noisy_top_k_gating(x_ds, training) 351 | # calculate importance loss 352 | importance = gates.sum(0) 353 | 354 | loss = self.cv_squared(importance) + self.cv_squared(load) 355 | # loss *= loss_coef 356 | 357 | dispatcher = SparseDispatcher(self.num_experts, gates) 358 | expert_inputs_all = [] 359 | for layer in range(24): 360 | expert_inputs_all.append(dispatcher.dispatch(x_tuple24[layer])) 361 | 362 | gates = dispatcher.expert_to_gates() 363 | expert_outputs_all = [] 364 | for layer in range(24): 365 | expert_outputs_all.append([self.experts[i](expert_inputs_all[layer][i]) for i in range(self.num_experts//24 * (layer), self.num_experts//24 * (layer + 1))]) 366 | 367 | y = dispatcher.combine_tuple(expert_outputs_all) 368 | return y, loss 369 | 370 | 371 | 372 | 373 | if __name__ == "__main__": 374 | moe_l = MoELocal( 375 | ds_inputsize=32,input_size=32,output_size=32, 376 | num_experts=16,hidden_size=64, noisy_gating=True, 377 | k = 2,trainingmode=True 378 | ) 379 | aa = torch.randn((8,32)) 380 | bb = torch.randn((8,32)) 381 | cc = torch.randn((8,32)) 382 | bs,sp = aa.shape 383 | print(bs) 384 | dd = moe_l(aa,bb,cc,bb,cc,True) 385 | # print(512*640) 386 | print(dd[0].view(8, 201, sp).shape) 387 | print(dd[1]) 388 | print(sum(i.numel() for i in moe_l.parameters() if i.requires_grad)/1000000) # 0.97M 389 | 390 | aa = torch.randn((8*201,1024)) 391 | bb = torch.randn((8*201,1024)) 392 | cc = torch.randn((8*201,1024)) 393 | moe_l = MoE24fusion( 394 | ds_inputsize=1024,input_size=1024,output_size=1024, 395 | num_experts=24*4,hidden_size=128, noisy_gating=True, 396 | k = 2,trainingmode=True 397 | ) 398 | dd = moe_l(aa,[bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc,bb,cc]) 399 | print(dd[0].view(8, 201, sp).shape) 400 | 401 | -------------------------------------------------------------------------------- /utils/tools/cul_eer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import sys 4 | 5 | 6 | 7 | def eer19only(score_file, label_file,pos=1): 8 | target=[] 9 | nontarget=[] 10 | target_score=[] 11 | nontarget_score=[] 12 | wav_lists=[] 13 | score={} 14 | lable_list={} 15 | with open(label_file, 'r', encoding="utf-8") as f: 16 | lines = f.readlines() 17 | for line in lines: 18 | line = line.strip().split() 19 | if len(line) > 1: 20 | wav_id = line[1] 21 | label=line[4] 22 | if label=="spoof": 23 | nontarget.append(wav_id) 24 | else: 25 | target.append(wav_id) 26 | 27 | with open(score_file, 'r', encoding="utf-8") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip().split() 31 | if len(line) > 1: 32 | wav_id = line[0] 33 | score[wav_id]=(line[pos]).replace("[","").replace("]","") 34 | for wav_id in target: 35 | target_score.append(float(score[wav_id])) 36 | for wav_id in nontarget: 37 | nontarget_score.append(float(score[wav_id])) 38 | target_score=np.array(target_score) 39 | nontarget_score=np.array(nontarget_score) 40 | eer_cm, _=compute_eer(target_score, nontarget_score) 41 | return eer_cm * 100 42 | 43 | 44 | 45 | def eeronly(score_file, label_file,pos=1): 46 | target=[] 47 | nontarget=[] 48 | target_score=[] 49 | nontarget_score=[] 50 | wav_lists=[] 51 | score={} 52 | lable_list={} 53 | with open(label_file, 'r', encoding="utf-8") as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | line = line.strip().split() 57 | if len(line) > 1: 58 | wav_id = line[2] 59 | label=line[5] 60 | if label=="deepfake": 61 | nontarget.append(wav_id) 62 | else: 63 | target.append(wav_id) 64 | 65 | with open(score_file, 'r', encoding="utf-8") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | line = line.strip().split() 69 | if len(line) > 1: 70 | wav_id = line[0] 71 | score[wav_id]=(line[pos]).replace("[","").replace("]","") 72 | for wav_id in target: 73 | target_score.append(float(score[wav_id])) 74 | for wav_id in nontarget: 75 | nontarget_score.append(float(score[wav_id])) 76 | target_score=np.array(target_score) 77 | nontarget_score=np.array(nontarget_score) 78 | eer_cm, _=compute_eer(target_score, nontarget_score) 79 | return eer_cm * 100 80 | 81 | 82 | 83 | def get_alltrn_data_kv(): 84 | # 初始化一个空字典 85 | key_value_dict = {} 86 | 87 | # 打开txt文件进行读取 88 | with open('/data8/wangzhiyong/project/fakeAudioDetection/lightning_FAD/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt', 'r') as file: 89 | # 逐行读取文件内容 90 | for line in file: 91 | # 使用空格分割每行内容,并取第二列作为key 92 | key = line.split()[1] 93 | # 初始化value为-1 94 | value = 0 95 | # 将key-value对添加到字典中 96 | key_value_dict[key] = value 97 | return key_value_dict 98 | 99 | 100 | def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold): 101 | 102 | # False alarm and miss rates for ASV 103 | Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size 104 | Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size 105 | 106 | # Rate of rejecting spoofs in ASV 107 | if spoof_asv.size == 0: 108 | Pmiss_spoof_asv = None 109 | else: 110 | Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size 111 | 112 | return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv 113 | 114 | 115 | def compute_det_curve(target_scores, nontarget_scores): 116 | 117 | n_scores = target_scores.size + nontarget_scores.size 118 | all_scores = np.concatenate((target_scores, nontarget_scores)) 119 | labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size))) 120 | 121 | # Sort labels based on scores 122 | indices = np.argsort(all_scores, kind='mergesort') 123 | labels = labels[indices] 124 | 125 | # Compute false rejection and false acceptance rates 126 | tar_trial_sums = np.cumsum(labels) 127 | nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums) 128 | 129 | frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates 130 | far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates 131 | thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores 132 | 133 | return frr, far, thresholds 134 | 135 | 136 | def compute_eer(target_scores, nontarget_scores): 137 | """ Returns equal error rate (EER) and the corresponding threshold. """ 138 | frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) 139 | abs_diffs = np.abs(frr - far) 140 | min_index = np.argmin(abs_diffs) 141 | eer = np.mean((frr[min_index], far[min_index])) 142 | return eer, thresholds[min_index] 143 | 144 | 145 | def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, print_cost): 146 | """ 147 | Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. 148 | In brief, t-DCF returns a detection cost of a cascaded system of this form, 149 | Speech waveform -> [CM] -> [ASV] -> decision 150 | where CM stands for countermeasure and ASV for automatic speaker 151 | verification. The CM is therefore used as a 'gate' to decided whether or 152 | not the input speech sample should be passed onwards to the ASV system. 153 | Generally, both CM and ASV can do detection errors. Not all those errors 154 | are necessarily equally cost, and not all types of users are necessarily 155 | equally likely. The tandem t-DCF gives a principled with to compare 156 | different spoofing countermeasures under a detection cost function 157 | framework that takes that information into account. 158 | INPUTS: 159 | bonafide_score_cm A vector of POSITIVE CPASS (bona fide or human) 160 | detection scores obtained by executing a spoofing 161 | countermeasure (CM) on some positive evaluation trials. 162 | trial represents a bona fide case. 163 | spoof_score_cm A vector of NEGATIVE CPASS (spoofing attack) 164 | detection scores obtained by executing a spoofing 165 | CM on some negative evaluation trials. 166 | Pfa_asv False alarm (false acceptance) rate of the ASV 167 | system that is evaluated in tandem with the CM. 168 | Assumed to be in fractions, not percentages. 169 | Pmiss_asv Miss (false rejection) rate of the ASV system that 170 | is evaluated in tandem with the spoofing CM. 171 | Assumed to be in fractions, not percentages. 172 | Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that 173 | is evaluated in tandem with the spoofing CM. That 174 | is, the fraction of spoof samples that were 175 | rejected by the ASV system. 176 | cost_model A struct that contains the parameters of t-DCF, 177 | with the following fields. 178 | Ptar Prior probability of target speaker. 179 | Pnon Prior probability of nontarget speaker (zero-effort impostor) 180 | Psoof Prior probability of spoofing attack. 181 | Cmiss_asv Cost of ASV falsely rejecting target. 182 | Cfa_asv Cost of ASV falsely accepting nontarget. 183 | Cmiss_cm Cost of CM falsely rejecting target. 184 | Cfa_cm Cost of CM falsely accepting spoof. 185 | print_cost Print a summary of the cost parameters and the 186 | implied t-DCF cost function? 187 | OUTPUTS: 188 | tDCF_norm Normalized t-DCF curve across the different CM 189 | system operating points; see [2] for more details. 190 | Normalized t-DCF > 1 indicates a useless 191 | countermeasure (as the tandem system would do 192 | better without it). min(tDCF_norm) will be the 193 | minimum t-DCF used in ASVspoof 2019 [2]. 194 | CM_thresholds Vector of same size as tDCF_norm corresponding to 195 | the CM threshold (operating point). 196 | NOTE: 197 | o In relative terms, higher detection scores values are assumed to 198 | indicate stronger support for the bona fide hypothesis. 199 | o You should provide real-valued soft scores, NOT hard decisions. The 200 | recommendation is that the scores are log-likelihood ratios (LLRs) 201 | from a bonafide-vs-spoof hypothesis based on some statistical model. 202 | This, however, is NOT required. The scores can have arbitrary range 203 | and scaling. 204 | o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. 205 | References: 206 | [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco, 207 | M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection 208 | Cost Function for the Tandem Assessment of Spoofing Countermeasures 209 | and Automatic Speaker Verification", Proc. Odyssey 2018: the 210 | Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne, 211 | France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf) 212 | [2] ASVspoof 2019 challenge evaluation plan 213 | TODO: 214 | """ 215 | 216 | 217 | # Sanity check of cost parameters 218 | if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \ 219 | cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0: 220 | print('WARNING: Usually the cost values should be positive!') 221 | 222 | if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ 223 | np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: 224 | sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.') 225 | 226 | # Unless we evaluate worst-case model, we need to have some spoof tests against asv 227 | if Pmiss_spoof_asv is None: 228 | sys.exit('ERROR: you should provide miss rate of spoof tests against your ASV system.') 229 | 230 | # Sanity check of scores 231 | combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) 232 | if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): 233 | sys.exit('ERROR: Your scores contain nan or inf.') 234 | 235 | # Sanity check that inputs are scores and not decisions 236 | n_uniq = np.unique(combined_scores).size 237 | if n_uniq < 3: 238 | sys.exit('ERROR: You should provide soft CM scores - not binary decisions') 239 | 240 | # Obtain miss and false alarm rates of CM 241 | Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm) 242 | 243 | # Constants - see ASVspoof 2019 evaluation plan 244 | C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \ 245 | cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv 246 | C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv) 247 | 248 | # Sanity check of the weights 249 | if C1 < 0 or C2 < 0: 250 | sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?') 251 | 252 | # Obtain t-DCF curve for all thresholds 253 | tDCF = C1 * Pmiss_cm + C2 * Pfa_cm 254 | 255 | # Normalized t-DCF 256 | tDCF_norm = tDCF / np.minimum(C1, C2) 257 | 258 | # Everything should be fine if reaching here. 259 | if print_cost: 260 | 261 | print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size)) 262 | print('t-DCF MODEL') 263 | print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar'])) 264 | print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon'])) 265 | print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof'])) 266 | print(' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'.format(cost_model['Cfa_asv'])) 267 | print(' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'.format(cost_model['Cmiss_asv'])) 268 | print(' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'.format(cost_model['Cfa_cm'])) 269 | print(' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'.format(cost_model['Cmiss_cm'])) 270 | print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)') 271 | 272 | if C2 == np.minimum(C1, C2): 273 | print(' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(C1 / C2)) 274 | else: 275 | print(' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(C2 / C1)) 276 | 277 | return tDCF_norm, CM_thresholds 278 | 279 | def eerandtdcf(score_file, label_file, asv_label,pos=1): 280 | 281 | # Fix tandem detection cost function (t-DCF) parameters 282 | Pspoof = 0.05 283 | cost_model = { 284 | 'Pspoof': Pspoof, # Prior probability of a spoofing attack 285 | 'Ptar': (1 - Pspoof) * 0.99, # Prior probability of target speaker 286 | 'Pnon': (1 - Pspoof) * 0.01, # Prior probability of nontarget speaker 287 | 'Cmiss_asv': 1, # Cost of ASV system falsely rejecting target speaker 288 | 'Cfa_asv': 10, # Cost of ASV system falsely accepting nontarget speaker 289 | 'Cmiss_cm': 1, # Cost of CM system falsely rejecting target speaker 290 | 'Cfa_cm': 10, # Cost of CM system falsely accepting spoof 291 | } 292 | asv_data = np.genfromtxt(asv_label, dtype=str) 293 | asv_sources = asv_data[:, 0] 294 | asv_keys = asv_data[:, 1] 295 | asv_scores = asv_data[:, 2].astype(np.float64) 296 | 297 | tar_asv = asv_scores[asv_keys == 'target'] 298 | non_asv = asv_scores[asv_keys == 'nontarget'] 299 | spoof_asv = asv_scores[asv_keys == 'spoof'] 300 | 301 | eer_asv, asv_threshold = compute_eer(tar_asv, non_asv) 302 | [Pfa_asv, Pmiss_asv, Pmiss_spoof_asv] = obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold) 303 | 304 | target=[] 305 | nontarget=[] 306 | target_score=[] 307 | nontarget_score=[] 308 | wav_lists=[] 309 | score={} 310 | lable_list={} 311 | wrong=0 312 | with open(label_file, 'r', encoding="utf-8") as f: 313 | lines = f.readlines() 314 | for line in lines: 315 | line = line.strip().split() 316 | if len(line) > 1: 317 | wav_id = line[1] 318 | label=line[4] 319 | lable_list[wav_id]=label 320 | if label=="spoof": 321 | nontarget.append(wav_id) 322 | else: 323 | target.append(wav_id) 324 | 325 | with open(score_file, 'r', encoding="utf-8") as f: 326 | lines = f.readlines() 327 | for line in lines: 328 | line = line.strip().split() 329 | if len(line) > 1: 330 | wav_id = line[0] 331 | wav_lists.append(wav_id) 332 | score[wav_id]=(line[pos]).replace("[","").replace("]","") 333 | for wav_id in target: 334 | target_score.append(float(score[wav_id])) 335 | for wav_id in nontarget: 336 | nontarget_score.append(float(score[wav_id])) 337 | target_score=np.array(target_score) 338 | nontarget_score=np.array(nontarget_score) 339 | eer_cm, Threshhold=compute_eer(target_score, nontarget_score) 340 | ''' 341 | print("EER={}, Threshhold={}".format(EER, Threshhold)) 342 | for wav_id in wav_lists: 343 | if float(score[wav_id])>Threshhold and lable_list[wav_id]=="spoof": 344 | wrong+=1 345 | 346 | acc=(len(score)-wrong)/len(score) 347 | print("Acc={}".format(acc)) 348 | ''' 349 | tDCF_curve, CM_thresholds = compute_tDCF(target_score, nontarget_score, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, True) 350 | 351 | # Minimum t-DCF 352 | min_tDCF_index = np.argmin(tDCF_curve) 353 | min_tDCF = tDCF_curve[min_tDCF_index] 354 | 355 | print('ASV SYSTEM') 356 | print(' EER = {:8.5f} % (Equal error rate (target vs. nontarget discrimination)'.format(eer_asv * 100)) 357 | print(' Pfa = {:8.5f} % (False acceptance rate of nontargets)'.format(Pfa_asv * 100)) 358 | print(' Pmiss = {:8.5f} % (False rejection rate of targets)'.format(Pmiss_asv * 100)) 359 | print(' 1-Pmiss,spoof = {:8.5f} % (Spoof false acceptance rate)'.format((1 - Pmiss_spoof_asv) * 100)) 360 | 361 | print('\nCM SYSTEM') 362 | print(' EER = {:8.5f} % (Equal error rate for countermeasure)'.format(eer_cm * 100)) 363 | 364 | print('\nTANDEM') 365 | print(' min-tDCF = {:8.5f}'.format(min_tDCF)) 366 | 367 | return eer_cm * 100, min_tDCF 368 | 369 | 370 | if __name__ == "__main__": 371 | parser = argparse.ArgumentParser() 372 | # Dataset 373 | # parser.add_argument('--layer', type=int, default=0) 374 | # parser.add_argument('--type', type=str, default="") 375 | parser.add_argument('--scoreFile', type=str, default="") 376 | parser.add_argument('--pos', type=int, default=1) 377 | args = parser.parse_args() 378 | 379 | # partial_spoof 380 | # labelFile="/data8/wangzhiyong/project/fakeAudioDetection/vocoderReWavFAD/datasets/partial_spoof/protocol/PartialSpoof.LA.cm.eval.trl.txt" 381 | # asvlabel="/data8/wangzhiyong/project/fakeAudioDetection/vocoderReWavFAD/datasets/partial_spoof/protocol/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt" 382 | 383 | # asvspoof 2019 384 | labelFile="/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt" 385 | asvlabel="/data8/wangzhiyong/project/fakeAudioDetection/investigating_partial_pre-trained_model_for_fake_audio_detection/datasets/asvspoof2019/LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt" 386 | 387 | # inthewild 388 | # labelFile="/data8/wangzhiyong/project/fakeAudioDetection/FAD_research/datasets/release_in_the_wild/inthewild_protocol.txt" 389 | 390 | # scoreFile=f"{args.type}log_eval_partialspoof_{args.layer}_score.txt" 391 | 392 | 393 | 394 | 395 | def remove_duplicate_lines_inplace(file_path): 396 | # 用于存储已经遇到的第一个字符串 397 | seen_first_strings = set() 398 | 399 | # 打开文件进行读写 400 | with open(file_path, 'r+') as file: 401 | lines = file.readlines() # 读取所有行 402 | 403 | # 将文件指针移到文件开头,准备写入新的内容 404 | file.seek(0) 405 | file.truncate() # 清空文件内容 406 | 407 | for line in lines: 408 | # 提取每行的第一个字符串 409 | first_string = line.split()[0] 410 | 411 | # 如果第一个字符串没有重复,写入文件并添加到集合中 412 | if first_string not in seen_first_strings: 413 | file.write(line) 414 | seen_first_strings.add(first_string) 415 | 416 | # 调用函数,传入文件路径 417 | remove_duplicate_lines_inplace(args.scoreFile) 418 | 419 | 420 | eerandtdcf(args.scoreFile, labelFile, asvlabel,pos=args.pos) --------------------------------------------------------------------------------