├── LICENSE ├── loss.py ├── README.md ├── preprocess.py ├── feature_extraction.py ├── evaluate_tDCF_asvspoof19.py ├── utils_dsp.py ├── eval_metrics.py ├── raw_dataset.py ├── dataset.py ├── test.py ├── model.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 You Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd.function import Function 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | class OCSoftmax(nn.Module): 9 | def __init__(self, feat_dim=2, r_real=0.9, r_fake=0.5, alpha=20.0): 10 | super(OCSoftmax, self).__init__() 11 | self.feat_dim = feat_dim 12 | self.r_real = r_real 13 | self.r_fake = r_fake 14 | self.alpha = alpha 15 | self.center = nn.Parameter(torch.randn(1, self.feat_dim)) 16 | nn.init.kaiming_uniform_(self.center, 0.25) 17 | self.softplus = nn.Softplus() 18 | 19 | def forward(self, x, labels): 20 | """ 21 | Args: 22 | x: feature matrix with shape (batch_size, feat_dim). 23 | labels: ground truth labels with shape (batch_size). 24 | """ 25 | w = F.normalize(self.center, p=2, dim=1) 26 | x = F.normalize(x, p=2, dim=1) 27 | 28 | scores = x @ w.transpose(0,1) 29 | output_scores = scores.clone() 30 | 31 | scores[labels == 0] = self.r_real - scores[labels == 0] 32 | scores[labels == 1] = scores[labels == 1] - self.r_fake 33 | 34 | loss = self.softplus(self.alpha * scores).mean() 35 | 36 | return loss, -output_scores.squeeze(1) 37 | 38 | 39 | class AMSoftmax(nn.Module): 40 | def __init__(self, num_classes, enc_dim, s=20, m=0.9): 41 | super(AMSoftmax, self).__init__() 42 | self.enc_dim = enc_dim 43 | self.num_classes = num_classes 44 | self.s = s 45 | self.m = m 46 | self.centers = nn.Parameter(torch.randn(num_classes, enc_dim)) 47 | 48 | def forward(self, feat, label): 49 | batch_size = feat.shape[0] 50 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 51 | nfeat = torch.div(feat, norms) 52 | 53 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 54 | ncenters = torch.div(self.centers, norms_c) 55 | logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) 56 | 57 | y_onehot = torch.FloatTensor(batch_size, self.num_classes) 58 | y_onehot.zero_() 59 | y_onehot = Variable(y_onehot).cuda() 60 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 61 | margin_logits = self.s * (logits - y_onehot) 62 | 63 | return logits, margin_logits 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Empirical-Channel-CM 2 | 3 | ## An Empirical Study on Channel Effects for Synthetic Voice Spoofing Countermeasure Systems 4 | This repository contains our implementation of the paper, "An Empirical Study on Channel Effects for Synthetic Voice Spoofing Countermeasure Systems". 5 | [[Paper link](https://www.isca-speech.org/archive/interspeech_2021/zhang21ea_interspeech.html)] [[arXiv](https://arxiv.org/pdf/2104.01320.pdf)] [[Video presentation](https://www.youtube.com/watch?v=t6qtehKer6w)] 6 | 7 | ### Cross-Dataset Studies 8 | Existing datasets: 9 | [ASVspoof2019LA](https://datashare.ed.ac.uk/handle/10283/3336), 10 | [ASVspoof2015](https://datashare.ed.ac.uk/handle/10283/853), 11 | [VCC2020 training data](https://zenodo.org/record/4345689#.YVp3UlNKgt0), 12 | [VCC2020 submissions](https://zenodo.org/record/4433173) 13 | 14 | 15 | Augmented data: 16 | 17 | Training + Development: [ASVspoof2019LA-Sim v1.0](https://zenodo.org/record/5548622) 18 | 19 | Evaluation: [ASVspoof2019LA-Sim v1.1](https://zenodo.org/record/5794671) 20 | 21 | ### Channel Robust Strategies 22 | 23 | #### Run the training code 24 | ``` 25 | python3 train.py -o /path/to/output/the/model 26 | ``` 27 | The options: 28 | 29 | --AUG use the plain augmentation 30 | 31 | --MT_AUG use the multitask augmentation 32 | 33 | --ADV_AUG use the adversarial augmentation 34 | 35 | #### Run the testing code 36 | ``` 37 | python3 test.py -m /path/to/the/trained/model --task ASVsppof2019LA 38 | ``` 39 | The options for testing on different dataset: 40 | 41 | ASVspoof2019LA, ASVspoof2015, VCC2020, ASVspoof2019LASim 42 | 43 | The code is based on our previous work "One-class Learning Towards Synthetic Voice Spoofing Detection" [[code link](https://github.com/yzyouzhang/AIR-ASVspoof)] [[paper link](https://ieeexplore.ieee.org/document/9417604)] 44 | 45 | 46 | ### Citation 47 | ``` 48 | @inproceedings{zhang21ea_interspeech, 49 | author={You Zhang and Ge Zhu and Fei Jiang and Zhiyao Duan}, 50 | title={{An Empirical Study on Channel Effects for Synthetic Voice Spoofing Countermeasure Systems}}, 51 | year=2021, 52 | booktitle={Proc. Interspeech 2021}, 53 | pages={4309--4313}, 54 | doi={10.21437/Interspeech.2021-1820} 55 | } 56 | ``` 57 | 58 | Please also feel free to check out our follow-up work: 59 | 60 | [1] Chen, X., Zhang, Y., Zhu, G., Duan, Z. (2021) UR Channel-Robust Synthetic Speech Detection System for ASVspoof 2021. Proc. 2021 Edition of the Automatic Speaker Verification and Spoofing Countermeasures Challenge, 75-82, doi: 10.21437/ASVSPOOF.2021-12 [[link](https://www.isca-speech.org/archive/pdfs/asvspoof_2021/chen21_asvspoof.pdf)] [[code](https://github.com/yzyouzhang/ASVspoof2021_AIR)] [[video](https://www.youtube.com/watch?v=-wKMOTp8Tt0)] 61 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import raw_dataset as dataset 2 | from feature_extraction import LFCC 3 | import os 4 | import torch 5 | from tqdm import tqdm 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 8 | 9 | cuda = torch.cuda.is_available() 10 | print('Cuda device available: ', cuda) 11 | device = torch.device("cuda" if cuda else "cpu") 12 | 13 | from scipy.fftpack import fft, ifft, fftshift, ifftshift, next_fast_len 14 | import numpy as np 15 | import copy 16 | 17 | # for part_ in ["train", "dev", "eval"]: 18 | # asvspoof_raw = dataset.ASVspoof2019Raw("LA", "/data/neil/DS_10283_3336/", "/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_cm_protocols/", part=part_) 19 | # target_dir = os.path.join("/data2/neil/ASVspoof2019LASW", part_, "LFCC") 20 | # if not os.path.exists(target_dir): 21 | # os.makedirs(target_dir) 22 | # lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 23 | # lfcc = lfcc.to(device) 24 | # for idx in tqdm(range(len(asvspoof_raw))): 25 | # waveform, filename, tag, label = asvspoof_raw[idx] 26 | # waveform = spectral_whitening(waveform.squeeze(0).numpy()) 27 | # waveform = torch.from_numpy(np.expand_dims(waveform, axis=0)) 28 | # waveform = waveform.to(device) 29 | # lfccOfWav = lfcc(waveform) 30 | # torch.save(lfccOfWav, os.path.join(target_dir, "%05d_%s_%s_%s.pt" % (idx, filename, tag, label))) 31 | # print("Done!") 32 | 33 | # vcc2020 = dataset.VCC2020Raw() 34 | # print(len(vcc2020)) 35 | # target_dir = "/data2/neil/VCC2020/LFCC/" 36 | # lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 37 | # lfcc = lfcc.to(device) 38 | # for idx in range(len(vcc2020)): 39 | # print("Processing", idx) 40 | # waveform, filename, tag, label = vcc2020[idx] 41 | # waveform = waveform.to(device) 42 | # lfccOfWav = lfcc(waveform) 43 | # torch.save(lfccOfWav, os.path.join(target_dir, "%04d_%s_%s_%s.pt" %(idx, filename, tag, label))) 44 | # print("Done!") 45 | 46 | # for part_ in ["train", "dev", "eval"]: 47 | # asvspoof_raw = dataset.ASVspoof2015Raw("/data/neil/ASVspoof2015/wav", "/data/neil/ASVspoof2015/CM_protocol", part=part_) 48 | # target_dir = os.path.join("/data2/neil/ASVspoof2015", part_, "LFCC") 49 | # lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 50 | # lfcc = lfcc.to(device) 51 | # for idx in range(len(asvspoof_raw)): 52 | # print("Processing", idx) 53 | # waveform, filename, tag, label = asvspoof_raw[idx] 54 | # waveform = waveform.to(device) 55 | # lfccOfWav = lfcc(waveform) 56 | # torch.save(lfccOfWav, os.path.join(target_dir, "%05d_%s_%s_%s.pt" % (idx, filename, tag, label))) 57 | # print("Done!") 58 | 59 | asvspoof2019channel = dataset.ASVspoof2019LARaw_withDevice() 60 | print(len(asvspoof2019channel)) 61 | target_dir = "/dataNVME/neil/ASVspoof2019LADeviceEval" 62 | lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 63 | lfcc = lfcc.to(device) 64 | for idx in tqdm(range(len(asvspoof2019channel))): 65 | waveform, filename, tag, label, channel = asvspoof2019channel[idx] 66 | waveform = waveform.to(device) 67 | lfccOfWav = lfcc(waveform) 68 | if not os.path.exists(os.path.join(target_dir, channel)): 69 | os.makedirs(os.path.join(target_dir, channel)) 70 | torch.save(lfccOfWav, os.path.join(target_dir, channel, "%06d_%s_%s_%s_%s.pt" %(idx, filename, tag, label, channel))) 71 | print("Done!") 72 | -------------------------------------------------------------------------------- /feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as torch_nn 3 | import torch.nn.functional as torch_nn_func 4 | import numpy as np 5 | from utils_dsp import LinearDCT 6 | import librosa 7 | import pickle 8 | 9 | __author__ = "Xin Wang" 10 | __email__ = "wangxin@nii.ac.jp" 11 | __copyright__ = "Copyright 2020, Xin Wang" 12 | 13 | ################## 14 | ## other utilities 15 | ################## 16 | def trimf(x, params): 17 | """ 18 | trimf: similar to Matlab definition 19 | https://www.mathworks.com/help/fuzzy/trimf.html?s_tid=srchtitle 20 | 21 | """ 22 | if len(params) != 3: 23 | print("trimp requires params to be a list of 3 elements") 24 | sys.exit(1) 25 | a = params[0] 26 | b = params[1] 27 | c = params[2] 28 | if a > b or b > c: 29 | print("trimp(x, [a, b, c]) requires a<=b<=c") 30 | sys.exit(1) 31 | y = torch.zeros_like(x, dtype=torch.float32) 32 | if a < b: 33 | index = np.logical_and(a < x, x < b) 34 | y[index] = (x[index] - a) / (b - a) 35 | if b < c: 36 | index = np.logical_and(b < x, x < c) 37 | y[index] = (c - x[index]) / (c - b) 38 | y[x == b] = 1 39 | return y 40 | 41 | def delta(x): 42 | """ By default 43 | input 44 | ----- 45 | x (batch, Length, dim) 46 | 47 | output 48 | ------ 49 | output (batch, Length, dim) 50 | 51 | Delta is calculated along Length 52 | """ 53 | length = x.shape[1] 54 | output = torch.zeros_like(x) 55 | x_temp = torch_nn_func.pad(x.unsqueeze(1), (0, 0, 1, 1), 56 | 'replicate').squeeze(1) 57 | output = -1 * x_temp[:, 0:length] + x_temp[:, 2:] 58 | return output 59 | 60 | 61 | class LFCC(torch_nn.Module): 62 | """ Based on asvspoof.org baseline Matlab code. 63 | Difference: with_energy is added to set the first dimension as energy 64 | 65 | """ 66 | 67 | def __init__(self, fl, fs, fn, sr, filter_num, 68 | with_energy=False, with_emphasis=True, 69 | with_delta=True): 70 | super(LFCC, self).__init__() 71 | self.fl = fl 72 | self.fs = fs 73 | self.fn = fn 74 | self.sr = sr 75 | self.filter_num = filter_num 76 | 77 | f = (sr / 2) * torch.linspace(0, 1, fn // 2 + 1) 78 | filter_bands = torch.linspace(min(f), max(f), filter_num + 2) 79 | 80 | filter_bank = torch.zeros([fn // 2 + 1, filter_num]) 81 | for idx in range(filter_num): 82 | filter_bank[:, idx] = trimf( 83 | f, [filter_bands[idx], 84 | filter_bands[idx + 1], 85 | filter_bands[idx + 2]]) 86 | self.lfcc_fb = torch_nn.Parameter(filter_bank, requires_grad=False) 87 | self.l_dct = LinearDCT(filter_num, 'dct', norm='ortho') 88 | self.with_energy = with_energy 89 | self.with_emphasis = with_emphasis 90 | self.with_delta = with_delta 91 | return 92 | 93 | def forward(self, x): 94 | """ 95 | 96 | input: 97 | ------ 98 | x: tensor(batch, length), where length is waveform length 99 | 100 | output: 101 | ------- 102 | lfcc_output: tensor(batch, frame_num, dim_num) 103 | """ 104 | # pre-emphasis 105 | if self.with_emphasis: 106 | x[:, 1:] = x[:, 1:] - 0.97 * x[:, 0:-1] 107 | 108 | # STFT 109 | x_stft = torch.stft(x, self.fn, self.fs, self.fl, 110 | window=torch.hamming_window(self.fl), 111 | onesided=True, pad_mode="constant") 112 | # amplitude 113 | sp_amp = torch.norm(x_stft, 2, -1).pow(2).permute(0, 2, 1).contiguous() 114 | 115 | # filter bank 116 | fb_feature = torch.log10(torch.matmul(sp_amp, self.lfcc_fb) + 117 | torch.finfo(torch.float32).eps) 118 | 119 | # DCT 120 | lfcc = self.l_dct(fb_feature) 121 | 122 | # Add energy 123 | if self.with_energy: 124 | power_spec = sp_amp / self.fn 125 | energy = torch.log10(power_spec.sum(axis=2) + 126 | torch.finfo(torch.float32).eps) 127 | lfcc[:, :, 0] = energy 128 | 129 | # Add delta coefficients 130 | if self.with_delta: 131 | lfcc_delta = delta(lfcc) 132 | lfcc_delta_delta = delta(lfcc_delta) 133 | lfcc_output = torch.cat((lfcc, lfcc_delta, lfcc_delta_delta), 2) 134 | else: 135 | lfcc_output = lfcc 136 | 137 | # done 138 | return lfcc_output 139 | 140 | if __name__ == "__main__": 141 | lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 142 | wav, sr = librosa.load("/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_train/flac/LA_T_3727749.flac", sr=16000) 143 | # wav = torch.randn(1, 32456) 144 | wav = torch.Tensor(np.expand_dims(wav, axis=0)) 145 | wav_lfcc = lfcc(wav) 146 | with open('/dataNVME/neil/ASVspoof2019LAFeatures/train' + '/' + "LA_T_3727749" + "LFCC" + '.pkl', 'rb') as feature_handle: 147 | ref_lfcc = pickle.load(feature_handle) 148 | print(ref_lfcc.shape) 149 | print(ref_lfcc[0:3,0:3]) 150 | print(wav_lfcc.shape) 151 | print(wav_lfcc[0,0:3,0:3]) 152 | 153 | -------------------------------------------------------------------------------- /evaluate_tDCF_asvspoof19.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import eval_metrics as em 4 | import matplotlib.pyplot as plt 5 | 6 | def compute_eer_and_tdcf(cm_score_file, path_to_database): 7 | asv_score_file = os.path.join(path_to_database, 'LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt') 8 | 9 | # Fix tandem detection cost function (t-DCF) parameters 10 | Pspoof = 0.05 11 | cost_model = { 12 | 'Pspoof': Pspoof, # Prior probability of a spoofing attack 13 | 'Ptar': (1 - Pspoof) * 0.99, # Prior probability of target speaker 14 | 'Pnon': (1 - Pspoof) * 0.01, # Prior probability of nontarget speaker 15 | 'Cmiss_asv': 1, # Cost of ASV system falsely rejecting target speaker 16 | 'Cfa_asv': 10, # Cost of ASV system falsely accepting nontarget speaker 17 | 'Cmiss_cm': 1, # Cost of CM system falsely rejecting target speaker 18 | 'Cfa_cm': 10, # Cost of CM system falsely accepting spoof 19 | } 20 | 21 | # Load organizers' ASV scores 22 | asv_data = np.genfromtxt(asv_score_file, dtype=str) 23 | asv_sources = asv_data[:, 0] 24 | asv_keys = asv_data[:, 1] 25 | asv_scores = asv_data[:, 2].astype(np.float) 26 | 27 | # Load CM scores 28 | cm_data = np.genfromtxt(cm_score_file, dtype=str) 29 | cm_utt_id = cm_data[:, 0] 30 | cm_sources = cm_data[:, 1] 31 | cm_keys = cm_data[:, 2] 32 | cm_scores = cm_data[:, 3].astype(np.float) 33 | 34 | other_cm_scores = -cm_scores 35 | 36 | # Extract target, nontarget, and spoof scores from the ASV scores 37 | tar_asv = asv_scores[asv_keys == 'target'] 38 | non_asv = asv_scores[asv_keys == 'nontarget'] 39 | spoof_asv = asv_scores[asv_keys == 'spoof'] 40 | 41 | # Extract bona fide (real human) and spoof scores from the CM scores 42 | bona_cm = cm_scores[cm_keys == 'bonafide'] 43 | spoof_cm = cm_scores[cm_keys == 'spoof'] 44 | 45 | # EERs of the standalone systems and fix ASV operating point to EER threshold 46 | eer_asv, asv_threshold = em.compute_eer(tar_asv, non_asv) 47 | eer_cm = em.compute_eer(bona_cm, spoof_cm)[0] 48 | 49 | other_eer_cm = em.compute_eer(other_cm_scores[cm_keys == 'bonafide'], other_cm_scores[cm_keys == 'spoof'])[0] 50 | 51 | [Pfa_asv, Pmiss_asv, Pmiss_spoof_asv] = em.obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold) 52 | 53 | if eer_cm < other_eer_cm: 54 | # Compute t-DCF 55 | tDCF_curve, CM_thresholds = em.compute_tDCF(bona_cm, spoof_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, True) 56 | 57 | # Minimum t-DCF 58 | min_tDCF_index = np.argmin(tDCF_curve) 59 | min_tDCF = tDCF_curve[min_tDCF_index] 60 | 61 | else: 62 | tDCF_curve, CM_thresholds = em.compute_tDCF(other_cm_scores[cm_keys == 'bonafide'], other_cm_scores[cm_keys == 'spoof'], 63 | Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, True) 64 | 65 | # Minimum t-DCF 66 | min_tDCF_index = np.argmin(tDCF_curve) 67 | min_tDCF = tDCF_curve[min_tDCF_index] 68 | 69 | 70 | # print('ASV SYSTEM') 71 | # print(' EER = {:8.5f} % (Equal error rate (target vs. nontarget discrimination)'.format(eer_asv * 100)) 72 | # print(' Pfa = {:8.5f} % (False acceptance rate of nontargets)'.format(Pfa_asv * 100)) 73 | # print(' Pmiss = {:8.5f} % (False rejection rate of targets)'.format(Pmiss_asv * 100)) 74 | # print(' 1-Pmiss,spoof = {:8.5f} % (Spoof false acceptance rate)'.format((1 - Pmiss_spoof_asv) * 100)) 75 | 76 | print('\nCM SYSTEM') 77 | print(' EER = {:8.5f} % (Equal error rate for countermeasure)'.format(min(eer_cm, other_eer_cm) * 100)) 78 | 79 | print('\nTANDEM') 80 | print(' min-tDCF = {:8.5f}'.format(min_tDCF)) 81 | 82 | 83 | # Visualize ASV scores and CM scores 84 | plt.figure() 85 | ax = plt.subplot(121) 86 | plt.hist(tar_asv, histtype='step', density=True, bins=50, label='Target') 87 | plt.hist(non_asv, histtype='step', density=True, bins=50, label='Nontarget') 88 | plt.hist(spoof_asv, histtype='step', density=True, bins=50, label='Spoof') 89 | plt.plot(asv_threshold, 0, 'o', markersize=10, mfc='none', mew=2, clip_on=False, label='EER threshold') 90 | plt.legend() 91 | plt.xlabel('ASV score') 92 | plt.ylabel('Density') 93 | plt.title('ASV score histogram') 94 | 95 | ax = plt.subplot(122) 96 | plt.hist(bona_cm, histtype='step', density=True, bins=50, label='Bona fide') 97 | plt.hist(spoof_cm, histtype='step', density=True, bins=50, label='Spoof') 98 | plt.legend() 99 | plt.xlabel('CM score') 100 | # plt.ylabel('Density') 101 | plt.title('CM score histogram') 102 | plt.savefig(cm_score_file[:-4]+'1.png') 103 | 104 | 105 | # Plot t-DCF as function of the CM threshold. 106 | plt.figure() 107 | plt.plot(CM_thresholds, tDCF_curve) 108 | plt.plot(CM_thresholds[min_tDCF_index], min_tDCF, 'o', markersize=10, mfc='none', mew=2) 109 | plt.xlabel('CM threshold index (operating point)') 110 | plt.ylabel('Norm t-DCF') 111 | plt.title('Normalized tandem t-DCF') 112 | plt.plot([np.min(CM_thresholds), np.max(CM_thresholds)], [1, 1], '--', color='black') 113 | plt.legend(('t-DCF', 'min t-DCF ({:.5f})'.format(min_tDCF), 'Arbitrarily bad CM (Norm t-DCF=1)')) 114 | plt.xlim([np.min(CM_thresholds), np.max(CM_thresholds)]) 115 | plt.ylim([0, 1.5]) 116 | plt.savefig(cm_score_file[:-4]+'2.png') 117 | 118 | plt.show() 119 | 120 | return min(eer_cm, other_eer_cm), min_tDCF 121 | 122 | 123 | if __name__ == "__main__": 124 | # Replace CM scores with your own scores or provide score file as the first argument. 125 | cm_score_file = 'cm_score_ocnn.txt' 126 | # Replace ASV scores with organizers' scores or provide score file as the second argument. 127 | # path_to_database = '/home/yzh298/Downloads/DS_10283_3336/' 128 | path_to_database = '/data/neil/DS_10283_3336/' # if run on GPU 129 | 130 | # args = sys.argv 131 | # if len(args) > 1: 132 | # if len(args) != 3: 133 | # print('USAGE: python evaluate_tDCF_asvspoof19.py ') 134 | # exit() 135 | # else: 136 | # cm_score_file = args[1] 137 | # asv_score_file = args[2] 138 | 139 | compute_eer_and_tdcf(cm_score_file, path_to_database) 140 | 141 | -------------------------------------------------------------------------------- /utils_dsp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | ## Adapted from https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/blob/newfunctions/ 4 | 5 | 6 | """ 7 | util_dsp.py 8 | Utilities for signal processing 9 | MuLaw Code adapted from 10 | https://github.com/fatchord/WaveRNN/blob/master/utils/distribution.py 11 | DCT code adapted from 12 | https://github.com/zh217/torch-dct 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import print_function 17 | 18 | import sys 19 | import numpy as np 20 | 21 | import torch 22 | import torch.nn as torch_nn 23 | import torch.nn.functional as torch_nn_func 24 | 25 | __author__ = "Xin Wang" 26 | __email__ = "wangxin@nii.ac.jp" 27 | __copyright__ = "Copyright 2020, Xin Wang" 28 | 29 | 30 | ###################### 31 | ### WaveForm utilities 32 | ###################### 33 | 34 | def label_2_float(x, bits): 35 | """Convert integer numbers to float values 36 | 37 | Note: dtype conversion is not handled 38 | Args: 39 | ----- 40 | x: data to be converted Tensor.long or int, any shape. 41 | bits: number of bits, int 42 | 43 | Return: 44 | ------- 45 | tensor.float 46 | 47 | """ 48 | return 2 * x / (2 ** bits - 1.) - 1. 49 | 50 | 51 | def float_2_label(x, bits): 52 | """Convert float wavs back to integer (quantization) 53 | 54 | Note: dtype conversion is not handled 55 | Args: 56 | ----- 57 | x: data to be converted Tensor.float, any shape. 58 | bits: number of bits, int 59 | 60 | Return: 61 | ------- 62 | tensor.float 63 | 64 | """ 65 | # assert abs(x).max() <= 1.0 66 | peak = torch.abs(x).max() 67 | if peak > 1.0: 68 | x /= peak 69 | x = (x + 1.) * (2 ** bits - 1) / 2 70 | return torch.clamp(x, 0, 2 ** bits - 1) 71 | 72 | 73 | def mulaw_encode(x, quantization_channels, scale_to_int=True): 74 | """Adapted from torchaudio 75 | https://pytorch.org/audio/functional.html mu_law_encoding 76 | Args: 77 | x (Tensor): Input tensor, float-valued waveforms in (-1, 1) 78 | quantization_channels (int): Number of channels 79 | scale_to_int: Bool 80 | True: scale mu-law companded to int 81 | False: return mu-law in (-1, 1) 82 | 83 | Returns: 84 | Tensor: Input after mu-law encoding 85 | """ 86 | # mu 87 | mu = quantization_channels - 1.0 88 | 89 | # no check on the value of x 90 | if not x.is_floating_point(): 91 | x = x.to(torch.float) 92 | mu = torch.tensor(mu, dtype=x.dtype, device=x.device) 93 | x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) 94 | if scale_to_int: 95 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64) 96 | return x_mu 97 | 98 | 99 | def mulaw_decode(x_mu, quantization_channels, input_int=True): 100 | """Adapted from torchaudio 101 | https://pytorch.org/audio/functional.html mu_law_encoding 102 | Args: 103 | x_mu (Tensor): Input tensor 104 | quantization_channels (int): Number of channels 105 | Returns: 106 | Tensor: Input after mu-law decoding (float-value waveform (-1, 1)) 107 | """ 108 | mu = quantization_channels - 1.0 109 | if not x_mu.is_floating_point(): 110 | x_mu = x_mu.to(torch.float) 111 | mu = torch.tensor(mu, dtype=x_mu.dtype, device=x_mu.device) 112 | if input_int: 113 | x = ((x_mu) / mu) * 2 - 1.0 114 | else: 115 | x = x_mu 116 | x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu 117 | return x 118 | 119 | 120 | ###################### 121 | ### DCT utilities 122 | ### https://github.com/zh217/torch-dct 123 | ### LICENSE: MIT 124 | ### 125 | ###################### 126 | 127 | def dct1(x): 128 | """ 129 | Discrete Cosine Transform, Type I 130 | :param x: the input signal 131 | :return: the DCT-I of the signal over the last dimension 132 | """ 133 | x_shape = x.shape 134 | x = x.view(-1, x_shape[-1]) 135 | 136 | return torch.rfft( 137 | torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape) 138 | 139 | 140 | def idct1(X): 141 | """ 142 | The inverse of DCT-I, which is just a scaled DCT-I 143 | Our definition if idct1 is such that idct1(dct1(x)) == x 144 | :param X: the input signal 145 | :return: the inverse DCT-I of the signal over the last dimension 146 | """ 147 | n = X.shape[-1] 148 | return dct1(X) / (2 * (n - 1)) 149 | 150 | 151 | def dct(x, norm=None): 152 | """ 153 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 154 | For the meaning of the parameter `norm`, see: 155 | https://docs.scipy.org/doc/ scipy.fftpack.dct.html 156 | :param x: the input signal 157 | :param norm: the normalization, None or 'ortho' 158 | :return: the DCT-II of the signal over the last dimension 159 | """ 160 | x_shape = x.shape 161 | N = x_shape[-1] 162 | x = x.contiguous().view(-1, N) 163 | 164 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 165 | 166 | Vc = torch.rfft(v, 1, onesided=False) 167 | 168 | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) 169 | W_r = torch.cos(k) 170 | W_i = torch.sin(k) 171 | 172 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 173 | 174 | if norm == 'ortho': 175 | V[:, 0] /= np.sqrt(N) * 2 176 | V[:, 1:] /= np.sqrt(N / 2) * 2 177 | 178 | V = 2 * V.view(*x_shape) 179 | 180 | return V 181 | 182 | 183 | def idct(X, norm=None): 184 | """ 185 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 186 | Our definition of idct is that idct(dct(x)) == x 187 | For the meaning of the parameter `norm`, see: 188 | https://docs.scipy.org/doc/ scipy.fftpack.dct.html 189 | :param X: the input signal 190 | :param norm: the normalization, None or 'ortho' 191 | :return: the inverse DCT-II of the signal over the last dimension 192 | """ 193 | 194 | x_shape = X.shape 195 | N = x_shape[-1] 196 | 197 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 198 | 199 | if norm == 'ortho': 200 | X_v[:, 0] *= np.sqrt(N) * 2 201 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 202 | 203 | k = torch.arange(x_shape[-1], dtype=X.dtype, 204 | device=X.device)[None, :] * np.pi / (2 * N) 205 | W_r = torch.cos(k) 206 | W_i = torch.sin(k) 207 | 208 | V_t_r = X_v 209 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 210 | 211 | V_r = V_t_r * W_r - V_t_i * W_i 212 | V_i = V_t_r * W_i + V_t_i * W_r 213 | 214 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 215 | 216 | v = torch.irfft(V, 1, onesided=False) 217 | x = v.new_zeros(v.shape) 218 | x[:, ::2] += v[:, :N - (N // 2)] 219 | x[:, 1::2] += v.flip([1])[:, :N // 2] 220 | 221 | return x.view(*x_shape) 222 | 223 | 224 | class LinearDCT(torch_nn.Linear): 225 | """Implement any DCT as a linear layer; in practice this executes around 226 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 227 | increase memory usage. 228 | :param in_features: size of expected input 229 | :param type: which dct function in this file to use""" 230 | 231 | def __init__(self, in_features, type, norm=None, bias=False): 232 | self.type = type 233 | self.N = in_features 234 | self.norm = norm 235 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias) 236 | 237 | def reset_parameters(self): 238 | # initialise using dct function 239 | I = torch.eye(self.N) 240 | if self.type == 'dct1': 241 | self.weight.data = dct1(I).data.t() 242 | elif self.type == 'idct1': 243 | self.weight.data = idct1(I).data.t() 244 | elif self.type == 'dct': 245 | self.weight.data = dct(I, norm=self.norm).data.t() 246 | elif self.type == 'idct': 247 | self.weight.data = idct(I, norm=self.norm).data.t() 248 | self.weight.requires_grad = False # don't learn this! 249 | 250 | 251 | if __name__ == "__main__": 252 | print("util_dsp.py") 253 | -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold): 5 | 6 | # False alarm and miss rates for ASV 7 | Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size 8 | Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size 9 | 10 | # Rate of rejecting spoofs in ASV 11 | if spoof_asv.size == 0: 12 | Pmiss_spoof_asv = None 13 | else: 14 | Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size 15 | 16 | return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv 17 | 18 | 19 | def compute_det_curve(target_scores, nontarget_scores): 20 | 21 | n_scores = target_scores.size + nontarget_scores.size 22 | all_scores = np.concatenate((target_scores, nontarget_scores)) 23 | labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size))) 24 | 25 | # Sort labels based on scores 26 | indices = np.argsort(all_scores, kind='mergesort') 27 | labels = labels[indices] 28 | 29 | # Compute false rejection and false acceptance rates 30 | tar_trial_sums = np.cumsum(labels) 31 | nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums) 32 | 33 | frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates 34 | far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates 35 | thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores 36 | 37 | return frr, far, thresholds 38 | 39 | 40 | def compute_eer(target_scores, nontarget_scores): 41 | """ Returns equal error rate (EER) and the corresponding threshold. """ 42 | frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) 43 | abs_diffs = np.abs(frr - far) 44 | min_index = np.argmin(abs_diffs) 45 | eer = np.mean((frr[min_index], far[min_index])) 46 | return eer, thresholds[min_index] 47 | 48 | 49 | def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, print_cost): 50 | """ 51 | Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. 52 | In brief, t-DCF returns a detection cost of a cascaded system of this form, 53 | 54 | Speech waveform -> [CM] -> [ASV] -> decision 55 | 56 | where CM stands for countermeasure and ASV for automatic speaker 57 | verification. The CM is therefore used as a 'gate' to decided whether or 58 | not the input speech sample should be passed onwards to the ASV system. 59 | Generally, both CM and ASV can do detection errors. Not all those errors 60 | are necessarily equally cost, and not all types of users are necessarily 61 | equally likely. The tandem t-DCF gives a principled with to compare 62 | different spoofing countermeasures under a detection cost function 63 | framework that takes that information into account. 64 | 65 | INPUTS: 66 | 67 | bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human) 68 | detection scores obtained by executing a spoofing 69 | countermeasure (CM) on some positive evaluation trials. 70 | trial represents a bona fide case. 71 | spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack) 72 | detection scores obtained by executing a spoofing 73 | CM on some negative evaluation trials. 74 | Pfa_asv False alarm (false acceptance) rate of the ASV 75 | system that is evaluated in tandem with the CM. 76 | Assumed to be in fractions, not percentages. 77 | Pmiss_asv Miss (false rejection) rate of the ASV system that 78 | is evaluated in tandem with the spoofing CM. 79 | Assumed to be in fractions, not percentages. 80 | Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that 81 | is evaluated in tandem with the spoofing CM. That 82 | is, the fraction of spoof samples that were 83 | rejected by the ASV system. 84 | cost_model A struct that contains the parameters of t-DCF, 85 | with the following fields. 86 | 87 | Ptar Prior probability of target speaker. 88 | Pnon Prior probability of nontarget speaker (zero-effort impostor) 89 | Psoof Prior probability of spoofing attack. 90 | Cmiss_asv Cost of ASV falsely rejecting target. 91 | Cfa_asv Cost of ASV falsely accepting nontarget. 92 | Cmiss_cm Cost of CM falsely rejecting target. 93 | Cfa_cm Cost of CM falsely accepting spoof. 94 | 95 | print_cost Print a summary of the cost parameters and the 96 | implied t-DCF cost function? 97 | 98 | OUTPUTS: 99 | 100 | tDCF_norm Normalized t-DCF curve across the different CM 101 | system operating points; see [2] for more details. 102 | Normalized t-DCF > 1 indicates a useless 103 | countermeasure (as the tandem system would do 104 | better without it). min(tDCF_norm) will be the 105 | minimum t-DCF used in ASVspoof 2019 [2]. 106 | CM_thresholds Vector of same size as tDCF_norm corresponding to 107 | the CM threshold (operating point). 108 | 109 | NOTE: 110 | o In relative terms, higher detection scores values are assumed to 111 | indicate stronger support for the bona fide hypothesis. 112 | o You should provide real-valued soft scores, NOT hard decisions. The 113 | recommendation is that the scores are log-likelihood ratios (LLRs) 114 | from a bonafide-vs-spoof hypothesis based on some statistical model. 115 | This, however, is NOT required. The scores can have arbitrary range 116 | and scaling. 117 | o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. 118 | 119 | References: 120 | 121 | [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco, 122 | M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection 123 | Cost Function for the Tandem Assessment of Spoofing Countermeasures 124 | and Automatic Speaker Verification", Proc. Odyssey 2018: the 125 | Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne, 126 | France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf) 127 | 128 | [2] ASVspoof 2019 challenge evaluation plan 129 | TODO: 130 | """ 131 | 132 | 133 | # Sanity check of cost parameters 134 | if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \ 135 | cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0: 136 | print('WARNING: Usually the cost values should be positive!') 137 | 138 | if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ 139 | np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: 140 | sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.') 141 | 142 | # Unless we evaluate worst-case model, we need to have some spoof tests against asv 143 | if Pmiss_spoof_asv is None: 144 | sys.exit('ERROR: you should provide miss rate of spoof tests against your ASV system.') 145 | 146 | # Sanity check of scores 147 | combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) 148 | if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): 149 | sys.exit('ERROR: Your scores contain nan or inf.') 150 | 151 | # Sanity check that inputs are scores and not decisions 152 | n_uniq = np.unique(combined_scores).size 153 | if n_uniq < 3: 154 | sys.exit('ERROR: You should provide soft CM scores - not binary decisions') 155 | 156 | # Obtain miss and false alarm rates of CM 157 | Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm) 158 | 159 | # Constants - see ASVspoof 2019 evaluation plan 160 | C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \ 161 | cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv 162 | C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv) 163 | 164 | # Sanity check of the weights 165 | if C1 < 0 or C2 < 0: 166 | sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?') 167 | 168 | # Obtain t-DCF curve for all thresholds 169 | tDCF = C1 * Pmiss_cm + C2 * Pfa_cm 170 | 171 | # Normalized t-DCF 172 | tDCF_norm = tDCF / np.minimum(C1, C2) 173 | 174 | # Everything should be fine if reaching here. 175 | if print_cost: 176 | 177 | print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size)) 178 | # print('t-DCF MODEL') 179 | # print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar'])) 180 | # print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon'])) 181 | # print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof'])) 182 | # print(' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'.format(cost_model['Cfa_asv'])) 183 | # print(' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'.format(cost_model['Cmiss_asv'])) 184 | # print(' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'.format(cost_model['Cfa_cm'])) 185 | # print(' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'.format(cost_model['Cmiss_cm'])) 186 | # print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)') 187 | 188 | if C2 == np.minimum(C1, C2): 189 | print(' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(C1 / C2)) 190 | else: 191 | print(' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(C2 / C1)) 192 | 193 | return tDCF_norm, CM_thresholds 194 | -------------------------------------------------------------------------------- /raw_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Dataset, DataLoader 7 | import scipy.io as sio 8 | import pickle 9 | import os 10 | import librosa 11 | from torch.utils.data.dataloader import default_collate 12 | import warnings 13 | from typing import Any, Tuple, Union 14 | from pathlib import Path 15 | 16 | 17 | torch.set_default_tensor_type(torch.FloatTensor) 18 | 19 | SampleType = Tuple[Tensor, int, str, str, str] 20 | 21 | def torchaudio_load(filepath): 22 | wave, sr = librosa.load(filepath, sr=16000) 23 | wave = librosa.util.normalize(wave) 24 | waveform = torch.Tensor(np.expand_dims(wave, axis=0)) 25 | return [waveform, sr] 26 | 27 | class ASVspoof2019Raw(Dataset): 28 | def __init__(self, access_type, path_to_database, path_to_protocol, part='train'): 29 | super(ASVspoof2019Raw, self).__init__() 30 | self.access_type = access_type 31 | self.ptd = path_to_database 32 | self.part = part 33 | self.path_to_audio = os.path.join(self.ptd, access_type, 'ASVspoof2019_'+access_type+'_'+ self.part +'/flac/') 34 | self.path_to_protocol = path_to_protocol 35 | protocol = os.path.join(self.path_to_protocol, 'ASVspoof2019.'+access_type+'.cm.'+ self.part + '.trl.txt') 36 | if self.part == "eval": 37 | protocol = os.path.join(self.ptd, access_type, 'ASVspoof2019_' + access_type + 38 | '_cm_protocols/ASVspoof2019.' + access_type + '.cm.' + self.part + '.trl.txt') 39 | if self.access_type == 'LA': 40 | self.tag = {"-": 0, "A01": 1, "A02": 2, "A03": 3, "A04": 4, "A05": 5, "A06": 6, "A07": 7, "A08": 8, "A09": 9, 41 | "A10": 10, "A11": 11, "A12": 12, "A13": 13, "A14": 14, "A15": 15, "A16": 16, "A17": 17, "A18": 18, 42 | "A19": 19} 43 | else: 44 | self.tag = {"-": 0, "AA": 1, "AB": 2, "AC": 3, "BA": 4, "BB": 5, "BC": 6, "CA": 7, "CB": 8, "CC": 9} 45 | self.label = {"spoof": 1, "bonafide": 0} 46 | 47 | # # would not work if change data split but this csv is only for feat_len 48 | # self.csv = pd.read_csv(self.ptf + "Set_csv.csv") 49 | 50 | with open(protocol, 'r') as f: 51 | audio_info = [info.strip().split() for info in f.readlines()] 52 | self.all_info = audio_info 53 | 54 | def __len__(self): 55 | return len(self.all_info) 56 | 57 | def __getitem__(self, idx): 58 | speaker, filename, _, tag, label = self.all_info[idx] 59 | filepath = os.path.join(self.path_to_audio, filename + ".flac") 60 | waveform, sr = torchaudio_load(filepath) 61 | 62 | return waveform, filename, tag, label 63 | 64 | def collate_fn(self, samples): 65 | return default_collate(samples) 66 | 67 | 68 | class VCC2020Raw(Dataset): 69 | def __init__(self, path_to_spoof="/data2/neil/nii-yamagishilab-VCC2020-listeningtest-31f913c", path_to_bonafide="/data2/neil/nii-yamagishilab-VCC2020-database-0b2fb2e"): 70 | super(VCC2020Raw, self).__init__() 71 | self.all_spoof = librosa.util.find_files(path_to_spoof, ext="wav") 72 | self.all_bonafide = librosa.util.find_files(path_to_bonafide, ext="wav") 73 | 74 | def __len__(self): 75 | # print(len(self.all_spoof), len(self.all_bonafide)) 76 | return len(self.all_spoof) + len(self.all_bonafide) 77 | 78 | def __getitem__(self, idx): 79 | if idx < len(self.all_bonafide): 80 | filepath = self.all_bonafide[idx] 81 | label = "bonafide" 82 | filename = "_".join(filepath.split("/")[-3:])[:-4] 83 | tag = "-" 84 | else: 85 | filepath = self.all_spoof[idx - len(self.all_bonafide)] 86 | filename = os.path.basename(filepath)[:-4] 87 | label = "spoof" 88 | tag = filepath.split("/")[-3] 89 | waveform, sr = torchaudio_load(filepath) 90 | 91 | return waveform, filename, tag, label 92 | 93 | def collate_fn(self, samples): 94 | return default_collate(samples) 95 | 96 | 97 | class ASVspoof2015Raw(Dataset): 98 | def __init__(self, path_to_database="/data/neil/ASVspoof2015/wav", path_to_protocol="/data/neil/ASVspoof2015/CM_protocol", part='train'): 99 | super(ASVspoof2015Raw, self).__init__() 100 | self.ptd = path_to_database 101 | self.part = part 102 | self.path_to_audio = os.path.join(self.ptd, self.part) 103 | self.path_to_protocol = path_to_protocol 104 | cm_pro_dict = {"train": "cm_train.trn", "dev": "cm_develop.ndx", "eval": "cm_evaluation.ndx"} 105 | protocol = os.path.join(self.path_to_protocol, cm_pro_dict[self.part]) 106 | self.tag = {"human": 0, "S1": 1, "S2": 2, "S3": 3, "S4": 4, "S5": 5, 107 | "S6": 6, "S7": 7, "S8": 8, "S9": 9, "S10": 10} 108 | self.label = {"spoof": 1, "human": 0} 109 | 110 | with open(protocol, 'r') as f: 111 | audio_info = [info.strip().split() for info in f.readlines()] 112 | self.all_info = audio_info 113 | 114 | def __len__(self): 115 | return len(self.all_info) 116 | 117 | def __getitem__(self, idx): 118 | speaker, filename, tag, label = self.all_info[idx] 119 | filepath = os.path.join(self.path_to_audio, speaker, filename + ".wav") 120 | waveform, sr = torchaudio_load(filepath) 121 | filename = filename.replace("_", "-") 122 | return waveform, filename, tag, label 123 | 124 | def collate_fn(self, samples): 125 | return default_collate(samples) 126 | 127 | 128 | class ASVspoof2019LARaw_withChannel(Dataset): 129 | def __init__(self, access_type="LA", path_to_database="/data/shared/ASVspoof2019Channel", path_to_protocol="/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_cm_protocols/", part='train'): 130 | super(ASVspoof2019LARaw_withChannel, self).__init__() 131 | self.access_type = access_type 132 | self.ptd = path_to_database 133 | self.part = part 134 | self.path_to_audio = path_to_database 135 | self.path_to_protocol = path_to_protocol 136 | protocol = os.path.join(self.path_to_protocol, 137 | 'ASVspoof2019.' + access_type + '.cm.' + self.part + '.trl.txt') 138 | if self.part == "eval": 139 | protocol = os.path.join(self.ptd, access_type, 'ASVspoof2019_' + access_type + 140 | '_cm_protocols/ASVspoof2019.' + access_type + '.cm.' + self.part + '.trl.txt') 141 | self.tag = {"-": 0, "A01": 1, "A02": 2, "A03": 3, "A04": 4, "A05": 5, "A06": 6, "A07": 7, "A08": 8, 142 | "A09": 9, 143 | "A10": 10, "A11": 11, "A12": 12, "A13": 13, "A14": 14, "A15": 15, "A16": 16, "A17": 17, 144 | "A18": 18, 145 | "A19": 19} 146 | self.label = {"spoof": 1, "bonafide": 0} 147 | self.channel = ['amr[br=5k15]', 'amrwb[br=15k85]', 'g711[law=u]', 'g722[br=56k]', 148 | 'g722[br=64k]', 'g726[law=a,br=16k]', 'g728', 'g729a', 'gsmfr', 149 | 'silk[br=20k]', 'silk[br=5k]', 'silkwb[br=10k,loss=5]', 'silkwb[br=30k]'] 150 | 151 | with open(protocol, 'r') as f: 152 | audio_info = [info.strip().split() for info in f.readlines()] 153 | self.all_info = audio_info 154 | 155 | def __len__(self): 156 | return len(self.all_info) * len(self.channel) 157 | 158 | def __getitem__(self, idx): 159 | file_idx = idx // len(self.channel) 160 | channel_idx = idx % len(self.channel) 161 | speaker, filename, _, tag, label = self.all_info[file_idx] 162 | channel = self.channel[channel_idx] 163 | filepath = os.path.join(self.path_to_audio, filename + "_" + channel + ".wav") 164 | waveform, sr = torchaudio_load(filepath) 165 | 166 | return waveform, filename, tag, label, channel 167 | 168 | def collate_fn(self, samples): 169 | return default_collate(samples) 170 | 171 | 172 | class ASVspoof2019LARaw_withDevice(Dataset): 173 | def __init__(self, access_type="LA", path_to_database="/data/shared/antispoofying2019-eval", path_to_protocol="/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_cm_protocols/", part='eval'): 174 | super(ASVspoof2019LARaw_withDevice, self).__init__() 175 | self.access_type = access_type 176 | self.ptd = path_to_database 177 | self.part = part 178 | self.path_to_audio = path_to_database 179 | self.path_to_protocol = path_to_protocol 180 | protocol = os.path.join(self.path_to_protocol, 181 | 'ASVspoof2019.' + access_type + '.cm.' + self.part + '.trl.txt') 182 | # if self.part == "eval": 183 | # protocol = os.path.join(self.ptd, access_type, 'ASVspoof2019_' + access_type + 184 | # '_cm_protocols/ASVspoof2019.' + access_type + '.cm.' + self.part + '.trl.txt') 185 | self.tag = {"-": 0, "A01": 1, "A02": 2, "A03": 3, "A04": 4, "A05": 5, "A06": 6, "A07": 7, "A08": 8, 186 | "A09": 9, 187 | "A10": 10, "A11": 11, "A12": 12, "A13": 13, "A14": 14, "A15": 15, "A16": 16, "A17": 17, 188 | "A18": 18, 189 | "A19": 19} 190 | self.label = {"spoof": 1, "bonafide": 0} 191 | self.devices = ['AKSPKRS80sUk002-16000', 'AKSPKRSVinUk002-16000', 'Doremi-16000', 'RCAPB90-16000', 192 | 'ResloRBRedLabel-16000', 'AKSPKRSSpeaker002-16000', 'BehritoneirRecording-16000', 193 | 'OktavaML19-16000', 'ResloRB250-16000', 'SonyC37Fet-16000', 'iPadirRecording-16000', 'iPhoneirRecording-16000'] 194 | 195 | with open(protocol, 'r') as f: 196 | audio_info = [info.strip().split() for info in f.readlines()] 197 | self.all_info = audio_info 198 | 199 | def __len__(self): 200 | return len(self.all_info) * len(self.devices) 201 | 202 | def __getitem__(self, idx): 203 | file_idx = idx // len(self.devices) 204 | device_idx = idx % len(self.devices) 205 | speaker, filename, _, tag, label = self.all_info[file_idx] 206 | device = self.devices[device_idx] 207 | filepath = os.path.join(self.path_to_audio, device, filename + ".wav") 208 | waveform, sr = torchaudio_load(filepath) 209 | 210 | return waveform, filename, tag, label, device 211 | 212 | def collate_fn(self, samples): 213 | return default_collate(samples) 214 | 215 | if __name__ == "__main__": 216 | # vctk = VCTK_092(root="/data/neil/VCTK", download=False) 217 | # print(len(vctk)) 218 | # waveform, sample_rate, utterance, speaker_id, utterance_id = vctk[124] 219 | # print(waveform.shape) 220 | # print(sample_rate) 221 | # print(utterance) 222 | # print(speaker_id) 223 | # print(utterance_id) 224 | # 225 | # librispeech = LIBRISPEECH(root="/data/neil") 226 | # print(len(librispeech)) 227 | # waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = librispeech[164] 228 | # print(waveform.shape) 229 | # print(sample_rate) 230 | # print(utterance) 231 | # print(speaker_id) 232 | # print(chapter_id) 233 | # print(utterance_id) 234 | # 235 | # libriGen = LibriGenuine("/dataNVME/neil/libriSpeech/", feature='LFCC', feat_len=750, pad_chop=True, padding='repeat') 236 | # print(len(libriGen)) 237 | # featTensor, tag, label = libriGen[123] 238 | # print(featTensor.shape) 239 | # print(tag) 240 | # print(label) 241 | # 242 | # asvspoof_raw = ASVspoof2019Raw("LA", "/data/neil/DS_10283_3336/", "/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_cm_protocols/", part="eval") 243 | # print(len(asvspoof_raw)) 244 | # waveform, filename, tag, label = asvspoof_raw[123] 245 | # print(waveform.shape) 246 | # print(filename) 247 | # print(tag) 248 | # print(label) 249 | 250 | # vcc2020_raw = VCC2020Raw() 251 | # print(len(vcc2020_raw)) 252 | # waveform, filename, tag, label = vcc2020_raw[123] 253 | # print(waveform.shape) 254 | # print(filename) 255 | # print(tag) 256 | # print(label) 257 | 258 | asvspoof2019channel = ASVspoof2019LARaw_withChannel() 259 | print(len(asvspoof2019channel)) 260 | waveform, filename, tag, label, channel = asvspoof2019channel[123] 261 | print(waveform.shape) 262 | print(filename) 263 | print(tag) 264 | print(label) 265 | print(channel) 266 | pass 267 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Dataset, DataLoader 7 | import pickle 8 | import os 9 | import librosa 10 | from feature_extraction import LFCC 11 | from torch.utils.data.dataloader import default_collate 12 | 13 | lfcc = LFCC(320, 160, 512, 16000, 20, with_energy=False) 14 | wavform = torch.Tensor(np.expand_dims([0]*3200, axis=0)) 15 | lfcc_silence = lfcc(wavform) 16 | silence_pad_value = lfcc_silence[:,0,:].unsqueeze(0) 17 | 18 | class ASVspoof2019(Dataset): 19 | def __init__(self, access_type, path_to_features, part='train', feature='LFCC', 20 | feat_len=750, padding='repeat', genuine_only=False): 21 | super(ASVspoof2019, self).__init__() 22 | self.access_type = access_type 23 | self.path_to_features = path_to_features 24 | self.part = part 25 | self.ptf = os.path.join(path_to_features, self.part) 26 | self.feat_len = feat_len 27 | self.feature = feature 28 | self.padding = padding 29 | self.genuine_only = genuine_only 30 | if self.access_type == 'LA': 31 | self.tag = {"-": 0, "A01": 1, "A02": 2, "A03": 3, "A04": 4, "A05": 5, "A06": 6, "A07": 7, "A08": 8, "A09": 9, 32 | "A10": 10, "A11": 11, "A12": 12, "A13": 13, "A14": 14, "A15": 15, "A16": 16, "A17": 17, "A18": 18, 33 | "A19": 19} 34 | elif self.access_type == 'PA': 35 | self.tag = {"-": 0, "AA": 1, "AB": 2, "AC": 3, "BA": 4, "BB": 5, "BC": 6, "CA": 7, "CB": 8, "CC": 9} 36 | else: 37 | raise ValueError("Access type should be LA or PA!") 38 | self.label = {"spoof": 1, "bonafide": 0} 39 | self.all_files = librosa.util.find_files(os.path.join(self.ptf, self.feature), ext="pt") 40 | if self.genuine_only: 41 | assert self.access_type == "LA" 42 | if self.part in ["train", "dev"]: 43 | num_bonafide = {"train": 2580, "dev": 2548} 44 | self.all_files = self.all_files[:num_bonafide[self.part]] 45 | else: 46 | res = [] 47 | for item in self.all_files: 48 | if "bonafide" in item: 49 | res.append(item) 50 | self.all_files = res 51 | assert len(self.all_files) == 7355 52 | 53 | def __len__(self): 54 | return len(self.all_files) 55 | 56 | def __getitem__(self, idx): 57 | filepath = self.all_files[idx] 58 | basename = os.path.basename(filepath) 59 | all_info = basename.split(".")[0].split("_") 60 | # assert len(all_info) == 6 61 | featureTensor = torch.load(filepath) 62 | this_feat_len = featureTensor.shape[1] 63 | if this_feat_len > self.feat_len: 64 | startp = np.random.randint(this_feat_len - self.feat_len) 65 | featureTensor = featureTensor[:, startp:startp + self.feat_len, :] 66 | if this_feat_len < self.feat_len: 67 | if self.padding == 'zero': 68 | featureTensor = padding_Tensor(featureTensor, self.feat_len) 69 | elif self.padding == 'repeat': 70 | featureTensor = repeat_padding_Tensor(featureTensor, self.feat_len) 71 | elif self.padding == 'silence': 72 | featureTensor = silence_padding_Tensor(featureTensor, self.feat_len) 73 | else: 74 | raise ValueError('Padding should be zero or repeat!') 75 | filename = "_".join(all_info[1:4]) 76 | tag = self.tag[all_info[4]] 77 | label = self.label[all_info[5]] 78 | return featureTensor, filename, tag, label, 2019 79 | 80 | def collate_fn(self, samples): 81 | return default_collate(samples) 82 | 83 | 84 | class VCC2020(Dataset): 85 | def __init__(self, path_to_features="/data2/neil/VCC2020/", feature='LFCC', 86 | feat_len=750, padding='repeat', genuine_only=False): 87 | super(VCC2020, self).__init__() 88 | self.ptf = path_to_features 89 | self.feat_len = feat_len 90 | self.feature = feature 91 | self.padding = padding 92 | self.tag = {"-": 0, "SOU": 20, "T01": 21, "T02": 22, "T03": 23, "T04": 24, "T05": 25, "T06": 26, "T07": 27, "T08": 28, "T09": 29, 93 | "T10": 30, "T11": 31, "T12": 32, "T13": 33, "T14": 34, "T15": 35, "T16": 36, "T17": 37, "T18": 38, "T19": 39, 94 | "T20": 40, "T21": 41, "T22": 42, "T23": 43, "T24": 44, "T25": 45, "T26": 46, "T27": 47, "T28": 48, "T29": 49, 95 | "T30": 50, "T31": 51, "T32": 52, "T33": 53, "TAR": 54} 96 | self.label = {"spoof": 1, "bonafide": 0} 97 | self.genuine_only = genuine_only 98 | self.all_files = librosa.util.find_files(os.path.join(self.ptf, self.feature), ext="pt") 99 | 100 | def __len__(self): 101 | if self.genuine_only: 102 | return 220 103 | return len(self.all_files) 104 | 105 | def __getitem__(self, idx): 106 | filepath = self.all_files[idx] 107 | basename = os.path.basename(filepath) 108 | all_info = basename.split(".")[0].split("_") 109 | featureTensor = torch.load(filepath) 110 | this_feat_len = featureTensor.shape[1] 111 | if this_feat_len > self.feat_len: 112 | startp = np.random.randint(this_feat_len - self.feat_len) 113 | featureTensor = featureTensor[:, startp:startp + self.feat_len, :] 114 | if this_feat_len < self.feat_len: 115 | if self.padding == 'zero': 116 | featureTensor = padding_Tensor(featureTensor, self.feat_len) 117 | elif self.padding == 'repeat': 118 | featureTensor = repeat_padding_Tensor(featureTensor, self.feat_len) 119 | elif self.padding == 'silence': 120 | featureTensor = silence_padding_Tensor(featureTensor, self.feat_len) 121 | else: 122 | raise ValueError('Padding should be zero or repeat!') 123 | tag = self.tag[all_info[-2]] 124 | label = self.label[all_info[-1]] 125 | return featureTensor, basename, tag, label, 2020 126 | 127 | def collate_fn(self, samples): 128 | return default_collate(samples) 129 | 130 | 131 | class ASVspoof2015(Dataset): 132 | def __init__(self, path_to_features, part='train', feature='LFCC', feat_len=750, 133 | padding='repeat', genuine_only=False): 134 | super(ASVspoof2015, self).__init__() 135 | self.path_to_features = path_to_features 136 | self.part = part 137 | self.ptf = os.path.join(path_to_features, self.part) 138 | self.feat_len = feat_len 139 | self.feature = feature 140 | self.padding = padding 141 | self.tag = {"human": 0, "S1": 1, "S2": 2, "S3": 3, "S4": 4, "S5": 5, 142 | "S6": 6, "S7": 7, "S8": 8, "S9": 9, "S10": 10} 143 | self.label = {"spoof": 1, "human": 0} 144 | self.all_files = librosa.util.find_files(os.path.join(self.ptf, self.feature), ext="pt") 145 | 146 | def __len__(self): 147 | return len(self.all_files) 148 | 149 | def __getitem__(self, idx): 150 | filepath = self.all_files[idx] 151 | basename = os.path.basename(filepath) 152 | all_info = basename.split(".")[0].split("_") 153 | assert len(all_info) == 4 154 | featureTensor = torch.load(filepath) 155 | this_feat_len = featureTensor.shape[1] 156 | if this_feat_len > self.feat_len: 157 | startp = np.random.randint(this_feat_len - self.feat_len) 158 | featureTensor = featureTensor[:, startp:startp + self.feat_len, :] 159 | if this_feat_len < self.feat_len: 160 | if self.padding == 'zero': 161 | featureTensor = padding_Tensor(featureTensor, self.feat_len) 162 | elif self.padding == 'repeat': 163 | featureTensor = repeat_padding_Tensor(featureTensor, self.feat_len) 164 | elif self.padding == 'silence': 165 | featureTensor = silence_padding_Tensor(featureTensor, self.feat_len) 166 | else: 167 | raise ValueError('Padding should be zero or repeat!') 168 | filename = all_info[1] 169 | tag = self.tag[all_info[-2]] 170 | label = self.label[all_info[-1]] 171 | return featureTensor, filename, tag, label, 2015 172 | 173 | def collate_fn(self, samples): 174 | return default_collate(samples) 175 | 176 | 177 | class ASVspoof2019LA_DeviceAdversarial(Dataset): 178 | def __init__(self, path_to_features="/data2/neil/ASVspoof2019LA/", path_to_deviced="/dataNVME/neil/ASVspoof2019LADevice", 179 | part="train", feature='LFCC', feat_len=750, padding='repeat'): 180 | super(ASVspoof2019LA_DeviceAdversarial, self).__init__() 181 | self.path_to_features = path_to_features 182 | suffix = {"train" : "", "dev":"Dev", "eval": "Eval"} 183 | self.path_to_deviced = path_to_deviced + suffix[part] 184 | self.path_to_features = path_to_features 185 | self.ptf = os.path.join(path_to_features, part) 186 | self.feat_len = feat_len 187 | self.feature = feature 188 | self.padding = padding 189 | self.tag = {"-": 0, "A01": 1, "A02": 2, "A03": 3, "A04": 4, "A05": 5, "A06": 6, "A07": 7, "A08": 8, "A09": 9, 190 | "A10": 10, "A11": 11, "A12": 12, "A13": 13, "A14": 14, "A15": 15, "A16": 16, "A17": 17, "A18": 18, 191 | "A19": 19} 192 | self.label = {"spoof": 1, "bonafide": 0} 193 | self.devices = ['AKSPKRS80sUk002-16000', 'AKSPKRSVinUk002-16000', 'Doremi-16000', 'RCAPB90-16000', 194 | 'ResloRBRedLabel-16000', 'AKSPKRSSpeaker002-16000', 'BehritoneirRecording-16000', 195 | 'OktavaML19-16000', 'ResloRB250-16000', 'SonyC37Fet-16000'] 196 | if part == "eval": 197 | self.devices = ['AKSPKRS80sUk002-16000', 'AKSPKRSVinUk002-16000', 'Doremi-16000', 'RCAPB90-16000', 198 | 'ResloRBRedLabel-16000', 'AKSPKRSSpeaker002-16000', 'BehritoneirRecording-16000', 199 | 'OktavaML19-16000', 'ResloRB250-16000', 'SonyC37Fet-16000', 200 | 'iPadirRecording-16000', 'iPhoneirRecording-16000'] 201 | self.original_all_files = librosa.util.find_files(os.path.join(self.ptf, self.feature), ext="pt") 202 | self.deviced_all_files = [librosa.util.find_files(os.path.join(self.path_to_deviced, devicex), ext="pt") for devicex in self.devices] 203 | 204 | def __len__(self): 205 | return len(self.original_all_files) * (len(self.devices) + 1) 206 | 207 | def __getitem__(self, idx): 208 | device_idx = idx % (len(self.devices) + 1) 209 | filename_idx = idx // (len(self.devices) + 1) 210 | if device_idx == 0: 211 | filepath = self.original_all_files[filename_idx] 212 | else: 213 | filepath = self.deviced_all_files[device_idx-1][filename_idx] 214 | basename = os.path.basename(filepath) 215 | all_info = basename.split(".")[0].split("_") 216 | featureTensor = torch.load(filepath) 217 | this_feat_len = featureTensor.shape[1] 218 | 219 | if this_feat_len > self.feat_len: 220 | startp = np.random.randint(this_feat_len - self.feat_len) 221 | featureTensor = featureTensor[:, startp:startp + self.feat_len, :] 222 | if this_feat_len < self.feat_len: 223 | if self.padding == 'zero': 224 | featureTensor = padding_Tensor(featureTensor, self.feat_len) 225 | elif self.padding == 'repeat': 226 | featureTensor = repeat_padding_Tensor(featureTensor, self.feat_len) 227 | elif self.padding == 'silence': 228 | featureTensor = silence_padding_Tensor(featureTensor, self.feat_len) 229 | else: 230 | raise ValueError('Padding should be zero or repeat!') 231 | filename = "_".join(all_info[1:4]) 232 | tag = self.tag[all_info[4]] 233 | label = self.label[all_info[5]] 234 | return featureTensor, filename, tag, label, device_idx 235 | 236 | def collate_fn(self, samples): 237 | return default_collate(samples) 238 | 239 | 240 | def padding_Tensor(spec, ref_len): 241 | _, cur_len, width = spec.shape 242 | assert ref_len > cur_len 243 | padd_len = ref_len - cur_len 244 | return torch.cat((spec, torch.zeros((1, padd_len, width), dtype=spec.dtype)), 1) 245 | 246 | def repeat_padding_Tensor(spec, ref_len): 247 | mul = int(np.ceil(ref_len / spec.shape[1])) 248 | spec = spec.repeat(1, mul, 1)[:, :ref_len, :] 249 | return spec 250 | 251 | def silence_padding_Tensor(spec, ref_len): 252 | _, cur_len, width = spec.shape 253 | assert ref_len > cur_len 254 | padd_len = ref_len - cur_len 255 | return torch.cat((silence_pad_value.repeat(1, padd_len, 1).to(spec.device), spec), 1) 256 | 257 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.utils.model_zoo import tqdm 5 | import random 6 | import numpy as np 7 | from dataset import * 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import eval_metrics as em 11 | from evaluate_tDCF_asvspoof19 import compute_eer_and_tdcf 12 | import time 13 | from distutils import util 14 | import argparse 15 | 16 | ## Adapted from https://github.com/pytorch/audio/tree/master/torchaudio 17 | ## https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/blob/newfunctions/ 18 | 19 | def str2bool(v): 20 | return bool(util.strtobool(v)) 21 | 22 | def setup_seed(random_seed, cudnn_deterministic=True): 23 | """ set_random_seed(random_seed, cudnn_deterministic=True) 24 | 25 | Set the random_seed for numpy, python, and cudnn 26 | 27 | input 28 | ----- 29 | random_seed: integer random seed 30 | cudnn_deterministic: for torch.backends.cudnn.deterministic 31 | 32 | Note: this default configuration may result in RuntimeError 33 | see https://pytorch.org/docs/stable/notes/randomness.html 34 | """ 35 | 36 | # # initialization 37 | # torch.manual_seed(random_seed) 38 | random.seed(random_seed) 39 | np.random.seed(random_seed) 40 | os.environ['PYTHONHASHSEED'] = str(random_seed) 41 | 42 | if torch.cuda.is_available(): 43 | torch.cuda.manual_seed_all(random_seed) 44 | torch.backends.cudnn.deterministic = cudnn_deterministic 45 | torch.backends.cudnn.benchmark = False 46 | 47 | def init(): 48 | parser = argparse.ArgumentParser("load model scores") 49 | parser.add_argument('-m', '--model_dir', type=str, help="directory for pretrained model", required=True, 50 | default='/data3/neil/chan/adv1010') 51 | parser.add_argument("-t", "--task", type=str, help="which dataset you would like to test on", 52 | required=True, default='ASVspoof2019LA', 53 | choices=["ASVspoof2019LA", "ASVspoof2015", "VCC2020", "ASVspoof2019LASim"]) 54 | parser.add_argument('-l', '--loss', help='loss for scoring', default="ocsoftmax", 55 | required=False, choices=[None, "ocsoftmax", "amsoftmax", "p2sgrad"]) 56 | parser.add_argument("--gpu", type=str, help="GPU index", default="0") 57 | args = parser.parse_args() 58 | 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | args.cuda = torch.cuda.is_available() 61 | args.device = torch.device("cuda" if args.cuda else "cpu") 62 | 63 | return args 64 | 65 | def test_model(feat_model_path, loss_model_path, part, add_loss): 66 | dirname = os.path.dirname 67 | basename = os.path.splitext(os.path.basename(feat_model_path))[0] 68 | if "checkpoint" in dirname(feat_model_path): 69 | dir_path = dirname(dirname(feat_model_path)) 70 | else: 71 | dir_path = dirname(feat_model_path) 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | model = torch.load(feat_model_path) 74 | loss_model = torch.load(loss_model_path) if add_loss is not None else None 75 | test_set = ASVspoof2019("LA", "/data2/neil/ASVspoof2019LA/", part, 76 | "LFCC", feat_len=750, padding="repeat") 77 | testDataLoader = DataLoader(test_set, batch_size=16, shuffle=False, num_workers=0) 78 | model.eval() 79 | score_loader, idx_loader = [], [] 80 | 81 | with open(os.path.join(dir_path, 'checkpoint_cm_score.txt'), 'w') as cm_score_file: 82 | for i, (lfcc, audio_fn, tags, labels, _) in enumerate(tqdm(testDataLoader)): 83 | lfcc = lfcc.transpose(2,3).to(device) 84 | # print(lfcc.shape) 85 | tags = tags.to(device) 86 | labels = labels.to(device) 87 | 88 | feats, lfcc_outputs = model(lfcc) 89 | 90 | score = F.softmax(lfcc_outputs)[:, 0] 91 | # print(score) 92 | 93 | if add_loss == "ocsoftmax": 94 | ang_isoloss, score = loss_model(feats, labels) 95 | elif add_loss == "amsoftmax": 96 | outputs, moutputs = loss_model(feats, labels) 97 | score = F.softmax(outputs, dim=1)[:, 0] 98 | else: pass 99 | 100 | for j in range(labels.size(0)): 101 | cm_score_file.write( 102 | 'A%02d %s %s\n' % (tags[j].data, 103 | "spoof" if labels[j].data.cpu().numpy() else "bonafide", 104 | score[j].item())) 105 | 106 | score_loader.append(score.detach().cpu()) 107 | idx_loader.append(labels.detach().cpu()) 108 | 109 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 110 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 111 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 112 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 113 | eer = min(eer, other_eer) 114 | 115 | return eer 116 | 117 | def test_on_VCC(feat_model_path, loss_model_path, part, add_loss): 118 | dirname = os.path.dirname 119 | basename = os.path.splitext(os.path.basename(feat_model_path))[0] 120 | if "checkpoint" in dirname(feat_model_path): 121 | dir_path = dirname(dirname(feat_model_path)) 122 | else: 123 | dir_path = dirname(feat_model_path) 124 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 125 | model = torch.load(feat_model_path) 126 | # model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))) # for multiple GPUs 127 | loss_model = torch.load(loss_model_path) if add_loss is not None else None 128 | test_set_VCC = VCC2020("/data2/neil/VCC2020/", "LFCC", feat_len=750, padding="repeat") 129 | testDataLoader = DataLoader(test_set_VCC, batch_size=4, shuffle=False, num_workers=0) 130 | model.eval() 131 | score_loader, idx_loader = [], [] 132 | 133 | with open(os.path.join(dir_path, 'checkpoint_cm_score_VCC.txt'), 'w') as cm_score_file: 134 | for i, (lfcc, _, tags, labels, _) in enumerate(tqdm(testDataLoader)): 135 | lfcc = lfcc.transpose(2,3).to(device) 136 | 137 | tags = tags.to(device) 138 | labels = labels.to(device) 139 | 140 | feats, lfcc_outputs = model(lfcc) 141 | 142 | score = F.softmax(lfcc_outputs)[:, 0] 143 | 144 | if add_loss == "ocsoftmax": 145 | ang_isoloss, score = loss_model(feats, labels) 146 | elif add_loss == "amsoftmax": 147 | outputs, moutputs = loss_model(feats, labels) 148 | score = F.softmax(outputs, dim=1)[:, 0] 149 | else: pass 150 | 151 | for j in range(labels.size(0)): 152 | cm_score_file.write( 153 | 'A%02d %s %s\n' % (tags[j].data, 154 | "spoof" if labels[j].data.cpu().numpy() else "bonafide", 155 | score[j].item())) 156 | 157 | score_loader.append(score.detach().cpu()) 158 | idx_loader.append(labels.detach().cpu()) 159 | 160 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 161 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 162 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 163 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 164 | eer = min(eer, other_eer) 165 | 166 | return eer 167 | 168 | def test_on_ASVspoof2015(feat_model_path, loss_model_path, part, add_loss): 169 | dirname = os.path.dirname 170 | basename = os.path.splitext(os.path.basename(feat_model_path))[0] 171 | if "checkpoint" in dirname(feat_model_path): 172 | dir_path = dirname(dirname(feat_model_path)) 173 | else: 174 | dir_path = dirname(feat_model_path) 175 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 176 | model = torch.load(feat_model_path) 177 | # model = torch.nn.DataParallel(model, list(range(torch.cuda.device_count()))) # for multiple GPUs 178 | loss_model = torch.load(loss_model_path) if add_loss is not None else None 179 | test_set_2015 = ASVspoof2015("/data2/neil/ASVspoof2015/", part="eval", feature="LFCC", feat_len=750, padding="repeat") 180 | print(len(test_set_2015)) 181 | testDataLoader = DataLoader(test_set_2015, batch_size=4, shuffle=False, num_workers=0) 182 | model.eval() 183 | score_loader, idx_loader = [], [] 184 | 185 | with open(os.path.join(dir_path, 'checkpoint_cm_score_VCC.txt'), 'w') as cm_score_file: 186 | for i, (lfcc, audio_fn, tags, labels, _) in enumerate(tqdm(testDataLoader)): 187 | lfcc = lfcc.transpose(2,3).to(device) 188 | tags = tags.to(device) 189 | labels = labels.to(device) 190 | 191 | feats, lfcc_outputs = model(lfcc) 192 | 193 | score = F.softmax(lfcc_outputs)[:, 0] 194 | # print(score) 195 | 196 | if add_loss == "ocsoftmax": 197 | ang_isoloss, score = loss_model(feats, labels) 198 | elif add_loss == "amsoftmax": 199 | outputs, moutputs = loss_model(feats, labels) 200 | score = F.softmax(outputs, dim=1)[:, 0] 201 | else: pass 202 | 203 | for j in range(labels.size(0)): 204 | cm_score_file.write( 205 | '%s A%02d %s %s\n' % (audio_fn[j], tags[j].data, 206 | "spoof" if labels[j].data.cpu().numpy() else "bonafide", 207 | score[j].item())) 208 | 209 | score_loader.append(score.detach().cpu()) 210 | idx_loader.append(labels.detach().cpu()) 211 | 212 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 213 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 214 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 215 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 216 | eer = min(eer, other_eer) 217 | 218 | return eer 219 | 220 | def test_individual_attacks(cm_score_file): 221 | # Load CM scores 222 | cm_data = np.genfromtxt(cm_score_file, dtype=str) 223 | cm_sources = cm_data[:, 1] 224 | cm_keys = cm_data[:, 2] 225 | cm_scores = cm_data[:, 3].astype(np.float) 226 | 227 | other_cm_scores = -cm_scores 228 | 229 | eer_cm_lst, min_tDCF_lst = [], [] 230 | for attack_idx in range(0, 55): 231 | # Extract target, nontarget, and spoof scores from the ASV scores 232 | 233 | # Extract bona fide (real human) and spoof scores from the CM scores 234 | bona_cm = cm_scores[cm_keys == 'bonafide'] 235 | spoof_cm = cm_scores[cm_sources == 'A%02d' % attack_idx] 236 | 237 | # EERs of the standalone systems and fix ASV operating point to EER threshold 238 | eer_cm = em.compute_eer(bona_cm, spoof_cm)[0] 239 | 240 | other_eer_cm = em.compute_eer(other_cm_scores[cm_keys == 'bonafide'], other_cm_scores[cm_sources == 'A%02d' % attack_idx])[0] 241 | 242 | eer_cm_lst.append(min(eer_cm, other_eer_cm)) 243 | 244 | return eer_cm_lst 245 | 246 | def test_on_ASVspoof2019LASim(feat_model_path, loss_model_path, part, add_loss): 247 | dirname = os.path.dirname 248 | basename = os.path.splitext(os.path.basename(feat_model_path))[0] 249 | if "checkpoint" in dirname(feat_model_path): 250 | dir_path = dirname(dirname(feat_model_path)) 251 | else: 252 | dir_path = dirname(feat_model_path) 253 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 254 | model = torch.load(feat_model_path) 255 | loss_model = torch.load(loss_model_path) if add_loss is not None else None 256 | test_set = ASVspoof2019LA_DeviceAdversarial(path_to_features="/data2/neil/ASVspoof2019LA/", 257 | path_to_deviced="/dataNVME/neil/ASVspoof2019LADevice", 258 | part="eval", 259 | feature="LFCC", feat_len=750, 260 | padding="repeat") 261 | testDataLoader = DataLoader(test_set, batch_size=16, shuffle=False, num_workers=0) 262 | model.eval() 263 | score_loader, idx_loader = [], [] 264 | 265 | with open(os.path.join(dir_path, 'checkpoint_cm_score.txt'), 'w') as cm_score_file: 266 | for i, (lfcc, audio_fn, tags, labels, _) in enumerate(tqdm(testDataLoader)): 267 | lfcc = lfcc.transpose(2,3).to(device) 268 | # print(lfcc.shape) 269 | tags = tags.to(device) 270 | labels = labels.to(device) 271 | 272 | feats, lfcc_outputs = model(lfcc) 273 | 274 | score = F.softmax(lfcc_outputs)[:, 0] 275 | # print(score) 276 | 277 | if add_loss == "ocsoftmax": 278 | ang_isoloss, score = loss_model(feats, labels) 279 | elif add_loss == "amsoftmax": 280 | outputs, moutputs = loss_model(feats, labels) 281 | score = F.softmax(outputs, dim=1)[:, 0] 282 | else: pass 283 | 284 | for j in range(labels.size(0)): 285 | cm_score_file.write( 286 | 'A%02d %s %s\n' % (tags[j].data, 287 | "spoof" if labels[j].data.cpu().numpy() else "bonafide", 288 | score[j].item())) 289 | 290 | score_loader.append(score.detach().cpu()) 291 | idx_loader.append(labels.detach().cpu()) 292 | 293 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 294 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 295 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 296 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 297 | eer = min(eer, other_eer) 298 | 299 | return eer 300 | 301 | 302 | if __name__ == "__main__": 303 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 304 | device = torch.device("cuda") 305 | 306 | args = init() 307 | 308 | model_path = os.path.join(args.model_dir, "anti-spoofing_cqcc_model.pt") 309 | loss_model_path = os.path.join(args.model_dir, "anti-spoofing_loss_model.pt") 310 | 311 | if args.task == "ASVspoof2019LA": 312 | eer = test_model(model_path, loss_model_path, "eval", args.loss) 313 | elif args.task == "ASVspoof2015": 314 | eer = test_on_ASVspoof2015(model_path, loss_model_path, "eval", args.loss) 315 | elif args.task =="VCC2020": 316 | eer = test_on_VCC(model_path, loss_model_path, "eval", args.loss) 317 | elif args.task =="ASVspoof2019LASim": 318 | eer = test_on_ASVspoof2019LASim(model_path, loss_model_path, "eval", args.loss) 319 | else: 320 | raise ValueError("Evaluation task unknown!") 321 | print(eer) 322 | 323 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Function 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | ## Adapted from https://github.com/joaomonteirof/e2e_antispoofing 11 | ## https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/blob/newfunctions/ 12 | 13 | 14 | class SelfAttention(nn.Module): 15 | def __init__(self, hidden_size, mean_only=False): 16 | super(SelfAttention, self).__init__() 17 | 18 | #self.output_size = output_size 19 | self.hidden_size = hidden_size 20 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size),requires_grad=True) 21 | 22 | self.mean_only = mean_only 23 | 24 | init.kaiming_uniform_(self.att_weights) 25 | 26 | def forward(self, inputs): 27 | batch_size = inputs.size(0) 28 | weights = torch.bmm(inputs, self.att_weights.permute(1, 0).unsqueeze(0).repeat(batch_size, 1, 1)) 29 | 30 | if inputs.size(0)==1: 31 | attentions = F.softmax(torch.tanh(weights),dim=1) 32 | weighted = torch.mul(inputs, attentions.expand_as(inputs)) 33 | else: 34 | attentions = F.softmax(torch.tanh(weights.squeeze()),dim=1) 35 | weighted = torch.mul(inputs, attentions.unsqueeze(2).expand_as(inputs)) 36 | 37 | if self.mean_only: 38 | return weighted.sum(1) 39 | else: 40 | noise = 1e-5*torch.randn(weighted.size()) 41 | 42 | if inputs.is_cuda: 43 | noise = noise.to(inputs.device) 44 | avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 45 | 46 | representations = torch.cat((avg_repr,std_repr),1) 47 | 48 | return representations 49 | 50 | 51 | class PreActBlock(nn.Module): 52 | '''Pre-activation version of the BasicBlock.''' 53 | expansion = 1 54 | 55 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 56 | super(PreActBlock, self).__init__() 57 | self.bn1 = nn.BatchNorm2d(in_planes) 58 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 61 | 62 | if stride != 1 or in_planes != self.expansion*planes: 63 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(x)) 67 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 68 | out = self.conv1(out) 69 | out = self.conv2(F.relu(self.bn2(out))) 70 | out += shortcut 71 | return out 72 | 73 | 74 | class PreActBottleneck(nn.Module): 75 | '''Pre-activation version of the original Bottleneck module.''' 76 | expansion = 4 77 | 78 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 79 | super(PreActBottleneck, self).__init__() 80 | self.bn1 = nn.BatchNorm2d(in_planes) 81 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 82 | self.bn2 = nn.BatchNorm2d(planes) 83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 84 | self.bn3 = nn.BatchNorm2d(planes) 85 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 86 | 87 | if stride != 1 or in_planes != self.expansion*planes: 88 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(x)) 92 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 93 | out = self.conv1(out) 94 | out = self.conv2(F.relu(self.bn2(out))) 95 | out = self.conv3(F.relu(self.bn3(out))) 96 | out += shortcut 97 | return out 98 | 99 | def conv3x3(in_planes, out_planes, stride=1): 100 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 101 | 102 | def conv1x1(in_planes, out_planes, stride=1): 103 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 104 | 105 | RESNET_CONFIGS = {'18': [[2, 2, 2, 2], PreActBlock], 106 | '28': [[3, 4, 6, 3], PreActBlock], 107 | '34': [[3, 4, 6, 3], PreActBlock], 108 | '50': [[3, 4, 6, 3], PreActBottleneck], 109 | '101': [[3, 4, 23, 3], PreActBottleneck] 110 | } 111 | 112 | class ResNet(nn.Module): 113 | def __init__(self, num_nodes, enc_dim, resnet_type='18', nclasses=2): 114 | self.in_planes = 16 115 | super(ResNet, self).__init__() 116 | 117 | layers, block = RESNET_CONFIGS[resnet_type] 118 | 119 | self._norm_layer = nn.BatchNorm2d 120 | 121 | self.conv1 = nn.Conv2d(1, 16, kernel_size=(9, 3), stride=(3, 1), padding=(1, 1), bias=False) 122 | self.bn1 = nn.BatchNorm2d(16) 123 | self.activation = nn.ReLU() 124 | 125 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 126 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 127 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 129 | 130 | self.conv5 = nn.Conv2d(512 * block.expansion, 256, kernel_size=(num_nodes, 3), stride=(1, 1), padding=(0, 1), 131 | bias=False) 132 | self.bn5 = nn.BatchNorm2d(256) 133 | self.fc = nn.Linear(256 * 2, enc_dim) 134 | self.fc_mu = nn.Linear(enc_dim, nclasses) if nclasses >= 2 else nn.Linear(enc_dim, 1) 135 | 136 | self.initialize_params() 137 | self.attention = SelfAttention(256) 138 | 139 | def initialize_params(self): 140 | for layer in self.modules(): 141 | if isinstance(layer, torch.nn.Conv2d): 142 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 143 | elif isinstance(layer, torch.nn.Linear): 144 | init.kaiming_uniform_(layer.weight) 145 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 146 | layer.weight.data.fill_(1) 147 | layer.bias.data.zero_() 148 | 149 | def _make_layer(self, block, planes, num_blocks, stride=1): 150 | norm_layer = self._norm_layer 151 | downsample = None 152 | if stride != 1 or self.in_planes != planes * block.expansion: 153 | downsample = nn.Sequential(conv1x1(self.in_planes, planes * block.expansion, stride), 154 | norm_layer(planes * block.expansion)) 155 | layers = [] 156 | layers.append(block(self.in_planes, planes, stride, downsample, 1, 64, 1, norm_layer)) 157 | self.in_planes = planes * block.expansion 158 | for _ in range(1, num_blocks): 159 | layers.append( 160 | block(self.in_planes, planes, 1, groups=1, base_width=64, dilation=False, norm_layer=norm_layer)) 161 | 162 | return nn.Sequential(*layers) 163 | 164 | def forward(self, x): 165 | 166 | x = self.conv1(x) 167 | x = self.activation(self.bn1(x)) 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | x = self.conv5(x) 173 | x = self.activation(self.bn5(x)).squeeze(2) 174 | 175 | stats = self.attention(x.permute(0, 2, 1).contiguous()) 176 | 177 | feat = self.fc(stats) 178 | 179 | mu = self.fc_mu(feat) 180 | 181 | return feat, mu 182 | 183 | 184 | class MaxFeatureMap2D(nn.Module): 185 | """ Max feature map (along 2D) 186 | 187 | MaxFeatureMap2D(max_dim=1) 188 | 189 | l_conv2d = MaxFeatureMap2D(1) 190 | data_in = torch.rand([1, 4, 5, 5]) 191 | data_out = l_conv2d(data_in) 192 | 193 | Input: 194 | ------ 195 | data_in: tensor of shape (batch, channel, ...) 196 | 197 | Output: 198 | ------- 199 | data_out: tensor of shape (batch, channel//2, ...) 200 | 201 | Note 202 | ---- 203 | By default, Max-feature-map is on channel dimension, 204 | and maxout is used on (channel ...) 205 | """ 206 | 207 | def __init__(self, max_dim=1): 208 | super(MaxFeatureMap2D, self).__init__() 209 | self.max_dim = max_dim 210 | 211 | def forward(self, inputs): 212 | # suppose inputs (batchsize, channel, length, dim) 213 | 214 | shape = list(inputs.size()) 215 | 216 | if self.max_dim >= len(shape): 217 | print("MaxFeatureMap: maximize on %d dim" % (self.max_dim)) 218 | print("But input has %d dimensions" % (len(shape))) 219 | sys.exit(1) 220 | if shape[self.max_dim] // 2 * 2 != shape[self.max_dim]: 221 | print("MaxFeatureMap: maximize on %d dim" % (self.max_dim)) 222 | print("But this dimension has an odd number of data") 223 | sys.exit(1) 224 | shape[self.max_dim] = shape[self.max_dim] // 2 225 | shape.insert(self.max_dim, 2) 226 | 227 | # view to (batchsize, 2, channel//2, ...) 228 | # maximize on the 2nd dim 229 | m, i = inputs.view(*shape).max(self.max_dim) 230 | return m 231 | 232 | 233 | class LCNN(nn.Module): 234 | def __init__(self, num_nodes, enc_dim, nclasses=2): 235 | super(LCNN, self).__init__() 236 | self.num_nodes = num_nodes 237 | self.enc_dim = enc_dim 238 | self.nclasses = nclasses 239 | self.conv1 = nn.Sequential(nn.Conv2d(1, 64, (5, 5), 1, padding=(2, 2)), 240 | MaxFeatureMap2D(), 241 | nn.MaxPool2d((2, 2), (2, 2))) 242 | self.conv2 = nn.Sequential(nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)), 243 | MaxFeatureMap2D(), 244 | nn.BatchNorm2d(32, affine=False)) 245 | self.conv3 = nn.Sequential(nn.Conv2d(32, 96, (3, 3), 1, padding=(1, 1)), 246 | MaxFeatureMap2D(), 247 | nn.MaxPool2d((2, 2), (2, 2)), 248 | nn.BatchNorm2d(48, affine=False)) 249 | self.conv4 = nn.Sequential(nn.Conv2d(48, 96, (1, 1), 1, padding=(0, 0)), 250 | MaxFeatureMap2D(), 251 | nn.BatchNorm2d(48, affine=False)) 252 | self.conv5 = nn.Sequential(nn.Conv2d(48, 128, (3, 3), 1, padding=(1, 1)), 253 | MaxFeatureMap2D(), 254 | nn.MaxPool2d((2, 2), (2, 2))) 255 | self.conv6 = nn.Sequential(nn.Conv2d(64, 128, (1, 1), 1, padding=(0, 0)), 256 | MaxFeatureMap2D(), 257 | nn.BatchNorm2d(64, affine=False)) 258 | self.conv7 = nn.Sequential(nn.Conv2d(64, 64, (3, 3), 1, padding=(1, 1)), 259 | MaxFeatureMap2D(), 260 | nn.BatchNorm2d(32, affine=False)) 261 | self.conv8 = nn.Sequential(nn.Conv2d(32, 64, (1, 1), 1, padding=(0, 0)), 262 | MaxFeatureMap2D(), 263 | nn.BatchNorm2d(32, affine=False)) 264 | self.conv9 = nn.Sequential(nn.Conv2d(32, 64, (3, 3), 1, padding=[1, 1]), 265 | MaxFeatureMap2D(), 266 | nn.MaxPool2d((2, 2), (2, 2))) 267 | self.out = nn.Sequential(nn.Dropout(0.7), 268 | nn.Linear((750 // 16) * (60 // 16) * 32, 160), 269 | MaxFeatureMap2D(), 270 | nn.Linear(80, self.enc_dim)) 271 | self.fc_mu = nn.Linear(enc_dim, nclasses) if nclasses >= 2 else nn.Linear(enc_dim, 1) 272 | 273 | def forward(self, x): 274 | 275 | x = self.conv1(x) 276 | x = self.conv2(x) 277 | x = self.conv3(x) 278 | x = self.conv4(x) 279 | x = self.conv5(x) 280 | x = self.conv6(x) 281 | x = self.conv7(x) 282 | x = self.conv8(x) 283 | x = self.conv9(x) 284 | feat = torch.flatten(x, 1) 285 | feat = self.out(feat) 286 | out = self.fc_mu(feat) 287 | 288 | return feat, out 289 | 290 | class GradientReversalFunction(Function): 291 | """ 292 | Gradient Reversal Layer from: 293 | Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) 294 | Forward pass is the identity function. In the backward pass, 295 | the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) 296 | """ 297 | 298 | @staticmethod 299 | def forward(ctx, x, lambda_): 300 | ctx.lambda_ = lambda_ 301 | return x.clone() 302 | 303 | @staticmethod 304 | def backward(ctx, grads): 305 | lambda_ = ctx.lambda_ 306 | lambda_ = grads.new_tensor(lambda_) 307 | dx = -lambda_ * grads 308 | return dx, None 309 | 310 | 311 | class GradientReversal(nn.Module): 312 | def __init__(self, lambda_=1): 313 | super(GradientReversal, self).__init__() 314 | self.lambda_ = lambda_ 315 | 316 | def forward(self, x): 317 | return GradientReversalFunction.apply(x, self.lambda_) 318 | 319 | 320 | class ChannelClassifier(nn.Module): 321 | def __init__(self, enc_dim, nclasses, lambda_=0.05, ADV=True): 322 | super(ChannelClassifier, self).__init__() 323 | self.adv = ADV 324 | if self.adv: 325 | self.grl = GradientReversal(lambda_) 326 | self.classifier = nn.Sequential(nn.Linear(enc_dim, enc_dim // 2), 327 | nn.Dropout(0.3), 328 | nn.ReLU(), 329 | nn.Linear(enc_dim // 2, nclasses), 330 | nn.ReLU()) 331 | 332 | def initialize_params(self): 333 | for layer in self.modules(): 334 | if isinstance(layer, torch.nn.Linear): 335 | init.kaiming_uniform_(layer.weight) 336 | 337 | def forward(self, x): 338 | if self.adv: 339 | x = self.grl(x) 340 | return self.classifier(x) 341 | 342 | 343 | 344 | if __name__ == "__main__": 345 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 346 | 347 | # cqcc = torch.randn((32,1,90,788)).cuda() 348 | # resnet = ResNet(4, 2, resnet_type='18', nclasses=2).cuda() 349 | # _, output = resnet(cqcc) 350 | # print(output.shape) 351 | lfcc = torch.randn((1, 1, 60, 750)).cuda() 352 | lcnn = LCNN(4, 2, nclasses=2).cuda() 353 | feat, output = lcnn(lfcc) 354 | print(output.shape) 355 | # cnn = ConvNet(num_classes = 2, num_nodes = 47232, enc_dim = 256).cuda() 356 | # _, output = cnn(lfcc) 357 | # print(output.shape) 358 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | import os 5 | import json 6 | import shutil 7 | import numpy as np 8 | from model import * 9 | from dataset import ASVspoof2019 10 | from torch.utils.data import DataLoader 11 | from evaluate_tDCF_asvspoof19 import compute_eer_and_tdcf 12 | from loss import * 13 | from collections import defaultdict 14 | from tqdm import tqdm, trange 15 | import random 16 | from test import * 17 | import eval_metrics as em 18 | 19 | torch.set_default_tensor_type(torch.FloatTensor) 20 | 21 | def initParams(): 22 | parser = argparse.ArgumentParser(description=__doc__) 23 | 24 | parser.add_argument('--seed', type=int, help="random number seed", default=688) 25 | 26 | # Data folder prepare 27 | parser.add_argument("-a", "--access_type", type=str, help="LA or PA", default='LA') 28 | parser.add_argument("-d", "--path_to_database", type=str, help="dataset path", default='/data/neil/DS_10283_3336/') 29 | parser.add_argument("-f", "--path_to_features", type=str, help="features path", 30 | default='/data2/neil/ASVspoof2019LA/') 31 | parser.add_argument("-p", "--path_to_protocol", type=str, help="protocol path", 32 | default='/data/neil/DS_10283_3336/LA/ASVspoof2019_LA_cm_protocols/') 33 | parser.add_argument("-o", "--out_fold", type=str, help="output folder", required=True, default='./models/try/') 34 | 35 | # Dataset prepare 36 | parser.add_argument("--feat", type=str, help="which feature to use", default='LFCC') 37 | parser.add_argument("--feat_len", type=int, help="features length", default=750) 38 | parser.add_argument('--padding', type=str, default='repeat', choices=['zero', 'repeat', 'silence'], 39 | help="how to pad short utterance") 40 | parser.add_argument("--enc_dim", type=int, help="encoding dimension", default=256) 41 | 42 | parser.add_argument('-m', '--model', help='Model arch', default='resnet', 43 | choices=['resnet', 'lcnn']) 44 | 45 | # Training hyperparameters 46 | parser.add_argument('--num_epochs', type=int, default=1000, help="Number of epochs for training") 47 | parser.add_argument('--batch_size', type=int, default=64, help="Mini batch size for training") 48 | parser.add_argument('--lr', type=float, default=0.0005, help="learning rate") 49 | parser.add_argument('--lr_decay', type=float, default=0.5, help="decay learning rate") 50 | parser.add_argument('--interval', type=int, default=100, help="interval to decay lr") 51 | 52 | parser.add_argument('--beta_1', type=float, default=0.9, help="bata_1 for Adam") 53 | parser.add_argument('--beta_2', type=float, default=0.999, help="beta_2 for Adam") 54 | parser.add_argument('--eps', type=float, default=1e-8, help="epsilon for Adam") 55 | parser.add_argument("--gpu", type=str, help="GPU index", default="1") 56 | parser.add_argument('--num_workers', type=int, default=0, help="number of workers") 57 | 58 | parser.add_argument('--add_loss', type=str, default="ocsoftmax", 59 | choices=[None, 'ocsoftmax'], help="add other loss for one-class training") 60 | parser.add_argument('--weight_loss', type=float, default=1, help="weight for other loss") 61 | parser.add_argument('--r_real', type=float, default=0.9, help="r_real for ocsoftmax loss") 62 | parser.add_argument('--r_fake', type=float, default=0.2, help="r_fake for ocsoftmax loss") 63 | parser.add_argument('--alpha', type=float, default=20, help="scale factor for angular isolate loss") 64 | 65 | parser.add_argument('--test_only', action='store_true', help="test the trained model in case the test crash sometimes or another test method") 66 | parser.add_argument('--continue_training', action='store_true', help="continue training with trained model") 67 | 68 | parser.add_argument('--AUG', type=str2bool, nargs='?', const=True, default=False, 69 | help="whether to use device_augmentation in training") 70 | parser.add_argument('--MT_AUG', type=str2bool, nargs='?', const=True, default=False, 71 | help="whether to use device_multitask_augmentation in training") 72 | parser.add_argument('--ADV_AUG', type=str2bool, nargs='?', const=True, default=False, 73 | help="whether to use device_adversarial_augmentation in training") 74 | parser.add_argument('--lambda_', type=float, default=0.05, help="lambda for gradient reversal layer") 75 | parser.add_argument('--lr_d', type=float, default=0.0001, help="learning rate") 76 | 77 | parser.add_argument('--pre_train', action='store_true', help="whether to pretrain the model") 78 | parser.add_argument('--test_on_eval', action='store_true', 79 | help="whether to run EER on the evaluation set") 80 | 81 | args = parser.parse_args() 82 | 83 | # Change this to specify GPU 84 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 85 | 86 | # Set seeds 87 | setup_seed(args.seed) 88 | 89 | if args.test_only or args.continue_training: 90 | pass 91 | else: 92 | # Path for output data 93 | if not os.path.exists(args.out_fold): 94 | os.makedirs(args.out_fold) 95 | else: 96 | shutil.rmtree(args.out_fold) 97 | os.mkdir(args.out_fold) 98 | 99 | # Folder for intermediate results 100 | if not os.path.exists(os.path.join(args.out_fold, 'checkpoint')): 101 | os.makedirs(os.path.join(args.out_fold, 'checkpoint')) 102 | else: 103 | shutil.rmtree(os.path.join(args.out_fold, 'checkpoint')) 104 | os.mkdir(os.path.join(args.out_fold, 'checkpoint')) 105 | 106 | # Path for input data 107 | # assert os.path.exists(args.path_to_database) 108 | assert os.path.exists(args.path_to_features) 109 | 110 | # Save training arguments 111 | with open(os.path.join(args.out_fold, 'args.json'), 'w') as file: 112 | file.write(json.dumps(vars(args), sort_keys=True, separators=('\n', ':'))) 113 | 114 | with open(os.path.join(args.out_fold, 'train_loss.log'), 'w') as file: 115 | file.write("Start recording training loss ...\n") 116 | with open(os.path.join(args.out_fold, 'dev_loss.log'), 'w') as file: 117 | file.write("Start recording validation loss ...\n") 118 | with open(os.path.join(args.out_fold, 'test_loss.log'), 'w') as file: 119 | file.write("Start recording test loss ...\n") 120 | 121 | args.cuda = torch.cuda.is_available() 122 | print('Cuda device available: ', args.cuda) 123 | args.device = torch.device("cuda" if args.cuda else "cpu") 124 | 125 | return args 126 | 127 | def adjust_learning_rate(args, lr, optimizer, epoch_num): 128 | lr = lr * (args.lr_decay ** (epoch_num // args.interval)) 129 | for param_group in optimizer.param_groups: 130 | param_group['lr'] = lr 131 | 132 | def shuffle(feat, tags, labels): 133 | shuffle_index = torch.randperm(labels.shape[0]) 134 | feat = feat[shuffle_index] 135 | tags = tags[shuffle_index] 136 | labels = labels[shuffle_index] 137 | # this_len = this_len[shuffle_index] 138 | return feat, tags, labels 139 | 140 | def train(args): 141 | torch.set_default_tensor_type(torch.FloatTensor) 142 | 143 | # initialize model 144 | if args.model == 'resnet': 145 | node_dict = {"CQCC": 4, "LFCC": 3} 146 | feat_model = ResNet(node_dict[args.feat], args.enc_dim, resnet_type='18', nclasses=2).to(args.device) 147 | elif args.model == 'lcnn': 148 | feat_model = LCNN(4, args.enc_dim, nclasses=2).to(args.device) 149 | 150 | if args.continue_training: 151 | feat_model = torch.load(os.path.join(args.out_fold, 'anti-spoofing_feat_model.pt')).to(args.device) 152 | # feat_model = nn.DataParallel(feat_model, list(range(torch.cuda.device_count()))) # for multiple GPUs 153 | feat_optimizer = torch.optim.Adam(feat_model.parameters(), lr=args.lr, 154 | betas=(args.beta_1, args.beta_2), eps=args.eps, weight_decay=0.0005) 155 | 156 | training_set = ASVspoof2019(args.access_type, args.path_to_features, 'train', 157 | args.feat, feat_len=args.feat_len, padding=args.padding) 158 | validation_set = ASVspoof2019(args.access_type, args.path_to_features, 'dev', 159 | args.feat, feat_len=args.feat_len, padding=args.padding) 160 | if args.AUG or args.MT_AUG or args.ADV_AUG: 161 | training_set = ASVspoof2019LA_DeviceAdversarial(path_to_features="/data2/neil/ASVspoof2019LA/", 162 | path_to_deviced="/dataNVME/neil/ASVspoof2019LADevice", 163 | part="train", 164 | feature=args.feat, feat_len=args.feat_len, 165 | padding=args.padding) 166 | validation_set = ASVspoof2019LA_DeviceAdversarial(path_to_features="/data2/neil/ASVspoof2019LA/", 167 | path_to_deviced="/dataNVME/neil/ASVspoof2019LADevice", 168 | part="dev", 169 | feature=args.feat, feat_len=args.feat_len, 170 | padding=args.padding) 171 | if args.MT_AUG or args.ADV_AUG: 172 | classifier = ChannelClassifier(args.enc_dim, len(training_set.devices)+1, args.lambda_, ADV=args.ADV_AUG).to(args.device) 173 | classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=args.lr_d, 174 | betas=(args.beta_1, args.beta_2), eps=args.eps, weight_decay=0.0005) 175 | 176 | trainDataLoader = DataLoader(training_set, batch_size=args.batch_size, 177 | shuffle=True, num_workers=args.num_workers, collate_fn=training_set.collate_fn) 178 | valDataLoader = DataLoader(validation_set, batch_size=args.batch_size, 179 | shuffle=True, num_workers=args.num_workers, collate_fn=validation_set.collate_fn) 180 | 181 | test_set = ASVspoof2019(args.access_type, args.path_to_features, "eval", args.feat, 182 | feat_len=args.feat_len, padding=args.padding) 183 | testDataLoader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=test_set.collate_fn) 184 | 185 | feat, _, _, _, _ = training_set[23] 186 | print("Feature shape", feat.shape) 187 | 188 | criterion = nn.CrossEntropyLoss() 189 | 190 | if args.add_loss == "ocsoftmax": 191 | ocsoftmax = OCSoftmax(args.enc_dim, r_real=args.r_real, r_fake=args.r_fake, alpha=args.alpha).to(args.device) 192 | ocsoftmax.train() 193 | ocsoftmax_optimzer = torch.optim.SGD(ocsoftmax.parameters(), lr=args.lr) 194 | 195 | early_stop_cnt = 0 196 | prev_loss = 1e8 197 | 198 | if args.add_loss is None: 199 | monitor_loss = 'base_loss' 200 | else: 201 | monitor_loss = args.add_loss 202 | 203 | for epoch_num in tqdm(range(args.num_epochs)): 204 | genuine_feats, ip1_loader, tag_loader, idx_loader = [], [], [], [] 205 | feat_model.train() 206 | trainlossDict = defaultdict(list) 207 | devlossDict = defaultdict(list) 208 | testlossDict = defaultdict(list) 209 | adjust_learning_rate(args, args.lr, feat_optimizer, epoch_num) 210 | if args.add_loss == "ocsoftmax": 211 | adjust_learning_rate(args, args.lr, ocsoftmax_optimzer, epoch_num) 212 | if args.MT_AUG or args.ADV_AUG: 213 | adjust_learning_rate(args, args.lr_d, classifier_optimizer, epoch_num) 214 | print('\nEpoch: %d ' % (epoch_num + 1)) 215 | correct_m, total_m, correct_c, total_c, correct_v, total_v = 0, 0, 0, 0, 0, 0 216 | 217 | for i, (feat, audio_fn, tags, labels, channel) in enumerate(tqdm(trainDataLoader)): 218 | if args.AUG or args.MT_AUG or args.ADV_AUG: 219 | if i > int(len(training_set) / args.batch_size / (len(training_set.devices) + 1)): break 220 | feat = feat.transpose(2,3).to(args.device) 221 | tags = tags.to(args.device) 222 | labels = labels.to(args.device) 223 | feats, feat_outputs = feat_model(feat) 224 | feat_loss = criterion(feat_outputs, labels) 225 | trainlossDict['base_loss'].append(feat_loss.item()) 226 | 227 | if args.add_loss == None: 228 | feat_optimizer.zero_grad() 229 | feat_loss.backward() 230 | feat_optimizer.step() 231 | 232 | if args.add_loss == "ocsoftmax": 233 | ocsoftmaxloss, _ = ocsoftmax(feats, labels) 234 | feat_loss = ocsoftmaxloss * args.weight_loss 235 | if epoch_num > 0 and (args.MT_AUG or args.ADV_AUG): 236 | channel = channel.to(args.device) 237 | classifier_out = classifier(feats) 238 | _, predicted = torch.max(classifier_out.data, 1) 239 | total_m += channel.size(0) 240 | correct_m += (predicted == channel).sum().item() 241 | device_loss = criterion(classifier_out, channel) 242 | feat_loss += device_loss 243 | trainlossDict["adv_loss"].append(device_loss.item()) 244 | feat_optimizer.zero_grad() 245 | ocsoftmax_optimzer.zero_grad() 246 | trainlossDict[args.add_loss].append(ocsoftmaxloss.item()) 247 | feat_loss.backward() 248 | feat_optimizer.step() 249 | ocsoftmax_optimzer.step() 250 | 251 | if (args.MT_AUG or args.ADV_AUG): 252 | channel = channel.to(args.device) 253 | feats, _ = feat_model(feat) 254 | feats = feats.detach() 255 | classifier_out = classifier(feats) 256 | _, predicted = torch.max(classifier_out.data, 1) 257 | total_c += channel.size(0) 258 | correct_c += (predicted == channel).sum().item() 259 | device_loss_c = criterion(classifier_out, channel) 260 | classifier_optimizer.zero_grad() 261 | device_loss_c.backward() 262 | classifier_optimizer.step() 263 | 264 | ip1_loader.append(feats) 265 | idx_loader.append((labels)) 266 | tag_loader.append((tags)) 267 | 268 | 269 | if epoch_num > 0 and (args.MT_AUG or args.ADV_AUG): 270 | with open(os.path.join(args.out_fold, "train_loss.log"), "a") as log: 271 | log.write(str(epoch_num) + "\t" + str(i) + "\t" + 272 | str(trainlossDict["adv_loss"][-1]) + "\t" + 273 | str(100 * correct_m / total_m) + "\t" + 274 | str(100 * correct_c / total_c) + "\t" + 275 | str(trainlossDict[monitor_loss][-1]) + "\n") 276 | else: 277 | with open(os.path.join(args.out_fold, "train_loss.log"), "a") as log: 278 | log.write(str(epoch_num) + "\t" + str(i) + "\t" + 279 | str(trainlossDict[monitor_loss][-1]) + "\n") 280 | 281 | 282 | # Val the model 283 | # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) 284 | feat_model.eval() 285 | with torch.no_grad(): 286 | ip1_loader, tag_loader, idx_loader, score_loader = [], [], [], [] 287 | # with trange(2) as v: 288 | # with trange(len(valDataLoader)) as v: 289 | # for i in v: 290 | for i, (feat, audio_fn, tags, labels, channel) in enumerate(tqdm(valDataLoader)): 291 | if args.AUG or args.MT_AUG or args.ADV_AUG: 292 | if i > int(len(validation_set) / args.batch_size / (len(validation_set.devices) + 1)): break 293 | feat = feat.transpose(2,3).to(args.device) 294 | 295 | tags = tags.to(args.device) 296 | labels = labels.to(args.device) 297 | 298 | feat, tags, labels = shuffle(feat, tags, labels) 299 | 300 | feats, feat_outputs = feat_model(feat) 301 | 302 | feat_loss = criterion(feat_outputs, labels) 303 | score = F.softmax(feat_outputs, dim=1)[:, 0] 304 | 305 | ip1_loader.append(feats) 306 | idx_loader.append((labels)) 307 | tag_loader.append((tags)) 308 | 309 | if args.add_loss == "ocsoftmax": 310 | ocsoftmaxloss, score = ocsoftmax(feats, labels) 311 | devlossDict[args.add_loss].append(ocsoftmaxloss.item()) 312 | if epoch_num > 0 and (args.MT_AUG or args.ADV_AUG): 313 | channel = channel.to(args.device) 314 | classifier_out = classifier(feats) 315 | _, predicted = torch.max(classifier_out.data, 1) 316 | total_v += channel.size(0) 317 | correct_v += (predicted == channel).sum().item() 318 | device_loss = criterion(classifier_out, channel) 319 | devlossDict["adv_loss"].append(device_loss.item()) 320 | 321 | score_loader.append(score) 322 | 323 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 324 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 325 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 326 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 327 | eer = min(eer, other_eer) 328 | 329 | if epoch_num > 0 and (args.MT_AUG or args.ADV_AUG): 330 | with open(os.path.join(args.out_fold, "dev_loss.log"), "a") as log: 331 | log.write(str(epoch_num) + "\t"+ "\t" + 332 | str(np.nanmean(devlossDict["adv_loss"])) + "\t" + 333 | str(100 * correct_v / total_v) + "\t" + 334 | str(np.nanmean(devlossDict[monitor_loss])) + "\t" + 335 | str(eer) + "\n") 336 | else: 337 | with open(os.path.join(args.out_fold, "dev_loss.log"), "a") as log: 338 | log.write(str(epoch_num) + "\t" + 339 | str(np.nanmean(devlossDict[monitor_loss])) + "\t" + 340 | str(eer) +"\n") 341 | print("Val EER: {}".format(eer)) 342 | 343 | 344 | if args.test_on_eval: 345 | with torch.no_grad(): 346 | ip1_loader, tag_loader, idx_loader, score_loader = [], [], [], [] 347 | for i, (feat, audio_fn, tags, labels, channel) in enumerate(tqdm(testDataLoader)): 348 | feat = feat.transpose(2,3).to(args.device) 349 | tags = tags.to(args.device) 350 | labels = labels.to(args.device) 351 | feats, feat_outputs = feat_model(feat) 352 | feat_loss = criterion(feat_outputs, labels) 353 | score = F.softmax(feat_outputs, dim=1)[:, 0] 354 | 355 | ip1_loader.append(feats) 356 | idx_loader.append((labels)) 357 | tag_loader.append((tags)) 358 | 359 | if args.add_loss == "ocsoftmax": 360 | ocsoftmaxloss, score = ocsoftmax(feats, labels) 361 | testlossDict[args.add_loss].append(ocsoftmaxloss.item()) 362 | score_loader.append(score) 363 | 364 | scores = torch.cat(score_loader, 0).data.cpu().numpy() 365 | labels = torch.cat(idx_loader, 0).data.cpu().numpy() 366 | eer = em.compute_eer(scores[labels == 0], scores[labels == 1])[0] 367 | other_eer = em.compute_eer(-scores[labels == 0], -scores[labels == 1])[0] 368 | eer = min(eer, other_eer) 369 | 370 | with open(os.path.join(args.out_fold, "test_loss.log"), "a") as log: 371 | log.write(str(epoch_num) + "\t" + str(np.nanmean(testlossDict[monitor_loss])) + "\t" + str(eer) + "\n") 372 | print("Test EER: {}".format(eer)) 373 | 374 | 375 | valLoss = np.nanmean(devlossDict[monitor_loss]) 376 | # if args.add_loss == "isolate": 377 | # print("isolate center: ", iso_loss.center.data) 378 | if (epoch_num + 1) % 1 == 0: 379 | torch.save(feat_model, os.path.join(args.out_fold, 'checkpoint', 380 | 'anti-spoofing_feat_model_%d.pt' % (epoch_num + 1))) 381 | if args.add_loss == "ocsoftmax": 382 | loss_model = ocsoftmax 383 | torch.save(loss_model, os.path.join(args.out_fold, 'checkpoint', 384 | 'anti-spoofing_loss_model_%d.pt' % (epoch_num + 1))) 385 | else: 386 | loss_model = None 387 | 388 | if valLoss < prev_loss: 389 | # Save the model checkpoint 390 | torch.save(feat_model, os.path.join(args.out_fold, 'anti-spoofing_feat_model.pt')) 391 | if args.add_loss == "ocsoftmax": 392 | loss_model = ocsoftmax 393 | torch.save(loss_model, os.path.join(args.out_fold, 'anti-spoofing_loss_model.pt')) 394 | else: 395 | loss_model = None 396 | prev_loss = valLoss 397 | early_stop_cnt = 0 398 | else: 399 | early_stop_cnt += 1 400 | 401 | if early_stop_cnt == 500: 402 | with open(os.path.join(args.out_fold, 'args.json'), 'a') as res_file: 403 | res_file.write('\nTrained Epochs: %d\n' % (epoch_num - 499)) 404 | break 405 | # if early_stop_cnt == 1: 406 | # torch.save(feat_model, os.path.join(args.out_fold, 'anti-spoofing_feat_model.pt') 407 | 408 | # print('Dev Accuracy of the model on the val features: {} % '.format(100 * feat_correct / total)) 409 | 410 | return feat_model, loss_model 411 | 412 | 413 | 414 | if __name__ == "__main__": 415 | args = initParams() 416 | if not args.test_only: 417 | _, _ = train(args) 418 | # model = torch.load(os.path.join(args.out_fold, 'anti-spoofing_feat_model.pt')) 419 | # if args.add_loss is None: 420 | # loss_model = None 421 | # else: 422 | # loss_model = torch.load(os.path.join(args.out_fold, 'anti-spoofing_loss_model.pt')) 423 | # # TReer_cm, TRmin_tDCF = test(args, model, loss_model, "train") 424 | # # VAeer_cm, VAmin_tDCF = test(args, model, loss_model, "dev") 425 | # TEeer_cm, TEmin_tDCF = test(args, model, loss_model) 426 | # with open(os.path.join(args.out_fold, 'args.json'), 'a') as res_file: 427 | # # res_file.write('\nTrain EER: %8.5f min-tDCF: %8.5f\n' % (TReer_cm, TRmin_tDCF)) 428 | # # res_file.write('\nVal EER: %8.5f min-tDCF: %8.5f\n' % (VAeer_cm, VAmin_tDCF)) 429 | # res_file.write('\nTest EER: %8.5f min-tDCF: %8.5f\n' % (TEeer_cm, TEmin_tDCF)) 430 | 431 | 432 | # # Test a checkpoint model 433 | # args = initParams() 434 | # model = torch.load(os.path.join(args.out_fold, 'checkpoint', 'anti-spoofing_feat_model_19.pt')) 435 | # loss_model = torch.load(os.path.join(args.out_fold, 'checkpoint', 'anti-spoofing_loss_model_19.pt')) 436 | # VAeer_cm, VAmin_tDCF = test(args, model, loss_model, "dev") 437 | --------------------------------------------------------------------------------