├── .gitignore ├── README.md ├── data ├── BPMProcess.py ├── RawBoost.py └── dataloader.py ├── figures ├── singfake_sota.jpg └── singgraph.png ├── main.py ├── model └── SingGraph.py ├── requirements.txt ├── run.sh └── utils ├── SingGraph.conf ├── eval_metrics.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | dataset/ 3 | cache/w2v2/xlsr2_300m.pt 4 | cache/models--m-a-p--MERT-v1-330M/ 5 | cache/hub/ 6 | cache/modules/ 7 | exp_result/ 8 | fairseq/ 9 | fairseq1/ 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SingGraph [![arXiv](https://img.shields.io/badge/arXiv-2406.03111-b31b1b.svg)](https://arxiv.org/abs/2406.03111) 2 | 3 | This is the official repository for the **[SingGraph](https://arxiv.org/abs/2406.03111)** model. The paper has been accepted by Interspeech 2024. 4 | 5 | The official code has been released. I need some time to write the `README.md` of the data pre-processing processes. 6 | ## Abstract 7 | 8 | Existing models for speech deepfake detection have struggled to adapt to unseen attacks in this unique singing voice domain of human vocalization. To bridge the gap, we present a groundbreaking SingGraph model. The model synergizes the capabilities of the MERT acoustic music understanding model for pitch and rhythm analysis with the wav2vec2.0 model for linguistic analysis of lyrics Additionally, we advocate for using RawBoost and beat matching techniques grounded in music domain knowledge for singing voice augmentation, thereby enhancing SingFake detection performance. 9 | Our proposed method achieves new state-of-the-art (SOTA) results within the SingFake dataset, surpassing the previous SOTA model across three distinct scenarios: it improves EER relatively for seen singers by 13.2\%, for unseen singers by 24.3\%, and unseen singers using different codecs by 37.1\%. 10 | 11 | ![](figures/singgraph.png) 12 | 13 | ## Datasets 14 | The dataset is based on the paper "SingFake: Singing Voice Deepfake Detection," which is accepted by ICASSP 2024. [[Project Webpage](https://singfake.org/)] 15 | 16 | Since the copyright issue, the dataset didn't open source. Please follow the instructions in the above paper to download the dataset by yourself. 17 | 18 | 19 | 20 | 21 | 22 | ## Comparison with the SOTA model 23 | 24 |
25 |
26 | 27 |
28 |
29 | 30 | ## Citation 31 | If you find our work useful, please consider cite 32 | ``` 33 | @inproceedings{chen24o_interspeech, 34 | title = {Singing Voice Graph Modeling for SingFake Detection}, 35 | author = {Xuanjun Chen and Haibin Wu and Roger Jang and Hung-yi Lee}, 36 | year = {2024}, 37 | booktitle = {Interspeech 2024}, 38 | pages = {4843--4847}, 39 | doi = {10.21437/Interspeech.2024-1185}, 40 | issn = {2958-1796}, 41 | } 42 | 43 | @article{chen2025how, 44 | title={How Does Instrumental Music Help SingFake Detection?}, 45 | author={Chen, Xuanjun and Hu, Chia-Yu and Lin, I-Ming and Lin, Yi-Cheng and Chiu, I-Hsiang and Zhang, You and Huang, Sung-Feng and Yang, Yi-Hsuan and Wu, Haibin and Lee, Hung-yi and Jang, Jyh-Shing Roger}, 46 | journal={arXiv preprint arXiv:2509.14675}, 47 | year={2025} 48 | } 49 | ``` 50 | ## Acknowledgement 51 | If you have any questions, please feel free to contact me by email at d12942018@ntu.edu.tw. 52 | -------------------------------------------------------------------------------- /data/BPMProcess.py: -------------------------------------------------------------------------------- 1 | import os, math, json 2 | import numpy as np 3 | import librosa 4 | 5 | class BpmProcessor: 6 | def __init__(self, train_acc_path, 7 | json2bpm_path, 8 | bpm2json_path, 9 | sample_rate, 10 | threshold): 11 | with open(json2bpm_path, "r") as f: 12 | j2b_dict = json.load(f) 13 | 14 | with open(bpm2json_path, "r") as f: 15 | b2j_dict = json.load(f) 16 | 17 | self.train_acc_path = train_acc_path 18 | self.j2b_dict = j2b_dict 19 | self.b2j_dict = b2j_dict 20 | 21 | self.sr = sample_rate 22 | self.thr = threshold / self.sr 23 | self.waveform_thr = threshold ## 1000000 24 | 25 | def load_audio_by_json(self, sel_json_name): 26 | sel_wav_path = os.path.join(self.train_acc_path, sel_json_name[:-5] + ".wav") 27 | wav, _ = librosa.load(sel_wav_path, sr=self.sr) 28 | return wav 29 | 30 | def sel_accom_from_bpm_group(self, bpm_num, y): 31 | count, downbeats_duration = 0, 0 32 | sel_json = None 33 | candidate_jsons = self.b2j_dict[str(bpm_num)] 34 | filtered_gen = filter(lambda x: x.startswith(str(y)), candidate_jsons) 35 | candidate_list = list(filtered_gen) 36 | 37 | if len(candidate_list) != 0: 38 | while True: 39 | sel_json = np.random.choice(candidate_list) 40 | bpm_beat = self.j2b_dict[sel_json] 41 | count += 1 42 | pos1 = np.where(np.array(bpm_beat["beat_positions"]) == 1)[0] 43 | pos1_len = pos1.shape[0] 44 | 45 | # At least one bar period 46 | if pos1_len > 2: 47 | first_pos, last_pos = pos1[0], pos1[-1] 48 | downbeats_duration = bpm_beat["downbeats"][-1] - bpm_beat["downbeats"][0] 49 | break 50 | 51 | # Stop because not found 52 | if count >= 5: 53 | break 54 | 55 | return sel_json, downbeats_duration 56 | 57 | def accom_beat_padding(self, waveform, 58 | s_name, 59 | dbs_duration): 60 | content = self.j2b_dict[s_name] 61 | start, end = content["downbeats"][0], content["downbeats"][-1] 62 | sel_waveform = waveform[int(start * self.sr) : int(end * self.sr)] 63 | 64 | if dbs_duration < self.thr: 65 | # print(f"start: {start}, end: {end}") 66 | cp_num = math.ceil(self.thr / dbs_duration) 67 | sel_waveform = np.concatenate([sel_waveform] * cp_num) 68 | 69 | waveform_thr = int(self.thr * self.sr) + 1 70 | return sel_waveform[:waveform_thr] 71 | 72 | def sv_beat_align(self, wav_sv, sel_json): 73 | downbeats = self.j2b_dict[sel_json + ".json"]["downbeats"] 74 | wav_seg = wav_sv[int(downbeats[0] * self.sr):int(downbeats[-1] * self.sr)] 75 | # print(f"downbeats[0:-1]: {downbeats}") 76 | rand_start = np.random.choice(downbeats[0:-1]) 77 | rand_start_seg = wav_sv[int(rand_start * self.sr):int(downbeats[-1] * self.sr)] 78 | 79 | if rand_start_seg.shape[0] >= self.waveform_thr: 80 | output = rand_start_seg 81 | else: 82 | remain_len = self.waveform_thr - rand_start_seg.shape[0] - wav_seg.shape[0] 83 | if remain_len >= 0: 84 | padded_len = math.ceil(remain_len // wav_seg.shape[0]) + 1 85 | else: 86 | padded_len = 1 87 | padded_wav_seg = np.concatenate([wav_seg] * padded_len) 88 | output = np.concatenate((rand_start_seg, padded_wav_seg)) 89 | 90 | output = output[:self.waveform_thr] 91 | return output 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | train_acc_path = "./dataset/split_dump_flac/train/non_vocals/" 97 | train_vocal_path = "./dataset/split_dump_flac/train/vocals/" 98 | b2j_path = "./dataset/split_dump_flac/train/bpm2json.json" 99 | j2b_path = "./dataset/split_dump_flac/train/json2bpm.json" 100 | 101 | with open(j2b_path, "r") as d: 102 | json2bpm_dict = json.load(d) 103 | 104 | BpmProCls = BpmProcessor(train_acc_path=train_acc_path, 105 | json2bpm_path=j2b_path, 106 | bpm2json_path=b2j_path, 107 | sample_rate=16000, 108 | threshold=64600) 109 | 110 | wav_path = "./dataset/split_dump_flac/train/non_vocals/0_0212_3.wav" 111 | json_name = os.path.basename(wav_path)[:-4] + ".json" 112 | wav, sample_rate = librosa.load(wav_path, sr=16000) 113 | 114 | bpm_n = json2bpm_dict[json_name]["bpm"] 115 | 116 | sel_json, dbs_duration = BpmProCls.sel_accom_from_bpm_group(bpm_n, "1") 117 | sel_wav = BpmProCls.load_audio_by_json(sel_json) 118 | padded_waveform = BpmProCls.accom_beat_padding(sel_wav, sel_json, dbs_duration) 119 | print(f"padded_waveform: {padded_waveform.shape}") 120 | 121 | bpm_res = json2bpm_dict["1_1344_11.json"]["bpm"] 122 | print(f"bpm_res: {bpm_res}") 123 | 124 | wav_path = os.path.join(train_vocal_path, "0_0321_4.flac") 125 | wav11, sample_rate = librosa.load(wav_path, sr=16000) 126 | 127 | wav_sv = BpmProCls.sv_beat_align(wav11, "0_0321_4") 128 | print(f"wav_sv: {wav_sv.shape}") 129 | -------------------------------------------------------------------------------- /data/RawBoost.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | from scipy import signal 6 | import copy 7 | 8 | 9 | """ 10 | ___author__ = "Massimiliano Todisco, Hemlata Tak" 11 | __email__ = "{todisco,tak}@eurecom.fr" 12 | """ 13 | 14 | ''' 15 | Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. 16 | RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. 17 | In Proc. ICASSP 2022, pp:6382--6386. 18 | ''' 19 | 20 | def randRange(x1, x2, integer): 21 | y = np.random.uniform(low=x1, high=x2, size=(1,)) 22 | if integer: 23 | y = int(y) 24 | return y 25 | 26 | def normWav(x,always): 27 | if always: 28 | x = x/np.amax(abs(x)) 29 | elif np.amax(abs(x)) > 1: 30 | x = x/np.amax(abs(x)) 31 | return x 32 | 33 | 34 | 35 | def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): 36 | b = 1 37 | for i in range(0, nBands): 38 | fc = randRange(minF,maxF,0); 39 | bw = randRange(minBW,maxBW,0); 40 | c = randRange(minCoeff,maxCoeff,1); 41 | 42 | if c/2 == int(c/2): 43 | c = c + 1 44 | f1 = fc - bw/2 45 | f2 = fc + bw/2 46 | if f1 <= 0: 47 | f1 = 1/1000 48 | if f2 >= fs/2: 49 | f2 = fs/2-1/1000 50 | b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b) 51 | 52 | G = randRange(minG,maxG,0); 53 | _, h = signal.freqz(b, 1, fs=fs) 54 | b = pow(10, G/20)*b/np.amax(abs(h)) 55 | return b 56 | 57 | 58 | def filterFIR(x,b): 59 | N = b.shape[0] + 1 60 | xpad = np.pad(x, (0, N), 'constant') 61 | y = signal.lfilter(b, 1, xpad) 62 | y = y[int(N/2):int(y.shape[0]-N/2)] 63 | return y 64 | 65 | # Linear and non-linear convolutive noise 66 | def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs): 67 | y = [0] * x.shape[0] 68 | for i in range(0, N_f): 69 | if i == 1: 70 | minG = minG-minBiasLinNonLin; 71 | maxG = maxG-maxBiasLinNonLin; 72 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) 73 | y = y + filterFIR(np.power(x, (i+1)), b) 74 | y = y - np.mean(y) 75 | y = normWav(y,0) 76 | return y 77 | 78 | 79 | # Impulsive signal dependent noise 80 | def ISD_additive_noise(x, P, g_sd): 81 | beta = randRange(0, P, 0) 82 | 83 | y = copy.deepcopy(x) 84 | x_len = x.shape[0] 85 | n = int(x_len*(beta/100)) 86 | p = np.random.permutation(x_len)[:n] 87 | f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1)) 88 | r = g_sd * x[p] * f_r 89 | y[p] = x[p] + r 90 | y = normWav(y,0) 91 | return y 92 | 93 | 94 | # Stationary signal independent noise 95 | 96 | def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): 97 | noise = np.random.normal(0, 1, x.shape[0]) 98 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) 99 | noise = filterFIR(noise, b) 100 | noise = normWav(noise,1) 101 | SNR = randRange(SNRmin, SNRmax, 0) 102 | noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR) 103 | x = x + noise 104 | return x 105 | 106 | def process_Rawboost_feature(feature, sr,args,algo): 107 | 108 | # Data process by Convolutive noise (1st algo) 109 | if algo==1: 110 | 111 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 112 | 113 | # Data process by Impulsive noise (2nd algo) 114 | elif algo==2: 115 | 116 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 117 | 118 | # Data process by coloured additive noise (3rd algo) 119 | elif algo==3: 120 | 121 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 122 | 123 | # Data process by all 3 algo. together in series (1+2+3) 124 | elif algo==4: 125 | 126 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 127 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 128 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 129 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF, 130 | args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 131 | 132 | # Data process by 1st two algo. together in series (1+2) 133 | elif algo==5: 134 | 135 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 136 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 137 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 138 | 139 | 140 | # Data process by 1st and 3rd algo. together in series (1+3) 141 | elif algo==6: 142 | 143 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 144 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 145 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 146 | 147 | # Data process by 2nd and 3rd algo. together in series (2+3) 148 | elif algo==7: 149 | 150 | feature=ISD_additive_noise(feature, args.P, args.g_sd) 151 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) 152 | 153 | # Data process by 1st two algo. together in Parallel (1||2) 154 | elif algo==8: 155 | 156 | feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, 157 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) 158 | feature2=ISD_additive_noise(feature, args.P, args.g_sd) 159 | 160 | feature_para=feature1+feature2 161 | feature=normWav(feature_para,0) #normalized resultant waveform 162 | 163 | # original data without Rawboost processing 164 | else: 165 | 166 | feature=feature 167 | 168 | return feature 169 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import soundfile as sf 6 | import torch 7 | from torch import Tensor 8 | from torch.utils.data import Dataset 9 | from utils.utils import str_to_bool 10 | 11 | from data.RawBoost import process_Rawboost_feature 12 | from data.BPMProcess import BpmProcessor 13 | 14 | ___author__ = "Xuanjun Chen" 15 | __email__ = "d12942018@ntu.edu.tw" 16 | 17 | def pad(x, max_len=64600): 18 | x_len = x.shape[0] 19 | if x_len >= max_len: 20 | return x[:max_len] 21 | # need to pad 22 | num_repeats = int(max_len / x_len) + 1 23 | padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] 24 | return padded_x 25 | 26 | def pad_random(x: np.ndarray, start: int, end: int): 27 | x_len = x.shape[0] 28 | 29 | # If the interval is within the length of x 30 | if end <= x_len: 31 | return x[start:end] 32 | 33 | # If the selected interval is longer than x 34 | padded_x = np.tile(x, (end // x_len + 1)) # Repeat x to ensure it covers the interval 35 | return padded_x[start:end] 36 | 37 | class Dataset_SingFake(Dataset): 38 | def __init__(self, args, base_dir, algo, state, is_mixture=False, target_sr=16000): 39 | """ 40 | base_dir should contain mixtures/ and vocals/ folders 41 | """ 42 | self.base_dir = base_dir 43 | self.is_mixture = is_mixture 44 | self.target_sr = target_sr 45 | self.cut = 64600 # take ~4 sec audio (64600 samples) 46 | self.args = args 47 | self.algo = algo 48 | self.state = state 49 | 50 | # get file list 51 | self.file_list = [] 52 | if self.is_mixture: 53 | self.target_path = os.path.join(self.base_dir, "mixtures") 54 | else: 55 | self.target_path = os.path.join(self.base_dir, "vocals") 56 | 57 | print(self.target_path) 58 | 59 | assert os.path.exists(self.target_path), f"{self.target_path} does not exist!" 60 | 61 | for file in os.listdir(self.target_path): 62 | if file.endswith(".flac"): 63 | self.file_list.append(file[:-5]) 64 | 65 | def __len__(self): 66 | return len(self.file_list) 67 | 68 | def __getitem__(self, index): 69 | key = self.file_list[index] 70 | file_path = os.path.join(self.target_path, key + ".flac") 71 | # X, _ = sf.read(file_path, samplerate=self.target_sr) 72 | try: 73 | X, fs = librosa.load(file_path, sr=self.target_sr, mono=False) 74 | except: 75 | return self.__getitem__(np.random.randint(len(self.file_list))) 76 | if X.shape[0] > 1: 77 | # if not mono, take random channel 78 | channel_id = np.random.randint(X.shape[0]) 79 | X = X[channel_id] 80 | 81 | # RawBoost Augmentation 82 | if self.state == "train": 83 | X = process_Rawboost_feature(X, fs, self.args, self.algo) 84 | 85 | X_pad = pad_random(X, self.cut) 86 | X_pad = X_pad / np.max(np.abs(X_pad)) 87 | x_inp = Tensor(X_pad) 88 | y = int(key.split("_")[0]) 89 | return x_inp, y 90 | 91 | class Dataset_SingFake_mert_w2v(Dataset): 92 | def __init__(self, args, config, base_dir, algo, state, 93 | target_sr=16000, target_sr2=24000): 94 | """ 95 | base_dir should contain mixtures/ and vocals/ folders 96 | """ 97 | self.base_dir = base_dir 98 | self.is_mixture = not str_to_bool(config["vocals_only"]) 99 | self.is_sep = str_to_bool(config["is_sep"]) 100 | self.is_rawboost = str_to_bool(config["is_rawboost"]) 101 | self.is_beat_matching = str_to_bool(config["is_beat_matching"]) 102 | 103 | self.target_sr = target_sr 104 | self.target_sr2 = target_sr2 105 | self.cut16 = 64600 # take ~4 sec audio (64600 samples) 106 | self.duration = 4.0375 107 | self.cut24 = 96900 # take ~4 sec audio (96900 samples) 108 | self.args = args 109 | self.algo = algo 110 | self.state = state 111 | 112 | # get file list 113 | self.file_list = [] 114 | if self.is_mixture: 115 | self.target_path = os.path.join(self.base_dir, "mixtures") 116 | self.tgt_v2_path = self.target_path 117 | elif not self.is_mixture and self.is_sep: 118 | self.target_path = os.path.join(self.base_dir, "vocals") 119 | self.tgt_v2_path = os.path.join(self.base_dir, "non_vocals") 120 | else: 121 | self.target_path = os.path.join(self.base_dir, "vocals") 122 | self.tgt_v2_path = self.tgt_v2_path 123 | 124 | self.beat_file_path = os.path.join(self.base_dir, "beats") 125 | 126 | # For beat matching 127 | self.BpmProCls = BpmProcessor(train_acc_path=config["train_acc_path"], 128 | json2bpm_path=config["j2b_path"], 129 | bpm2json_path=config["b2j_path"], 130 | sample_rate=16000, 131 | threshold=self.cut16) 132 | 133 | assert os.path.exists(self.target_path), f"{self.target_path} does not exist!" 134 | 135 | for file in os.listdir(self.target_path): 136 | if file.endswith(".flac"): 137 | self.file_list.append(file[:-5]) 138 | 139 | # Filter 0 or 1 140 | 141 | 142 | def __len__(self): 143 | return len(self.file_list) 144 | 145 | def __getitem__(self, index): 146 | key = self.file_list[index] 147 | y = int(key.split("_")[0]) 148 | 149 | file_path = os.path.join(self.target_path, key + ".flac") 150 | file2_path = os.path.join(self.tgt_v2_path, key + ".wav") 151 | try: 152 | X, fs = librosa.load(file_path, sr=self.target_sr, mono=False) 153 | X2, _ = librosa.load(file2_path, sr=self.target_sr, mono=False) 154 | except: 155 | return self.__getitem__(np.random.randint(len(self.file_list))) 156 | 157 | if X.shape[0] > 1 or X2.shape[0] > 1: 158 | # If not mono, take random channel 159 | channel_id = np.random.randint(X.shape[0]) 160 | X, X2 = X[channel_id], X2[channel_id] 161 | 162 | if self.state == "train" and self.is_beat_matching: 163 | bpm_n = self.BpmProCls.j2b_dict[key + ".json"]["bpm"] 164 | db_num = len(self.BpmProCls.j2b_dict[key + ".json"]["downbeats"]) 165 | if bpm_n is not None and db_num > 2: 166 | X = self.BpmProCls.sv_beat_align(X, key) 167 | sel_json, dbs_duration = self.BpmProCls.sel_accom_from_bpm_group(bpm_n, y) 168 | if dbs_duration != 0: 169 | sel_wav = self.BpmProCls.load_audio_by_json(sel_json) 170 | X2 = self.BpmProCls.accom_beat_padding(sel_wav, sel_json, dbs_duration) 171 | 172 | # RawBoost Augmentation 173 | if self.state == "train" and self.is_rawboost: 174 | X = process_Rawboost_feature(X, fs, self.args, self.algo) 175 | 176 | waveform_shift = X.shape[0] - self.cut16 177 | if waveform_shift > 0: 178 | x_start = np.random.randint(0, waveform_shift) 179 | else: 180 | x_start = 0 181 | 182 | x_end = x_start + self.cut16 183 | X_pad, X2_pad = pad_random(X, x_start, x_end), pad_random(X2, x_start, x_end) 184 | X_pad, X2_pad = X_pad / np.max(np.abs(X_pad)), X2_pad / np.max(np.abs(X2_pad)) 185 | 186 | x_inp, x2_inp = Tensor(X_pad), Tensor(X2_pad) 187 | 188 | 189 | return x_inp, x2_inp, y 190 | -------------------------------------------------------------------------------- /figures/singfake_sota.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjchenGit/SingGraph/f4809cd946ed13670431a0400d434ad3b6fbba26/figures/singfake_sota.jpg -------------------------------------------------------------------------------- /figures/singgraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xjchenGit/SingGraph/f4809cd946ed13670431a0400d434ad3b6fbba26/figures/singgraph.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main script that trains, validates, and evaluates 3 | various models including AASIST. 4 | 5 | AASIST 6 | Copyright (c) 2021-present NAVER Corp. 7 | MIT license 8 | """ 9 | import argparse 10 | import json 11 | import os 12 | import sys 13 | import warnings 14 | from importlib import import_module 15 | from pathlib import Path 16 | from shutil import copy 17 | from typing import Dict, List, Union 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.utils.data import DataLoader 22 | from torch.utils.tensorboard import SummaryWriter 23 | from torchcontrib.optim import SWA 24 | from tqdm import tqdm 25 | 26 | from data.dataloader import Dataset_SingFake, Dataset_SingFake_mert_w2v 27 | from utils.eval_metrics import compute_eer 28 | from utils.utils import create_optimizer, seed_worker, set_seed, str_to_bool 29 | from model.SingGraph import Wav2Vec2Model 30 | 31 | warnings.filterwarnings("ignore", category=UserWarning) 32 | warnings.filterwarnings("ignore", category=FutureWarning) 33 | 34 | def main(args: argparse.Namespace) -> None: 35 | """ 36 | Main function. 37 | Trains, validates, and evaluates the ASVspoof detection model. 38 | """ 39 | # load experiment configurations 40 | with open(args.config, "r") as f_json: 41 | config = json.loads(f_json.read()) 42 | model_config = config["model_config"] 43 | optim_config = config["optim_config"] 44 | optim_config["epochs"] = config["num_epochs"] 45 | track = config["track"] 46 | assert track in ["LA", "PA", "DF"], "Invalid track given" 47 | if "eval_all_best" not in config: 48 | config["eval_all_best"] = "True" 49 | if "freq_aug" not in config: 50 | config["freq_aug"] = "False" 51 | 52 | # make experiment reproducible 53 | set_seed(args.seed, config) 54 | 55 | # define database related paths 56 | output_dir = Path(args.output_dir) 57 | 58 | # define model related paths 59 | model_tag = "{}_{}_ep{}_bs{}".format(track, 60 | os.path.splitext(os.path.basename(args.config))[0], 61 | config["num_epochs"], config["batch_size"]) 62 | if args.comment: 63 | model_tag = model_tag + "_{}".format(args.comment) 64 | model_tag = output_dir / model_tag 65 | model_save_path = model_tag / "weights" 66 | eval_score_path = model_tag / config["eval_output"] 67 | writer = SummaryWriter(model_tag) 68 | os.makedirs(model_save_path, exist_ok=True) 69 | copy(args.config, model_tag / "config.conf") 70 | 71 | # set device 72 | gpu_id = args.gpu 73 | device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu") 74 | # device = "cuda" if torch.cuda.is_available() else "cpu" 75 | print("Device: {}".format(device)) 76 | if device == "cpu": 77 | raise ValueError("GPU not detected!") 78 | 79 | # define model architecture 80 | # model = get_model(model_config, device) 81 | model = get_wav2vec2_model(model_config, device).to(device) 82 | 83 | # define dataloaders 84 | # trn_loader, dev_loader, eval_loader, additional_loader, persian_loader, mp3_loader, ogg_loader, aac_loader, opus_loader = get_singfake_loaders(args.seed, args, config) 85 | dataset_loaders = get_singfake_loaders(args.seed, args, config) 86 | 87 | # evaluates pretrained model and exit script 88 | if args.eval: 89 | model.load_state_dict( 90 | torch.load(config["model_path"], map_location=device)) 91 | print("Model loaded : {}".format(config["model_path"])) 92 | print("Evaluating on SingFake Eval Set...") 93 | 94 | eer_results = {} 95 | for data_key in dataset_loaders: 96 | eer = evaluate(dataset_loaders[data_key], model, device) 97 | eer_results[data_key] = eer 98 | 99 | avg_codec_eer = (eer_results["codec_test/mp3_128k"] + 100 | eer_results["codec_test/ogg_64k"] + 101 | eer_results["codec_test/adts_64k"] + 102 | eer_results["codec_test/opus_64k"]) / 4. 103 | 104 | print("Done. train_eer: {:.2f} %, test_set_eer: {:.2f} %, additional_test_eer: {:.2f} %, persian_eer: {:.2f} %".format(eer_results["train"] * 100, 105 | eer_results["T01"] * 100, 106 | eer_results["T02"] * 100, 107 | eer_results["T04"] * 100)) 108 | 109 | print("Codec testing: Average EER: {:.2f} %, MP3 EER: {:.2f} %, OGG EER: {:.2f} %, AAC EER: {:.2f} %, OPUS EER: {:.2f} %".format(avg_codec_eer * 100, 110 | eer_results["codec_test/mp3_128k"] * 100, 111 | eer_results["codec_test/ogg_64k"] * 100, 112 | eer_results["codec_test/adts_64k"] * 100, 113 | eer_results["codec_test/opus_64k"] * 100)) 114 | sys.exit(0) 115 | 116 | # get optimizer and scheduler 117 | optim_config["steps_per_epoch"] = len(dataset_loaders["train"]) 118 | optimizer, scheduler = create_optimizer(model.parameters(), optim_config) 119 | optimizer_swa = SWA(optimizer) 120 | 121 | best_dev_eer = 1. 122 | best_eval_eer = 100. 123 | n_swa_update = 0 # number of snapshots of model to use in SWA 124 | f_log = open(model_tag / "metric_log.txt", "a") 125 | f_log.write("=" * 5 + "\n") 126 | 127 | # make directory for metric logging 128 | metric_path = model_tag / "metrics" 129 | os.makedirs(metric_path, exist_ok=True) 130 | 131 | # Training 132 | for epoch in range(config["num_epochs"]): 133 | print("Start training epoch{:03d}".format(epoch)) 134 | running_loss = train_epoch(dataset_loaders["train"], model, optimizer, device, 135 | scheduler, config) 136 | 137 | dev_eer = evaluate(dataset_loaders["dev"], model, device) 138 | additional_eer = evaluate(dataset_loaders["T02"], model, device) 139 | 140 | print("DONE.\nLoss:{:.5f}, dev_eer: {:.2f} %, additional_test_eer: {:.2f} %".format( 141 | running_loss, dev_eer * 100, additional_eer * 100)) 142 | writer.add_scalar("loss", running_loss, epoch) 143 | writer.add_scalar("dev_eer", dev_eer, epoch) 144 | writer.add_scalar("additional_eer", additional_eer, epoch) 145 | 146 | if best_dev_eer >= dev_eer: 147 | print("best model find at epoch", epoch) 148 | best_dev_eer = dev_eer 149 | torch.save(model.state_dict(), 150 | model_save_path / "epoch_{}_{:03.3f}.pth".format(epoch, dev_eer)) 151 | 152 | # do evaluation whenever best model is renewed 153 | if str_to_bool(config["eval_all_best"]): 154 | eval_eer = evaluate(dataset_loaders["T01"], model, device) 155 | log_text = "epoch{:03d}, ".format(epoch) 156 | if eval_eer < best_eval_eer: 157 | log_text += "best eer, {:.4f}%".format(eval_eer) 158 | best_eval_eer = eval_eer 159 | torch.save(model.state_dict(), model_save_path / "best.pth") 160 | 161 | if len(log_text) > 0: 162 | print(log_text) 163 | f_log.write(log_text + "\n") 164 | 165 | print("Saving epoch {} for swa".format(epoch)) 166 | optimizer_swa.update_swa() 167 | n_swa_update += 1 168 | writer.add_scalar("best_dev_eer", best_dev_eer, epoch) 169 | 170 | print("Start final evaluation") 171 | epoch += 1 172 | if n_swa_update > 0: 173 | optimizer_swa.swap_swa_sgd() 174 | optimizer_swa.bn_update(dataset_loaders["train"], model, device=device) 175 | eval_eer = evaluate(dataset_loaders["T01"], model, device) 176 | f_log = open(model_tag / "metric_log.txt", "a") 177 | f_log.write("=" * 5 + "\n") 178 | f_log.write("EER: {:.3f} %".format(eval_eer * 100)) 179 | f_log.close() 180 | 181 | torch.save(model.state_dict(), 182 | model_save_path / "swa.pth") 183 | 184 | if eval_eer <= best_eval_eer: 185 | best_eval_eer = eval_eer 186 | torch.save(model.state_dict(), model_save_path / "best.pth") 187 | 188 | print("Exp FIN. EER: {:.3f} %".format(best_eval_eer * 100)) 189 | 190 | 191 | def get_model(model_config: Dict, device: torch.device): 192 | """Define DNN model architecture""" 193 | module = import_module("models.{}".format(model_config["architecture"])) 194 | _model = getattr(module, "Model") 195 | model = _model(model_config).to(device) 196 | nb_params = sum([param.view(-1).size()[0] for param in model.parameters()]) 197 | print("no. model params:{}".format(nb_params)) 198 | 199 | return model 200 | 201 | def get_wav2vec2_model(model_config: Dict, device: torch.device): 202 | model = Wav2Vec2Model(model_config, device) 203 | return model 204 | 205 | def get_singfake_loaders(seed: int, args: argparse.Namespace, config: dict) -> List[torch.utils.data.DataLoader]: 206 | # base_dir = "../../dataset/split_dump_flac/" 207 | base_dir = "./dataset/split_dump_flac/" 208 | target_sr = float(config["target_sr"]) 209 | 210 | # Define dataset keys and their corresponding paths 211 | dataset_keys = ["train", "dev", "T01", "T02", "T04", 212 | "codec_test/mp3_128k", "codec_test/ogg_64k", 213 | "codec_test/adts_64k", "codec_test/opus_64k"] 214 | datasets = {} 215 | 216 | # Common DataLoader settings 217 | common_settings = { 218 | "batch_size": config["batch_size"], 219 | "num_workers": 4, 220 | "pin_memory": True 221 | } 222 | 223 | # Initialize the generator for reproducibility 224 | gen = torch.Generator() 225 | gen.manual_seed(seed) 226 | 227 | # Creating DataLoaders for each dataset 228 | for key in dataset_keys: 229 | dataset_path = os.path.join(base_dir, key) 230 | shuffle = True if key == "train" else False 231 | drop_last = True if key == "train" else False 232 | worker_init_fn = seed_worker if key == "train" else None 233 | generator = gen if key == "train" else None 234 | 235 | dataset = Dataset_SingFake_mert_w2v(args, config, base_dir=dataset_path, algo=args.algo, 236 | state="train" if key == "train" else "test", target_sr=target_sr) 237 | 238 | datasets[key] = DataLoader(dataset, shuffle=shuffle, drop_last=drop_last, worker_init_fn=worker_init_fn, 239 | generator=generator, **common_settings) 240 | 241 | # Return DataLoaders in the specified order 242 | return datasets 243 | 244 | def evaluate(loader, model, device: torch.device): 245 | """ 246 | Evaluate the model on the given loader, then return EER. 247 | """ 248 | model.eval() 249 | # we save target (1) scores to target_scores, and non target (0) scores to nontarget_scores. 250 | target_scores = [] 251 | nontarget_scores = [] 252 | debug = False 253 | count = 0 254 | with torch.no_grad(): 255 | for batch_x, batch_x2, batch_y in tqdm(loader, total=len(loader)): 256 | batch_x, batch_x2 = batch_x.to(device), batch_x2.to(device) 257 | batch_out = model(batch_x, batch_x2) 258 | batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel() 259 | batch_y = batch_y.data.cpu().numpy().ravel() 260 | for i in range(len(batch_y)): 261 | if batch_y[i] == 1: 262 | target_scores.append(batch_score[i]) 263 | else: 264 | nontarget_scores.append(batch_score[i]) 265 | count += 1 266 | if count == 10 and debug: 267 | break 268 | 269 | eer, _ = compute_eer(target_scores, nontarget_scores) 270 | return eer 271 | 272 | 273 | def train_epoch( 274 | trn_loader: DataLoader, 275 | model, 276 | optim: Union[torch.optim.SGD, torch.optim.Adam], 277 | device: torch.device, 278 | scheduler: torch.optim.lr_scheduler, 279 | config: argparse.Namespace): 280 | """Train the model for one epoch""" 281 | running_loss = 0 282 | num_total = 0.0 283 | ii = 0 284 | model.train() 285 | 286 | # set objective (Loss) functions 287 | weight = torch.FloatTensor([0.1, 0.9]).to(device) 288 | criterion = nn.CrossEntropyLoss(weight=weight) 289 | pbar = tqdm(trn_loader, total=len(trn_loader)) 290 | for batch_x, batch_x2, batch_y in pbar: 291 | batch_size = batch_x.size(0) 292 | num_total += batch_size 293 | ii += 1 294 | batch_x, batch_x2 = batch_x.to(device), batch_x2.to(device) 295 | batch_y = batch_y.view(-1).type(torch.int64).to(device) 296 | batch_out = model(batch_x, batch_x2) 297 | batch_loss = criterion(batch_out, batch_y) 298 | running_loss += batch_loss.item() * batch_size 299 | pbar.set_description("loss: {:.5f}, running loss: {:.5f}".format( 300 | batch_loss.item(), running_loss / num_total)) 301 | optim.zero_grad() 302 | batch_loss.backward() 303 | optim.step() 304 | 305 | if config["optim_config"]["scheduler"] in ["cosine", "keras_decay"]: 306 | scheduler.step() 307 | elif scheduler is None: 308 | pass 309 | else: 310 | raise ValueError("scheduler error, got:{}".format(scheduler)) 311 | 312 | running_loss /= num_total 313 | return running_loss 314 | 315 | 316 | if __name__ == "__main__": 317 | parser = argparse.ArgumentParser(description="ASVspoof detection system") 318 | parser.add_argument("--config", 319 | dest="config", 320 | type=str, 321 | help="configuration file", 322 | required=True) 323 | parser.add_argument( 324 | "--output_dir", 325 | dest="output_dir", 326 | type=str, 327 | help="output directory for results", 328 | default="./exp_result", 329 | ) 330 | parser.add_argument("--seed", 331 | type=int, 332 | default=1234, 333 | help="random seed (default: 1234)") 334 | parser.add_argument( 335 | "--eval", 336 | action="store_true", 337 | help="when this flag is given, evaluates given model and exit") 338 | parser.add_argument("--comment", 339 | type=str, 340 | default=None, 341 | help="comment to describe the saved model") 342 | parser.add_argument("--eval_model_weights", 343 | type=str, 344 | default=None, 345 | help="directory to the model weight file (can be also given in the config file)") 346 | parser.add_argument("--gpu", 347 | type=int, 348 | default=0, 349 | help="gpu id to use (default: 0)") 350 | 351 | ##===================================================Rawboost data augmentation ===============================================================# 352 | 353 | parser.add_argument('--algo', type=int, default=3, 354 | help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \ 355 | 5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .[default=0]') 356 | 357 | # LnL_convolutive_noise parameters 358 | parser.add_argument('--nBands', type=int, default=5, 359 | help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]') 360 | parser.add_argument('--minF', type=int, default=20, 361 | help='minimum centre frequency [Hz] of notch filter.[default=20] ') 362 | parser.add_argument('--maxF', type=int, default=8000, 363 | help='maximum centre frequency [Hz] ( 0 else nn.Identity() 345 | self.in_dim = in_dim 346 | 347 | def forward(self, h): 348 | Z = self.drop(h) 349 | weights = self.proj(Z) 350 | scores = self.sigmoid(weights) 351 | new_h = self.top_k_graph(scores, h, self.k) 352 | 353 | return new_h 354 | 355 | def top_k_graph(self, scores, h, k): 356 | """ 357 | args 358 | ===== 359 | scores: attention-based weights (#bs, #node, 1) 360 | h: graph data (#bs, #node, #dim) 361 | k: ratio of remaining nodes, (float) 362 | returns 363 | ===== 364 | h: graph pool applied data (#bs, #node', #dim) 365 | """ 366 | _, n_nodes, n_feat = h.size() 367 | n_nodes = max(int(n_nodes * k), 1) 368 | _, idx = torch.topk(scores, n_nodes, dim=1) 369 | idx = idx.expand(-1, -1, n_feat) 370 | 371 | h = h * scores 372 | h = torch.gather(h, 1, idx) 373 | 374 | return h 375 | 376 | class Residual_block(nn.Module): 377 | def __init__(self, nb_filts, first=False): 378 | super().__init__() 379 | self.first = first 380 | 381 | if not self.first: 382 | self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) 383 | self.conv1 = nn.Conv2d(in_channels=nb_filts[0], 384 | out_channels=nb_filts[1], 385 | kernel_size=(2, 3), 386 | padding=(1, 1), 387 | stride=1) 388 | self.selu = nn.SELU(inplace=True) 389 | 390 | self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) 391 | self.conv2 = nn.Conv2d(in_channels=nb_filts[1], 392 | out_channels=nb_filts[1], 393 | kernel_size=(2, 3), 394 | padding=(0, 1), 395 | stride=1) 396 | 397 | if nb_filts[0] != nb_filts[1]: 398 | self.downsample = True 399 | self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], 400 | out_channels=nb_filts[1], 401 | padding=(0, 1), 402 | kernel_size=(1, 3), 403 | stride=1) 404 | 405 | else: 406 | self.downsample = False 407 | 408 | 409 | def forward(self, x): 410 | identity = x 411 | if not self.first: 412 | out = self.bn1(x) 413 | out = self.selu(out) 414 | else: 415 | out = x 416 | 417 | out = self.conv1(x) 418 | 419 | out = self.bn2(out) 420 | out = self.selu(out) 421 | out = self.conv2(out) 422 | 423 | if self.downsample: 424 | identity = self.conv_downsample(identity) 425 | 426 | out += identity 427 | return out 428 | 429 | class Wav2Vec2Model(nn.Module): 430 | def __init__(self, args,device): 431 | super().__init__() 432 | self.device = device 433 | 434 | # AASIST parameters 435 | filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] 436 | gat_dims = [64, 32] 437 | pool_ratios = [0.5, 0.5, 0.5, 0.5] 438 | temperatures = [2.0, 2.0, 100.0, 100.0] 439 | 440 | 441 | #### 442 | # create network wav2vec 2.0 443 | #### 444 | self.ssl_model = SSLModel(self.device) 445 | self.LL = nn.Linear(self.ssl_model.out_dim, 128) 446 | 447 | self.first_bn = nn.BatchNorm2d(num_features=1) 448 | self.first_bn1 = nn.BatchNorm2d(num_features=64) 449 | self.drop = nn.Dropout(0.5, inplace=True) 450 | self.drop_way = nn.Dropout(0.2, inplace=True) 451 | self.selu = nn.SELU(inplace=True) 452 | 453 | # RawNet2 encoder 454 | self.encoder = nn.Sequential( 455 | nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), 456 | nn.Sequential(Residual_block(nb_filts=filts[2])), 457 | nn.Sequential(Residual_block(nb_filts=filts[3])), 458 | nn.Sequential(Residual_block(nb_filts=filts[4])), 459 | nn.Sequential(Residual_block(nb_filts=filts[4])), 460 | nn.Sequential(Residual_block(nb_filts=filts[4]))) 461 | 462 | self.attention = nn.Sequential( 463 | nn.Conv2d(64, 128, kernel_size=(1,1)), 464 | nn.SELU(inplace=True), 465 | nn.BatchNorm2d(128), 466 | nn.Conv2d(128, 64, kernel_size=(1,1)), 467 | 468 | ) 469 | # position encoding 470 | self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) 471 | 472 | self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 473 | self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) 474 | 475 | # Graph module 476 | self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], 477 | gat_dims[0], 478 | temperature=temperatures[0]) 479 | self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], 480 | gat_dims[0], 481 | temperature=temperatures[1]) 482 | # HS-GAL layer 483 | self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( 484 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 485 | self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( 486 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 487 | self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( 488 | gat_dims[0], gat_dims[1], temperature=temperatures[2]) 489 | self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( 490 | gat_dims[1], gat_dims[1], temperature=temperatures[2]) 491 | 492 | # Graph pooling layers 493 | self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) 494 | self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) 495 | self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 496 | self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 497 | 498 | self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 499 | self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) 500 | 501 | self.out_layer = nn.Linear(5 * gat_dims[1], 2) 502 | 503 | def forward(self, x, x2): 504 | #-------pre-trained Wav2vec model fine tunning ------------------------## 505 | x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1), x2.squeeze(-1)) 506 | x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) 507 | 508 | # post-processing on front-end features 509 | x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number) 510 | x = x.unsqueeze(dim=1) # add channel 511 | x = F.max_pool2d(x, (3, 3)) 512 | x = self.first_bn(x) 513 | x = self.selu(x) 514 | 515 | # RawNet2-based encoder 516 | x = self.encoder(x) 517 | x = self.first_bn1(x) 518 | x = self.selu(x) 519 | 520 | w = self.attention(x) 521 | 522 | #------------SA for spectral feature-------------# 523 | w1 = F.softmax(w,dim=-1) 524 | m = torch.sum(x * w1, dim=-1) 525 | e_S = m.transpose(1, 2) + self.pos_S 526 | 527 | # graph module layer 528 | gat_S = self.GAT_layer_S(e_S) 529 | out_S = self.pool_S(gat_S) # (#bs, #node, #dim) 530 | 531 | #------------SA for temporal feature-------------# 532 | w2 = F.softmax(w,dim=-2) 533 | m1 = torch.sum(x * w2, dim=-2) 534 | 535 | e_T = m1.transpose(1, 2) 536 | 537 | # graph module layer 538 | gat_T = self.GAT_layer_T(e_T) 539 | out_T = self.pool_T(gat_T) 540 | 541 | # learnable master node 542 | master1 = self.master1.expand(x.size(0), -1, -1) 543 | master2 = self.master2.expand(x.size(0), -1, -1) 544 | 545 | # inference 1 546 | out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( 547 | out_T, out_S, master=self.master1) 548 | 549 | out_S1 = self.pool_hS1(out_S1) 550 | out_T1 = self.pool_hT1(out_T1) 551 | 552 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( 553 | out_T1, out_S1, master=master1) 554 | out_T1 = out_T1 + out_T_aug 555 | out_S1 = out_S1 + out_S_aug 556 | master1 = master1 + master_aug 557 | 558 | # inference 2 559 | out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( 560 | out_T, out_S, master=self.master2) 561 | out_S2 = self.pool_hS2(out_S2) 562 | out_T2 = self.pool_hT2(out_T2) 563 | 564 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( 565 | out_T2, out_S2, master=master2) 566 | out_T2 = out_T2 + out_T_aug 567 | out_S2 = out_S2 + out_S_aug 568 | master2 = master2 + master_aug 569 | 570 | out_T1 = self.drop_way(out_T1) 571 | out_T2 = self.drop_way(out_T2) 572 | out_S1 = self.drop_way(out_S1) 573 | out_S2 = self.drop_way(out_S2) 574 | master1 = self.drop_way(master1) 575 | master2 = self.drop_way(master2) 576 | 577 | out_T = torch.max(out_T1, out_T2) 578 | out_S = torch.max(out_S1, out_S2) 579 | master = torch.max(master1, master2) 580 | 581 | # Readout operation 582 | T_max, _ = torch.max(torch.abs(out_T), dim=1) 583 | T_avg = torch.mean(out_T, dim=1) 584 | 585 | S_max, _ = torch.max(torch.abs(out_S), dim=1) 586 | S_avg = torch.mean(out_S, dim=1) 587 | 588 | last_hidden = torch.cat( 589 | [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) 590 | 591 | last_hidden = self.drop(last_hidden) 592 | output = self.out_layer(last_hidden) 593 | 594 | return output 595 | 596 | if __name__ == "__main__": 597 | import json 598 | with open("AASIST.conf", "r") as f_json: 599 | config = json.loads(f_json.read()) 600 | with torch.no_grad(): 601 | wav = torch.randn((1, 32000)) 602 | w2v2_model = Wav2Vec2Model(config["model_config"], "cuda") 603 | output = w2v2_model(wav, wav) 604 | print(f"w2v2_model: {w2v2_model}") 605 | print(f"output: {output.shape}") 606 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.1 2 | transformers 3 | pandas 4 | numpy==1.21.6 5 | numba==0.56.4 6 | tensorboardX 7 | protobuf==3.20.* 8 | pexpect>4.3 9 | numba>=0.53 10 | dcase-util>=0.2.4 11 | tensorboard 12 | torchcontrib -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # export TRANSFORMERS_CACHE=./cache 2 | export HF_HOME=./cache 3 | 4 | python3 main.py \ 5 | --config utils/SingGraph.conf \ 6 | --output_dir ./exp_result \ 7 | --eval 8 | -------------------------------------------------------------------------------- /utils/SingGraph.conf: -------------------------------------------------------------------------------- 1 | { 2 | "database_path": "./LA/", 3 | "asv_score_path": "ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt", 4 | "b2j_path": "./dataset/split_dump_flac/train/bpm2json.json", 5 | "j2b_path": "./dataset/split_dump_flac/train/json2bpm.json", 6 | "train_acc_path": "./dataset/split_dump_flac/train/non_vocals/", 7 | "train_vocal_path": "./dataset/split_dump_flac/train/vocals/", 8 | "model_path": "./exp_result/LA_SingGraph_ep80_bs12/weights/epoch_39_0.046.pth", 9 | "batch_size": 12, 10 | "num_epochs": 80, 11 | "target_sr": 16000, 12 | "vocals_only": "True", 13 | "is_sep": "True", 14 | "is_rawboost": "True", 15 | "is_beat_matching": "True", 16 | "loss": "CCE", 17 | "track": "LA", 18 | "eval_all_best": "True", 19 | "eval_output": "eval_scores_using_best_dev_model.txt", 20 | "cudnn_deterministic_toggle": "True", 21 | "cudnn_benchmark_toggle": "False", 22 | "model_config": { 23 | "architecture": "AASIST", 24 | "nb_samp": 64000, 25 | "first_conv": 128, 26 | "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]], 27 | "gat_dims": [64, 32], 28 | "pool_ratios": [0.5, 0.7, 0.5, 0.5], 29 | "temperatures": [2.0, 2.0, 100.0, 100.0] 30 | }, 31 | "optim_config": { 32 | "optimizer": "adam", 33 | "amsgrad": "False", 34 | "base_lr": 0.000001, 35 | "lr_min": 0.000005, 36 | "betas": [0.9, 0.999], 37 | "weight_decay": 0.0001, 38 | "scheduler": "cosine" 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /utils/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | 5 | 6 | def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold): 7 | 8 | # False alarm and miss rates for ASV 9 | Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size 10 | Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size 11 | 12 | # Rate of rejecting spoofs in ASV 13 | if spoof_asv.size == 0: 14 | Pmiss_spoof_asv = None 15 | else: 16 | Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size 17 | 18 | return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv 19 | 20 | 21 | def compute_det_curve(target_scores, nontarget_scores): 22 | target_scores, nontarget_scores = np.array(target_scores), np.array(nontarget_scores) 23 | n_scores = target_scores.size + nontarget_scores.size 24 | all_scores = np.concatenate((target_scores, nontarget_scores)) 25 | labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size))) 26 | 27 | # Sort labels based on scores 28 | indices = np.argsort(all_scores, kind='mergesort') 29 | labels = labels[indices] 30 | 31 | # Compute false rejection and false acceptance rates 32 | tar_trial_sums = np.cumsum(labels) 33 | nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums) 34 | 35 | frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates 36 | far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates 37 | thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores 38 | 39 | return frr, far, thresholds 40 | 41 | 42 | def compute_eer(target_scores, nontarget_scores): 43 | """ Returns equal error rate (EER) and the corresponding threshold. """ 44 | frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) 45 | abs_diffs = np.abs(frr - far) 46 | min_index = np.argmin(abs_diffs) 47 | eer = np.mean((frr[min_index], far[min_index])) 48 | return eer, thresholds[min_index] 49 | 50 | 51 | def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, print_cost=False): 52 | """ 53 | Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. 54 | In brief, t-DCF returns a detection cost of a cascaded system of this form, 55 | 56 | Speech waveform -> [CM] -> [ASV] -> decision 57 | 58 | where CM stands for countermeasure and ASV for automatic speaker 59 | verification. The CM is therefore used as a 'gate' to decided whether or 60 | not the input speech sample should be passed onwards to the ASV system. 61 | Generally, both CM and ASV can do detection errors. Not all those errors 62 | are necessarily equally cost, and not all types of users are necessarily 63 | equally likely. The tandem t-DCF gives a principled with to compare 64 | different spoofing countermeasures under a detection cost function 65 | framework that takes that information into account. 66 | 67 | INPUTS: 68 | 69 | bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human) 70 | detection scores obtained by executing a spoofing 71 | countermeasure (CM) on some positive evaluation trials. 72 | trial represents a bona fide case. 73 | spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack) 74 | detection scores obtained by executing a spoofing 75 | CM on some negative evaluation trials. 76 | Pfa_asv False alarm (false acceptance) rate of the ASV 77 | system that is evaluated in tandem with the CM. 78 | Assumed to be in fractions, not percentages. 79 | Pmiss_asv Miss (false rejection) rate of the ASV system that 80 | is evaluated in tandem with the spoofing CM. 81 | Assumed to be in fractions, not percentages. 82 | Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that 83 | is evaluated in tandem with the spoofing CM. That 84 | is, the fraction of spoof samples that were 85 | rejected by the ASV system. 86 | cost_model A struct that contains the parameters of t-DCF, 87 | with the following fields. 88 | 89 | Ptar Prior probability of target speaker. 90 | Pnon Prior probability of nontarget speaker (zero-effort impostor) 91 | Psoof Prior probability of spoofing attack. 92 | Cmiss_asv Cost of ASV falsely rejecting target. 93 | Cfa_asv Cost of ASV falsely accepting nontarget. 94 | Cmiss_cm Cost of CM falsely rejecting target. 95 | Cfa_cm Cost of CM falsely accepting spoof. 96 | 97 | print_cost Print a summary of the cost parameters and the 98 | implied t-DCF cost function? 99 | 100 | OUTPUTS: 101 | 102 | tDCF_norm Normalized t-DCF curve across the different CM 103 | system operating points; see [2] for more details. 104 | Normalized t-DCF > 1 indicates a useless 105 | countermeasure (as the tandem system would do 106 | better without it). min(tDCF_norm) will be the 107 | minimum t-DCF used in ASVspoof 2019 [2]. 108 | CM_thresholds Vector of same size as tDCF_norm corresponding to 109 | the CM threshold (operating point). 110 | 111 | NOTE: 112 | o In relative terms, higher detection scores values are assumed to 113 | indicate stronger support for the bona fide hypothesis. 114 | o You should provide real-valued soft scores, NOT hard decisions. The 115 | recommendation is that the scores are log-likelihood ratios (LLRs) 116 | from a bonafide-vs-spoof hypothesis based on some statistical model. 117 | This, however, is NOT required. The scores can have arbitrary range 118 | and scaling. 119 | o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. 120 | 121 | References: 122 | 123 | [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco, 124 | M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection 125 | Cost Function for the Tandem Assessment of Spoofing Countermeasures 126 | and Automatic Speaker Verification", Proc. Odyssey 2018: the 127 | Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne, 128 | France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf) 129 | 130 | [2] ASVspoof 2019 challenge evaluation plan 131 | TODO: 132 | """ 133 | 134 | 135 | # Sanity check of cost parameters 136 | if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \ 137 | cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0: 138 | print('WARNING: Usually the cost values should be positive!') 139 | 140 | if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ 141 | np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: 142 | sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.') 143 | 144 | # Unless we evaluate worst-case model, we need to have some spoof tests against asv 145 | if Pmiss_spoof_asv is None: 146 | sys.exit('ERROR: you should provide miss rate of spoof tests against your ASV system.') 147 | 148 | # Sanity check of scores 149 | combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) 150 | if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): 151 | sys.exit('ERROR: Your scores contain nan or inf.') 152 | 153 | # Sanity check that inputs are scores and not decisions 154 | n_uniq = np.unique(combined_scores).size 155 | if n_uniq < 3: 156 | sys.exit('ERROR: You should provide soft CM scores - not binary decisions') 157 | 158 | # Obtain miss and false alarm rates of CM 159 | Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm) 160 | 161 | # Constants - see ASVspoof 2019 evaluation plan 162 | C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \ 163 | cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv 164 | C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv) 165 | 166 | # Sanity check of the weights 167 | if C1 < 0 or C2 < 0: 168 | sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?') 169 | 170 | # Obtain t-DCF curve for all thresholds 171 | tDCF = C1 * Pmiss_cm + C2 * Pfa_cm 172 | 173 | # Normalized t-DCF 174 | tDCF_norm = tDCF / np.minimum(C1, C2) 175 | 176 | # Everything should be fine if reaching here. 177 | if print_cost: 178 | 179 | print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size)) 180 | print('t-DCF MODEL') 181 | print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar'])) 182 | print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon'])) 183 | print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof'])) 184 | print(' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'.format(cost_model['Cfa_asv'])) 185 | print(' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'.format(cost_model['Cmiss_asv'])) 186 | print(' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'.format(cost_model['Cfa_cm'])) 187 | print(' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'.format(cost_model['Cmiss_cm'])) 188 | print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)') 189 | 190 | if C2 == np.minimum(C1, C2): 191 | print(' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(C1 / C2)) 192 | else: 193 | print(' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(C2 / C1)) 194 | 195 | return tDCF_norm, CM_thresholds 196 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilization functions 3 | """ 4 | 5 | import os 6 | import random 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | 12 | # from warmup_scheduler import GradualWarmupScheduler 13 | 14 | 15 | def str_to_bool(val): 16 | """Convert a string representation of truth to true (1) or false (0). 17 | Copied from the python implementation distutils.utils.strtobool 18 | 19 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 20 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 21 | 'val' is anything else. 22 | >>> str_to_bool('YES') 23 | 1 24 | >>> str_to_bool('FALSE') 25 | 0 26 | """ 27 | val = val.lower() 28 | if val in ('y', 'yes', 't', 'true', 'on', '1'): 29 | return True 30 | if val in ('n', 'no', 'f', 'false', 'off', '0'): 31 | return False 32 | raise ValueError('invalid truth value {}'.format(val)) 33 | 34 | 35 | def cosine_annealing(step, total_steps, lr_max, lr_min): 36 | """Cosine Annealing for learning rate decay scheduler""" 37 | return lr_min + (lr_max - 38 | lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi)) 39 | 40 | 41 | def keras_decay(step, decay=0.0001): 42 | """Learning rate decay in Keras-style""" 43 | return 1. / (1. + decay * step) 44 | 45 | 46 | class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler): 47 | """SGD with restarts scheduler""" 48 | def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1): 49 | self.Ti = T0 50 | self.T_mul = T_mul 51 | self.eta_min = eta_min 52 | 53 | self.last_restart = 0 54 | 55 | super().__init__(optimizer, last_epoch) 56 | 57 | def get_lr(self): 58 | T_cur = self.last_epoch - self.last_restart 59 | if T_cur >= self.Ti: 60 | self.last_restart = self.last_epoch 61 | self.Ti = self.Ti * self.T_mul 62 | T_cur = 0 63 | 64 | return [ 65 | self.eta_min + (base_lr - self.eta_min) * 66 | (1 + np.cos(np.pi * T_cur / self.Ti)) / 2 67 | for base_lr in self.base_lrs 68 | ] 69 | 70 | 71 | def _get_optimizer(model_parameters, optim_config): 72 | """Defines optimizer according to the given config""" 73 | optimizer_name = optim_config['optimizer'] 74 | 75 | if optimizer_name == 'sgd': 76 | optimizer = torch.optim.SGD(model_parameters, 77 | lr=optim_config['base_lr'], 78 | momentum=optim_config['momentum'], 79 | weight_decay=optim_config['weight_decay'], 80 | nesterov=optim_config['nesterov']) 81 | elif optimizer_name == 'adam': 82 | optimizer = torch.optim.Adam(model_parameters, 83 | lr=optim_config['base_lr'], 84 | betas=optim_config['betas'], 85 | weight_decay=optim_config['weight_decay'], 86 | amsgrad=str_to_bool( 87 | optim_config['amsgrad'])) 88 | else: 89 | print('Un-known optimizer', optimizer_name) 90 | sys.exit() 91 | 92 | return optimizer 93 | 94 | 95 | def _get_scheduler(optimizer, optim_config): 96 | """ 97 | Defines learning rate scheduler according to the given config 98 | """ 99 | if optim_config['scheduler'] == 'multistep': 100 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 101 | optimizer, 102 | milestones=optim_config['milestones'], 103 | gamma=optim_config['lr_decay']) 104 | 105 | elif optim_config['scheduler'] == 'sgdr': 106 | scheduler = SGDRScheduler(optimizer, optim_config['T0'], 107 | optim_config['Tmult'], 108 | optim_config['lr_min']) 109 | 110 | elif optim_config['scheduler'] == 'cosine': 111 | total_steps = optim_config['epochs'] * \ 112 | optim_config['steps_per_epoch'] 113 | 114 | scheduler = torch.optim.lr_scheduler.LambdaLR( 115 | optimizer, 116 | lr_lambda=lambda step: cosine_annealing( 117 | step, 118 | total_steps, 119 | 1, # since lr_lambda computes multiplicative factor 120 | optim_config['lr_min'] / optim_config['base_lr'])) 121 | 122 | elif optim_config['scheduler'] == 'keras_decay': 123 | scheduler = torch.optim.lr_scheduler.LambdaLR( 124 | optimizer, lr_lambda=lambda step: keras_decay(step)) 125 | else: 126 | scheduler = None 127 | return scheduler 128 | 129 | 130 | def create_optimizer(model_parameters, optim_config): 131 | """Defines an optimizer and a scheduler""" 132 | optimizer = _get_optimizer(model_parameters, optim_config) 133 | scheduler = _get_scheduler(optimizer, optim_config) 134 | return optimizer, scheduler 135 | 136 | 137 | def seed_worker(worker_id): 138 | """ 139 | Used in generating seed for the worker of torch.utils.data.Dataloader 140 | """ 141 | worker_seed = torch.initial_seed() % 2**32 142 | np.random.seed(worker_seed) 143 | random.seed(worker_seed) 144 | 145 | 146 | def set_seed(seed, config = None): 147 | """ 148 | set initial seed for reproduction 149 | """ 150 | if config is None: 151 | raise ValueError("config should not be None") 152 | 153 | random.seed(seed) 154 | np.random.seed(seed) 155 | torch.manual_seed(seed) 156 | if torch.cuda.is_available(): 157 | torch.cuda.manual_seed_all(seed) 158 | torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"]) 159 | torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"]) 160 | --------------------------------------------------------------------------------