├── .gitignore
├── README.md
├── data
├── BPMProcess.py
├── RawBoost.py
└── dataloader.py
├── figures
├── singfake_sota.jpg
└── singgraph.png
├── main.py
├── model
└── SingGraph.py
├── requirements.txt
├── run.sh
└── utils
├── SingGraph.conf
├── eval_metrics.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | dataset/
3 | cache/w2v2/xlsr2_300m.pt
4 | cache/models--m-a-p--MERT-v1-330M/
5 | cache/hub/
6 | cache/modules/
7 | exp_result/
8 | fairseq/
9 | fairseq1/
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SingGraph [](https://arxiv.org/abs/2406.03111)
2 |
3 | This is the official repository for the **[SingGraph](https://arxiv.org/abs/2406.03111)** model. The paper has been accepted by Interspeech 2024.
4 |
5 | The official code has been released. I need some time to write the `README.md` of the data pre-processing processes.
6 | ## Abstract
7 |
8 | Existing models for speech deepfake detection have struggled to adapt to unseen attacks in this unique singing voice domain of human vocalization. To bridge the gap, we present a groundbreaking SingGraph model. The model synergizes the capabilities of the MERT acoustic music understanding model for pitch and rhythm analysis with the wav2vec2.0 model for linguistic analysis of lyrics Additionally, we advocate for using RawBoost and beat matching techniques grounded in music domain knowledge for singing voice augmentation, thereby enhancing SingFake detection performance.
9 | Our proposed method achieves new state-of-the-art (SOTA) results within the SingFake dataset, surpassing the previous SOTA model across three distinct scenarios: it improves EER relatively for seen singers by 13.2\%, for unseen singers by 24.3\%, and unseen singers using different codecs by 37.1\%.
10 |
11 | 
12 |
13 | ## Datasets
14 | The dataset is based on the paper "SingFake: Singing Voice Deepfake Detection," which is accepted by ICASSP 2024. [[Project Webpage](https://singfake.org/)]
15 |
16 | Since the copyright issue, the dataset didn't open source. Please follow the instructions in the above paper to download the dataset by yourself.
17 |
18 |
19 |
20 |
21 |
22 | ## Comparison with the SOTA model
23 |
24 |
25 |
26 |

27 |
28 |
29 |
30 | ## Citation
31 | If you find our work useful, please consider cite
32 | ```
33 | @inproceedings{chen24o_interspeech,
34 | title = {Singing Voice Graph Modeling for SingFake Detection},
35 | author = {Xuanjun Chen and Haibin Wu and Roger Jang and Hung-yi Lee},
36 | year = {2024},
37 | booktitle = {Interspeech 2024},
38 | pages = {4843--4847},
39 | doi = {10.21437/Interspeech.2024-1185},
40 | issn = {2958-1796},
41 | }
42 |
43 | @article{chen2025how,
44 | title={How Does Instrumental Music Help SingFake Detection?},
45 | author={Chen, Xuanjun and Hu, Chia-Yu and Lin, I-Ming and Lin, Yi-Cheng and Chiu, I-Hsiang and Zhang, You and Huang, Sung-Feng and Yang, Yi-Hsuan and Wu, Haibin and Lee, Hung-yi and Jang, Jyh-Shing Roger},
46 | journal={arXiv preprint arXiv:2509.14675},
47 | year={2025}
48 | }
49 | ```
50 | ## Acknowledgement
51 | If you have any questions, please feel free to contact me by email at d12942018@ntu.edu.tw.
52 |
--------------------------------------------------------------------------------
/data/BPMProcess.py:
--------------------------------------------------------------------------------
1 | import os, math, json
2 | import numpy as np
3 | import librosa
4 |
5 | class BpmProcessor:
6 | def __init__(self, train_acc_path,
7 | json2bpm_path,
8 | bpm2json_path,
9 | sample_rate,
10 | threshold):
11 | with open(json2bpm_path, "r") as f:
12 | j2b_dict = json.load(f)
13 |
14 | with open(bpm2json_path, "r") as f:
15 | b2j_dict = json.load(f)
16 |
17 | self.train_acc_path = train_acc_path
18 | self.j2b_dict = j2b_dict
19 | self.b2j_dict = b2j_dict
20 |
21 | self.sr = sample_rate
22 | self.thr = threshold / self.sr
23 | self.waveform_thr = threshold ## 1000000
24 |
25 | def load_audio_by_json(self, sel_json_name):
26 | sel_wav_path = os.path.join(self.train_acc_path, sel_json_name[:-5] + ".wav")
27 | wav, _ = librosa.load(sel_wav_path, sr=self.sr)
28 | return wav
29 |
30 | def sel_accom_from_bpm_group(self, bpm_num, y):
31 | count, downbeats_duration = 0, 0
32 | sel_json = None
33 | candidate_jsons = self.b2j_dict[str(bpm_num)]
34 | filtered_gen = filter(lambda x: x.startswith(str(y)), candidate_jsons)
35 | candidate_list = list(filtered_gen)
36 |
37 | if len(candidate_list) != 0:
38 | while True:
39 | sel_json = np.random.choice(candidate_list)
40 | bpm_beat = self.j2b_dict[sel_json]
41 | count += 1
42 | pos1 = np.where(np.array(bpm_beat["beat_positions"]) == 1)[0]
43 | pos1_len = pos1.shape[0]
44 |
45 | # At least one bar period
46 | if pos1_len > 2:
47 | first_pos, last_pos = pos1[0], pos1[-1]
48 | downbeats_duration = bpm_beat["downbeats"][-1] - bpm_beat["downbeats"][0]
49 | break
50 |
51 | # Stop because not found
52 | if count >= 5:
53 | break
54 |
55 | return sel_json, downbeats_duration
56 |
57 | def accom_beat_padding(self, waveform,
58 | s_name,
59 | dbs_duration):
60 | content = self.j2b_dict[s_name]
61 | start, end = content["downbeats"][0], content["downbeats"][-1]
62 | sel_waveform = waveform[int(start * self.sr) : int(end * self.sr)]
63 |
64 | if dbs_duration < self.thr:
65 | # print(f"start: {start}, end: {end}")
66 | cp_num = math.ceil(self.thr / dbs_duration)
67 | sel_waveform = np.concatenate([sel_waveform] * cp_num)
68 |
69 | waveform_thr = int(self.thr * self.sr) + 1
70 | return sel_waveform[:waveform_thr]
71 |
72 | def sv_beat_align(self, wav_sv, sel_json):
73 | downbeats = self.j2b_dict[sel_json + ".json"]["downbeats"]
74 | wav_seg = wav_sv[int(downbeats[0] * self.sr):int(downbeats[-1] * self.sr)]
75 | # print(f"downbeats[0:-1]: {downbeats}")
76 | rand_start = np.random.choice(downbeats[0:-1])
77 | rand_start_seg = wav_sv[int(rand_start * self.sr):int(downbeats[-1] * self.sr)]
78 |
79 | if rand_start_seg.shape[0] >= self.waveform_thr:
80 | output = rand_start_seg
81 | else:
82 | remain_len = self.waveform_thr - rand_start_seg.shape[0] - wav_seg.shape[0]
83 | if remain_len >= 0:
84 | padded_len = math.ceil(remain_len // wav_seg.shape[0]) + 1
85 | else:
86 | padded_len = 1
87 | padded_wav_seg = np.concatenate([wav_seg] * padded_len)
88 | output = np.concatenate((rand_start_seg, padded_wav_seg))
89 |
90 | output = output[:self.waveform_thr]
91 | return output
92 |
93 |
94 |
95 | if __name__ == "__main__":
96 | train_acc_path = "./dataset/split_dump_flac/train/non_vocals/"
97 | train_vocal_path = "./dataset/split_dump_flac/train/vocals/"
98 | b2j_path = "./dataset/split_dump_flac/train/bpm2json.json"
99 | j2b_path = "./dataset/split_dump_flac/train/json2bpm.json"
100 |
101 | with open(j2b_path, "r") as d:
102 | json2bpm_dict = json.load(d)
103 |
104 | BpmProCls = BpmProcessor(train_acc_path=train_acc_path,
105 | json2bpm_path=j2b_path,
106 | bpm2json_path=b2j_path,
107 | sample_rate=16000,
108 | threshold=64600)
109 |
110 | wav_path = "./dataset/split_dump_flac/train/non_vocals/0_0212_3.wav"
111 | json_name = os.path.basename(wav_path)[:-4] + ".json"
112 | wav, sample_rate = librosa.load(wav_path, sr=16000)
113 |
114 | bpm_n = json2bpm_dict[json_name]["bpm"]
115 |
116 | sel_json, dbs_duration = BpmProCls.sel_accom_from_bpm_group(bpm_n, "1")
117 | sel_wav = BpmProCls.load_audio_by_json(sel_json)
118 | padded_waveform = BpmProCls.accom_beat_padding(sel_wav, sel_json, dbs_duration)
119 | print(f"padded_waveform: {padded_waveform.shape}")
120 |
121 | bpm_res = json2bpm_dict["1_1344_11.json"]["bpm"]
122 | print(f"bpm_res: {bpm_res}")
123 |
124 | wav_path = os.path.join(train_vocal_path, "0_0321_4.flac")
125 | wav11, sample_rate = librosa.load(wav_path, sr=16000)
126 |
127 | wav_sv = BpmProCls.sv_beat_align(wav11, "0_0321_4")
128 | print(f"wav_sv: {wav_sv.shape}")
129 |
--------------------------------------------------------------------------------
/data/RawBoost.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | from scipy import signal
6 | import copy
7 |
8 |
9 | """
10 | ___author__ = "Massimiliano Todisco, Hemlata Tak"
11 | __email__ = "{todisco,tak}@eurecom.fr"
12 | """
13 |
14 | '''
15 | Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans.
16 | RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing.
17 | In Proc. ICASSP 2022, pp:6382--6386.
18 | '''
19 |
20 | def randRange(x1, x2, integer):
21 | y = np.random.uniform(low=x1, high=x2, size=(1,))
22 | if integer:
23 | y = int(y)
24 | return y
25 |
26 | def normWav(x,always):
27 | if always:
28 | x = x/np.amax(abs(x))
29 | elif np.amax(abs(x)) > 1:
30 | x = x/np.amax(abs(x))
31 | return x
32 |
33 |
34 |
35 | def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs):
36 | b = 1
37 | for i in range(0, nBands):
38 | fc = randRange(minF,maxF,0);
39 | bw = randRange(minBW,maxBW,0);
40 | c = randRange(minCoeff,maxCoeff,1);
41 |
42 | if c/2 == int(c/2):
43 | c = c + 1
44 | f1 = fc - bw/2
45 | f2 = fc + bw/2
46 | if f1 <= 0:
47 | f1 = 1/1000
48 | if f2 >= fs/2:
49 | f2 = fs/2-1/1000
50 | b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b)
51 |
52 | G = randRange(minG,maxG,0);
53 | _, h = signal.freqz(b, 1, fs=fs)
54 | b = pow(10, G/20)*b/np.amax(abs(h))
55 | return b
56 |
57 |
58 | def filterFIR(x,b):
59 | N = b.shape[0] + 1
60 | xpad = np.pad(x, (0, N), 'constant')
61 | y = signal.lfilter(b, 1, xpad)
62 | y = y[int(N/2):int(y.shape[0]-N/2)]
63 | return y
64 |
65 | # Linear and non-linear convolutive noise
66 | def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs):
67 | y = [0] * x.shape[0]
68 | for i in range(0, N_f):
69 | if i == 1:
70 | minG = minG-minBiasLinNonLin;
71 | maxG = maxG-maxBiasLinNonLin;
72 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs)
73 | y = y + filterFIR(np.power(x, (i+1)), b)
74 | y = y - np.mean(y)
75 | y = normWav(y,0)
76 | return y
77 |
78 |
79 | # Impulsive signal dependent noise
80 | def ISD_additive_noise(x, P, g_sd):
81 | beta = randRange(0, P, 0)
82 |
83 | y = copy.deepcopy(x)
84 | x_len = x.shape[0]
85 | n = int(x_len*(beta/100))
86 | p = np.random.permutation(x_len)[:n]
87 | f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1))
88 | r = g_sd * x[p] * f_r
89 | y[p] = x[p] + r
90 | y = normWav(y,0)
91 | return y
92 |
93 |
94 | # Stationary signal independent noise
95 |
96 | def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs):
97 | noise = np.random.normal(0, 1, x.shape[0])
98 | b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs)
99 | noise = filterFIR(noise, b)
100 | noise = normWav(noise,1)
101 | SNR = randRange(SNRmin, SNRmax, 0)
102 | noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR)
103 | x = x + noise
104 | return x
105 |
106 | def process_Rawboost_feature(feature, sr,args,algo):
107 |
108 | # Data process by Convolutive noise (1st algo)
109 | if algo==1:
110 |
111 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr)
112 |
113 | # Data process by Impulsive noise (2nd algo)
114 | elif algo==2:
115 |
116 | feature=ISD_additive_noise(feature, args.P, args.g_sd)
117 |
118 | # Data process by coloured additive noise (3rd algo)
119 | elif algo==3:
120 |
121 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr)
122 |
123 | # Data process by all 3 algo. together in series (1+2+3)
124 | elif algo==4:
125 |
126 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,
127 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr)
128 | feature=ISD_additive_noise(feature, args.P, args.g_sd)
129 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,
130 | args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr)
131 |
132 | # Data process by 1st two algo. together in series (1+2)
133 | elif algo==5:
134 |
135 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,
136 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr)
137 | feature=ISD_additive_noise(feature, args.P, args.g_sd)
138 |
139 |
140 | # Data process by 1st and 3rd algo. together in series (1+3)
141 | elif algo==6:
142 |
143 | feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,
144 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr)
145 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr)
146 |
147 | # Data process by 2nd and 3rd algo. together in series (2+3)
148 | elif algo==7:
149 |
150 | feature=ISD_additive_noise(feature, args.P, args.g_sd)
151 | feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr)
152 |
153 | # Data process by 1st two algo. together in Parallel (1||2)
154 | elif algo==8:
155 |
156 | feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,
157 | args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr)
158 | feature2=ISD_additive_noise(feature, args.P, args.g_sd)
159 |
160 | feature_para=feature1+feature2
161 | feature=normWav(feature_para,0) #normalized resultant waveform
162 |
163 | # original data without Rawboost processing
164 | else:
165 |
166 | feature=feature
167 |
168 | return feature
169 |
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import soundfile as sf
6 | import torch
7 | from torch import Tensor
8 | from torch.utils.data import Dataset
9 | from utils.utils import str_to_bool
10 |
11 | from data.RawBoost import process_Rawboost_feature
12 | from data.BPMProcess import BpmProcessor
13 |
14 | ___author__ = "Xuanjun Chen"
15 | __email__ = "d12942018@ntu.edu.tw"
16 |
17 | def pad(x, max_len=64600):
18 | x_len = x.shape[0]
19 | if x_len >= max_len:
20 | return x[:max_len]
21 | # need to pad
22 | num_repeats = int(max_len / x_len) + 1
23 | padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
24 | return padded_x
25 |
26 | def pad_random(x: np.ndarray, start: int, end: int):
27 | x_len = x.shape[0]
28 |
29 | # If the interval is within the length of x
30 | if end <= x_len:
31 | return x[start:end]
32 |
33 | # If the selected interval is longer than x
34 | padded_x = np.tile(x, (end // x_len + 1)) # Repeat x to ensure it covers the interval
35 | return padded_x[start:end]
36 |
37 | class Dataset_SingFake(Dataset):
38 | def __init__(self, args, base_dir, algo, state, is_mixture=False, target_sr=16000):
39 | """
40 | base_dir should contain mixtures/ and vocals/ folders
41 | """
42 | self.base_dir = base_dir
43 | self.is_mixture = is_mixture
44 | self.target_sr = target_sr
45 | self.cut = 64600 # take ~4 sec audio (64600 samples)
46 | self.args = args
47 | self.algo = algo
48 | self.state = state
49 |
50 | # get file list
51 | self.file_list = []
52 | if self.is_mixture:
53 | self.target_path = os.path.join(self.base_dir, "mixtures")
54 | else:
55 | self.target_path = os.path.join(self.base_dir, "vocals")
56 |
57 | print(self.target_path)
58 |
59 | assert os.path.exists(self.target_path), f"{self.target_path} does not exist!"
60 |
61 | for file in os.listdir(self.target_path):
62 | if file.endswith(".flac"):
63 | self.file_list.append(file[:-5])
64 |
65 | def __len__(self):
66 | return len(self.file_list)
67 |
68 | def __getitem__(self, index):
69 | key = self.file_list[index]
70 | file_path = os.path.join(self.target_path, key + ".flac")
71 | # X, _ = sf.read(file_path, samplerate=self.target_sr)
72 | try:
73 | X, fs = librosa.load(file_path, sr=self.target_sr, mono=False)
74 | except:
75 | return self.__getitem__(np.random.randint(len(self.file_list)))
76 | if X.shape[0] > 1:
77 | # if not mono, take random channel
78 | channel_id = np.random.randint(X.shape[0])
79 | X = X[channel_id]
80 |
81 | # RawBoost Augmentation
82 | if self.state == "train":
83 | X = process_Rawboost_feature(X, fs, self.args, self.algo)
84 |
85 | X_pad = pad_random(X, self.cut)
86 | X_pad = X_pad / np.max(np.abs(X_pad))
87 | x_inp = Tensor(X_pad)
88 | y = int(key.split("_")[0])
89 | return x_inp, y
90 |
91 | class Dataset_SingFake_mert_w2v(Dataset):
92 | def __init__(self, args, config, base_dir, algo, state,
93 | target_sr=16000, target_sr2=24000):
94 | """
95 | base_dir should contain mixtures/ and vocals/ folders
96 | """
97 | self.base_dir = base_dir
98 | self.is_mixture = not str_to_bool(config["vocals_only"])
99 | self.is_sep = str_to_bool(config["is_sep"])
100 | self.is_rawboost = str_to_bool(config["is_rawboost"])
101 | self.is_beat_matching = str_to_bool(config["is_beat_matching"])
102 |
103 | self.target_sr = target_sr
104 | self.target_sr2 = target_sr2
105 | self.cut16 = 64600 # take ~4 sec audio (64600 samples)
106 | self.duration = 4.0375
107 | self.cut24 = 96900 # take ~4 sec audio (96900 samples)
108 | self.args = args
109 | self.algo = algo
110 | self.state = state
111 |
112 | # get file list
113 | self.file_list = []
114 | if self.is_mixture:
115 | self.target_path = os.path.join(self.base_dir, "mixtures")
116 | self.tgt_v2_path = self.target_path
117 | elif not self.is_mixture and self.is_sep:
118 | self.target_path = os.path.join(self.base_dir, "vocals")
119 | self.tgt_v2_path = os.path.join(self.base_dir, "non_vocals")
120 | else:
121 | self.target_path = os.path.join(self.base_dir, "vocals")
122 | self.tgt_v2_path = self.tgt_v2_path
123 |
124 | self.beat_file_path = os.path.join(self.base_dir, "beats")
125 |
126 | # For beat matching
127 | self.BpmProCls = BpmProcessor(train_acc_path=config["train_acc_path"],
128 | json2bpm_path=config["j2b_path"],
129 | bpm2json_path=config["b2j_path"],
130 | sample_rate=16000,
131 | threshold=self.cut16)
132 |
133 | assert os.path.exists(self.target_path), f"{self.target_path} does not exist!"
134 |
135 | for file in os.listdir(self.target_path):
136 | if file.endswith(".flac"):
137 | self.file_list.append(file[:-5])
138 |
139 | # Filter 0 or 1
140 |
141 |
142 | def __len__(self):
143 | return len(self.file_list)
144 |
145 | def __getitem__(self, index):
146 | key = self.file_list[index]
147 | y = int(key.split("_")[0])
148 |
149 | file_path = os.path.join(self.target_path, key + ".flac")
150 | file2_path = os.path.join(self.tgt_v2_path, key + ".wav")
151 | try:
152 | X, fs = librosa.load(file_path, sr=self.target_sr, mono=False)
153 | X2, _ = librosa.load(file2_path, sr=self.target_sr, mono=False)
154 | except:
155 | return self.__getitem__(np.random.randint(len(self.file_list)))
156 |
157 | if X.shape[0] > 1 or X2.shape[0] > 1:
158 | # If not mono, take random channel
159 | channel_id = np.random.randint(X.shape[0])
160 | X, X2 = X[channel_id], X2[channel_id]
161 |
162 | if self.state == "train" and self.is_beat_matching:
163 | bpm_n = self.BpmProCls.j2b_dict[key + ".json"]["bpm"]
164 | db_num = len(self.BpmProCls.j2b_dict[key + ".json"]["downbeats"])
165 | if bpm_n is not None and db_num > 2:
166 | X = self.BpmProCls.sv_beat_align(X, key)
167 | sel_json, dbs_duration = self.BpmProCls.sel_accom_from_bpm_group(bpm_n, y)
168 | if dbs_duration != 0:
169 | sel_wav = self.BpmProCls.load_audio_by_json(sel_json)
170 | X2 = self.BpmProCls.accom_beat_padding(sel_wav, sel_json, dbs_duration)
171 |
172 | # RawBoost Augmentation
173 | if self.state == "train" and self.is_rawboost:
174 | X = process_Rawboost_feature(X, fs, self.args, self.algo)
175 |
176 | waveform_shift = X.shape[0] - self.cut16
177 | if waveform_shift > 0:
178 | x_start = np.random.randint(0, waveform_shift)
179 | else:
180 | x_start = 0
181 |
182 | x_end = x_start + self.cut16
183 | X_pad, X2_pad = pad_random(X, x_start, x_end), pad_random(X2, x_start, x_end)
184 | X_pad, X2_pad = X_pad / np.max(np.abs(X_pad)), X2_pad / np.max(np.abs(X2_pad))
185 |
186 | x_inp, x2_inp = Tensor(X_pad), Tensor(X2_pad)
187 |
188 |
189 | return x_inp, x2_inp, y
190 |
--------------------------------------------------------------------------------
/figures/singfake_sota.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xjchenGit/SingGraph/f4809cd946ed13670431a0400d434ad3b6fbba26/figures/singfake_sota.jpg
--------------------------------------------------------------------------------
/figures/singgraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xjchenGit/SingGraph/f4809cd946ed13670431a0400d434ad3b6fbba26/figures/singgraph.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Main script that trains, validates, and evaluates
3 | various models including AASIST.
4 |
5 | AASIST
6 | Copyright (c) 2021-present NAVER Corp.
7 | MIT license
8 | """
9 | import argparse
10 | import json
11 | import os
12 | import sys
13 | import warnings
14 | from importlib import import_module
15 | from pathlib import Path
16 | from shutil import copy
17 | from typing import Dict, List, Union
18 |
19 | import torch
20 | import torch.nn as nn
21 | from torch.utils.data import DataLoader
22 | from torch.utils.tensorboard import SummaryWriter
23 | from torchcontrib.optim import SWA
24 | from tqdm import tqdm
25 |
26 | from data.dataloader import Dataset_SingFake, Dataset_SingFake_mert_w2v
27 | from utils.eval_metrics import compute_eer
28 | from utils.utils import create_optimizer, seed_worker, set_seed, str_to_bool
29 | from model.SingGraph import Wav2Vec2Model
30 |
31 | warnings.filterwarnings("ignore", category=UserWarning)
32 | warnings.filterwarnings("ignore", category=FutureWarning)
33 |
34 | def main(args: argparse.Namespace) -> None:
35 | """
36 | Main function.
37 | Trains, validates, and evaluates the ASVspoof detection model.
38 | """
39 | # load experiment configurations
40 | with open(args.config, "r") as f_json:
41 | config = json.loads(f_json.read())
42 | model_config = config["model_config"]
43 | optim_config = config["optim_config"]
44 | optim_config["epochs"] = config["num_epochs"]
45 | track = config["track"]
46 | assert track in ["LA", "PA", "DF"], "Invalid track given"
47 | if "eval_all_best" not in config:
48 | config["eval_all_best"] = "True"
49 | if "freq_aug" not in config:
50 | config["freq_aug"] = "False"
51 |
52 | # make experiment reproducible
53 | set_seed(args.seed, config)
54 |
55 | # define database related paths
56 | output_dir = Path(args.output_dir)
57 |
58 | # define model related paths
59 | model_tag = "{}_{}_ep{}_bs{}".format(track,
60 | os.path.splitext(os.path.basename(args.config))[0],
61 | config["num_epochs"], config["batch_size"])
62 | if args.comment:
63 | model_tag = model_tag + "_{}".format(args.comment)
64 | model_tag = output_dir / model_tag
65 | model_save_path = model_tag / "weights"
66 | eval_score_path = model_tag / config["eval_output"]
67 | writer = SummaryWriter(model_tag)
68 | os.makedirs(model_save_path, exist_ok=True)
69 | copy(args.config, model_tag / "config.conf")
70 |
71 | # set device
72 | gpu_id = args.gpu
73 | device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu")
74 | # device = "cuda" if torch.cuda.is_available() else "cpu"
75 | print("Device: {}".format(device))
76 | if device == "cpu":
77 | raise ValueError("GPU not detected!")
78 |
79 | # define model architecture
80 | # model = get_model(model_config, device)
81 | model = get_wav2vec2_model(model_config, device).to(device)
82 |
83 | # define dataloaders
84 | # trn_loader, dev_loader, eval_loader, additional_loader, persian_loader, mp3_loader, ogg_loader, aac_loader, opus_loader = get_singfake_loaders(args.seed, args, config)
85 | dataset_loaders = get_singfake_loaders(args.seed, args, config)
86 |
87 | # evaluates pretrained model and exit script
88 | if args.eval:
89 | model.load_state_dict(
90 | torch.load(config["model_path"], map_location=device))
91 | print("Model loaded : {}".format(config["model_path"]))
92 | print("Evaluating on SingFake Eval Set...")
93 |
94 | eer_results = {}
95 | for data_key in dataset_loaders:
96 | eer = evaluate(dataset_loaders[data_key], model, device)
97 | eer_results[data_key] = eer
98 |
99 | avg_codec_eer = (eer_results["codec_test/mp3_128k"] +
100 | eer_results["codec_test/ogg_64k"] +
101 | eer_results["codec_test/adts_64k"] +
102 | eer_results["codec_test/opus_64k"]) / 4.
103 |
104 | print("Done. train_eer: {:.2f} %, test_set_eer: {:.2f} %, additional_test_eer: {:.2f} %, persian_eer: {:.2f} %".format(eer_results["train"] * 100,
105 | eer_results["T01"] * 100,
106 | eer_results["T02"] * 100,
107 | eer_results["T04"] * 100))
108 |
109 | print("Codec testing: Average EER: {:.2f} %, MP3 EER: {:.2f} %, OGG EER: {:.2f} %, AAC EER: {:.2f} %, OPUS EER: {:.2f} %".format(avg_codec_eer * 100,
110 | eer_results["codec_test/mp3_128k"] * 100,
111 | eer_results["codec_test/ogg_64k"] * 100,
112 | eer_results["codec_test/adts_64k"] * 100,
113 | eer_results["codec_test/opus_64k"] * 100))
114 | sys.exit(0)
115 |
116 | # get optimizer and scheduler
117 | optim_config["steps_per_epoch"] = len(dataset_loaders["train"])
118 | optimizer, scheduler = create_optimizer(model.parameters(), optim_config)
119 | optimizer_swa = SWA(optimizer)
120 |
121 | best_dev_eer = 1.
122 | best_eval_eer = 100.
123 | n_swa_update = 0 # number of snapshots of model to use in SWA
124 | f_log = open(model_tag / "metric_log.txt", "a")
125 | f_log.write("=" * 5 + "\n")
126 |
127 | # make directory for metric logging
128 | metric_path = model_tag / "metrics"
129 | os.makedirs(metric_path, exist_ok=True)
130 |
131 | # Training
132 | for epoch in range(config["num_epochs"]):
133 | print("Start training epoch{:03d}".format(epoch))
134 | running_loss = train_epoch(dataset_loaders["train"], model, optimizer, device,
135 | scheduler, config)
136 |
137 | dev_eer = evaluate(dataset_loaders["dev"], model, device)
138 | additional_eer = evaluate(dataset_loaders["T02"], model, device)
139 |
140 | print("DONE.\nLoss:{:.5f}, dev_eer: {:.2f} %, additional_test_eer: {:.2f} %".format(
141 | running_loss, dev_eer * 100, additional_eer * 100))
142 | writer.add_scalar("loss", running_loss, epoch)
143 | writer.add_scalar("dev_eer", dev_eer, epoch)
144 | writer.add_scalar("additional_eer", additional_eer, epoch)
145 |
146 | if best_dev_eer >= dev_eer:
147 | print("best model find at epoch", epoch)
148 | best_dev_eer = dev_eer
149 | torch.save(model.state_dict(),
150 | model_save_path / "epoch_{}_{:03.3f}.pth".format(epoch, dev_eer))
151 |
152 | # do evaluation whenever best model is renewed
153 | if str_to_bool(config["eval_all_best"]):
154 | eval_eer = evaluate(dataset_loaders["T01"], model, device)
155 | log_text = "epoch{:03d}, ".format(epoch)
156 | if eval_eer < best_eval_eer:
157 | log_text += "best eer, {:.4f}%".format(eval_eer)
158 | best_eval_eer = eval_eer
159 | torch.save(model.state_dict(), model_save_path / "best.pth")
160 |
161 | if len(log_text) > 0:
162 | print(log_text)
163 | f_log.write(log_text + "\n")
164 |
165 | print("Saving epoch {} for swa".format(epoch))
166 | optimizer_swa.update_swa()
167 | n_swa_update += 1
168 | writer.add_scalar("best_dev_eer", best_dev_eer, epoch)
169 |
170 | print("Start final evaluation")
171 | epoch += 1
172 | if n_swa_update > 0:
173 | optimizer_swa.swap_swa_sgd()
174 | optimizer_swa.bn_update(dataset_loaders["train"], model, device=device)
175 | eval_eer = evaluate(dataset_loaders["T01"], model, device)
176 | f_log = open(model_tag / "metric_log.txt", "a")
177 | f_log.write("=" * 5 + "\n")
178 | f_log.write("EER: {:.3f} %".format(eval_eer * 100))
179 | f_log.close()
180 |
181 | torch.save(model.state_dict(),
182 | model_save_path / "swa.pth")
183 |
184 | if eval_eer <= best_eval_eer:
185 | best_eval_eer = eval_eer
186 | torch.save(model.state_dict(), model_save_path / "best.pth")
187 |
188 | print("Exp FIN. EER: {:.3f} %".format(best_eval_eer * 100))
189 |
190 |
191 | def get_model(model_config: Dict, device: torch.device):
192 | """Define DNN model architecture"""
193 | module = import_module("models.{}".format(model_config["architecture"]))
194 | _model = getattr(module, "Model")
195 | model = _model(model_config).to(device)
196 | nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
197 | print("no. model params:{}".format(nb_params))
198 |
199 | return model
200 |
201 | def get_wav2vec2_model(model_config: Dict, device: torch.device):
202 | model = Wav2Vec2Model(model_config, device)
203 | return model
204 |
205 | def get_singfake_loaders(seed: int, args: argparse.Namespace, config: dict) -> List[torch.utils.data.DataLoader]:
206 | # base_dir = "../../dataset/split_dump_flac/"
207 | base_dir = "./dataset/split_dump_flac/"
208 | target_sr = float(config["target_sr"])
209 |
210 | # Define dataset keys and their corresponding paths
211 | dataset_keys = ["train", "dev", "T01", "T02", "T04",
212 | "codec_test/mp3_128k", "codec_test/ogg_64k",
213 | "codec_test/adts_64k", "codec_test/opus_64k"]
214 | datasets = {}
215 |
216 | # Common DataLoader settings
217 | common_settings = {
218 | "batch_size": config["batch_size"],
219 | "num_workers": 4,
220 | "pin_memory": True
221 | }
222 |
223 | # Initialize the generator for reproducibility
224 | gen = torch.Generator()
225 | gen.manual_seed(seed)
226 |
227 | # Creating DataLoaders for each dataset
228 | for key in dataset_keys:
229 | dataset_path = os.path.join(base_dir, key)
230 | shuffle = True if key == "train" else False
231 | drop_last = True if key == "train" else False
232 | worker_init_fn = seed_worker if key == "train" else None
233 | generator = gen if key == "train" else None
234 |
235 | dataset = Dataset_SingFake_mert_w2v(args, config, base_dir=dataset_path, algo=args.algo,
236 | state="train" if key == "train" else "test", target_sr=target_sr)
237 |
238 | datasets[key] = DataLoader(dataset, shuffle=shuffle, drop_last=drop_last, worker_init_fn=worker_init_fn,
239 | generator=generator, **common_settings)
240 |
241 | # Return DataLoaders in the specified order
242 | return datasets
243 |
244 | def evaluate(loader, model, device: torch.device):
245 | """
246 | Evaluate the model on the given loader, then return EER.
247 | """
248 | model.eval()
249 | # we save target (1) scores to target_scores, and non target (0) scores to nontarget_scores.
250 | target_scores = []
251 | nontarget_scores = []
252 | debug = False
253 | count = 0
254 | with torch.no_grad():
255 | for batch_x, batch_x2, batch_y in tqdm(loader, total=len(loader)):
256 | batch_x, batch_x2 = batch_x.to(device), batch_x2.to(device)
257 | batch_out = model(batch_x, batch_x2)
258 | batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
259 | batch_y = batch_y.data.cpu().numpy().ravel()
260 | for i in range(len(batch_y)):
261 | if batch_y[i] == 1:
262 | target_scores.append(batch_score[i])
263 | else:
264 | nontarget_scores.append(batch_score[i])
265 | count += 1
266 | if count == 10 and debug:
267 | break
268 |
269 | eer, _ = compute_eer(target_scores, nontarget_scores)
270 | return eer
271 |
272 |
273 | def train_epoch(
274 | trn_loader: DataLoader,
275 | model,
276 | optim: Union[torch.optim.SGD, torch.optim.Adam],
277 | device: torch.device,
278 | scheduler: torch.optim.lr_scheduler,
279 | config: argparse.Namespace):
280 | """Train the model for one epoch"""
281 | running_loss = 0
282 | num_total = 0.0
283 | ii = 0
284 | model.train()
285 |
286 | # set objective (Loss) functions
287 | weight = torch.FloatTensor([0.1, 0.9]).to(device)
288 | criterion = nn.CrossEntropyLoss(weight=weight)
289 | pbar = tqdm(trn_loader, total=len(trn_loader))
290 | for batch_x, batch_x2, batch_y in pbar:
291 | batch_size = batch_x.size(0)
292 | num_total += batch_size
293 | ii += 1
294 | batch_x, batch_x2 = batch_x.to(device), batch_x2.to(device)
295 | batch_y = batch_y.view(-1).type(torch.int64).to(device)
296 | batch_out = model(batch_x, batch_x2)
297 | batch_loss = criterion(batch_out, batch_y)
298 | running_loss += batch_loss.item() * batch_size
299 | pbar.set_description("loss: {:.5f}, running loss: {:.5f}".format(
300 | batch_loss.item(), running_loss / num_total))
301 | optim.zero_grad()
302 | batch_loss.backward()
303 | optim.step()
304 |
305 | if config["optim_config"]["scheduler"] in ["cosine", "keras_decay"]:
306 | scheduler.step()
307 | elif scheduler is None:
308 | pass
309 | else:
310 | raise ValueError("scheduler error, got:{}".format(scheduler))
311 |
312 | running_loss /= num_total
313 | return running_loss
314 |
315 |
316 | if __name__ == "__main__":
317 | parser = argparse.ArgumentParser(description="ASVspoof detection system")
318 | parser.add_argument("--config",
319 | dest="config",
320 | type=str,
321 | help="configuration file",
322 | required=True)
323 | parser.add_argument(
324 | "--output_dir",
325 | dest="output_dir",
326 | type=str,
327 | help="output directory for results",
328 | default="./exp_result",
329 | )
330 | parser.add_argument("--seed",
331 | type=int,
332 | default=1234,
333 | help="random seed (default: 1234)")
334 | parser.add_argument(
335 | "--eval",
336 | action="store_true",
337 | help="when this flag is given, evaluates given model and exit")
338 | parser.add_argument("--comment",
339 | type=str,
340 | default=None,
341 | help="comment to describe the saved model")
342 | parser.add_argument("--eval_model_weights",
343 | type=str,
344 | default=None,
345 | help="directory to the model weight file (can be also given in the config file)")
346 | parser.add_argument("--gpu",
347 | type=int,
348 | default=0,
349 | help="gpu id to use (default: 0)")
350 |
351 | ##===================================================Rawboost data augmentation ===============================================================#
352 |
353 | parser.add_argument('--algo', type=int, default=3,
354 | help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \
355 | 5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .[default=0]')
356 |
357 | # LnL_convolutive_noise parameters
358 | parser.add_argument('--nBands', type=int, default=5,
359 | help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]')
360 | parser.add_argument('--minF', type=int, default=20,
361 | help='minimum centre frequency [Hz] of notch filter.[default=20] ')
362 | parser.add_argument('--maxF', type=int, default=8000,
363 | help='maximum centre frequency [Hz] ( 0 else nn.Identity()
345 | self.in_dim = in_dim
346 |
347 | def forward(self, h):
348 | Z = self.drop(h)
349 | weights = self.proj(Z)
350 | scores = self.sigmoid(weights)
351 | new_h = self.top_k_graph(scores, h, self.k)
352 |
353 | return new_h
354 |
355 | def top_k_graph(self, scores, h, k):
356 | """
357 | args
358 | =====
359 | scores: attention-based weights (#bs, #node, 1)
360 | h: graph data (#bs, #node, #dim)
361 | k: ratio of remaining nodes, (float)
362 | returns
363 | =====
364 | h: graph pool applied data (#bs, #node', #dim)
365 | """
366 | _, n_nodes, n_feat = h.size()
367 | n_nodes = max(int(n_nodes * k), 1)
368 | _, idx = torch.topk(scores, n_nodes, dim=1)
369 | idx = idx.expand(-1, -1, n_feat)
370 |
371 | h = h * scores
372 | h = torch.gather(h, 1, idx)
373 |
374 | return h
375 |
376 | class Residual_block(nn.Module):
377 | def __init__(self, nb_filts, first=False):
378 | super().__init__()
379 | self.first = first
380 |
381 | if not self.first:
382 | self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
383 | self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
384 | out_channels=nb_filts[1],
385 | kernel_size=(2, 3),
386 | padding=(1, 1),
387 | stride=1)
388 | self.selu = nn.SELU(inplace=True)
389 |
390 | self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
391 | self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
392 | out_channels=nb_filts[1],
393 | kernel_size=(2, 3),
394 | padding=(0, 1),
395 | stride=1)
396 |
397 | if nb_filts[0] != nb_filts[1]:
398 | self.downsample = True
399 | self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
400 | out_channels=nb_filts[1],
401 | padding=(0, 1),
402 | kernel_size=(1, 3),
403 | stride=1)
404 |
405 | else:
406 | self.downsample = False
407 |
408 |
409 | def forward(self, x):
410 | identity = x
411 | if not self.first:
412 | out = self.bn1(x)
413 | out = self.selu(out)
414 | else:
415 | out = x
416 |
417 | out = self.conv1(x)
418 |
419 | out = self.bn2(out)
420 | out = self.selu(out)
421 | out = self.conv2(out)
422 |
423 | if self.downsample:
424 | identity = self.conv_downsample(identity)
425 |
426 | out += identity
427 | return out
428 |
429 | class Wav2Vec2Model(nn.Module):
430 | def __init__(self, args,device):
431 | super().__init__()
432 | self.device = device
433 |
434 | # AASIST parameters
435 | filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
436 | gat_dims = [64, 32]
437 | pool_ratios = [0.5, 0.5, 0.5, 0.5]
438 | temperatures = [2.0, 2.0, 100.0, 100.0]
439 |
440 |
441 | ####
442 | # create network wav2vec 2.0
443 | ####
444 | self.ssl_model = SSLModel(self.device)
445 | self.LL = nn.Linear(self.ssl_model.out_dim, 128)
446 |
447 | self.first_bn = nn.BatchNorm2d(num_features=1)
448 | self.first_bn1 = nn.BatchNorm2d(num_features=64)
449 | self.drop = nn.Dropout(0.5, inplace=True)
450 | self.drop_way = nn.Dropout(0.2, inplace=True)
451 | self.selu = nn.SELU(inplace=True)
452 |
453 | # RawNet2 encoder
454 | self.encoder = nn.Sequential(
455 | nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
456 | nn.Sequential(Residual_block(nb_filts=filts[2])),
457 | nn.Sequential(Residual_block(nb_filts=filts[3])),
458 | nn.Sequential(Residual_block(nb_filts=filts[4])),
459 | nn.Sequential(Residual_block(nb_filts=filts[4])),
460 | nn.Sequential(Residual_block(nb_filts=filts[4])))
461 |
462 | self.attention = nn.Sequential(
463 | nn.Conv2d(64, 128, kernel_size=(1,1)),
464 | nn.SELU(inplace=True),
465 | nn.BatchNorm2d(128),
466 | nn.Conv2d(128, 64, kernel_size=(1,1)),
467 |
468 | )
469 | # position encoding
470 | self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
471 |
472 | self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
473 | self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
474 |
475 | # Graph module
476 | self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
477 | gat_dims[0],
478 | temperature=temperatures[0])
479 | self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
480 | gat_dims[0],
481 | temperature=temperatures[1])
482 | # HS-GAL layer
483 | self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
484 | gat_dims[0], gat_dims[1], temperature=temperatures[2])
485 | self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
486 | gat_dims[1], gat_dims[1], temperature=temperatures[2])
487 | self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
488 | gat_dims[0], gat_dims[1], temperature=temperatures[2])
489 | self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
490 | gat_dims[1], gat_dims[1], temperature=temperatures[2])
491 |
492 | # Graph pooling layers
493 | self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
494 | self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
495 | self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
496 | self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
497 |
498 | self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
499 | self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
500 |
501 | self.out_layer = nn.Linear(5 * gat_dims[1], 2)
502 |
503 | def forward(self, x, x2):
504 | #-------pre-trained Wav2vec model fine tunning ------------------------##
505 | x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1), x2.squeeze(-1))
506 | x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim)
507 |
508 | # post-processing on front-end features
509 | x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number)
510 | x = x.unsqueeze(dim=1) # add channel
511 | x = F.max_pool2d(x, (3, 3))
512 | x = self.first_bn(x)
513 | x = self.selu(x)
514 |
515 | # RawNet2-based encoder
516 | x = self.encoder(x)
517 | x = self.first_bn1(x)
518 | x = self.selu(x)
519 |
520 | w = self.attention(x)
521 |
522 | #------------SA for spectral feature-------------#
523 | w1 = F.softmax(w,dim=-1)
524 | m = torch.sum(x * w1, dim=-1)
525 | e_S = m.transpose(1, 2) + self.pos_S
526 |
527 | # graph module layer
528 | gat_S = self.GAT_layer_S(e_S)
529 | out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
530 |
531 | #------------SA for temporal feature-------------#
532 | w2 = F.softmax(w,dim=-2)
533 | m1 = torch.sum(x * w2, dim=-2)
534 |
535 | e_T = m1.transpose(1, 2)
536 |
537 | # graph module layer
538 | gat_T = self.GAT_layer_T(e_T)
539 | out_T = self.pool_T(gat_T)
540 |
541 | # learnable master node
542 | master1 = self.master1.expand(x.size(0), -1, -1)
543 | master2 = self.master2.expand(x.size(0), -1, -1)
544 |
545 | # inference 1
546 | out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
547 | out_T, out_S, master=self.master1)
548 |
549 | out_S1 = self.pool_hS1(out_S1)
550 | out_T1 = self.pool_hT1(out_T1)
551 |
552 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
553 | out_T1, out_S1, master=master1)
554 | out_T1 = out_T1 + out_T_aug
555 | out_S1 = out_S1 + out_S_aug
556 | master1 = master1 + master_aug
557 |
558 | # inference 2
559 | out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
560 | out_T, out_S, master=self.master2)
561 | out_S2 = self.pool_hS2(out_S2)
562 | out_T2 = self.pool_hT2(out_T2)
563 |
564 | out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
565 | out_T2, out_S2, master=master2)
566 | out_T2 = out_T2 + out_T_aug
567 | out_S2 = out_S2 + out_S_aug
568 | master2 = master2 + master_aug
569 |
570 | out_T1 = self.drop_way(out_T1)
571 | out_T2 = self.drop_way(out_T2)
572 | out_S1 = self.drop_way(out_S1)
573 | out_S2 = self.drop_way(out_S2)
574 | master1 = self.drop_way(master1)
575 | master2 = self.drop_way(master2)
576 |
577 | out_T = torch.max(out_T1, out_T2)
578 | out_S = torch.max(out_S1, out_S2)
579 | master = torch.max(master1, master2)
580 |
581 | # Readout operation
582 | T_max, _ = torch.max(torch.abs(out_T), dim=1)
583 | T_avg = torch.mean(out_T, dim=1)
584 |
585 | S_max, _ = torch.max(torch.abs(out_S), dim=1)
586 | S_avg = torch.mean(out_S, dim=1)
587 |
588 | last_hidden = torch.cat(
589 | [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
590 |
591 | last_hidden = self.drop(last_hidden)
592 | output = self.out_layer(last_hidden)
593 |
594 | return output
595 |
596 | if __name__ == "__main__":
597 | import json
598 | with open("AASIST.conf", "r") as f_json:
599 | config = json.loads(f_json.read())
600 | with torch.no_grad():
601 | wav = torch.randn((1, 32000))
602 | w2v2_model = Wav2Vec2Model(config["model_config"], "cuda")
603 | output = w2v2_model(wav, wav)
604 | print(f"w2v2_model: {w2v2_model}")
605 | print(f"output: {output.shape}")
606 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.9.1
2 | transformers
3 | pandas
4 | numpy==1.21.6
5 | numba==0.56.4
6 | tensorboardX
7 | protobuf==3.20.*
8 | pexpect>4.3
9 | numba>=0.53
10 | dcase-util>=0.2.4
11 | tensorboard
12 | torchcontrib
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # export TRANSFORMERS_CACHE=./cache
2 | export HF_HOME=./cache
3 |
4 | python3 main.py \
5 | --config utils/SingGraph.conf \
6 | --output_dir ./exp_result \
7 | --eval
8 |
--------------------------------------------------------------------------------
/utils/SingGraph.conf:
--------------------------------------------------------------------------------
1 | {
2 | "database_path": "./LA/",
3 | "asv_score_path": "ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt",
4 | "b2j_path": "./dataset/split_dump_flac/train/bpm2json.json",
5 | "j2b_path": "./dataset/split_dump_flac/train/json2bpm.json",
6 | "train_acc_path": "./dataset/split_dump_flac/train/non_vocals/",
7 | "train_vocal_path": "./dataset/split_dump_flac/train/vocals/",
8 | "model_path": "./exp_result/LA_SingGraph_ep80_bs12/weights/epoch_39_0.046.pth",
9 | "batch_size": 12,
10 | "num_epochs": 80,
11 | "target_sr": 16000,
12 | "vocals_only": "True",
13 | "is_sep": "True",
14 | "is_rawboost": "True",
15 | "is_beat_matching": "True",
16 | "loss": "CCE",
17 | "track": "LA",
18 | "eval_all_best": "True",
19 | "eval_output": "eval_scores_using_best_dev_model.txt",
20 | "cudnn_deterministic_toggle": "True",
21 | "cudnn_benchmark_toggle": "False",
22 | "model_config": {
23 | "architecture": "AASIST",
24 | "nb_samp": 64000,
25 | "first_conv": 128,
26 | "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
27 | "gat_dims": [64, 32],
28 | "pool_ratios": [0.5, 0.7, 0.5, 0.5],
29 | "temperatures": [2.0, 2.0, 100.0, 100.0]
30 | },
31 | "optim_config": {
32 | "optimizer": "adam",
33 | "amsgrad": "False",
34 | "base_lr": 0.000001,
35 | "lr_min": 0.000005,
36 | "betas": [0.9, 0.999],
37 | "weight_decay": 0.0001,
38 | "scheduler": "cosine"
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/utils/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import numpy as np
4 |
5 |
6 | def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_threshold):
7 |
8 | # False alarm and miss rates for ASV
9 | Pfa_asv = sum(non_asv >= asv_threshold) / non_asv.size
10 | Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size
11 |
12 | # Rate of rejecting spoofs in ASV
13 | if spoof_asv.size == 0:
14 | Pmiss_spoof_asv = None
15 | else:
16 | Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size
17 |
18 | return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv
19 |
20 |
21 | def compute_det_curve(target_scores, nontarget_scores):
22 | target_scores, nontarget_scores = np.array(target_scores), np.array(nontarget_scores)
23 | n_scores = target_scores.size + nontarget_scores.size
24 | all_scores = np.concatenate((target_scores, nontarget_scores))
25 | labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
26 |
27 | # Sort labels based on scores
28 | indices = np.argsort(all_scores, kind='mergesort')
29 | labels = labels[indices]
30 |
31 | # Compute false rejection and false acceptance rates
32 | tar_trial_sums = np.cumsum(labels)
33 | nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)
34 |
35 | frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates
36 | far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates
37 | thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
38 |
39 | return frr, far, thresholds
40 |
41 |
42 | def compute_eer(target_scores, nontarget_scores):
43 | """ Returns equal error rate (EER) and the corresponding threshold. """
44 | frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
45 | abs_diffs = np.abs(frr - far)
46 | min_index = np.argmin(abs_diffs)
47 | eer = np.mean((frr[min_index], far[min_index]))
48 | return eer, thresholds[min_index]
49 |
50 |
51 | def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, print_cost=False):
52 | """
53 | Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system.
54 | In brief, t-DCF returns a detection cost of a cascaded system of this form,
55 |
56 | Speech waveform -> [CM] -> [ASV] -> decision
57 |
58 | where CM stands for countermeasure and ASV for automatic speaker
59 | verification. The CM is therefore used as a 'gate' to decided whether or
60 | not the input speech sample should be passed onwards to the ASV system.
61 | Generally, both CM and ASV can do detection errors. Not all those errors
62 | are necessarily equally cost, and not all types of users are necessarily
63 | equally likely. The tandem t-DCF gives a principled with to compare
64 | different spoofing countermeasures under a detection cost function
65 | framework that takes that information into account.
66 |
67 | INPUTS:
68 |
69 | bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human)
70 | detection scores obtained by executing a spoofing
71 | countermeasure (CM) on some positive evaluation trials.
72 | trial represents a bona fide case.
73 | spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack)
74 | detection scores obtained by executing a spoofing
75 | CM on some negative evaluation trials.
76 | Pfa_asv False alarm (false acceptance) rate of the ASV
77 | system that is evaluated in tandem with the CM.
78 | Assumed to be in fractions, not percentages.
79 | Pmiss_asv Miss (false rejection) rate of the ASV system that
80 | is evaluated in tandem with the spoofing CM.
81 | Assumed to be in fractions, not percentages.
82 | Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that
83 | is evaluated in tandem with the spoofing CM. That
84 | is, the fraction of spoof samples that were
85 | rejected by the ASV system.
86 | cost_model A struct that contains the parameters of t-DCF,
87 | with the following fields.
88 |
89 | Ptar Prior probability of target speaker.
90 | Pnon Prior probability of nontarget speaker (zero-effort impostor)
91 | Psoof Prior probability of spoofing attack.
92 | Cmiss_asv Cost of ASV falsely rejecting target.
93 | Cfa_asv Cost of ASV falsely accepting nontarget.
94 | Cmiss_cm Cost of CM falsely rejecting target.
95 | Cfa_cm Cost of CM falsely accepting spoof.
96 |
97 | print_cost Print a summary of the cost parameters and the
98 | implied t-DCF cost function?
99 |
100 | OUTPUTS:
101 |
102 | tDCF_norm Normalized t-DCF curve across the different CM
103 | system operating points; see [2] for more details.
104 | Normalized t-DCF > 1 indicates a useless
105 | countermeasure (as the tandem system would do
106 | better without it). min(tDCF_norm) will be the
107 | minimum t-DCF used in ASVspoof 2019 [2].
108 | CM_thresholds Vector of same size as tDCF_norm corresponding to
109 | the CM threshold (operating point).
110 |
111 | NOTE:
112 | o In relative terms, higher detection scores values are assumed to
113 | indicate stronger support for the bona fide hypothesis.
114 | o You should provide real-valued soft scores, NOT hard decisions. The
115 | recommendation is that the scores are log-likelihood ratios (LLRs)
116 | from a bonafide-vs-spoof hypothesis based on some statistical model.
117 | This, however, is NOT required. The scores can have arbitrary range
118 | and scaling.
119 | o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages.
120 |
121 | References:
122 |
123 | [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco,
124 | M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection
125 | Cost Function for the Tandem Assessment of Spoofing Countermeasures
126 | and Automatic Speaker Verification", Proc. Odyssey 2018: the
127 | Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne,
128 | France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf)
129 |
130 | [2] ASVspoof 2019 challenge evaluation plan
131 | TODO:
132 | """
133 |
134 |
135 | # Sanity check of cost parameters
136 | if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \
137 | cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0:
138 | print('WARNING: Usually the cost values should be positive!')
139 |
140 | if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \
141 | np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10:
142 | sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.')
143 |
144 | # Unless we evaluate worst-case model, we need to have some spoof tests against asv
145 | if Pmiss_spoof_asv is None:
146 | sys.exit('ERROR: you should provide miss rate of spoof tests against your ASV system.')
147 |
148 | # Sanity check of scores
149 | combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm))
150 | if np.isnan(combined_scores).any() or np.isinf(combined_scores).any():
151 | sys.exit('ERROR: Your scores contain nan or inf.')
152 |
153 | # Sanity check that inputs are scores and not decisions
154 | n_uniq = np.unique(combined_scores).size
155 | if n_uniq < 3:
156 | sys.exit('ERROR: You should provide soft CM scores - not binary decisions')
157 |
158 | # Obtain miss and false alarm rates of CM
159 | Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm)
160 |
161 | # Constants - see ASVspoof 2019 evaluation plan
162 | C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
163 | cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv
164 | C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)
165 |
166 | # Sanity check of the weights
167 | if C1 < 0 or C2 < 0:
168 | sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?')
169 |
170 | # Obtain t-DCF curve for all thresholds
171 | tDCF = C1 * Pmiss_cm + C2 * Pfa_cm
172 |
173 | # Normalized t-DCF
174 | tDCF_norm = tDCF / np.minimum(C1, C2)
175 |
176 | # Everything should be fine if reaching here.
177 | if print_cost:
178 |
179 | print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size))
180 | print('t-DCF MODEL')
181 | print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar']))
182 | print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon']))
183 | print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof']))
184 | print(' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'.format(cost_model['Cfa_asv']))
185 | print(' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'.format(cost_model['Cmiss_asv']))
186 | print(' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'.format(cost_model['Cfa_cm']))
187 | print(' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'.format(cost_model['Cmiss_cm']))
188 | print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)')
189 |
190 | if C2 == np.minimum(C1, C2):
191 | print(' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(C1 / C2))
192 | else:
193 | print(' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(C2 / C1))
194 |
195 | return tDCF_norm, CM_thresholds
196 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilization functions
3 | """
4 |
5 | import os
6 | import random
7 | import sys
8 |
9 | import numpy as np
10 | import torch
11 |
12 | # from warmup_scheduler import GradualWarmupScheduler
13 |
14 |
15 | def str_to_bool(val):
16 | """Convert a string representation of truth to true (1) or false (0).
17 | Copied from the python implementation distutils.utils.strtobool
18 |
19 | True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
20 | are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
21 | 'val' is anything else.
22 | >>> str_to_bool('YES')
23 | 1
24 | >>> str_to_bool('FALSE')
25 | 0
26 | """
27 | val = val.lower()
28 | if val in ('y', 'yes', 't', 'true', 'on', '1'):
29 | return True
30 | if val in ('n', 'no', 'f', 'false', 'off', '0'):
31 | return False
32 | raise ValueError('invalid truth value {}'.format(val))
33 |
34 |
35 | def cosine_annealing(step, total_steps, lr_max, lr_min):
36 | """Cosine Annealing for learning rate decay scheduler"""
37 | return lr_min + (lr_max -
38 | lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
39 |
40 |
41 | def keras_decay(step, decay=0.0001):
42 | """Learning rate decay in Keras-style"""
43 | return 1. / (1. + decay * step)
44 |
45 |
46 | class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
47 | """SGD with restarts scheduler"""
48 | def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
49 | self.Ti = T0
50 | self.T_mul = T_mul
51 | self.eta_min = eta_min
52 |
53 | self.last_restart = 0
54 |
55 | super().__init__(optimizer, last_epoch)
56 |
57 | def get_lr(self):
58 | T_cur = self.last_epoch - self.last_restart
59 | if T_cur >= self.Ti:
60 | self.last_restart = self.last_epoch
61 | self.Ti = self.Ti * self.T_mul
62 | T_cur = 0
63 |
64 | return [
65 | self.eta_min + (base_lr - self.eta_min) *
66 | (1 + np.cos(np.pi * T_cur / self.Ti)) / 2
67 | for base_lr in self.base_lrs
68 | ]
69 |
70 |
71 | def _get_optimizer(model_parameters, optim_config):
72 | """Defines optimizer according to the given config"""
73 | optimizer_name = optim_config['optimizer']
74 |
75 | if optimizer_name == 'sgd':
76 | optimizer = torch.optim.SGD(model_parameters,
77 | lr=optim_config['base_lr'],
78 | momentum=optim_config['momentum'],
79 | weight_decay=optim_config['weight_decay'],
80 | nesterov=optim_config['nesterov'])
81 | elif optimizer_name == 'adam':
82 | optimizer = torch.optim.Adam(model_parameters,
83 | lr=optim_config['base_lr'],
84 | betas=optim_config['betas'],
85 | weight_decay=optim_config['weight_decay'],
86 | amsgrad=str_to_bool(
87 | optim_config['amsgrad']))
88 | else:
89 | print('Un-known optimizer', optimizer_name)
90 | sys.exit()
91 |
92 | return optimizer
93 |
94 |
95 | def _get_scheduler(optimizer, optim_config):
96 | """
97 | Defines learning rate scheduler according to the given config
98 | """
99 | if optim_config['scheduler'] == 'multistep':
100 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
101 | optimizer,
102 | milestones=optim_config['milestones'],
103 | gamma=optim_config['lr_decay'])
104 |
105 | elif optim_config['scheduler'] == 'sgdr':
106 | scheduler = SGDRScheduler(optimizer, optim_config['T0'],
107 | optim_config['Tmult'],
108 | optim_config['lr_min'])
109 |
110 | elif optim_config['scheduler'] == 'cosine':
111 | total_steps = optim_config['epochs'] * \
112 | optim_config['steps_per_epoch']
113 |
114 | scheduler = torch.optim.lr_scheduler.LambdaLR(
115 | optimizer,
116 | lr_lambda=lambda step: cosine_annealing(
117 | step,
118 | total_steps,
119 | 1, # since lr_lambda computes multiplicative factor
120 | optim_config['lr_min'] / optim_config['base_lr']))
121 |
122 | elif optim_config['scheduler'] == 'keras_decay':
123 | scheduler = torch.optim.lr_scheduler.LambdaLR(
124 | optimizer, lr_lambda=lambda step: keras_decay(step))
125 | else:
126 | scheduler = None
127 | return scheduler
128 |
129 |
130 | def create_optimizer(model_parameters, optim_config):
131 | """Defines an optimizer and a scheduler"""
132 | optimizer = _get_optimizer(model_parameters, optim_config)
133 | scheduler = _get_scheduler(optimizer, optim_config)
134 | return optimizer, scheduler
135 |
136 |
137 | def seed_worker(worker_id):
138 | """
139 | Used in generating seed for the worker of torch.utils.data.Dataloader
140 | """
141 | worker_seed = torch.initial_seed() % 2**32
142 | np.random.seed(worker_seed)
143 | random.seed(worker_seed)
144 |
145 |
146 | def set_seed(seed, config = None):
147 | """
148 | set initial seed for reproduction
149 | """
150 | if config is None:
151 | raise ValueError("config should not be None")
152 |
153 | random.seed(seed)
154 | np.random.seed(seed)
155 | torch.manual_seed(seed)
156 | if torch.cuda.is_available():
157 | torch.cuda.manual_seed_all(seed)
158 | torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"])
159 | torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"])
160 |
--------------------------------------------------------------------------------