├── ims └── Figs_framework.jpg ├── indices_gen.py ├── LICENSE ├── data_preprocess ├── preprocess_VGGsound50K_2.py ├── preprocess_RAVDESS_0.py ├── preprocess_VGGsound50K_1.py ├── preprocess_RAVDESS_2.py ├── preprocess_VGGsound50K_3.py ├── preprocess_CrisisMMD_1.py ├── preprocess_RAVDESS_1.py ├── preprocess_VGGsound50K_0.py └── preprocess_CrisisMMD_0.py ├── KD_methods ├── KD.py ├── MLLD.py ├── MGDFR.py ├── RKD.py ├── C2KD.py ├── OFA.py └── CRD.py ├── models ├── LeNet5.py ├── MLPs.py ├── CNNs.py └── SeqNets.py ├── test-T.py ├── test-S.py ├── Dataset.py ├── .gitignore ├── README.md ├── main-S.py ├── main-T.py ├── main-KD-UU.py ├── main-MLLD-UU.py ├── main-RKD-UU.py ├── main-RKD.py ├── main-OFA.py └── main-KD.py /ims/Figs_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gray-OREO/MST-Distill/HEAD/ims/Figs_framework.jpg -------------------------------------------------------------------------------- /indices_gen.py: -------------------------------------------------------------------------------- 1 | from utils import get_data, seed_all 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | 6 | 7 | def generate_random_indices(data_length, num_groups=20, output_file="random_indices.csv"): 8 | # 创建一个空的DataFrame来存放随机索引 9 | random_indices = pd.DataFrame() 10 | 11 | # 生成 num_groups 组随机索引 12 | for i in range(num_groups): 13 | indices = np.random.permutation(data_length) # 生成一个乱序的索引数组 14 | random_indices[f'group_{i}'] = indices # 将该数组作为新列添加到DataFrame 15 | 16 | # 保存为csv文件 17 | random_indices.to_csv(output_file, index=False) 18 | print(f"{output_file} generated successfully!") 19 | 20 | 21 | if __name__ == "__main__": 22 | dataset_name = 'CMMD-V2' 23 | save_path = 'metadata/' 24 | os.makedirs(save_path, exist_ok=True) 25 | seed_all(19980427) # 设置随机种子 26 | data1, data2, labels = get_data(dataset_name) 27 | num_data = len(data1) 28 | generate_random_indices(num_data, num_groups=20, output_file=f"{save_path}{dataset_name}_indices.csv") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Hui Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_VGGsound50K_2.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | 6 | ''' 7 | 处理segment标签的代码 8 | ''' 9 | 10 | output_path = 'E:/Gray/Database/VGGSound-AVEL50k/seg_label/' 11 | if not os.path.exists(output_path): 12 | os.makedirs(output_path) 13 | 14 | # 读取JSON文件内容 15 | with open('E:/Gray/Database/VGGSound-AVEL50k/vggsound-avel50k_labels.json', 'r') as file: 16 | data = json.load(file) 17 | 18 | with open('D:/Gray/Database/VGGS50K/VGGS50k_metadata.txt', 'r') as txt_file: 19 | txt_info = txt_file.read().splitlines() 20 | 21 | vid2vinfo = {} 22 | for item in txt_info: 23 | vid = item[1:12] 24 | sample_name = item.split('&')[0] 25 | vid2vinfo.update({vid: sample_name}) 26 | 27 | for item in tqdm(data, total=len(data), desc='Processing...'): 28 | for video_id, info in item.items(): 29 | # 获取label数据并转换为ndarray 30 | label = np.array(info['label']).reshape(1, 10) 31 | # 将ndarray保存为.npy文件 32 | try: 33 | sample_name = vid2vinfo[video_id] 34 | np.save(f'{output_path}{sample_name}_sLabel.npy', label) 35 | except: 36 | pass 37 | -------------------------------------------------------------------------------- /KD_methods/KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def distillation_loss(args, student_outputs, teacher_outputs, T=4.0): 6 | if args.database == 'NYU-Depth-V2': 7 | pred = student_outputs 8 | tar = teacher_outputs.detach() 9 | student_probs_T = F.log_softmax(pred / T, dim=1) 10 | teacher_probs_T_0 = F.softmax(tar / T, dim=1) 11 | kl_loss = F.kl_div(student_probs_T, teacher_probs_T_0, reduction='batchmean') * (T * T) * 1/(pred.shape[2] * pred.shape[3]) 12 | 13 | elif args.database == 'VGGSound-50k' and args.Tmodel == 'CPSP': 14 | teacher_outputs = teacher_outputs[1].detach() 15 | student_probs_T = F.log_softmax(student_outputs / T, dim=1) 16 | teacher_probs_T = F.softmax(teacher_outputs / T, dim=1) 17 | kl_loss = F.kl_div(student_probs_T, teacher_probs_T, reduction='batchmean') * (T * T) 18 | 19 | else: 20 | teacher_outputs = teacher_outputs.detach() 21 | student_probs_T = F.log_softmax(student_outputs / T, dim=1) 22 | teacher_probs_T = F.softmax(teacher_outputs / T, dim=1) 23 | kl_loss = F.kl_div(student_probs_T, teacher_probs_T, reduction='batchmean') * (T * T) 24 | 25 | return kl_loss 26 | 27 | 28 | if __name__ == '__main__': 29 | pred = torch.randn(32, 10) 30 | target = torch.randn(32, 10) 31 | 32 | kl_loss = distillation_loss(pred, target) 33 | print(kl_loss) -------------------------------------------------------------------------------- /data_preprocess/preprocess_RAVDESS_0.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import os 3 | import soundfile as sf 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | # audiofile = 'E://OpenDR_datasets//RAVDESS//Actor_19//03-01-07-02-01-02-19.wav' 8 | ##this file preprocess audio files to ensure they are of the same length. if length is less than 3.6 seconds, it is padded with zeros in the end. otherwise, it is equally cropped from 9 | ##both sides 10 | 11 | root = 'E:/Gray/Database/RAVDESS/Audio_Speech' 12 | tar_root = 'E:/Gray/Database/RAVDESS-preprocessed' 13 | target_time = 3.6 # sec 14 | for actor in tqdm(os.listdir(root), total=len(os.listdir(root)), desc='Precessing audio data...',): 15 | for audiofile in os.listdir(os.path.join(root, actor)): 16 | 17 | if not audiofile.endswith('.wav') or 'croppad' in audiofile: 18 | continue 19 | 20 | audios = librosa.core.load(os.path.join(root, actor, audiofile), sr=22050) 21 | 22 | y = audios[0] 23 | sr = audios[1] 24 | target_length = int(sr * target_time) 25 | if len(y) < target_length: 26 | y = np.array(list(y) + [0 for i in range(target_length - len(y))]) 27 | else: 28 | remain = len(y) - target_length 29 | y = y[remain // 2:-(remain - remain // 2)] 30 | 31 | os.makedirs(os.path.join(tar_root, actor), exist_ok=True) 32 | sf.write(os.path.join(tar_root, actor, audiofile[:-4] + '_croppad.wav'), y, sr) 33 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_VGGsound50K_1.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import tqdm 4 | import subprocess 5 | import logging 6 | import numpy as np 7 | 8 | 9 | def read_txt_to_list(file_path): 10 | with open(file_path, 'r', encoding='utf-8') as file: 11 | lines = file.readlines() # 读取所有行并保存为列表 12 | lines = [line.strip() for line in lines] # 去掉每行末尾的换行符 13 | return lines 14 | 15 | 16 | if __name__ == '__main__': 17 | # data_a = np.load('E:/Gray/Project/feature_extractor/VGGS50K_features/audio_features/v6bPsB1h3wvE_110_120_out_aFeature.npy') 18 | # data_v = np.load('E:/Gray/Project/feature_extractor/VGGS50K_features/visual_features/v6bPsB1h3wvE_110_120_out_vFeature.npy') 19 | # print(data_a.shape) 20 | # print(data_v.shape) 21 | 22 | with open('D:/Gray/Database/VGGS50K/VGGS50K_videos.txt', 'r', encoding='utf-8') as file: 23 | lines = file.readlines() # 读取所有行并保存为列表 24 | data = [line.strip() for line in lines] # 去掉每行末尾的换行符 25 | 26 | feature_path = 'E:/Gray/Project/feature_extractor/VGGS50K_features/' 27 | v_features, a_features = [], [] 28 | for root, dirs, files in os.walk(feature_path+'visual_features'): 29 | for file in files: 30 | v_features.append(file[1:12]) 31 | for root, dirs, files in os.walk(feature_path+'audio_features'): 32 | for file in files: 33 | a_features.append(file[1:12]) 34 | 35 | for line in data: 36 | sample_id = line.split('/')[6].split('.')[0][1:12] 37 | file_name = line.split('/')[6].split('.')[0] 38 | label = line.split('&')[-1] 39 | if sample_id in v_features and sample_id in a_features: 40 | with open('D:/Gray/Database/VGGS50K/VGGS50k_metadata.txt', 'a', encoding='utf-8') as f: 41 | f.write(f'{file_name}&{label}\n') 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_RAVDESS_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import librosa 4 | from collections import Counter 5 | 6 | 7 | def video_loader(video_dir_path): 8 | video = np.load(video_dir_path) # [N, H, W, C] 9 | video_data = video.transpose(0, 3, 1, 2) 10 | return video_data 11 | 12 | 13 | def load_audio(audiofile, sr=22050): 14 | audios, sr = librosa.core.load(audiofile, sr=sr) 15 | mfcc = librosa.feature.mfcc(y=audios, sr=sr, n_mfcc=15) 16 | return mfcc 17 | 18 | 19 | if __name__ == '__main__': 20 | # data1 = video_loader('E:/Gray/Database/RAVDESS-Speech/Actor_01/01-01-01-01-01-01-01_facecroppad.npy') 21 | # data2 = load_audio('E:/Gray/Database/RAVDESS-Speech/Actor_01/03-01-01-01-01-01-01_croppad.wav') 22 | 23 | video_list = [] 24 | tar_root = 'E:/Gray/Database/RAVDESS_preprocessed_npy/' 25 | os.makedirs(tar_root, exist_ok=True) 26 | for root, dirs, files in os.walk('E:/Gray/Database/RAVDESS-preprocessed/'): 27 | for file in files: 28 | if file.endswith('.npy'): 29 | video_list.append(root+'/'+file) 30 | 31 | video_data, audio_data, label_data = [], [], [] 32 | for i in range(len(video_list)): 33 | """ 34 | root = 'E:/Gray/Database/RAVDESS-Speech/Actor_01/' 35 | sample_name = '01-01-01-01-01-01-01' 36 | 'facecroppad.npy' 37 | """ 38 | root = video_list[i].split('_')[0]+'_'+video_list[i].split('_')[1].split('/')[0]+'/' 39 | sample_name = video_list[i].split('_')[1].split('/')[1] 40 | no_mode_name = sample_name[2:] 41 | label = int(sample_name.split('-')[2]) - 1 42 | video_dir = root+'01'+no_mode_name+'_facecroppad.npy' 43 | audio_dir = root+'03'+no_mode_name+'_croppad.wav' 44 | video_data.append(video_loader(video_dir)) 45 | audio_data.append(load_audio(audio_dir)) 46 | label_data.append(label) 47 | np.save(f'{tar_root}video_data.npy', np.array(video_data)) 48 | np.save(f'{tar_root}audio_data.npy', np.array(audio_data)) 49 | np.save(f'{tar_root}label_data.npy', np.array(label_data)) 50 | # # 使用Counter统计元素出现次数 51 | # counter = Counter(label_data) 52 | # 53 | # # 输出结果 54 | # print(counter) 55 | -------------------------------------------------------------------------------- /models/LeNet5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LeNet5(nn.Module): 7 | def __init__(self, cls_num=10): 8 | super(LeNet5, self).__init__() 9 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2) 10 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 11 | self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5) 12 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 13 | self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5) 14 | self.fc1 = nn.Linear(in_features=120, out_features=84) 15 | self.fc2 = nn.Linear(in_features=84, out_features=cls_num) 16 | self.flatten_identity = nn.Identity() 17 | self.hook_names = ['conv1', 'conv2', 'conv3', 'fc1', 'fc2', 'flatten_identity'] 18 | 19 | def forward(self, x): 20 | # 卷积层C1 + 激活函数 + 池化层S2 21 | x = self.pool1(F.relu(self.conv1(x))) 22 | # 卷积层C3 + 激活函数 + 池化层S4 23 | x = self.pool2(F.relu(self.conv2(x))) 24 | # 卷积层C5 + 激活函数 25 | x = F.relu(self.conv3(x)) 26 | # 将特征图展平为一维向量 27 | x = x.view(-1, 120) 28 | x = self.flatten_identity(x) 29 | # 全连接层F6 + 激活函数 30 | x = F.relu(self.fc1(x)) 31 | # 输出层 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | class LeNet5_woClsHead(nn.Module): 37 | def __init__(self,): 38 | super(LeNet5_woClsHead, self).__init__() 39 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2) 40 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 41 | self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5) 42 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 43 | self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5) 44 | self.fc1 = nn.Linear(in_features=120, out_features=84) 45 | self.hook_names = ['conv1', 'conv2', 'fc1'] 46 | 47 | def forward(self, x): 48 | x = self.pool1(F.relu(self.conv1(x))) 49 | x = self.pool2(F.relu(self.conv2(x))) 50 | x = F.relu(self.conv3(x)) 51 | x = x.view(-1, 120) 52 | x = F.relu(self.fc1(x)) 53 | return x 54 | 55 | 56 | if __name__ == '__main__': 57 | model = LeNet5() 58 | print(model) 59 | 60 | input_tensor = torch.randn(1, 1, 28, 28) 61 | output = model(input_tensor) 62 | print(output) 63 | -------------------------------------------------------------------------------- /KD_methods/MLLD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def kd_loss(logits_student, logits_teacher, temperature): 7 | """ 8 | 实例级对齐损失 (Instance-level Alignment) 9 | 标准KL散度损失,必须保持T² 10 | """ 11 | loss_kd = F.kl_div( 12 | F.log_softmax(logits_student / temperature, dim=1), 13 | F.softmax(logits_teacher / temperature, dim=1), 14 | reduction='batchmean' 15 | ) 16 | return loss_kd * (temperature ** 2) 17 | 18 | 19 | def cc_loss(logits_student, logits_teacher, temperature): 20 | """ 21 | 类别级对齐损失 (Class-level Alignment) 22 | 改进归一化:使用MSE而不是sum()÷C 23 | """ 24 | pred_student = F.softmax(logits_student / temperature, dim=1) 25 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 26 | 27 | # 类别相关性矩阵:M = P^T @ P [C, C] 28 | student_matrix = torch.mm(pred_student.t(), pred_student) 29 | teacher_matrix = torch.mm(pred_teacher.t(), pred_teacher) 30 | 31 | # 改进:使用MSE归一化,数学上更合理 32 | return F.mse_loss(student_matrix, teacher_matrix, reduction='mean') 33 | 34 | 35 | def bc_loss(logits_student, logits_teacher, temperature): 36 | """ 37 | 批次级对齐损失 (Batch-level Alignment) 38 | 改进归一化:使用MSE而不是sum()÷B 39 | """ 40 | pred_student = F.softmax(logits_student / temperature, dim=1) 41 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 42 | 43 | # Gram矩阵:G = P @ P^T [B, B] 44 | student_matrix = torch.mm(pred_student, pred_student.t()) 45 | teacher_matrix = torch.mm(pred_teacher, pred_teacher.t()) 46 | 47 | # 改进:使用MSE归一化,数学上更合理 48 | return F.mse_loss(student_matrix, teacher_matrix, reduction='mean') 49 | 50 | 51 | class MultiLevelLogitDistillation(nn.Module): 52 | """ 53 | 改进的Multi-level Logit Distillation 54 | - 保持等权重(遵循原论文) 55 | - 改进CC/BC归一化(解决数值问题) 56 | - 对温度数量归一化(提高稳定性) 57 | """ 58 | 59 | def __init__(self, temperatures=[2.0, 3.0, 4.0, 5.0, 6.0]): 60 | super().__init__() 61 | self.temperatures = temperatures 62 | 63 | def forward(self, logits_student, logits_teacher): 64 | """ 65 | 计算多层次logit蒸馏损失 66 | """ 67 | total_kd_loss = 0.0 68 | total_cc_loss = 0.0 69 | total_bc_loss = 0.0 70 | 71 | logits_teacher = logits_teacher.detach() # 确保教师logits不参与梯度计算 72 | 73 | # 对每个温度参数计算损失 74 | for temp in self.temperatures: 75 | total_kd_loss += kd_loss(logits_student, logits_teacher, temp) 76 | total_cc_loss += cc_loss(logits_student, logits_teacher, temp) 77 | total_bc_loss += bc_loss(logits_student, logits_teacher, temp) 78 | 79 | # 等权重求和,对温度数量归一化 80 | num_temps = len(self.temperatures) 81 | total_loss = (total_kd_loss + total_cc_loss + total_bc_loss) / num_temps 82 | 83 | return total_loss / 3 -------------------------------------------------------------------------------- /data_preprocess/preprocess_VGGsound50K_3.py: -------------------------------------------------------------------------------- 1 | from utils import _obtain_avel_label 2 | import platform 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import os 7 | 8 | def get_vggsound(rootpath): 9 | root = rootpath + 'VGGS50K/' if platform.system() == 'Linux' else 'D:/Gray/Database/VGGS50K/' 10 | vf_list, af_list, labels = [], [], [] 11 | with open(root + 'VGGS50k_metadata.txt', 'r', encoding='utf-8') as file: 12 | lines = file.readlines() # 读取所有行并保存为列表 13 | data = [line.strip() for line in lines] # 去掉每行末尾的换行符 14 | # data = data[:1000] # For debug(failed for contrastive learning) 15 | for line in tqdm(data, total=len(data), desc='Data loading...'): 16 | sample_name = line.split('&')[0] 17 | label = int(line.split('&')[1]) 18 | vf = root + 'VGGS50K_features/visual_features/' + f'{sample_name}_vFeature.npy' 19 | af = root + 'VGGS50K_features/audio_features/' + f'{sample_name}_aFeature.npy' 20 | avc_label = np.load(root + 'seg_labels/' + f'{sample_name}_sLabel.npy') 21 | slabel = _obtain_avel_label(avc_label, label) 22 | vf_list.append(torch.from_numpy(np.load(vf).astype(np.float32))) 23 | af_list.append(torch.from_numpy(np.load(af))) 24 | labels.append(torch.from_numpy(slabel.astype(np.float32))) 25 | return vf_list, af_list, labels 26 | 27 | 28 | def get_max_min(data_list): 29 | all_values = np.concatenate([tensor.flatten() for tensor in data_list]) 30 | global_min = all_values.min() 31 | global_max = all_values.max() 32 | return global_min, global_max 33 | 34 | 35 | if __name__ == '__main__': 36 | root = 'D:/Gray/Database/' if platform.system() == 'Windows' else '/home/gray/Database/' 37 | save_path_v = root + 'VGGS50K/VGGS50K_features_normed/visual_features/' 38 | save_path_a = root + 'VGGS50K/VGGS50K_features_normed/audio_features/' 39 | 40 | os.makedirs(save_path_v, exist_ok=True) 41 | os.makedirs(save_path_a, exist_ok=True) 42 | 43 | vf_list, af_list, _ = get_vggsound(root) 44 | v_min, v_max = get_max_min(vf_list) 45 | a_min, a_max = get_max_min(af_list) 46 | 47 | with open(root + 'VGGS50K/VGGS50k_metadata.txt', 'r', encoding='utf-8') as file: 48 | lines = file.readlines() # 读取所有行并保存为列表 49 | data = [line.strip() for line in lines] # 去掉每行末尾的换行符 50 | for line in tqdm(data, total=len(data), desc='Data norm & save...'): 51 | sample_name = line.split('&')[0] 52 | label = int(line.split('&')[1]) 53 | vf = root + 'VGGS50K/VGGS50K_features/visual_features/' + f'{sample_name}_vFeature.npy' 54 | af = root + 'VGGS50K/VGGS50K_features/audio_features/' + f'{sample_name}_aFeature.npy' 55 | vf = np.load(vf) 56 | af = np.load(af) 57 | vf = (vf - v_min) / (v_max - v_min) 58 | af = (af - a_min) / (a_max - a_min) 59 | np.save(save_path_v + f'{sample_name}_vFeature.npy', vf) 60 | np.save(save_path_a + f'{sample_name}_aFeature.npy', af) 61 | 62 | -------------------------------------------------------------------------------- /models/MLPs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class visualMLP(nn.Module): 7 | def __init__(self, cls_num=8): 8 | super(visualMLP, self).__init__() 9 | self.fc1 = nn.Linear(2048, 1024) 10 | self.fc2 = nn.Linear(1024, 512) 11 | self.fc3 = nn.Linear(512, cls_num) 12 | self.hook_names = ['fc1', 'fc2', 'fc3'] 13 | 14 | def forward(self, x): 15 | x = F.relu(self.fc1(x)) 16 | x = F.relu(self.fc2(x)) 17 | x = self.fc3(x) 18 | return x 19 | 20 | 21 | class textualMLP(nn.Module): 22 | def __init__(self, cls_num=8): 23 | super(textualMLP, self).__init__() 24 | self.fc1 = nn.Linear(768, 1024) 25 | self.fc2 = nn.Linear(1024, 512) 26 | self.fc3 = nn.Linear(512, cls_num) 27 | self.hook_names = ['fc1', 'fc2', 'fc3'] 28 | 29 | def forward(self, x): 30 | x = F.relu(self.fc1(x)) 31 | x = F.relu(self.fc2(x)) 32 | x = self.fc3(x) 33 | return x 34 | 35 | 36 | class visualMLP_woClsHead(nn.Module): 37 | def __init__(self): 38 | super(visualMLP_woClsHead, self).__init__() 39 | self.fc1 = nn.Linear(2048, 1024) 40 | self.fc2 = nn.Linear(1024, 512) 41 | self.hook_names = ['fc1', 'fc2'] 42 | 43 | def forward(self, x): 44 | x = F.relu(self.fc1(x)) 45 | x = F.relu(self.fc2(x)) 46 | return x 47 | 48 | 49 | class textualMLP_woClsHead(nn.Module): 50 | def __init__(self): 51 | super(textualMLP_woClsHead, self).__init__() 52 | self.fc1 = nn.Linear(768, 1024) 53 | self.fc2 = nn.Linear(1024, 512) 54 | self.hook_names = ['fc1', 'fc2'] 55 | 56 | def forward(self, x): 57 | x = F.relu(self.fc1(x)) 58 | x = F.relu(self.fc2(x)) 59 | return x 60 | 61 | 62 | class IntermediateFusionMLP(nn.Module): 63 | def __init__(self): 64 | super(IntermediateFusionMLP, self).__init__() 65 | self.VisualNet = visualMLP_woClsHead() 66 | self.TextualNet = textualMLP_woClsHead() 67 | self.fc1 = nn.Linear(512+512, 512) 68 | self.fc2 = nn.Linear(512, 256) 69 | self.fc3 = nn.Linear(256, 8) 70 | self.hook_names = ['VisualNet.fc1', 'VisualNet.fc2', 'TextualNet.fc1', 'TextualNet.fc2', 71 | 'fc1', 'fc2', 'fc3'] 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.VisualNet(x1) 75 | x2 = self.TextualNet(x2) 76 | x = torch.cat((x1, x2), dim=1) 77 | x = F.relu(self.fc1(x)) 78 | x = F.relu(self.fc2(x)) 79 | x = self.fc3(x) 80 | return x 81 | 82 | 83 | if __name__ == '__main__': 84 | x1 = torch.randn(1, 2048) 85 | x2 = torch.randn(1, 768) 86 | 87 | # Tea.-MM-I 88 | model = IntermediateFusionMLP() 89 | y = model(x1, x2) 90 | 91 | # Stu.-V 92 | # model = visualMLP() 93 | # y = model(x1) 94 | 95 | # Stu.-T 96 | # model = textualMLP() 97 | # y = model(x2) 98 | 99 | print(y.size()) 100 | print(y) -------------------------------------------------------------------------------- /test-T.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from torch.optim import Adam 5 | from utils import * 6 | 7 | import numpy as np 8 | import torch.optim as optim 9 | import time 10 | from argparse import ArgumentParser 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | from Dataset import MultiModalX 14 | import sys 15 | 16 | 17 | if __name__ == '__main__': 18 | ''' 19 | Args Setting for CML. 20 | ''' 21 | parser = ArgumentParser(description='CML-TO') 22 | parser.add_argument('--database', type=str, default='AV-MNIST', 23 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 24 | parser.add_argument('--Tmodel', type=str, default='CNN-I', 25 | help='Teacher model name') 26 | 27 | parser.add_argument('--ckpt_name', type=str, 28 | default='AV-MNIST_CNN-I_seed0_ORG_ep97-100.pth', 29 | help='The name of the weight to be loaded in ./checkpoints/tea') 30 | parser.add_argument('--seed', type=int, default=0, 31 | help='Random seed') 32 | parser.add_argument('--mode', type=str, default='ORG', 33 | help='Data mode: ORG or m1-MSK-0.1 or m2-GN-0.01') 34 | parser.add_argument('--epochs', type=int, default=100, 35 | help='Epoch for training, invalid in this file but for get_module runing.') 36 | parser.add_argument('--batch_size', type=int, default=1, 37 | help='batch size for training') 38 | parser.add_argument('--lr', type=float, default=0.0001, 39 | help='learning rate for training') 40 | parser.add_argument('--record', type=bool, default=True, 41 | help='flag whether to record the learning log') 42 | parser.add_argument('--cuda_id', type=int, default=0, 43 | help='cuda id') 44 | parser.add_argument('--freeze_bn', type=bool, default=True, 45 | help='flag whether to freeze BN layers in the model') 46 | args = parser.parse_args() 47 | 48 | seed_all(args.seed) 49 | 50 | data = get_data(args.database) 51 | data_test = get_dataset(args.database, data, 'test', args.seed) 52 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 53 | 54 | test_loader = DataLoader( 55 | test_dataset, 56 | pin_memory=True, 57 | batch_size=args.batch_size 58 | ) 59 | 60 | # ===========GPU Setting==================== 61 | device = torch.device(f"cuda:{args.cuda_id}") 62 | # ==========Initialization=========== 63 | model_t, optimizers, scheduler_t, criterion, preprocessing, postprocessing, metric = get_Tmodules(args, device) 64 | 65 | model_t = model_t.to(device) 66 | model_t.load_state_dict(torch.load(f'checkpoints/tea/{args.ckpt_name}', map_location=device, weights_only=True)) 67 | # test 68 | model_t.eval() 69 | metric.reset() 70 | gt_list, pred_list = [], [] 71 | with torch.no_grad(): 72 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 73 | total=len(test_loader), dynamic_ncols=True, disable=True, file=sys.stdout): 74 | data, data2, label = data.to(device), data2.to(device), label.to(device) 75 | outputs_t = model_t(data, data2) if preprocessing is None else preprocessing(model_t, data, data2) 76 | metric.update(outputs_t, label) 77 | res = metric.compute() 78 | 79 | print('\n===============Metrics==================') 80 | for e in res.keys(): 81 | print(e) 82 | print(res[e]) 83 | print('----------------------------') 84 | print('=======================================\n') 85 | 86 | -------------------------------------------------------------------------------- /test-S.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import get_data, get_dataset, get_Smodules, seed_all 4 | from argparse import ArgumentParser 5 | from tqdm import tqdm 6 | from Dataset import SingleModalX 7 | import sys 8 | 9 | 10 | if __name__ == '__main__': 11 | ''' 12 | Args Setting for CML. 13 | ''' 14 | parser = ArgumentParser(description='CML-SO') 15 | parser.add_argument('--database', type=str, default='CMMD-V2', 16 | help="database name must be one of " 17 | "['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 18 | parser.add_argument('--Smodel', type=str, default='MLP-Vb', 19 | help='Student model name') 20 | 21 | parser.add_argument('--KD_method', type=str, default='KD', 22 | help='The name of the dictionary to be loaded in ./checkpoints/stu/...') 23 | parser.add_argument('--ckpt_name', type=str, 24 | default='Your checkpoint name', 25 | help='The name of the weight to be loaded in ./checkpoints/stu/{KD method}/...') 26 | parser.add_argument('--seed', type=int, default=0, 27 | help='Random seed') 28 | parser.add_argument('--mode', type=str, default='m1', 29 | help='Data mode: m1 or m2') 30 | 31 | parser.add_argument('--batch_size', type=int, default=1, 32 | help='batch size for training') 33 | parser.add_argument('--lr', type=float, default=0.0001, 34 | help='learning rate for training') 35 | parser.add_argument('--record', type=bool, default=True, 36 | help='flag whether to record the learning log') 37 | parser.add_argument('--cuda_id', type=int, default=0, 38 | help='cuda id') 39 | parser.add_argument('--epochs', type=int, default=100, 40 | help='epochs for training, invalid in this file but for get_module runing.') 41 | parser.add_argument('--save_model', type=bool, default=True, 42 | help='flag whether to save best model') 43 | parser.add_argument('--commit', type=str, default='MM-T', 44 | help='Commit for logs') 45 | args = parser.parse_args() 46 | 47 | seed_all(args.seed) 48 | 49 | data = get_data(args.database) 50 | data_test = get_dataset(args.database, data, 'test', args.seed) 51 | test_dataset = SingleModalX(data_test, args.database, mode=args.mode) 52 | 53 | test_loader = DataLoader( 54 | test_dataset, 55 | pin_memory=True, 56 | batch_size=args.batch_size 57 | ) 58 | 59 | # ===========GPU Setting==================== 60 | device = torch.device(f"cuda:{args.cuda_id}") 61 | # ==========Initialization=========== 62 | model_s, optimizer_s, scheduler_s, criterion_s, preprocessing, postprocessing, metric = get_Smodules(args) 63 | 64 | model_s = model_s.to(device) 65 | model_s.load_state_dict(torch.load(f'checkpoints/stu/{args.KD_method}/{args.ckpt_name}', map_location=device, weights_only=True)) 66 | # test 67 | model_s.eval() 68 | metric.reset() 69 | gt_list, pred_list = [], [] 70 | with torch.no_grad(): 71 | for i, (data, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 72 | total=len(test_loader), dynamic_ncols=True, disable=True, file=sys.stdout): 73 | data, label = data.to(device), label.to(device) 74 | outputs_s = model_s(data) if preprocessing is None else preprocessing(model_s, data) 75 | metric.update(outputs_s, label) 76 | res = metric.compute() 77 | 78 | print('\n===============Metrics==================') 79 | for e in res.keys(): 80 | print(e) 81 | print(res[e]) 82 | print('----------------------------') 83 | print('=======================================\n') 84 | 85 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_CrisisMMD_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.models as models 3 | import torchvision.transforms as transforms 4 | from PIL import Image 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from transformers import AutoModel, AutoTokenizer 9 | from tqdm import tqdm 10 | 11 | # 定义映射函数 12 | def map_category(value): 13 | """ 14 | - 输入整数 (0-7),返回类别名称 15 | - 输入类别名称,返回整数 (0-7) 16 | - 处理无效输入 17 | """ 18 | if isinstance(value, int): 19 | return index_to_category.get(value, "Unknown Category") # 防止索引越界 20 | elif isinstance(value, str): 21 | return category_to_index.get(value, -1) # 未知类别返回 -1 22 | else: 23 | raise ValueError("Input must be an integer (0-7) or a valid category name.") 24 | 25 | 26 | def text_feature_extractor(text): 27 | with torch.no_grad(): 28 | features = bertweet(text) # Models outputs are now tuples 29 | return features.pooler_output.squeeze(0).detach().cpu().numpy() 30 | 31 | 32 | def im_feature_extractor(image_path, device): 33 | image = Image.open(image_path).convert("RGB") # 确保图像是 RGB 模式 34 | image = transform(image).unsqueeze(0) 35 | with torch.no_grad(): 36 | features = resnet50(image.to(device)) # Models outputs are now tuples 37 | return features.squeeze(0).squeeze(-1).squeeze(-1).detach().cpu().numpy() 38 | 39 | 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | # ====================================== TEXT ===================================== 42 | categories = [ 43 | 'not_humanitarian', # 0 6179 44 | 'affected_individuals', # 1 521 45 | 'infrastructure_and_utility_damage', # 2 2150 46 | 'injured_or_dead_people', # 3 301 47 | 'missing_or_found_people', # 4 32 48 | 'rescue_volunteering_or_donation_effort', # 5 2594 49 | 'vehicle_damage', # 6 160 50 | 'other_relevant_information' # 7 4121 51 | ] 52 | 53 | # 创建索引到类别的映射 54 | index_to_category = {i: cat for i, cat in enumerate(categories)} 55 | category_to_index = {cat: i for i, cat in enumerate(categories)} 56 | # ================================= IMG =========================================== 57 | resnet50 = models.resnet50(pretrained=True).to(device) 58 | resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1]) # 移除最后的分类层 59 | resnet50.eval() # 设置为评估模式 60 | 61 | # 2. 定义图像预处理步骤 62 | transform = transforms.Compose([ 63 | transforms.Resize((224, 224)), # 调整到 ResNet 输入大小 64 | transforms.ToTensor(), # 转换为张量 65 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化 66 | ]) 67 | 68 | 69 | data_file = 'E:/Gray/Database/CrisisMMD_v2.0/all_data.csv' 70 | root = 'E:/Gray/Database/CrisisMMD_v2.0' 71 | textpath = 'E:/Gray/Database/CrisisMMD_v2.0/crisismmd_datasplit_all' 72 | 73 | data = pd.read_csv(data_file, sep='\t', encoding="utf-8") 74 | category = set(data['label']) 75 | # print(data) 76 | 77 | bertweet = AutoModel.from_pretrained("vinai/bertweet-base").to(device) 78 | tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base") 79 | 80 | # INPUT TWEET IS ALREADY NORMALIZED! 81 | line = "DHEC confirms HTTPURL via @USER :crying_face:" 82 | 83 | text_inputs = [] 84 | text_fs = [] 85 | im_paths = [] 86 | im_fs = [] 87 | labels = [] 88 | 89 | for im_path, text, label in tqdm(zip(data['image'], data['tweet_text'], data['label']), total=len(data), desc='Data Processing'): 90 | input_ids = torch.tensor([tokenizer.encode(text, max_length=128)], device=device) 91 | text_f = text_feature_extractor(input_ids) 92 | text_fs.append(text_f) 93 | 94 | im_path = os.path.join(root, im_path) 95 | im_f = im_feature_extractor(im_path, device) 96 | im_fs.append(im_f) 97 | 98 | labels.append(map_category(label)) 99 | 100 | print(len(text_fs)) 101 | text_fs = np.array(text_fs) 102 | im_fs = np.array(im_fs) 103 | labels = np.array(labels) 104 | res = (im_fs, text_fs, labels) 105 | torch.save(res, f"{root}/CMMD_data.pth") 106 | print('All Done!') 107 | exit() -------------------------------------------------------------------------------- /data_preprocess/preprocess_RAVDESS_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from tqdm import tqdm 5 | import torch 6 | from facenet_pytorch import MTCNN 7 | 8 | device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 9 | 10 | mtcnn = MTCNN(image_size=(720, 1280), device=device) 11 | 12 | # mtcnn.to(device) 13 | save_frames = 15 14 | input_fps = 30 15 | 16 | save_length = 3.6 # seconds 17 | save_avi = False 18 | 19 | failed_videos = [] 20 | root = 'E:/Gray/Database/RAVDESS/Speech' 21 | tar_root = 'E:/Gray/Database/RAVDESS-preprocessed/' 22 | 23 | select_distributed = lambda m, n: [i * n // m + n // (2 * m) for i in range(m)] 24 | n_processed = 0 25 | for sess in tqdm(sorted(os.listdir(root)), total=len(os.listdir(root)), desc='Processing video data...'): 26 | for filename in os.listdir(os.path.join(root, sess)): 27 | 28 | if filename.endswith('.mp4') and filename[:2] == '01': 29 | 30 | cap = cv2.VideoCapture(os.path.join(root, sess, filename)) 31 | # calculate length in frames 32 | framen = 0 33 | while True: 34 | i, q = cap.read() 35 | if not i: 36 | break 37 | framen += 1 38 | cap = cv2.VideoCapture(os.path.join(root, sess, filename)) 39 | 40 | if save_length * input_fps > framen: 41 | skip_begin = int((framen - (save_length * input_fps)) // 2) 42 | for i in range(skip_begin): 43 | _, im = cap.read() 44 | 45 | framen = int(save_length * input_fps) 46 | frames_to_select = select_distributed(save_frames, framen) 47 | save_fps = save_frames // (framen // input_fps) 48 | if save_avi: 49 | out = cv2.VideoWriter(os.path.join(tar_root, sess, filename[:-4] + '_facecroppad.avi'), 50 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), save_fps, (224, 224)) 51 | 52 | numpy_video = [] 53 | success = 0 54 | frame_ctr = 0 55 | 56 | while True: 57 | ret, im = cap.read() 58 | if not ret: 59 | break 60 | if frame_ctr not in frames_to_select: 61 | frame_ctr += 1 62 | continue 63 | else: 64 | frames_to_select.remove(frame_ctr) 65 | frame_ctr += 1 66 | 67 | try: 68 | gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) 69 | except: 70 | failed_videos.append((sess, i)) 71 | break 72 | 73 | temp = im[:, :, -1] 74 | im_rgb = im.copy() 75 | im_rgb[:, :, -1] = im_rgb[:, :, 0] 76 | im_rgb[:, :, 0] = temp 77 | im_rgb = torch.tensor(im_rgb) 78 | im_rgb = im_rgb.to(device) 79 | 80 | bbox = mtcnn.detect(im_rgb) 81 | if bbox[0] is not None: 82 | bbox = bbox[0][0] 83 | bbox = [round(x) for x in bbox] 84 | x1, y1, x2, y2 = bbox 85 | im = im[y1:y2, x1:x2, :] 86 | im = cv2.resize(im, (224, 224)) 87 | if save_avi: 88 | out.write(im) 89 | numpy_video.append(im) 90 | if len(frames_to_select) > 0: 91 | for i in range(len(frames_to_select)): 92 | if save_avi: 93 | out.write(np.zeros((224, 224, 3), dtype=np.uint8)) 94 | numpy_video.append(np.zeros((224, 224, 3), dtype=np.uint8)) 95 | if save_avi: 96 | out.release() 97 | os.makedirs(os.path.join(tar_root, sess), exist_ok=True) 98 | np.save(os.path.join(tar_root, sess, filename[:-4] + '_facecroppad.npy'), np.array(numpy_video)) 99 | if len(numpy_video) != 15: 100 | print('Error', sess, filename) 101 | 102 | n_processed += 1 103 | with open('processed.txt', 'a') as f: 104 | f.write(sess + '\n') 105 | print(failed_videos) 106 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_VGGsound50K_0.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | 4 | filepath = 'E:/Gray/Database/VGGSound' 5 | df = pd.read_csv('E:/Gray/Database/VGGSound-AVEL50k/vggsound-avel50k.csv') 6 | 7 | VGGS50K_names = df['video_id'] 8 | VGGS50K_clss = df['category'] 9 | 10 | # 提取 'category' 列中的唯一元素 11 | unique_categories = df['category'].unique() 12 | # 将唯一元素转换为字典,键为元素,值为从0开始的类别序号 13 | category_to_index = {category: idx for idx, category in enumerate(unique_categories)} 14 | vid_to_cls = {VGGS50K_names[i]: VGGS50K_clss[i] for i in range(len(VGGS50K_names))} 15 | 16 | file_list = [] 17 | vid_to_path = {} 18 | for root, dirs, files in os.walk(filepath): 19 | for file in files: 20 | vid = file[1:12] 21 | file_list.append(vid) 22 | if not vid in vid_to_path.keys(): 23 | vid_to_path.update({f'{vid}': f'{os.path.join(root, file)}'}) 24 | 25 | error = 0 26 | with open('D:/Gray/Database/VGGS50K/VGGS50K_videos.txt', 'w') as file: 27 | for name in VGGS50K_names: 28 | if name in vid_to_path.keys(): 29 | path = vid_to_path[name].replace('\\', '/') 30 | label = category_to_index[vid_to_cls[name]] 31 | file.write(f'{path}&{label}\n') 32 | else: 33 | error+=1 34 | print(f'Error Video Number: {error}') 35 | 36 | """ 37 | {'engine accelerating': 0, 'playing trumpet': 1, 'race car': 2, 'orchestra': 3, 'lighting firecrackers': 4, 'playing violin': 5, 'playing erhu': 6, 'playing bass guitar': 7, 'playing snare drum': 8, 'cat purring': 9, 'playing harp': 10, 'people sniggering': 11, 'child singing': 12, 'goose honking': 13, 'ice cream truck': 14, 'playing bagpipes': 15, 'electric shaver': 16, 'people booing': 17, 'driving buses': 18, 'train horning': 19, 'police car (siren)': 20, 'wind noise': 21, 'playing clarinet': 22, 'people burping': 23, 'vehicle horn': 24, 'playing cymbal': 25, 'singing bowl': 26, 'playing badminton': 27, 'stream burbling': 28, 'cap gun shooting': 29, 'male singing': 30, 'vacuum cleaner cleaning floors': 31, 'rope skipping': 32, 'arc welding': 33, 'scuba diving': 34, 'playing bassoon': 35, 'people clapping': 36, 'playing harpsichord': 37, 'beat boxing': 38, 'playing double bass': 39, 'railroad car': 40, 'playing cello': 41, 'basketball bounce': 42, 'playing tabla': 43, 'civil defense siren': 44, 'pheasant crowing': 45, 'playing accordion': 46, 'gibbon howling': 47, 'playing drum kit': 48, 'people marching': 49, 'rowboat': 50, 'tractor digging': 51, 'dog barking': 52, 'toilet flushing': 53, 'cricket chirping': 54, 'playing french horn': 55, 'playing acoustic guitar': 56, 'playing banjo': 57, 'playing volleyball': 58, 'car engine knocking': 59, 'female singing': 60, 'playing mandolin': 61, 'bird chirping': 62, 'dog howling': 63, 'playing squash': 64, 'mynah bird singing': 65, 'machine gun shooting': 66, 'airplane flyby': 67, 'child speech': 68, 'missile launch': 69, 'fireworks banging': 70, 'ambulance siren': 71, 'playing marimba': 72, 'fire truck siren': 73, 'playing cornet': 74, 'pigeon': 75, 'skateboarding': 76, 'chainsawing trees': 77, 'people screaming': 78, 'people crowd': 79, 'skidding': 80, 'playing saxophone': 81, 'playing didgeridoo': 82, 'playing vibraphone': 83, 'playing bongo': 84, 'motorboat': 85, 'subway': 86, 'bowling impact': 87, 'playing piano': 88, 'dog growling': 89, 'lions roaring': 90, 'planing timber': 91, 'skiing': 92, 'lawn mowing': 93, 'playing electric guitar': 94, 'playing sitar': 95, 'lathe spinning': 96, 'playing bass drum': 97, 'typing on typewriter': 98, 'driving motorcycle': 99, 'sharpen knife': 100, 'people cheering': 101, 'ocean burbling': 102, 'church bell ringing': 103, 'singing choir': 104, 'playing electronic organ': 105, 'horse clip-clop': 106, 'people whistling': 107, 'playing glockenspiel': 108, 'people whispering': 109, 'male speech': 110, 'owl hooting': 111, 'frog croaking': 112, 'female speech': 113, 'playing tambourine': 114, 'playing table tennis': 115, 'printer printing': 116, 'roller coaster running': 117, 'crow cawing': 118, 'police radio chatter': 119, 'turkey gobbling': 120, 'tap dancing': 121, 'playing synthesizer': 122, 'helicopter': 123, 'playing hammond organ': 124, 'chicken crowing': 125, 'cattle': 126, 'playing steel guitar': 127, 'woodpecker pecking tree': 128, 'cattle mooing': 129, 'playing trombone': 130, 'playing flute': 131, 'playing ukulele': 132, 'volcano explosion': 133, 'canary calling': 134, 'baby laughter': 135, 'playing harmonica': 136, 'slot machine': 137, 'playing theremin': 138, 'yodelling': 139, 'tapping guitar': 140} 38 | """ 39 | -------------------------------------------------------------------------------- /data_preprocess/preprocess_CrisisMMD_0.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from emoji import demojize 4 | import ftfy 5 | import re 6 | from nltk.tokenize import TweetTokenizer 7 | 8 | 9 | tokenizer = TweetTokenizer() 10 | 11 | 12 | def normalizeToken(token): 13 | lowercased_token = token.lower() 14 | if token.startswith("@"): # 用户名替换 15 | return "@USER" 16 | elif lowercased_token.startswith("https://t.co/"): # 删除特定 URL 17 | return "" 18 | elif lowercased_token.startswith("http") or lowercased_token.startswith("www"): # 其他 URL 替换 19 | return "HTTPURL" 20 | elif len(token) == 1: # 单字符处理(可能是表情符号) 21 | return demojize(token) 22 | else: # 特殊字符替换 23 | token = token.replace("\\'", "'") # 处理反斜杠引号 24 | if token == "’": 25 | return "'" 26 | elif token == "…": 27 | return "..." 28 | else: 29 | return token 30 | 31 | 32 | def normalizeTweet(tweet): 33 | tweet = tweet.replace("\\'", "'") # 处理 Windows 下的转义问题 34 | tokens = tokenizer.tokenize(tweet.replace("’", "'").replace("…", "...")) 35 | normTweet = " ".join([normalizeToken(token) for token in tokens]) 36 | # 处理 \' 为 ' 37 | normTweet = re.sub(r"\\'", "'", normTweet) 38 | normTweet = ( 39 | normTweet.replace("cannot ", "can not ") 40 | .replace("n't ", " n't ") 41 | .replace("n 't ", " n't ") 42 | .replace("ca n't", "can't") 43 | .replace("ai n't", "ain't") 44 | ) 45 | normTweet = ( 46 | normTweet.replace("'m ", " 'm ") 47 | .replace(r"'re ", " 're ") 48 | .replace(r"'s ", " 's ") 49 | .replace(r"'ll ", " 'll ") 50 | .replace(r"'d ", " 'd ") 51 | .replace(r"'ve ", " 've ") 52 | ) 53 | normTweet = ( 54 | normTweet.replace(" p . m .", " p.m.") 55 | .replace(" p . m ", " p.m ") 56 | .replace(" a . m .", " a.m.") 57 | .replace(" a . m ", " a.m ") 58 | ) 59 | 60 | return " ".join(normTweet.split()) 61 | 62 | 63 | def safe_unicode_decode(text): 64 | """ 仅当字符串包含 Unicode 转义字符时才进行 decode 处理 """ 65 | if isinstance(text, str): 66 | if '\\u' in text or '\\U' in text: # 仅处理包含 \u 或 \U 的文本 67 | try: 68 | return text.encode('utf-8').decode('unicode-escape') 69 | except UnicodeDecodeError: 70 | return text # 解析失败,返回原始文本 71 | else: 72 | return text # 直接返回 73 | return text # 非字符串类型,保持不变 74 | 75 | 76 | def fix_unicode_escapes_1(text): 77 | # 先将字符串转换回 Unicode escape 格式 78 | unicode_escaped_text = text.encode('unicode-escape').decode('utf-8') 79 | # 替换 \uXXXX 形式为 \U000XXXXX 以确保兼容性 80 | return re.sub(r'\\u([0-9a-fA-F]{4,6})', lambda m: '\\U' + m.group(1).zfill(8), unicode_escaped_text) 81 | 82 | 83 | root_ = 'E:/Gray/Database/CrisisMMD_v2.0' 84 | impath = 'E:/Gray/Database/CrisisMMD_v2.0/data_image' 85 | textpath = 'E:/Gray/Database/CrisisMMD_v2.0/crisismmd_datasplit_all' 86 | 87 | ims = [] 88 | for root, dirs, files in os.walk(impath): 89 | for file in files: 90 | if file.endswith('.jpg'): 91 | ims.append(file) 92 | 93 | ano_files = [] 94 | for root, dirs, files in os.walk(textpath): 95 | for file in files: 96 | if file.startswith('task_humanitarian'): 97 | ano_files.append(file) 98 | 99 | # print(len(ano_files)) 100 | 101 | datas = [] 102 | for file in ano_files: 103 | data = pd.read_csv(os.path.join(textpath, file), sep='\t', encoding="utf-8") 104 | datas.append(data) 105 | 106 | ano_data = pd.concat(datas, ignore_index=True) 107 | ano_data["tweet_text"] = ano_data["tweet_text"].apply(lambda x: ftfy.fix_text(x)) 108 | ano_data["tweet_text"] = ano_data["tweet_text"].apply(lambda x: fix_unicode_escapes_1(x)) 109 | ano_data["tweet_text"] = ano_data["tweet_text"].apply(safe_unicode_decode) 110 | ano_data["tweet_text"] = ano_data["tweet_text"].apply(demojize) 111 | ano_data["tweet_text"] = ano_data["tweet_text"].apply(normalizeTweet) 112 | 113 | # print(ano_data.loc[58, 'tweet_text']) 114 | 115 | 116 | # 步骤1:提取指定列(包含tweet_id用于去重) 117 | selected_columns = ["tweet_id", "tweet_text", "image", "label"] 118 | df_selected = ano_data[selected_columns] 119 | 120 | # 步骤2:按tweet_id去重(保留第一条记录) 121 | df_unique = df_selected.drop_duplicates(subset=["tweet_id"], keep="first") 122 | 123 | # 步骤3:移除临时使用的tweet_id列(若需要保留则跳过这步) 124 | df_final = df_unique.drop(columns=["tweet_id"]) 125 | 126 | # 步骤4:保存结果到新CSV(保留列名) 127 | df_final.to_csv(f"{root_}/all_data.csv", index=False, sep='\t') 128 | # print(df_final.loc[119, 'tweet_text']) 129 | 130 | -------------------------------------------------------------------------------- /Dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from utils import get_data, get_dataset, data_preprocessing 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | 7 | 8 | class MultiModalX(torch.utils.data.Dataset): 9 | """ Generic class for a MultiModal Data """ 10 | 11 | def __init__(self, data, data_name, mode=None, for_contrast=False, k=4096): 12 | """ 13 | Args: 14 | data_name: [str] database name) 15 | """ 16 | super(MultiModalX, self).__init__() 17 | self.dataset = data_name 18 | self.data = data[0] 19 | self.data1 = data[1] 20 | self.labels = data[2] 21 | self.mode = mode 22 | self.for_contrast = for_contrast 23 | self.k = k 24 | 25 | if self.for_contrast: 26 | n = len(self.data) 27 | num_cls = get_ClsNum(data_name) 28 | self.cls_positive = [[] for _ in range(num_cls)] 29 | for i in range(n): 30 | if self.dataset == 'VGGSound-50k': 31 | self.cls_positive[self.labels[i][0, -1].long()].append(i) 32 | else: 33 | self.cls_positive[self.labels[i]].append(i) 34 | self.cls_negative = [[] for _ in range(num_cls)] 35 | for i in range(num_cls): 36 | for j in range(num_cls): 37 | if j == i: 38 | continue 39 | self.cls_negative[i].extend(self.cls_positive[j]) 40 | self.cls_positive = np.asarray([np.asarray(self.cls_positive[i]) for i in range(num_cls)], dtype=object) 41 | self.cls_negative = np.asarray([np.asarray(self.cls_negative[i]) for i in range(num_cls)], dtype=object) 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, idx): 47 | # Load the total data from memory 48 | m1_data, m2_data, gt = self.data[idx], self.data1[idx], self.labels[idx] 49 | # ================================================================================ 50 | m1_data, m2_data = data_preprocessing(m1_data, m2_data, self.mode) 51 | if not self.for_contrast: 52 | return m1_data, m2_data, gt 53 | else: 54 | pos_idx = idx 55 | cls_neg_idx = gt[0, -1].long() if self.dataset == 'VGGSound-50k' else gt.long() 56 | replace = True if self.k > len(self.cls_negative[cls_neg_idx]) else False 57 | neg_idx = np.random.choice(self.cls_negative[cls_neg_idx], self.k, replace=replace) 58 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 59 | return m1_data, m2_data, gt, idx, sample_idx 60 | 61 | 62 | class SingleModalX(torch.utils.data.Dataset): 63 | """ Generic class for a MultiModal Data """ 64 | 65 | def __init__(self, data, data_name, mode=None): 66 | """ 67 | Args: 68 | dataname: [str] database name) 69 | """ 70 | super(SingleModalX, self).__init__() 71 | self.dataset = data_name 72 | self.labels = data[2] 73 | self.mode = mode 74 | if mode in ['m1', 'm2']: 75 | # print(f'Data loading mode: {mode}') 76 | self.data = data[0] if mode == 'm1' else data[1] 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | def __getitem__(self, idx): 82 | # Load the sample from memory for VGGSound-50k==================================== 83 | # if self.dataset == 'VGGSound-50k_old': 84 | # fea = np.load(self.data[idx]).astype(np.float32) if self.mode == 'm1' else np.load(self.data[idx]) 85 | # label = self.labels[idx].astype(np.float32) 86 | # data, gt = torch.from_numpy(fea), torch.from_numpy(label) 87 | # else: 88 | # data, gt = self.data[idx], self.labels[idx] 89 | # ================================================================================ 90 | # Load the total data from memory 91 | data, gt = self.data[idx], self.labels[idx] 92 | return data, gt 93 | 94 | 95 | def get_ClsNum(data_name): 96 | if data_name == 'AV-MNIST': 97 | return 10 98 | elif data_name == 'NYU-Depth-V2': 99 | return 41 100 | elif data_name == 'RAVDESS': 101 | return 8 102 | elif data_name == 'VGGSound-50k': 103 | return 141 104 | elif data_name == 'CMMD-V2': 105 | return 8 106 | else: 107 | raise ValueError(f'Invalid data name: {data_name}') 108 | 109 | 110 | if __name__ == '__main__': 111 | dataname = 'NYU-Depth-V2' 112 | data = get_data(dataname) 113 | data_train = get_dataset(dataname, data, 'train', 0) 114 | train_dataset = MultiModalX(data_train, dataname, mode='none') 115 | train_loader = DataLoader( 116 | train_dataset, 117 | batch_size=1, 118 | pin_memory=True, 119 | # num_workers=16, 120 | shuffle=True 121 | ) 122 | 123 | t0 = time.time() 124 | for i, datas in enumerate(train_loader): 125 | t1 = time.time() 126 | t = t1 - t0 127 | t0 = t1 128 | print(f'Time cost:{t:.2f}') 129 | # print(datas[0].shape, datas[1].shape, datas[2].shape) 130 | -------------------------------------------------------------------------------- /models/CNNs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.LeNet5 import LeNet5_woClsHead 5 | from torchinfo import summary 6 | 7 | 8 | class ThreeLayerCNN_A(nn.Module): 9 | def __init__(self, cls_num=10): 10 | super(ThreeLayerCNN_A, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 13 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 14 | self.fc1 = nn.Linear(128 * 14 * 14, cls_num) 15 | self.conv3_maxpool = nn.Identity() 16 | self.conv3_maxpool_flatten = nn.Identity() 17 | self.hook_names = ['conv1', 'conv2', 'conv3', 'conv3_maxpool', 'conv3_maxpool_flatten', 'fc1'] 18 | 19 | def forward(self, x): 20 | x = F.relu(self.conv1(x)) 21 | x = F.max_pool2d(x, 2) 22 | x = F.relu(self.conv2(x)) 23 | x = F.max_pool2d(x, 2) 24 | x = F.relu(self.conv3(x)) 25 | x = F.max_pool2d(x, 2) 26 | conv3_maxpool = self.conv3_maxpool(x) 27 | x = conv3_maxpool.view(-1, 128 * 14 * 14) 28 | x = self.conv3_maxpool_flatten(x) 29 | x = self.fc1(x) 30 | return x 31 | 32 | 33 | class FiveLayerCNN_A(nn.Module): 34 | def __init__(self): 35 | super(FiveLayerCNN_A, self).__init__() 36 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 37 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 38 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 39 | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 40 | self.conv5 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 41 | self.fc = nn.Linear(512 * 3 * 3, 10) # Assuming input size is 112x112 42 | self.hook_names = ['conv1', 'conv3', 'conv5'] 43 | 44 | def forward(self, x): 45 | x = F.relu(self.conv1(x)) 46 | x = F.max_pool2d(x, 2) 47 | x = F.relu(self.conv2(x)) 48 | x = F.max_pool2d(x, 2) 49 | x = F.relu(self.conv3(x)) 50 | x = F.max_pool2d(x, 2) 51 | x = F.relu(self.conv4(x)) 52 | x = F.max_pool2d(x, 2) 53 | x = F.relu(self.conv5(x)) 54 | x = F.max_pool2d(x, 2) 55 | x = x.view(-1, 512 * 3 * 3) # Flatten the tensor 56 | x = self.fc(x) 57 | return x 58 | 59 | 60 | class FiveLayerCNN_A_woClsHead(nn.Module): 61 | def __init__(self): 62 | super(FiveLayerCNN_A_woClsHead, self).__init__() 63 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 64 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 65 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 66 | self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 67 | self.conv5 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 68 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 69 | self.hook_names = ['conv1', 'conv3', 'conv5'] 70 | 71 | def forward(self, x): 72 | x = F.relu(self.conv1(x)) 73 | x = F.max_pool2d(x, 2) 74 | x = F.relu(self.conv2(x)) 75 | x = F.max_pool2d(x, 2) 76 | x = F.relu(self.conv3(x)) 77 | x = F.max_pool2d(x, 2) 78 | x = F.relu(self.conv4(x)) 79 | x = F.max_pool2d(x, 2) 80 | x = F.relu(self.conv5(x)) 81 | x = self.avg_pool(x).view(-1, 512) 82 | return x 83 | 84 | 85 | class IntermediateFusionNet(nn.Module): 86 | def __init__(self): 87 | super(IntermediateFusionNet, self).__init__() 88 | self.AudioNet = FiveLayerCNN_A_woClsHead() 89 | self.ImageNet = LeNet5_woClsHead() 90 | self.fc1 = nn.Linear(596, 298) 91 | self.fc2 = nn.Linear(298, 10) 92 | self.hook_names = ['AudioNet.conv1', 'AudioNet.conv3', 'AudioNet.conv5', 93 | 'ImageNet.conv1', 'ImageNet.conv2','ImageNet.fc1', 94 | 'concat_identity', 'fc1', 'fc2'] 95 | self.concat_identity = nn.Identity() 96 | 97 | def forward(self, x1, x2): 98 | x1 = self.ImageNet(x1) 99 | x2 = self.AudioNet(x2) 100 | concat_feature = self.concat_identity(torch.cat((x1, x2), dim=1)) 101 | x = self.fc1(concat_feature) 102 | x = self.fc2(x) 103 | return x 104 | 105 | 106 | if __name__ == '__main__': 107 | from utils import seed_all 108 | seed_all(0) 109 | # 示例输入:28x28的灰度图 112x112的mel频谱图 110 | input_tensor1 = torch.randn(1, 1, 28, 28) 111 | input_tensor2 = torch.randn(1, 1, 112, 112) 112 | 113 | # Tea.-MM L 114 | # model = LateFusionNet() 115 | # summary(model, (input_tensor1.shape, input_tensor2.shape)) 116 | # res = model(input_tensor1, input_tensor2) # [bs, 10] 117 | # print(res) 118 | 119 | # Tea.-MM I 120 | # model = IntermediateFusionNet() 121 | # summary(model, (input_tensor1.shape, input_tensor2.shape)) 122 | # res = model(input_tensor1, input_tensor2) # [bs, 10] 123 | # print(res) 124 | 125 | # Stu.-V 126 | # model = LeNet5() 127 | # summary(model, input_tensor1.shape) 128 | # res = model(input_tensor1) # [bs, 10] 129 | 130 | # Stu.-A 131 | # model = ThreeLayerCNN_A() 132 | # summary(model, input_tensor2.shape) 133 | # res = model(input_tensor2) # [bs, 10] -------------------------------------------------------------------------------- /KD_methods/MGDFR.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def my_permute_new(x, index): 7 | if isinstance(x, list): 8 | y = deepcopy(x) 9 | perm_index = torch.randperm(x[0].shape[0]) 10 | for i in index: 11 | y[0][:, i] = x[0][perm_index, i] 12 | y[1][:, i] = x[1][perm_index, i] 13 | return y 14 | else: 15 | y = x.clone() 16 | perm_index = torch.randperm(x.shape[0]) 17 | for i in index: 18 | y[:, i] = x[perm_index, i] 19 | return y 20 | 21 | 22 | def my_freeze_new(x, index): # in-place modification 23 | if isinstance(x, list): 24 | y = deepcopy(x) 25 | # 计算每个列表元素在指定索引上的均值 26 | tmp_mean_0 = x[0][:, index].mean(dim=0) 27 | tmp_mean_1 = x[1][:, index].mean(dim=0) 28 | # 将列表中的每个元素在指定索引上修改为均值 29 | y[0][:, index] = tmp_mean_0 30 | y[1][:, index] = tmp_mean_1 31 | return y 32 | else: 33 | y = x.clone() 34 | # 计算指定索引的均值 35 | tmp_mean = x[:, index].mean(dim=0) 36 | # 将指定索引上的值设置为均值 37 | y[:, index] = tmp_mean 38 | return y 39 | 40 | 41 | def my_change(x, change_type, index): 42 | if change_type == 'permute': 43 | return my_permute_new(x, index) 44 | elif change_type == 'freeze': 45 | return my_freeze_new(x, index) 46 | else: 47 | raise ValueError("Undefined change_type") 48 | 49 | 50 | def hook_fn(name, features, fn, axis): 51 | def hook(module, input, output): 52 | if fn == 'zero': 53 | res = torch.zeros_like(output) 54 | elif fn == 'permute': 55 | res = my_change(output, 'permute', axis) 56 | else: 57 | res = output 58 | features[name] = res 59 | return res 60 | return hook 61 | 62 | 63 | def hooks_builder(model, hook_names, fn=None, axis=None): 64 | features = {} 65 | hooks = [] 66 | for name in hook_names: 67 | submodule = get_submodule(model, name) 68 | hook = submodule.register_forward_hook(hook_fn(name, features, fn, axis)) 69 | hooks.append(hook) 70 | return hooks, features 71 | 72 | 73 | def get_submodule(model, submodule_name): 74 | """递归获取子模块""" 75 | names = submodule_name.split('.') 76 | submodule = model 77 | for name in names: 78 | submodule = submodule._modules[name] 79 | return submodule 80 | 81 | 82 | def hooks_remover(hooks): 83 | for hook in hooks: 84 | hook.remove() 85 | 86 | 87 | def get_MGDFRmodules(args): 88 | feat_name, feat_dim = None, None 89 | if args.database == 'AV-MNIST': 90 | if args.Tmodel == 'CNN-I': 91 | feat_name = ['fc1'] 92 | feat_dim = 298 # [bs, 20] 93 | elif args.Tmodel == 'LeNet5': 94 | feat_name = ['fc1'] 95 | feat_dim = 84 # [bs, 84] 96 | elif args.Tmodel == 'ThreeLayerCNN-A': 97 | feat_name = ['conv3'] 98 | feat_dim = 128 # [bs, 128, 28, 28] 99 | 100 | elif args.database == 'NYU-Depth-V2': 101 | if args.Tmodel == 'FuseNet-I': 102 | feat_name = ['after_fusion_identity'] 103 | feat_dim = 512 # [bs, 512, 15, 20] 104 | elif args.Tmodel == 'FuseNet-RGBbranch': 105 | feat_name = ['CBR5_RGB_ENC'] # [bs, 512, 30, 40] 106 | feat_dim = 512 107 | elif args.Tmodel == 'FuseNet-Dbranch': 108 | feat_name = ['CBR5_DEPTH_ENC'] # [bs, 512, 30, 40] 109 | feat_dim = 512 110 | 111 | 112 | elif args.database == 'RAVDESS': 113 | if args.Tmodel == 'DSCNN-I': 114 | feat_name = ['fc2'] 115 | feat_dim = 160 # [bs, 160] 116 | elif args.Tmodel in ['VisualBranchNet', 'AudioBranchNet']: 117 | feat_name = ['fc2'] 118 | feat_dim = 160 # [bs, 160] 119 | 120 | elif args.database == 'VGGSound-50k': 121 | if args.Tmodel == 'DSCNN-VGGS-I': 122 | feat_name = ['fc2'] 123 | feat_dim = 160 # [bs, 160] 124 | elif args.Tmodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 125 | feat_name = ['fc2'] 126 | feat_dim = 160 # [bs, 160] 127 | 128 | elif args.database == 'CMMD-V2': 129 | if args.Tmodel == 'MLP-I': 130 | feat_name = ['fc1'] 131 | feat_dim = 512 # [bs, 20] 132 | elif args.Tmodel in ['MLP-Vb', 'MLP-Tb']: 133 | feat_name = ['fc2'] 134 | feat_dim = 512 # [bs, 84] 135 | 136 | else: 137 | raise ValueError("Undefined database") 138 | criterion = DistLoss(args) 139 | salience_vector = torch.zeros(args.repeat_permute, feat_dim) 140 | return feat_name, feat_dim, salience_vector, criterion 141 | 142 | 143 | def get_feat(database, outputs): 144 | if database == 'NYU-Depth-V2': 145 | res = outputs[0][-1] 146 | elif database == 'VGGSound-50k': 147 | res = outputs[1] if isinstance(outputs, tuple) else outputs 148 | else: 149 | res = outputs 150 | return res 151 | 152 | 153 | 154 | class DistLoss(nn.Module): 155 | def __init__(self, args): 156 | super(DistLoss, self).__init__() 157 | self.database_name = args.database 158 | self.loss = nn.MSELoss() 159 | 160 | def forward(self, out_t, out_s): 161 | student_logits = get_feat(self.database_name, out_s) 162 | teacher_logits = get_feat(self.database_name, out_t) 163 | return self.loss(teacher_logits, student_logits) 164 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codz] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | #poetry.toml 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 114 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 115 | #pdm.lock 116 | #pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # pixi 121 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 122 | #pixi.lock 123 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 124 | # in the .venv directory. It is recommended not to include this directory in version control. 125 | .pixi 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .envrc 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | .idea/ 177 | 178 | # Abstra 179 | # Abstra is an AI-powered process automation framework. 180 | # Ignore directories containing user credentials, local state, and settings. 181 | # Learn more at https://abstra.io/docs 182 | .abstra/ 183 | 184 | # Visual Studio Code 185 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 186 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 187 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 188 | # you could uncomment the following to ignore the entire vscode folder 189 | # .vscode/ 190 | 191 | # Ruff stuff: 192 | .ruff_cache/ 193 | 194 | # PyPI configuration file 195 | .pypirc 196 | 197 | # Cursor 198 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 199 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 200 | # refer to https://docs.cursor.com/context/ignore-files 201 | .cursorignore 202 | .cursorindexingignore 203 | 204 | # Marimo 205 | marimo/_static/ 206 | marimo/_lsp/ 207 | __marimo__/ 208 | 209 | # Local configuration files 210 | /checkpoints/ 211 | /logs/ 212 | /metadata/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MST-Distill 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2507.07015-b31b1b.svg)](https://arxiv.org/abs/2507.07015) [![ACM MM](https://img.shields.io/badge/ACM%20MM-2025-007ACC.svg?style=flat-square)](https://doi.org/10.1145/3746027.3755276) 4 | 5 | **Paper**: MST-Distill: Mixture of Specialized Teachers for Cross-Modal Knowledge Distillation 6 | 7 | **Authors**: Hui Li, Pengfei Yang, Juanyang Chen, Le Dong, Yanxin Chen, Quan Wang 8 | 9 | **Published in**: ACM Multimedia 2025 10 | 11 | --- 12 | 13 | ## 🎯 Overview 14 | 15 | MST-Distill addresses the key challenges of path selection and knowledge drift in cross-modal knowledge distillation by constructing diverse teacher model ensembles, implementing instance-level dynamic distillation through routing networks, and employing feature masking networks to suppress modality discrepancies, significantly improving knowledge transfer quality across different modalities. 16 | 17 | ![Overall model architecture of ADCMT.](ims/Figs_framework.jpg) 18 | 19 | --- 20 | 21 | ## 📋 Requirements 22 | 23 | **Main Dependencies:** 24 | 25 | - Python >= 3.9 26 | - PyTorch >= 2.1 27 | 28 | All experiments are conducted on a server equipped with an Intel Xeon Gold 6248R CPU and an NVIDIA A100 GPU. 29 | 30 | Noting that the results may be still not the same among different implement devices. See [randomness@Pytorch Docs](https://pytorch.org/docs/stable/notes/randomness.html). 31 | 32 | --- 33 | 34 | ## 🚀 Quick Start 35 | 36 | ### 1. Dataset Preparation 37 | 38 | Download and prepare the datasets: 39 | 40 | - **AV-MNIST**: Image-audio digit classification 41 | - **RAVDESS**: Visual-audio emotion recognition 42 | - **VGGSound-50k**: Visual-audio scene classification 43 | - **CrisisMMD-V2**: Image-text humanitarian classification 44 | - **NYU-Depth-V2**: RGB-depth semantic segmentation 45 | 46 | Then, you can generate index meta files for data partitioning by running the `indices_gen.py` file, or download the meta files consistent with ours from [here](https://drive.google.com/drive/folders/11p7GQ9iazVogsImgPvsJjTWNXTCHYCD3?usp=sharing). 47 | 48 | We also provide some preprocessed data for download ([Google Drive](https://drive.google.com/drive/folders/11p7GQ9iazVogsImgPvsJjTWNXTCHYCD3?usp=sharing) | [Hugging Face](https://huggingface.co/Gray1y/datasets)). Alternatively, you can download the original datasets from their respective papers and process them using the code in the `data_preprocess` directory. 49 | 50 | ### 2. Cross-modal Knowledge Distillation 51 | 52 | Run our method: 53 | 54 | ```python 55 | # Example for RAVDESS dataset (target modality: visual) 56 | python main-MST-Distill.py --database RAVDESS --batch_size 32 --mode m1 --Tmodel 'DSCNN-I' --Smodel 'VisualBranchNet' --AUXmodel 'AudioBranchNet' 57 | ``` 58 | 59 | Run other method: 60 | If you want to run other CMKD methods, you might need to obtain the pre-trained teacher models first. 61 | 62 | 1. Teacher model training: 63 | 64 | ```python 65 | # MM Teacher 66 | python main-T.py --database RAVDESS --batch_size 32 --mode m1 --Tmodel 'DSCNN-I' 67 | 68 | # CM Teacher 69 | python main-S.py --database RAVDESS --batch_size 32 --mode m2 --Smodel 'AudioBranchNet' 70 | ``` 71 | 72 | 2. Run other CMKD method: 73 | 74 | ```python 75 | # Example 1: KD (MM->m1) 76 | python main-KD.py --database RAVDESS --batch_size 32 --mode m1 --Tmodel 'DSCNN-I' --Smodel 'VisualBranchNet' --ckpt_name 'DSCNN-I_weights_file_path' 77 | 78 | # Example 2: KD (m2->m1) 79 | python main-KD-UU.py --database RAVDESS --batch_size 32 --mode m1 --Tmodel 'AudioBranchNet' --Smodel 'VisualBranchNet' --ckpt_name 'AudioBranchNet_weights_file_path' 80 | ``` 81 | 82 | ### 3. Model Test 83 | 84 | Run `test-T.py` or `test-S.py` to test multimodal and unimodal models respectively. The parameter settings follow the same pattern as described above. 85 | 86 | --- 87 | 88 | ## ⚙️ Configuration 89 | 90 | ### Dataset-Specific Parameters 91 | 92 | | Dataset | Batch Size | Learning Rate | Modality (m1-m2) | Model Name (mm, m1, m2) | 93 | | ------------ | :--------: | :-----------: | :--------------: | ----------------- | 94 | | AV-MNIST | 512 | 1e-4 | Image-Audio | CNN-I, LeNet5, ThreeLayerCNN-A | 95 | | RAVDESS | 32 | 1e-4 | Visual-Audio | DSCNN-I, VisualBranchNet, AudioBranchNet | 96 | | VGGSound-50k | 512 | 1e-4 | Visual-Audio | DSCNN-VGGS-I, VisualBranchNet-VGGS, AudioBranchNet-VGGS | 97 | | CrisisMMD-V2 | 512 | 5e-3 | Image-Text | MLP-I, MLP-Vb, MLP-Tb | 98 | | NYU-Depth-V2 | 6 | 1e-4 | RGB-Depth | FuseNet-I, FuseNet-RGBbranch, FuseNet-Dbranch | 99 | 100 | You can use the `get_Tmodules` and `get_Smodules` functions in `utils.py` to find the network architectures of the corresponding models located in the `models` directory. 101 | 102 | ### Important Notes 103 | 104 | - **Gradient Accumulation**: You can implement gradient accumulation to maintain consistent effective batch sizes when hardware limitations prevent using the recommended batch sizes. 105 | - **Baseline Methods**: Some comparison methods may require different learning rates according to their original papers. 106 | - **Hyperparameter Tuning**: Since our method already achieves good performance with default settings, we did not further optimize the teacher feature layer selection or MaskNet hyperparameters. You can adjust these as needed for your hardware constraints or performance requirements. 107 | 108 | --- 109 | 110 | ## 📄 Citation 111 | 112 | If you find this work helpful for your research, please consider citing our paper: 113 | 114 | ```bibtex 115 | @inproceedings{10.1145/3746027.3755276, 116 | author = {Li, Hui and Yang, Pengfei and Chen, Juanyang and Dong, Le and Chen, Yanxin and Wang, Quan}, 117 | title = {MST-Distill: Mixture of Specialized Teachers for Cross-Modal Knowledge Distillation}, 118 | year = {2025}, 119 | publisher = {Association for Computing Machinery}, 120 | address = {New York, NY, USA}, 121 | doi = {10.1145/3746027.3755276}, 122 | booktitle = {Proceedings of the 33rd ACM International Conference on Multimedia}, 123 | pages = {1588–1597}, 124 | numpages = {10}, 125 | location = {Dublin, Ireland}, 126 | series = {MM '25} 127 | } 128 | ``` 129 | 130 | We appreciate your interest in our work and welcome any feedback or contributions to improve this research! 🙏 131 | 132 | --- 133 | 134 | ## 📞 Contact 135 | 136 | For any questions or issues, please feel free to open an issue in this repository or reach out to us at [gray1y@stu.xidian.edu.cn](mailto:gray1y@stu.xidian.edu.cn). 137 | -------------------------------------------------------------------------------- /KD_methods/RKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchinfo import summary 5 | 6 | 7 | def get_RKDmodules(args, model_s, n_data): 8 | if args.database == 'AV-MNIST': 9 | if args.Tmodel == 'CNN-I': 10 | t_dim = 298 11 | if args.Smodel == 'LeNet5': 12 | feat_names = ['fc1', 'fc1'] 13 | s_dim = 84 14 | elif args.Smodel == 'ThreeLayerCNN-A': 15 | feat_names = ['fc1', 'conv3'] 16 | s_dim = 128 * 14 * 14 17 | elif args.Tmodel == 'LeNet5': 18 | t_dim = 84 19 | if args.Smodel == 'ThreeLayerCNN-A': 20 | feat_names = ['fc1', 'conv3'] 21 | s_dim = 128 * 14 * 14 22 | elif args.Tmodel == 'ThreeLayerCNN-A': 23 | t_dim = 128 * 14 * 14 24 | if args.Smodel == 'LeNet5': 25 | feat_names = ['conv3', 'fc1'] 26 | s_dim = 84 27 | 28 | elif args.database == 'RAVDESS': 29 | if args.Tmodel == 'DSCNN-I': 30 | t_dim = 160 31 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 32 | feat_names = ['fc2', 'fc2'] 33 | s_dim = 160 34 | elif args.Tmodel in ['AudioBranchNet', 'VisualBranchNet']: 35 | t_dim = 160 36 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 37 | feat_names = ['fc2', 'fc2'] 38 | s_dim = 160 39 | 40 | elif args.database == 'VGGSound-50k': 41 | if args.Tmodel == 'DSCNN-VGGS-I': 42 | t_dim = 160 43 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 44 | feat_names = ['fc2', 'fc2'] 45 | s_dim = 160 46 | elif args.Tmodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 47 | t_dim = 160 48 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 49 | feat_names = ['fc2', 'fc2'] 50 | s_dim = 160 51 | 52 | elif args.database == 'CMMD-V2': 53 | if args.Tmodel == 'MLP-I': 54 | t_dim = 256 55 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 56 | feat_names = ['fc2', 'fc2'] 57 | s_dim = 512 58 | elif args.Tmodel in ['MLP-Vb', 'MLP-Tb']: 59 | t_dim = 512 60 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 61 | feat_names = ['fc2', 'fc2'] 62 | s_dim = 512 63 | 64 | else: 65 | raise ValueError(f"Invalid database name {args.database}.") 66 | criterion = RKDLoss(s_dim, t_dim, 128).cuda(args.cuda_id) 67 | proj_s, proj_t = criterion.embed_s, criterion.embed_t 68 | proj_s.cuda(args.cuda_id) 69 | proj_t.cuda(args.cuda_id) 70 | params = list(model_s.parameters()) + list(proj_s.parameters()) + list(proj_t.parameters()) 71 | optim = torch.optim.Adam(params, lr=args.lr) 72 | return feat_names, criterion, optim 73 | 74 | 75 | class Normalize(nn.Module): 76 | """normalization layer""" 77 | def __init__(self, power=2): 78 | super(Normalize, self).__init__() 79 | self.power = power 80 | 81 | def forward(self, x): 82 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 83 | out = x.div(norm) 84 | return out 85 | 86 | 87 | class Embed(nn.Module): 88 | """Embedding module""" 89 | def __init__(self, dim_in=1024, dim_out=128): 90 | super(Embed, self).__init__() 91 | self.linear = nn.Linear(dim_in, dim_out) 92 | self.l2norm = Normalize(2) 93 | 94 | def forward(self, x): 95 | x = x.view(x.shape[0], -1) 96 | x = self.linear(x) 97 | x = self.l2norm(x) 98 | return x 99 | 100 | 101 | def pdist(e, squared=False, eps=1e-12): 102 | e_square = e.pow(2).sum(dim=1) 103 | prod = e @ e.t() 104 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 105 | 106 | if not squared: 107 | res = res.sqrt() 108 | 109 | res = res.clone() 110 | res[range(len(e)), range(len(e))] = 0 111 | return res 112 | 113 | 114 | class RKdAngle(nn.Module): 115 | def forward(self, student, teacher): 116 | # N x C 117 | # N x N x C 118 | 119 | with torch.no_grad(): 120 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 121 | norm_td = F.normalize(td, p=2, dim=2) 122 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 123 | 124 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 125 | norm_sd = F.normalize(sd, p=2, dim=2) 126 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 127 | 128 | loss = F.smooth_l1_loss(s_angle, t_angle, reduction='mean') 129 | return loss 130 | 131 | 132 | class RkdDistance(nn.Module): 133 | def forward(self, student, teacher): 134 | with torch.no_grad(): 135 | t_d = pdist(teacher, squared=False) 136 | mean_td = t_d[t_d>0].mean() 137 | t_d = t_d / mean_td 138 | 139 | d = pdist(student, squared=False) 140 | mean_d = d[d>0].mean() 141 | d = d / mean_d 142 | 143 | loss = F.smooth_l1_loss(d, t_d, reduction='mean') 144 | return loss 145 | 146 | 147 | class RKDLoss(nn.Module): 148 | def __init__(self, s_dim, t_dim, feat_dim): 149 | super(RKDLoss, self).__init__() 150 | self.embed_s = Embed(s_dim, feat_dim) 151 | self.embed_t = Embed(t_dim, feat_dim) 152 | self.dist_loss = RkdDistance() 153 | self.dist_angle = RKdAngle() 154 | 155 | def forward(self, f_s, f_t): 156 | f_s = self.embed_s(f_s) 157 | f_t = self.embed_t(f_t) 158 | loss_d = self.dist_loss(f_s, f_t) 159 | loss_a = self.dist_angle(f_s, f_t) 160 | loss = 1*loss_d + 2*loss_a 161 | return loss 162 | 163 | 164 | def penultimate_feature_extractor(feat_name, features, args): 165 | if args.database == 'AV-MNIST': 166 | penultimate_feature = features[feat_name] 167 | if penultimate_feature.dim() == 4: 168 | penultimate_feature = F.max_pool2d(penultimate_feature, 2).view(-1, 128 * 14 * 14) 169 | elif args.database == 'VGGSound-50k': 170 | if feat_name == 'psp': 171 | penultimate_feature = features[feat_name][0].mean(dim=1) 172 | else: 173 | penultimate_feature = features[feat_name] 174 | else: 175 | penultimate_feature = features[feat_name] 176 | return penultimate_feature 177 | 178 | if __name__ == '__main__': 179 | emb1 = Embed(84, 128) 180 | emb2 = Embed(128 * 14 * 14, 128) 181 | 182 | # summary(emb1, input_size=(1, 84)) 183 | summary(emb2, input_size=(1, 128 * 14 * 14)) 184 | -------------------------------------------------------------------------------- /models/SeqNets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchinfo import summary 5 | 6 | 7 | class LSTM(nn.Module): 8 | def __init__(self, cls_num, input_size): 9 | super().__init__() 10 | self.lstm = nn.LSTM(input_size=input_size, hidden_size=input_size//2, num_layers=2, batch_first=True) 11 | self.fc = nn.Linear(input_size//2, cls_num) 12 | 13 | def forward(self, x): 14 | lstm_out, (h_n, c_n) = self.lstm(x) 15 | # 使用最后一个时间步的输出进行分类 16 | output = self.fc(lstm_out[:, -1, :]) 17 | return output 18 | 19 | 20 | class BiLSTM(nn.Module): 21 | def __init__(self, cls_num, input_size): 22 | super().__init__() 23 | self.bilstm = nn.LSTM(input_size=input_size, hidden_size=input_size//2, num_layers=2, batch_first=True, bidirectional=True) 24 | self.fc = nn.Linear(input_size, cls_num) 25 | 26 | def forward(self, x): 27 | lstm_out, (h_n, c_n) = self.bilstm(x) 28 | # 使用最后一个时间步的输出进行分类 29 | output = self.fc(lstm_out[:, -1, :]) 30 | return output 31 | 32 | 33 | class GRU(nn.Module): 34 | def __init__(self, cls_num, input_size): 35 | super().__init__() 36 | self.num_layers = 1 37 | self.hidden_size = input_size//2 38 | self.gru = nn.GRU(input_size=input_size, hidden_size=self.hidden_size, 39 | num_layers=self.num_layers, bias=True, batch_first=True) 40 | self.fc = nn.Linear(input_size//2, cls_num) 41 | 42 | def forward(self, x): 43 | h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device) 44 | out, _ = self.gru(x, h0) 45 | out = out[:, -1, :] 46 | out = self.fc(out) 47 | return out 48 | 49 | 50 | class Video3DCNN(nn.Module): 51 | def __init__(self): 52 | super(Video3DCNN, self).__init__() 53 | # 3D卷积层 54 | self.conv1 = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(3, 3, 3), stride=1, padding=1) 55 | self.conv2 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=(3, 3, 3), stride=1, padding=1) 56 | self.conv3 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(3, 3, 3), stride=1, padding=1) 57 | 58 | # 最大池化层 59 | self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)) 60 | 61 | # 全局平均池化,代替全连接层 62 | self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 63 | 64 | # 最后的全连接层 65 | self.fc = nn.Linear(128, 10) 66 | 67 | def forward(self, x): 68 | # 输入x的维度为 (batch_size, channels, depth, height, width) 69 | x = self.pool(F.relu(self.conv1(x))) # (batch_size, 32, depth//2, height//2, width//2) 70 | x = self.pool(F.relu(self.conv2(x))) # (batch_size, 64, depth//4, height//4, width//4) 71 | x = self.pool(F.relu(self.conv3(x))) # (batch_size, 128, depth//8, height//8, width//8) 72 | 73 | # 全局平均池化 74 | x = self.global_pool(x) # (batch_size, 128, 1, 1, 1) 75 | x = x.view(x.size(0), -1) # 展平为 (batch_size, 128) 76 | 77 | # 全连接层 78 | x = self.fc(x) # 输出为 (batch_size, num_classes) 79 | return x 80 | 81 | 82 | class AudioBranch(nn.Module): 83 | def __init__(self): 84 | super(AudioBranch, self).__init__() 85 | self.conv1 = nn.Conv1d(in_channels=15, out_channels=32, kernel_size=3, padding=1) 86 | self.bn1 = nn.BatchNorm1d(32) 87 | self.conv2 = nn.Conv1d(in_channels=32, out_channels=128, kernel_size=3, padding=1) 88 | self.bn2 = nn.BatchNorm1d(128) 89 | self.conv3 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1) 90 | self.bn3 = nn.BatchNorm1d(256) 91 | self.pool = nn.MaxPool1d(kernel_size=3) 92 | self.hook_names = ['conv1', 'conv2', 'conv3'] 93 | 94 | def forward(self, x): 95 | x = self.pool(F.relu(self.bn1(self.conv1(x)))) 96 | x = self.pool(F.relu(self.bn2(self.conv2(x)))) 97 | x = self.pool(F.relu(self.bn3(self.conv3(x)))) 98 | x = x.view(x.size(0), -1) # Flatten 99 | return x 100 | 101 | 102 | class VisualBranch(nn.Module): 103 | def __init__(self): 104 | super(VisualBranch, self).__init__() 105 | # 三层卷积 106 | self.conv1 = nn.Conv3d(in_channels=3, out_channels=6, kernel_size=3, padding=1) 107 | self.bn1 = nn.BatchNorm3d(6) 108 | self.conv2 = nn.Conv3d(in_channels=6, out_channels=12, kernel_size=3, padding=1) 109 | self.bn2 = nn.BatchNorm3d(12) 110 | self.conv3 = nn.Conv3d(in_channels=12, out_channels=24, kernel_size=3, padding=1) 111 | self.bn3 = nn.BatchNorm3d(24) 112 | self.pool = nn.MaxPool3d(kernel_size=2) # 使用池化层减小特征图尺寸 113 | # 全局平均池化,输出大小为 (batch_size, 320, 1, 1, 1) 114 | # self.global_max_pool = nn.AdaptiveMaxPool2d((1, 1)) 115 | self.hook_names = ['conv1', 'conv2', 'conv3'] 116 | 117 | def forward(self, x): 118 | x = self.pool(F.relu(self.bn1(self.conv1(x)))) # 经过第一层卷积和池化 119 | x = self.pool(F.relu(self.bn2(self.conv2(x)))) # 经过第二层卷积和池化 120 | x = self.pool(F.relu(self.bn3(self.conv3(x)))) # 经过第三层卷积和池化 121 | # x = self.global_max_pool(x.squeeze(2)) # 全局平均池化降维 122 | x = x.view(x.size(0), -1) # 展平,输出大小将是 (batch_size, 320) 123 | return x 124 | 125 | 126 | class DualStreamCNN(nn.Module): 127 | def __init__(self, cls_num): 128 | super(DualStreamCNN, self).__init__() 129 | self.audio_branch = AudioBranch() 130 | self.visual_branch = VisualBranch() 131 | 132 | # MLP 133 | self.fc1 = nn.Linear(20096, 320) 134 | self.fc2 = nn.Linear(320, 160) 135 | self.fc3 = nn.Linear(160, cls_num) 136 | self.hook_names = ['visual_branch', 'visual_branch.conv1', 'visual_branch.conv2', 'visual_branch.conv3', 137 | 'audio_branch', 'audio_branch.conv1', 'audio_branch.conv2', 'audio_branch.conv3', 138 | 'fc1', 'fc2', 'fc3'] 139 | 140 | def forward(self, x_visual, x_audio): 141 | audio_features = self.audio_branch(x_audio) # 1280 142 | visual_features = self.visual_branch(x_visual) # 18816 143 | 144 | # Concatenate audio and visual features 145 | combined_features = torch.cat((audio_features, visual_features), dim=1) 146 | 147 | # MLP for classification 148 | x = F.relu(self.fc1(combined_features)) 149 | x = F.relu(self.fc2(x)) 150 | x = self.fc3(x) 151 | return x 152 | 153 | 154 | class AudioBranchNet(nn.Module): 155 | def __init__(self, cls_num): 156 | super(AudioBranchNet, self).__init__() 157 | self.audio_branch = AudioBranch() 158 | 159 | # MLP 160 | self.fc1 = nn.Linear(1280, 320) 161 | self.fc2 = nn.Linear(320, 160) 162 | self.fc3 = nn.Linear(160, cls_num) 163 | self.hook_names = ['audio_branch', 'fc1', 'fc2', 'fc3'] 164 | 165 | def forward(self, x_audio): 166 | audio_features = self.audio_branch(x_audio) 167 | 168 | # MLP for classification 169 | x = F.relu(self.fc1(audio_features)) 170 | x = F.relu(self.fc2(x)) 171 | x = self.fc3(x) 172 | return x 173 | 174 | 175 | class VisualBranchNet(nn.Module): 176 | def __init__(self, cls_num): 177 | super(VisualBranchNet, self).__init__() 178 | self.visual_branch = VisualBranch() 179 | 180 | # MLP 181 | self.fc1 = nn.Linear(18816, 320) 182 | self.fc2 = nn.Linear(320, 160) 183 | self.fc3 = nn.Linear(160, cls_num) 184 | self.hook_names = ['visual_branch', 'fc1', 'fc2', 'fc3'] 185 | 186 | def forward(self, x_visual): 187 | visual_features = self.visual_branch(x_visual) 188 | 189 | # MLP for classification 190 | x = F.relu(self.fc1(visual_features)) 191 | x = F.relu(self.fc2(x)) 192 | x = self.fc3(x) 193 | return x 194 | 195 | 196 | if __name__ == '__main__': 197 | # 假设输入维度为 (batch_size, sequence_length, feature_dim) 198 | input_tensor_A = torch.randn(1, 15, 156) # batch size 32, sequence length 15, feature dim 128 199 | input_tensor_V = torch.randn(1, 3, 15, 224, 224) 200 | 201 | # Tea.-MM DualStream-I 202 | # model = DualStreamCNN(cls_num=8) 203 | # summary(model, (input_tensor_A.shape, input_tensor_V.shape)) 204 | # out = model(input_tensor_V, input_tensor_A) # [bs, 8] 205 | 206 | # Stu.-A AudioBranchNet 207 | # model = AudioBranchNet(cls_num=8) 208 | # summary(model, input_tensor_A.shape) 209 | # out = model(input_tensor_A) # [bs, 8] 210 | # print(out.shape) 211 | 212 | # Stu.-V VisualBranchNet 213 | # model = VisualBranchNet(cls_num=8) 214 | # summary(model, input_tensor_V.shape) 215 | # out = model(input_tensor_V) # [bs, 8] 216 | # print(out.shape) 217 | -------------------------------------------------------------------------------- /main-S.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from utils import get_data, get_dataset, get_Smodules, seed_all 5 | import copy 6 | import time 7 | import os 8 | from argparse import ArgumentParser 9 | from tensorboardX import SummaryWriter 10 | from tqdm import tqdm 11 | from Dataset import SingleModalX 12 | import sys 13 | 14 | 15 | if __name__ == '__main__': 16 | ''' 17 | Args Setting for CML. 18 | ''' 19 | parser = ArgumentParser(description='CML-S') 20 | parser.add_argument('--database', type=str, default='AV-MNIST', 21 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'CMMD-V2']") 22 | parser.add_argument('--Smodel', type=str, default='LeNet5', 23 | help='Student model name') 24 | parser.add_argument('--mode', type=str, default='m1', 25 | help='Data mode: m1 or m2') 26 | parser.add_argument('--batch_size', type=int, default=128, 27 | help='batch size for training') 28 | parser.add_argument('--seed', type=int, default=0, 29 | help='Random seed') 30 | parser.add_argument('--num_workers', type=int, default=0, 31 | help='num_workers for DataLoader') 32 | parser.add_argument('--lr', type=float, default=0.0001, 33 | help='learning rate for training') 34 | parser.add_argument('--record', type=bool, default=True, 35 | help='flag whether to record the learning log') 36 | parser.add_argument('--cuda_id', type=int, default=0, 37 | help='cuda id') 38 | parser.add_argument('--epochs', type=int, default=100, 39 | help='epochs for training, default: 100-200') 40 | parser.add_argument('--save_model', type=bool, default=True, 41 | help='flag whether to save best model') 42 | parser.add_argument('--test_phase', type=bool, default=False, 43 | help='flag whether to conduct the test phase') 44 | parser.add_argument('--final_test', type=bool, default=True, 45 | help='flag whether to conduct the test phase') 46 | parser.add_argument('--commit', type=str, default='Stu-A', 47 | help='Commit for logs') 48 | args = parser.parse_args() 49 | 50 | seed_all(args.seed) 51 | 52 | 53 | log_dir = f'./logs/stu' 54 | if not os.path.exists(log_dir): 55 | os.makedirs(log_dir) 56 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.Smodel}_{args.commit}' 57 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 58 | 59 | data = get_data(args.database) 60 | data_train = get_dataset(args.database, data, 'train', args.seed) 61 | data_val = get_dataset(args.database, data, 'val', args.seed) 62 | data_test = get_dataset(args.database, data, 'test', args.seed) 63 | 64 | train_dataset = SingleModalX(data_train, args.database, mode=args.mode) 65 | valid_dataset = SingleModalX(data_val, args.database, mode=args.mode) 66 | test_dataset = SingleModalX(data_test, args.database, mode=args.mode) 67 | 68 | train_loader = DataLoader( 69 | train_dataset, 70 | batch_size=args.batch_size, 71 | pin_memory=True, 72 | num_workers=args.num_workers, 73 | shuffle=True 74 | ) 75 | 76 | valid_loader = DataLoader( 77 | valid_dataset, 78 | pin_memory=True, 79 | batch_size=args.batch_size 80 | ) 81 | 82 | test_loader = DataLoader( 83 | test_dataset, 84 | pin_memory=True, 85 | batch_size=args.batch_size 86 | ) 87 | 88 | # ===========GPU Setting==================== 89 | device = torch.device(f"cuda:{args.cuda_id}") 90 | # ==========Initialization=========== 91 | model_s, optimizer_s, scheduler_s, criterion_s, preprocessing, postprocessing, metric = get_Smodules(args) 92 | 93 | model_s = model_s.to(device) 94 | best_model_state = None 95 | best_val_loss = float('inf') 96 | print('================= Student Model Independent Training =================') 97 | for epoch in range(args.epochs): 98 | start_time1 = time.time() 99 | # train 100 | model_s.train() 101 | tra_LOSS_s = 0 102 | for i, (data, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 103 | total=len(train_loader), dynamic_ncols=True, 104 | disable=False, file=sys.stdout): 105 | # print('Train Iter {}'.format(i)) 106 | data, label = data.to(device), label.to(device) 107 | # outputs_t = model_t(data2) 108 | outputs_s = model_s(data) if preprocessing is None else preprocessing(model_s, data) 109 | loss_s = criterion_s(outputs_s, label) if postprocessing is None else postprocessing(outputs_s, label) 110 | optimizer_s.zero_grad() 111 | loss_s.backward() 112 | optimizer_s.step() 113 | tra_LOSS_s += loss_s.item() 114 | 115 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 116 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss_c = {tra_LOSS_s_avg:.4f}') 117 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 118 | 119 | # validation 120 | model_s.eval() 121 | metric.reset() 122 | L_s_val = 0 123 | acc_c = 0 124 | gt_list, pred_list = [], [] 125 | with torch.no_grad(): 126 | for i, (data, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 127 | total=len(valid_loader), dynamic_ncols=True, disable=False, file=sys.stdout): 128 | # print('Val Iter {}'.format(i)) 129 | data, label = data.to(device), label.to(device) 130 | outputs_s = model_s(data) if preprocessing is None else preprocessing(model_s, data) 131 | loss_s = criterion_s(outputs_s, label) if postprocessing is None else postprocessing(outputs_s, label) 132 | L_s_val += loss_s.item() 133 | metric.update(outputs_s, label) 134 | L_s_val = L_s_val / (i + 1) 135 | # res = metric.compute() 136 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_c = {L_s_val:.4f}, OA_c = {res['Accuracy']:.2f}%") 137 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 138 | # writer.add_scalar('valid/Acc', res['Accuracy'], epoch) 139 | 140 | # For NYU-Depth-V2 141 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s_val:.4f}") 142 | writer.add_scalar('valid/Loss', L_s_val, epoch) 143 | 144 | if L_s_val < best_val_loss: 145 | best_val_loss = L_s_val 146 | best_model_state = copy.deepcopy(model_s.state_dict()) 147 | best_epoch = epoch 148 | 149 | # test 150 | if args.test_phase: 151 | model_s.eval() 152 | metric.reset() 153 | L_s = 0 154 | acc_c = 0 155 | gt_list, pred_list = [], [] 156 | with torch.no_grad(): 157 | for i, (data, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 158 | total=len(test_loader), dynamic_ncols=True, disable=False, file=sys.stdout): 159 | # print('Val Iter {}'.format(i)) 160 | data, label = data.to(device), label.to(device) 161 | outputs_s = model_s(data) if preprocessing is None else preprocessing(model_s, data) 162 | loss_s = criterion_s(outputs_s, label) if postprocessing is None else postprocessing(outputs_s, label) 163 | L_s = L_s + loss_s.item() 164 | metric.update(outputs_s, label) 165 | L_s = L_s / (i + 1) 166 | res = metric.compute() 167 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_c = {L_s:.4f}, OA_c = {res['Accuracy']}") 168 | writer.add_scalar('test/Loss', L_s, epoch) 169 | writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 170 | 171 | # For NYU-Depth-V2 172 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s:.4f}") 173 | # writer.add_scalar('test/Loss', L_s, epoch) 174 | 175 | if (epoch + 1) % 10 == 0: 176 | print('\n===============Metrics==================') 177 | for e in res.keys(): 178 | print(e) 179 | print(res[e]) 180 | print('----------------------------') 181 | print('=======================================\n') 182 | 183 | # scheduler_t.step() 184 | 185 | start_time2 = time.time() 186 | time_cost = start_time2 - start_time1 187 | if time_cost > 100: 188 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 189 | else: 190 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 191 | 192 | writer.close() 193 | 194 | if args.final_test: 195 | model_s.load_state_dict(best_model_state) 196 | print('================= Final Test for This Model =================') 197 | # test 198 | model_s.eval() 199 | metric.reset() 200 | gt_list, pred_list = [], [] 201 | with torch.no_grad(): 202 | for i, (data, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 203 | total=len(test_loader), dynamic_ncols=True, disable=True, file=sys.stdout): 204 | data, label = data.to(device), label.to(device) 205 | outputs_s = model_s(data) if preprocessing is None else preprocessing(model_s, data) 206 | metric.update(outputs_s, label) 207 | res = metric.compute() 208 | 209 | print('\n===============Metrics==================') 210 | for e in res.keys(): 211 | print(e) 212 | print(res[e]) 213 | print('----------------------------') 214 | print('=======================================\n') 215 | 216 | if args.save_model: 217 | if not os.path.exists('./checkpoints/stu/wo_kd'): 218 | os.makedirs('./checkpoints/stu/wo_kd') 219 | torch.save(best_model_state, f'./checkpoints/stu/wo_kd/{args.database}_{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch+1}-{args.epochs}.pth') -------------------------------------------------------------------------------- /main-T.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import os 4 | from utils import seed_all, get_data, get_dataset, get_Tmodules, optimizers_zero_grad, optimizers_step 5 | import copy 6 | import time 7 | from argparse import ArgumentParser 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm 10 | from Dataset import MultiModalX 11 | import sys 12 | 13 | 14 | if __name__ == '__main__': 15 | ''' 16 | Args Setting for CML. 17 | ''' 18 | parser = ArgumentParser(description='CML-T') 19 | parser.add_argument('--database', type=str, default='AV-MNIST', 20 | help="database name must be one of " 21 | "['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'CMMD-V2']") 22 | parser.add_argument('--Tmodel', type=str, default='CNN-I', 23 | help='Teacher model name') 24 | parser.add_argument('--mode', type=str, default='ORG', 25 | help='Data mode: ORG or m1 or m2') 26 | parser.add_argument('--batch_size', type=int, default=128, 27 | help='batch size for training') 28 | parser.add_argument('--seed', type=int, default=0, 29 | help='Random seed') 30 | parser.add_argument('--num_workers', type=int, default=0, 31 | help='num_workers for DataLoader') 32 | parser.add_argument('--lr', type=float, default=0.0001, 33 | help='learning rate for training') 34 | parser.add_argument('--record', type=bool, default=True, 35 | help='flag whether to record the learning log') 36 | parser.add_argument('--cuda_id', type=int, default=0, 37 | help='cuda id') 38 | parser.add_argument('--epochs', type=int, default=100, 39 | help='epochs for training, default: 100') 40 | parser.add_argument('--save_model', type=bool, default=True, 41 | help='flag whether to save best model') 42 | parser.add_argument('--test_phase', type=bool, default=False, 43 | help='flag whether to conduct the test phase') 44 | parser.add_argument('--final_test', type=bool, default=True, 45 | help='flag whether to conduct the test phase') 46 | parser.add_argument('--commit', type=str, default='MM-T', 47 | help='Commit for logs') 48 | args = parser.parse_args() 49 | 50 | seed_all(args.seed) 51 | 52 | log_dir = f'./logs/tea' 53 | if not os.path.exists(log_dir): 54 | os.makedirs(log_dir) 55 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 56 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 57 | 58 | data = get_data(args.database) 59 | data_train = get_dataset(args.database, data, 'train', args.seed) 60 | data_val = get_dataset(args.database, data, 'val', args.seed) 61 | data_test = get_dataset(args.database, data, 'test', args.seed) 62 | 63 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 64 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 65 | test_dataset = MultiModalX(data_test, args.database, mode='ORG') 66 | 67 | train_loader = DataLoader( 68 | train_dataset, 69 | batch_size=args.batch_size, 70 | pin_memory=True, 71 | num_workers=args.num_workers, 72 | shuffle=True 73 | ) 74 | 75 | valid_loader = DataLoader( 76 | valid_dataset, 77 | pin_memory=True, 78 | batch_size=args.batch_size 79 | ) 80 | 81 | test_loader = DataLoader( 82 | test_dataset, 83 | pin_memory=True, 84 | batch_size=args.batch_size 85 | ) 86 | 87 | # ===========GPU Setting==================== 88 | device = torch.device(f"cuda:{args.cuda_id}") 89 | # ==========Initialization=========== 90 | model_t, optimizers, scheduler_t, criterion, preprocessing, postprocessing, metric = get_Tmodules(args, device) 91 | 92 | model_t = model_t.to(device) 93 | best_model_state = None 94 | best_val_loss = float('inf') 95 | print('================= Teacher Model Independent Training =================') 96 | for epoch in range(args.epochs): 97 | start_time1 = time.time() 98 | # train 99 | model_t.train() 100 | 101 | tra_LOSS_t = 0 102 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 103 | total=len(train_loader), dynamic_ncols=True, 104 | disable=False, file=sys.stdout): 105 | # print('Train Iter {}'.format(i)) 106 | data, data2, label = data.to(device), data2.to(device), label.to(device) 107 | # outputs_t = model_t(data2) 108 | outputs_t = model_t(data, data2) if preprocessing is None else preprocessing(model_t, data, data2) 109 | loss_t = criterion(outputs_t, label) if postprocessing is None else postprocessing(outputs_t, label) 110 | optimizers_zero_grad(optimizers) 111 | loss_t.backward() 112 | optimizers_step(optimizers, epoch) 113 | tra_LOSS_t += loss_t.item() 114 | 115 | tra_LOSS_c_avg = tra_LOSS_t / (i + 1) 116 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_c_avg:.4f}') 117 | writer.add_scalar('train/Loss', tra_LOSS_c_avg, epoch) 118 | 119 | # validation 120 | model_t.eval() 121 | metric.reset() 122 | L_t_val = 0 123 | acc_c = 0 124 | gt_list, pred_list = [], [] 125 | with torch.no_grad(): 126 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 127 | total=len(valid_loader), dynamic_ncols=True, disable=False, file=sys.stdout): 128 | # print('Val Iter {}'.format(i)) 129 | data, data2, label = data.to(device), data2.to(device), label.to(device) 130 | # outputs_t = model_t(data2) 131 | outputs_t = model_t(data, data2) if preprocessing is None else preprocessing(model_t, data, data2) 132 | 133 | loss_t = criterion(outputs_t, label) if postprocessing is None else postprocessing(outputs_t, label) 134 | L_t_val += loss_t.item() 135 | metric.update(outputs_t, label) 136 | L_t_val = L_t_val / (i + 1) 137 | # res = metric.compute() 138 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_t = {L_t_val:.4f}, OA_t = {res['Accuracy']:.2f}%") 139 | # writer.add_scalar('valid/Loss', L_t_val, epoch) 140 | # writer.add_scalar('valid/Acc', res['Accuracy'], epoch) 141 | 142 | # For NYU-Depth-V2 143 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t_val:.4f}") 144 | writer.add_scalar('valid/Loss', L_t_val, epoch) 145 | 146 | if L_t_val < best_val_loss: 147 | best_val_loss = L_t_val 148 | best_model_state = copy.deepcopy(model_t.state_dict()) 149 | best_epoch = epoch 150 | 151 | # test 152 | if args.test_phase: 153 | model_t.eval() 154 | metric.reset() 155 | L_t = 0 156 | acc_c = 0 157 | gt_list, pred_list = [], [] 158 | with torch.no_grad(): 159 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 160 | total=len(test_loader), dynamic_ncols=True, disable=False, file=sys.stdout): 161 | # print('Val Iter {}'.format(i)) 162 | data, data2, label = data.to(device), data2.to(device), label.to(device) 163 | # outputs_t = model_t(data2) 164 | outputs_t = model_t(data, data2) if preprocessing is None else preprocessing(model_t, data, data2) 165 | 166 | loss_t = criterion(outputs_t, label) if postprocessing is None else postprocessing(outputs_t, label) 167 | L_t = L_t + loss_t.item() 168 | metric.update(outputs_t, label) 169 | L_t = L_t / (i + 1) 170 | res = metric.compute() 171 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_t = {L_t:.4f}, OA_t = {res['Accuracy']:.4f}") 172 | writer.add_scalar('test/Loss', L_t, epoch) 173 | writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 174 | 175 | # # For NYU-Depth-V2 176 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t:.4f}") 177 | # writer.add_scalar('test/Loss', L_t, epoch) 178 | 179 | if (epoch + 1) % 10 == 0: 180 | print('\n===============Metrics==================') 181 | for e in res.keys(): 182 | print(e) 183 | print(res[e]) 184 | print('----------------------------') 185 | print('=======================================\n') 186 | 187 | # scheduler_t.step() 188 | 189 | start_time2 = time.time() 190 | time_cost = start_time2 - start_time1 191 | if time_cost > 100: 192 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 193 | else: 194 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 195 | 196 | writer.close() 197 | 198 | if args.final_test: 199 | model_t.load_state_dict(best_model_state) 200 | print('================= Final Test for This Model =================') 201 | # test 202 | model_t.eval() 203 | metric.reset() 204 | gt_list, pred_list = [], [] 205 | with torch.no_grad(): 206 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 207 | total=len(test_loader), dynamic_ncols=True, disable=True, file=sys.stdout): 208 | data, data2, label = data.to(device), data2.to(device), label.to(device) 209 | # outputs_t = model_t(data2) 210 | outputs_t = model_t(data, data2) if preprocessing is None else preprocessing(model_t, data, data2) 211 | metric.update(outputs_t, label) 212 | res = metric.compute() 213 | 214 | print('\n===============Metrics==================') 215 | for e in res.keys(): 216 | print(e) 217 | print(res[e]) 218 | print('----------------------------') 219 | print('=======================================\n') 220 | 221 | if args.save_model: 222 | if not os.path.exists('./checkpoints/tea'): 223 | os.makedirs('./checkpoints/tea') 224 | torch.save(best_model_state, f'./checkpoints/tea/{args.database}_{args.Tmodel}_seed{args.seed}_{args.mode}_ep{best_epoch+1}-{args.epochs}.pth') -------------------------------------------------------------------------------- /KD_methods/C2KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchinfo import summary 5 | 6 | 7 | def get_C2KDmodules(args): 8 | if args.database == 'AV-MNIST': 9 | if args.Tmodel == 'CNN-I': 10 | px_t = ProxyNet1D(num_classes=10, input_dim=596) 11 | if args.Smodel == 'LeNet5': 12 | feat_names = [['concat_identity'], ['fc1']] 13 | px_s = ProxyNet1D(num_classes=10, input_dim=84) 14 | elif args.Smodel == 'ThreeLayerCNN-A': 15 | feat_names = [['concat_identity'], ['conv3_maxpool']] 16 | px_s = ProxyNet(num_classes=10, in_channels=128) 17 | elif args.Tmodel == 'LeNet5': 18 | px_t = ProxyNet1D(num_classes=10, input_dim=84) 19 | feat_names = [['fc1'], ['conv3_maxpool']] 20 | px_s = ProxyNet(num_classes=10, in_channels=128) 21 | elif args.Tmodel == 'ThreeLayerCNN-A': 22 | px_t = ProxyNet(num_classes=10, in_channels=128) 23 | feat_names = [['conv3_maxpool'], ['fc1']] 24 | px_s = ProxyNet1D(num_classes=10, input_dim=84) 25 | 26 | elif args.database == 'NYU-Depth-V2': 27 | if args.Tmodel == 'FuseNet-I': 28 | px_t = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 29 | if args.Smodel == 'FuseNet-RGBbranch': 30 | feat_names = [['CBR2_RGBD_DEC'], ['CBR2_RGB_DEC']] 31 | px_s = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 32 | elif args.Smodel == 'FuseNet-Dbranch': 33 | feat_names = [['CBR2_RGBD_DEC'], ['CBR2_D_DEC']] 34 | px_s = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 35 | elif args.Tmodel == 'FuseNet-RGBbranch': 36 | px_t = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 37 | feat_names = [['CBR2_RGB_DEC'], ['CBR2_D_DEC']] 38 | px_s = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 39 | elif args.Tmodel == 'FuseNet-Dbranch': 40 | px_t = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 41 | feat_names = [['CBR2_D_DEC'], ['CBR2_RGB_DEC']] 42 | px_s = ProxyNet_resoFixed(num_classes=41, in_channels=64, up_sample=(480, 640)) 43 | 44 | elif args.database == 'RAVDESS': 45 | if args.Tmodel == 'DSCNN-I': 46 | px_t = ProxyNet1D(num_classes=8, input_dim=320) 47 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 48 | feat_names = [['fc1'], ['fc1']] 49 | px_s = ProxyNet1D(num_classes=8, input_dim=320) 50 | elif args.Tmodel in ['AudioBranchNet', 'VisualBranchNet']: 51 | px_t = ProxyNet1D(num_classes=8, input_dim=320) 52 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 53 | feat_names = [['fc1'], ['fc1']] 54 | px_s = ProxyNet1D(num_classes=8, input_dim=320) 55 | 56 | elif args.database == 'VGGSound-50k': 57 | if args.Tmodel == 'DSCNN-VGGS-I': 58 | px_t = ProxyNet1D(num_classes=141, input_dim=320) 59 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 60 | feat_names = [['fc1'], ['fc1']] 61 | px_s = ProxyNet1D(num_classes=141, input_dim=320) 62 | elif args.Tmodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 63 | px_t = ProxyNet1D(num_classes=141, input_dim=320) 64 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 65 | feat_names = [['fc1'], ['fc1']] 66 | px_s = ProxyNet1D(num_classes=141, input_dim=320) 67 | 68 | elif args.database == 'CMMD-V2': 69 | if args.Tmodel == 'MLP-I': 70 | px_t = ProxyNet1D(num_classes=8, input_dim=256) 71 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 72 | feat_names = [['fc2'], ['fc2']] 73 | px_s = ProxyNet1D(num_classes=8, input_dim=512) 74 | elif args.Tmodel in ['MLP-Vb', 'MLP-Tb']: 75 | px_t = ProxyNet1D(num_classes=8, input_dim=512) 76 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 77 | feat_names = [['fc2'], ['fc2']] 78 | px_s = ProxyNet1D(num_classes=8, input_dim=512) 79 | 80 | else: 81 | raise ValueError(f"Invalid database name {args.database}.") 82 | 83 | KL_batchmean = torch.nn.KLDivLoss(reduction='batchmean') 84 | KL_none = torch.nn.KLDivLoss(reduction='none') 85 | return feat_names, px_t, px_s, KL_batchmean, KL_none 86 | 87 | 88 | class ProxyNet(nn.Module): 89 | """Proxy network for C$^2$KD, serving as either a teacher or student model""" 90 | 91 | def __init__(self, num_classes=28, in_channels=256): 92 | super(ProxyNet, self).__init__() 93 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 94 | self.layer = conv_1x1_bn(in_channels, in_channels) 95 | self.fc = nn.Linear(in_channels, num_classes) 96 | 97 | # Initialize weights 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | if m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | nn.init.constant_(m.weight, 1) 105 | nn.init.constant_(m.bias, 0) 106 | 107 | def forward(self, x): 108 | x = self.avgpool(x) 109 | x = self.layer(x) 110 | x = x.view(x.size(0), -1) 111 | x = self.fc(x) 112 | return x 113 | 114 | 115 | def conv_1x1_bn(num_input_channels, num_mid_channel): 116 | return nn.Sequential( 117 | conv1x1(num_input_channels, num_mid_channel), 118 | nn.BatchNorm2d(num_mid_channel), 119 | nn.LeakyReLU(0.1, inplace=True), 120 | ) 121 | 122 | 123 | def conv1x1(in_planes, out_planes, stride=1): 124 | """1x1 convolution""" 125 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 126 | 127 | 128 | class ProxyNet1D(nn.Module): 129 | """Proxy network for C$^2$KD, adapted to process 1D vector data with shape (batch_size, dim)""" 130 | 131 | def __init__(self, num_classes=28, input_dim=256): 132 | super(ProxyNet1D, self).__init__() 133 | self.fc1 = nn.Linear(input_dim, input_dim) # Equivalent to conv1x1 layer 134 | self.bn = nn.BatchNorm1d(input_dim) 135 | self.activation = nn.LeakyReLU(0.1, inplace=True) 136 | self.fc2 = nn.Linear(input_dim, num_classes) 137 | 138 | # Initialize weights 139 | for m in self.modules(): 140 | if isinstance(m, nn.Linear): 141 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 142 | if m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | elif isinstance(m, nn.BatchNorm1d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | def forward(self, x): 149 | x = self.fc1(x) 150 | x = self.bn(x) 151 | x = self.activation(x) 152 | x = self.fc2(x) 153 | return x 154 | 155 | 156 | def conv_1x1_bn_resoFixed(in_channels, out_channels): 157 | return nn.Sequential( 158 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False), 159 | nn.BatchNorm2d(out_channels), 160 | nn.ReLU(inplace=True) 161 | ) 162 | 163 | 164 | class ProxyNet_resoFixed(nn.Module): 165 | """Proxy network for C$^2$KD, serving as either a teacher or student model""" 166 | 167 | def __init__(self, num_classes=28, in_channels=256, up_sample=None): 168 | super(ProxyNet_resoFixed, self).__init__() 169 | self.layer = conv_1x1_bn_resoFixed(in_channels, in_channels) # 保持特征图分辨率 170 | self.fc = nn.Conv2d(in_channels, num_classes, kernel_size=1) # 改为1x1卷积替代全连接层 171 | self.up_sample = up_sample 172 | 173 | # Initialize weights 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 177 | if m.bias is not None: 178 | nn.init.constant_(m.bias, 0) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | nn.init.constant_(m.weight, 1) 181 | nn.init.constant_(m.bias, 0) 182 | 183 | def forward(self, x): 184 | x = self.layer(x) 185 | x = self.fc(x) # 输出特征图分辨率不变 186 | x = F.interpolate(x, size=self.up_sample, mode='bilinear', align_corners=False) if self.up_sample else x 187 | return x 188 | 189 | 190 | def ntkl(logits_student, logits_teacher, target, mask=None, criterion4=None, temperature=1): 191 | 192 | gt_mask = _get_gt_mask(logits_student, target) 193 | logits_teacher = logits_teacher * (~gt_mask) 194 | pred_teacher_part2 = F.softmax(logits_teacher / temperature, dim=1) 195 | logits_student = logits_student * (~gt_mask) 196 | log_pred_student_part2 = F.log_softmax(logits_student / temperature, dim=1) 197 | if mask.sum() == 0: 198 | temp = torch.tensor(0) 199 | else: 200 | temp = ((mask * (criterion4(log_pred_student_part2, pred_teacher_part2.detach()).sum(1)))).mean() 201 | return temp 202 | 203 | 204 | def ntkl_ss(logits_student, logits_teacher, target, original_shape, mask=None, criterion3=None, temperature=1): 205 | gt_mask = _get_gt_mask(logits_student, target) 206 | logits_teacher = logits_teacher * (~gt_mask) 207 | pred_teacher_part2 = F.softmax(logits_teacher / temperature, dim=1) 208 | logits_student = logits_student * (~gt_mask) 209 | log_pred_student_part2 = F.log_softmax(logits_student / temperature, dim=1) 210 | 211 | b, h, w = original_shape 212 | log_pred_student_part2 = log_pred_student_part2.view(b, h * w, -1) 213 | pred_teacher_part2 = pred_teacher_part2.view(b, h * w, -1) 214 | 215 | if mask.sum() == 0: 216 | temp = torch.tensor(0) 217 | else: 218 | temp = criterion3(log_pred_student_part2, pred_teacher_part2.detach())/(h*w) 219 | return temp 220 | 221 | 222 | def _get_gt_mask(logits, target): 223 | if target.dim() ==2: 224 | mask = target.bool() 225 | return mask 226 | else: 227 | target = target.reshape(-1) 228 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 229 | return mask 230 | 231 | 232 | def get_feats(args, features, feat_name): 233 | if args.database == 'NYU-Depth-V2' and args.Tmodel == 'CEN': 234 | if args.mode == 'm1': 235 | return features[feat_name][0] 236 | elif args.mode == 'm2': 237 | return features[feat_name][1] 238 | else: 239 | return features[feat_name] 240 | 241 | 242 | def get_logits(args, outputs): 243 | if isinstance(outputs, tuple): 244 | if len(outputs) == 4: 245 | return outputs[1] 246 | elif outputs.dim() == 4 and args.database != 'NYU-Depth-V2': 247 | return outputs.view(-1) 248 | else: 249 | return outputs 250 | 251 | 252 | def get_labels(args, label): 253 | if args.database == 'NYU-Depth-V2': 254 | # [bs, 480, 640] 255 | one_hot_label = F.one_hot(label, num_classes=41).permute(0, 3, 1, 2).float() 256 | return one_hot_label 257 | elif args.database == 'VGGSound-50k': 258 | return label[:, 0, -1].long() 259 | else: 260 | return label 261 | 262 | 263 | if __name__ == '__main__': 264 | # input_tensor_1 = torch.randn(2, 256, 28, 14) 265 | # input_tensor_2 = torch.randn(2, 256) 266 | 267 | # Test the ProxyNet model 268 | model1 = ProxyNet(10, 128) 269 | # Test the ProxyNet1D model 270 | model2 = ProxyNet1D(10, 84) 271 | 272 | # res1 = model1(input_tensor_1) 273 | # res2 = model2(input_tensor_2) 274 | 275 | # print(res1.shape) 276 | # print(res2.shape) 277 | 278 | # summary(model1, input_size=(1, 128, 14, 14)) 279 | # summary(model2, input_size=(1, 84)) 280 | 281 | 282 | -------------------------------------------------------------------------------- /KD_methods/OFA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import trunc_normal_ 5 | 6 | 7 | def get_OFAmodules(args, model_s): 8 | feat_names = [[], []] 9 | projectors = [] 10 | if args.database == 'AV-MNIST': 11 | cls_num = 10 12 | if args.Tmodel == 'CNN-I': 13 | if args.Smodel == 'LeNet5': 14 | feat_names = [['fc2'], ['conv1', 'conv2', 'fc1']] 15 | proj_1 = Projector(6, 2*cls_num, cls_num) 16 | projectors.append(proj_1) 17 | proj_2 = Projector(16, 2*cls_num, cls_num) 18 | projectors.append(proj_2) 19 | proj_3 = Projector_1D(84, 2*cls_num, cls_num) 20 | projectors.append(proj_3) 21 | elif args.Smodel == 'ThreeLayerCNN-A': 22 | feat_names = [['fc2'], ['conv1', 'conv2', 'conv3']] 23 | proj_1 = Projector(32, 2*cls_num, cls_num) 24 | projectors.append(proj_1) 25 | proj_2 = Projector(64, 2*cls_num, cls_num) 26 | projectors.append(proj_2) 27 | proj_3 = Projector(128, 2*cls_num, cls_num) 28 | projectors.append(proj_3) 29 | elif args.Tmodel == 'LeNet5': 30 | if args.Smodel == 'ThreeLayerCNN-A': 31 | feat_names = [['fc2'], ['conv1', 'conv2', 'conv3']] 32 | proj_1 = Projector(32, 2 * cls_num, cls_num) 33 | projectors.append(proj_1) 34 | proj_2 = Projector(64, 2 * cls_num, cls_num) 35 | projectors.append(proj_2) 36 | proj_3 = Projector(128, 2 * cls_num, cls_num) 37 | projectors.append(proj_3) 38 | elif args.Tmodel == 'ThreeLayerCNN-A': 39 | if args.Smodel == 'LeNet5': 40 | feat_names = [['fc1'], ['conv1', 'conv2', 'fc1']] 41 | proj_1 = Projector(6, 2 * cls_num, cls_num) 42 | projectors.append(proj_1) 43 | proj_2 = Projector(16, 2 * cls_num, cls_num) 44 | projectors.append(proj_2) 45 | proj_3 = Projector_1D(84, 2 * cls_num, cls_num) 46 | projectors.append(proj_3) 47 | 48 | elif args.database == 'RAVDESS': 49 | cls_num = 8 50 | if args.Tmodel == 'DSCNN-I': 51 | if args.Smodel == 'AudioBranchNet': 52 | feat_names = [['fc3'], ['audio_branch', 'fc1', 'fc2']] 53 | proj_1 = Projector_1D(1280, 2*cls_num, cls_num) 54 | projectors.append(proj_1) 55 | proj_2 = Projector_1D(320, 2*cls_num, cls_num) 56 | projectors.append(proj_2 ) 57 | proj_3 = Projector_1D(160, 2*cls_num, cls_num) 58 | projectors.append(proj_3) 59 | elif args.Smodel == 'VisualBranchNet': 60 | feat_names = [['fc3'], ['visual_branch', 'fc1', 'fc2']] 61 | proj_1 = Projector_1D(18816, 2*cls_num, cls_num) 62 | projectors.append(proj_1) 63 | proj_2 = Projector_1D(320, 2*cls_num, cls_num) 64 | projectors.append(proj_2) 65 | proj_3 = Projector_1D(160, 2*cls_num, cls_num) 66 | projectors.append(proj_3) 67 | elif args.Tmodel == 'VisualBranchNet': 68 | if args.Smodel == 'AudioBranchNet': 69 | feat_names = [['fc3'], ['audio_branch', 'fc1', 'fc2']] 70 | proj_1 = Projector_1D(1280, 2 * cls_num, cls_num) 71 | projectors.append(proj_1) 72 | proj_2 = Projector_1D(320, 2 * cls_num, cls_num) 73 | projectors.append(proj_2) 74 | proj_3 = Projector_1D(160, 2 * cls_num, cls_num) 75 | projectors.append(proj_3) 76 | elif args.Tmodel == 'AudioBranchNet': 77 | if args.Smodel == 'VisualBranchNet': 78 | feat_names = [['fc3'], ['visual_branch', 'fc1', 'fc2']] 79 | proj_1 = Projector_1D(18816, 2 * cls_num, cls_num) 80 | projectors.append(proj_1) 81 | proj_2 = Projector_1D(320, 2 * cls_num, cls_num) 82 | projectors.append(proj_2) 83 | proj_3 = Projector_1D(160, 2 * cls_num, cls_num) 84 | projectors.append(proj_3) 85 | 86 | elif args.database == 'VGGSound-50k': 87 | cls_num = 141 88 | if args.Tmodel == 'DSCNN-VGGS-I': 89 | if args.Smodel == 'VisualBranchNet-VGGS': 90 | feat_names = [['fc3'], ['visual_branch', 'fc1', 'fc2']] 91 | proj_1 = Projector_1D(128, 2*cls_num, cls_num) 92 | projectors.append(proj_1) 93 | proj_2 = Projector_1D(320, 2*cls_num, cls_num) 94 | projectors.append(proj_2) 95 | proj_3 = Projector_1D(160, 2*cls_num, cls_num) 96 | projectors.append(proj_3) 97 | elif args.Smodel == 'AudioBranchNet-VGGS': 98 | feat_names = [['fc3'], ['audio_branch', 'fc1', 'fc2']] 99 | proj_1 = Projector_1D(256, 2*cls_num, cls_num) 100 | projectors.append(proj_1) 101 | proj_2 = Projector_1D(320, 2*cls_num, cls_num) 102 | projectors.append(proj_2) 103 | proj_3 = Projector_1D(160, 2*cls_num, cls_num) 104 | projectors.append(proj_3) 105 | elif args.Tmodel == 'VisualBranchNet-VGGS': 106 | if args.Smodel == 'AudioBranchNet-VGGS': 107 | feat_names = [['fc3'], ['audio_branch', 'fc1', 'fc2']] 108 | proj_1 = Projector_1D(256, 2 * cls_num, cls_num) 109 | projectors.append(proj_1) 110 | proj_2 = Projector_1D(320, 2 * cls_num, cls_num) 111 | projectors.append(proj_2) 112 | proj_3 = Projector_1D(160, 2 * cls_num, cls_num) 113 | projectors.append(proj_3) 114 | elif args.Tmodel == 'AudioBranchNet-VGGS': 115 | if args.Smodel == 'VisualBranchNet-VGGS': 116 | feat_names = [['fc3'], ['visual_branch', 'fc1', 'fc2']] 117 | proj_1 = Projector_1D(128, 2 * cls_num, cls_num) 118 | projectors.append(proj_1) 119 | proj_2 = Projector_1D(320, 2 * cls_num, cls_num) 120 | projectors.append(proj_2) 121 | proj_3 = Projector_1D(160, 2 * cls_num, cls_num) 122 | projectors.append(proj_3) 123 | 124 | elif args.database == 'CMMD-V2': 125 | cls_num = 8 126 | if args.Tmodel == 'MLP-I': 127 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 128 | feat_names = [['fc3'], ['fc1', 'fc2']] 129 | proj_1 = Projector_1D(1024, 2 * cls_num, cls_num) 130 | projectors.append(proj_1) 131 | proj_2 = Projector_1D(512, 2 * cls_num, cls_num) 132 | projectors.append(proj_2) 133 | 134 | elif args.Tmodel in ['MLP-Vb', 'MLP-Tb']: 135 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 136 | feat_names = [['fc3'], ['fc1', 'fc2']] 137 | proj_1 = Projector_1D(1024, 2 * cls_num, cls_num) 138 | projectors.append(proj_1) 139 | proj_2 = Projector_1D(512, 2 * cls_num, cls_num) 140 | projectors.append(proj_2) 141 | 142 | else: 143 | raise ValueError(f"Invalid database name {args.database}.") 144 | criterion = OFA_Loss(feat_names, args, projectors, cls_num) 145 | params = list(model_s.parameters()) 146 | for proj in projectors: 147 | proj.apply(init_weights) 148 | params += list(proj.parameters()) 149 | optim = torch.optim.Adam(params, lr=args.lr) 150 | return projectors, criterion, optim 151 | 152 | 153 | class SepConv(nn.Module): 154 | def __init__(self, channel_in, channel_out, kernel_size=3, stride=2, padding=1, affine=True): 155 | # depthwise and pointwise convolution, downsample by 2 156 | super(SepConv, self).__init__() 157 | self.op = nn.Sequential( 158 | nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=stride, padding=padding, 159 | groups=channel_in, bias=False), 160 | nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False), 161 | nn.BatchNorm2d(channel_in, affine=affine), 162 | nn.ReLU(inplace=False), 163 | nn.Conv2d(channel_in, channel_in, kernel_size=kernel_size, stride=1, padding=padding, groups=channel_in, 164 | bias=False), 165 | nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False), 166 | nn.BatchNorm2d(channel_out, affine=affine), 167 | nn.ReLU(inplace=False), 168 | ) 169 | 170 | def forward(self, x): 171 | return self.op(x) 172 | 173 | 174 | class Projector(nn.Module): 175 | def __init__(self, in_channels, out_channels, num_classes): 176 | super(Projector, self).__init__() 177 | down_sample_blks = [SepConv(in_channels, 2*in_channels), 178 | SepConv(2*in_channels, out_channels)] 179 | self.blks = nn.Sequential( 180 | *down_sample_blks, 181 | nn.AdaptiveAvgPool2d(1), 182 | nn.Flatten(), 183 | nn.Linear(out_channels, num_classes) 184 | ) 185 | 186 | def forward(self, x): 187 | return self.blks(x) 188 | 189 | 190 | class Projector_1D(nn.Module): 191 | def __init__(self, in_dim, hidden_dim, num_classes): 192 | super(Projector_1D, self).__init__() 193 | self.blks = nn.Sequential( 194 | nn.Linear(in_dim, hidden_dim), 195 | nn.ReLU(), 196 | nn.Linear(hidden_dim, num_classes) 197 | ) 198 | 199 | def forward(self, x): 200 | return self.blks(x) 201 | 202 | 203 | def init_weights(module): 204 | for n, m in module.named_modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 207 | elif isinstance(m, nn.BatchNorm2d): 208 | nn.init.ones_(m.weight) 209 | nn.init.zeros_(m.bias) 210 | elif isinstance(m, nn.Linear): 211 | trunc_normal_(m.weight, std=.02) 212 | if m.bias is not None: 213 | nn.init.zeros_(m.bias) 214 | 215 | 216 | class OFA_Loss(nn.Module): 217 | def __init__(self, feat_names, args, projectors, cls_num): 218 | super(OFA_Loss, self).__init__() 219 | self.feat_names = feat_names 220 | self.database = args.database 221 | self.mode = args.mode 222 | self.eps = args.ofa_eps 223 | self.temperature = args.ofa_temperature 224 | self.projectors = projectors 225 | self.cls_num = cls_num 226 | 227 | def forward(self, teacher_feats, student_feats, labels): 228 | loss = 0 229 | tea_logits = teacher_feats[self.feat_names[0][0]].detach() 230 | N = len(self.feat_names[1]) 231 | labels = labels[:, 0, -1].long() if self.database == 'VGGSound-50k' else labels 232 | target_mask = F.one_hot(labels, self.cls_num) 233 | for i in range(N): 234 | stu_feat = student_feats[self.feat_names[1][i]] 235 | proj = self.projectors[i] 236 | stu_logits = proj(stu_feat) 237 | loss += ofa_loss(stu_logits, tea_logits, target_mask, self.eps, self.temperature) 238 | return loss / N 239 | 240 | 241 | def ofa_loss(logits_student, logits_teacher, target_mask, eps, temperature=1.): 242 | pred_student = F.softmax(logits_student / temperature, dim=1) 243 | pred_teacher = F.softmax(logits_teacher / temperature, dim=1) 244 | prod = (pred_teacher + target_mask) ** eps 245 | loss = torch.sum(- (prod - target_mask) * torch.log(pred_student), dim=-1) 246 | return loss.mean() 247 | 248 | 249 | def projectors_train(projectors): 250 | for projector in projectors: 251 | projector.train() 252 | 253 | 254 | def projectors_eval(projectors): 255 | for projector in projectors: 256 | projector.eval() 257 | 258 | 259 | if __name__ == '__main__': 260 | input_tensor = torch.randn(3, 84) 261 | proj = Projector_1D(84, 20, 10) 262 | output = proj(input_tensor) 263 | print(output.shape) -------------------------------------------------------------------------------- /main-KD-UU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import get_data, get_dataset, get_Tmodules, get_Smodules, seed_all 4 | import copy 5 | import time 6 | from argparse import ArgumentParser 7 | from tensorboardX import SummaryWriter 8 | from tqdm import tqdm 9 | from Dataset import MultiModalX 10 | from KD_methods.KD import distillation_loss 11 | import sys 12 | import os 13 | 14 | 15 | if __name__ == '__main__': 16 | ''' 17 | Args Setting for CML. 18 | ''' 19 | parser = ArgumentParser(description='CML-KD') 20 | parser.add_argument('--database', type=str, default='AV-MNIST', 21 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 22 | parser.add_argument('--Tmodel', type=str, default='ThreeLayerCNN-A', 23 | help='Teacher model name') 24 | 25 | parser.add_argument('--Smodel', type=str, default='LeNet5', 26 | help='Student model name') 27 | parser.add_argument('--mode', type=str, default='m1', 28 | help='modality mode: m1 or m2') 29 | parser.add_argument('--ckpt_name', type=str, 30 | default='AV-MNIST_ThreeLayerCNN-A_seed0_m2_ep24-100.pth', 31 | help='The name of the weight to be loaded in ./checkpoints/stu') 32 | parser.add_argument('--seed', type=int, default=0, 33 | help='Random seed') 34 | 35 | parser.add_argument('--num_workers', type=int, default=0, 36 | help='num_workers for DataLoader') 37 | parser.add_argument('--alpha', type=float, default=0.5, 38 | help='weight for loss') 39 | parser.add_argument('--batch_size', type=int, default=512, 40 | help='batch size for training') 41 | parser.add_argument('--lr', type=float, default=0.0001, 42 | help='learning rate for training') 43 | parser.add_argument('--record', type=bool, default=True, 44 | help='flag whether to record the learning log') 45 | parser.add_argument('--cuda_id', type=int, default=0, 46 | help='cuda id') 47 | parser.add_argument('--epochs', type=int, default=100, 48 | help='epochs for training') 49 | parser.add_argument('--save_model', type=bool, default=True, 50 | help='flag whether to save best model') 51 | parser.add_argument('--test_phase', type=bool, default=False, 52 | help='flag whether to conduct the test phase') 53 | parser.add_argument('--commit', type=str, default='UUKD', 54 | help='Commit for logs') 55 | args = parser.parse_args() 56 | 57 | seed_all(args.seed) 58 | 59 | # 保存log 60 | log_dir = f'./logs/kd' 61 | if not os.path.exists(log_dir): 62 | os.makedirs(log_dir) 63 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 64 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 65 | 66 | data = get_data(args.database) 67 | data_train = get_dataset(args.database, data, 'train', args.seed) 68 | data_val = get_dataset(args.database, data, 'val', args.seed) 69 | data_test = get_dataset(args.database, data, 'test', args.seed) 70 | 71 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 72 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 73 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 74 | 75 | train_loader = DataLoader( 76 | train_dataset, 77 | batch_size=args.batch_size, 78 | pin_memory=True, 79 | num_workers=args.num_workers, 80 | shuffle=True 81 | ) 82 | 83 | valid_loader = DataLoader( 84 | valid_dataset, 85 | pin_memory=True, 86 | batch_size=args.batch_size 87 | ) 88 | 89 | test_loader = DataLoader( 90 | test_dataset, 91 | pin_memory=True, 92 | batch_size=args.batch_size 93 | ) 94 | 95 | # ===========GPU Setting==================== 96 | device = torch.device(f"cuda:{args.cuda_id}") 97 | # ==========Initialization=========== 98 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 99 | model_s, optimizer_s, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 100 | 101 | model_t = model_t.to(device) 102 | model_s = model_s.to(device) 103 | model_t.load_state_dict(torch.load(f'checkpoints/stu/wo_kd/{args.ckpt_name}', weights_only=True, map_location=f'cuda:{args.cuda_id}')) 104 | model_t.eval() 105 | 106 | best_model_state = None 107 | best_val_loss = float('inf') 108 | # Cloud-only Learning 109 | print('================= CML-KD =================') 110 | for epoch in range(args.epochs): 111 | start_time1 = time.time() 112 | # train 113 | model_s.train() 114 | 115 | tra_LOSS_s, tra_LOSS_task, tra_LOSS_kl = 0, 0 ,0 116 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 117 | total=len(train_loader), dynamic_ncols=True, 118 | disable=False, file=sys.stdout): 119 | # print('Train Iter {}'.format(i)) 120 | data, data2, label = data.to(device), data2.to(device), label.to(device) 121 | if args.mode == 'm1': 122 | data_t, data_s = data2, data 123 | else: 124 | data_t, data_s = data, data2 125 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 126 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 127 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 128 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 129 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 130 | optimizer_s.zero_grad() 131 | loss.backward() 132 | optimizer_s.step() 133 | tra_LOSS_s += loss.item() 134 | tra_LOSS_task += loss_s.item() 135 | tra_LOSS_kl += loss_kl.item() 136 | 137 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 138 | tra_LOSS_task_avg = tra_LOSS_task / (i + 1) 139 | tra_LOSS_kl_avg = tra_LOSS_kl / (i + 1) 140 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | ' 141 | f'loss_task = {tra_LOSS_task_avg:.4f} | loss_kl = {tra_LOSS_kl_avg:.4f}') 142 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 143 | 144 | # validation 145 | model_s.eval() 146 | metric.reset() 147 | L_s_val = 0 148 | acc_c = 0 149 | Loss_s, Loss_kl = 0, 0 150 | gt_list, pred_list = [], [] 151 | with torch.no_grad(): 152 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 153 | total=len(valid_loader), dynamic_ncols=True, disable=False, 154 | file=sys.stdout): 155 | # print('Val Iter {}'.format(i)) 156 | data, data2, label = data.to(device), data2.to(device), label.to(device) 157 | if args.mode == 'm1': 158 | data_t, data_s = data2, data 159 | else: 160 | data_t, data_s = data, data2 161 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 162 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 163 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 164 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 165 | loss = loss_s 166 | L_s_val += loss.item() 167 | Loss_kl += loss_kl.item() 168 | Loss_s += loss_s.item() 169 | metric.update(outputs_s, label) 170 | L_s_val = L_s_val / (i + 1) 171 | Loss_kl_avg = Loss_kl / (i + 1) 172 | Loss_s_avg = Loss_s / (i + 1) 173 | # res = metric.compute() 174 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_s_val:.4f} | " 175 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 176 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 177 | # writer.add_scalar('valid/Acc', float(res['Accuracy']), epoch) 178 | 179 | # For NYU-Depth-V2 180 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s_val:.4f} | loss_task = {Loss_s_avg:.4f} | loss_kl = {Loss_kl_avg:.4f}") 181 | writer.add_scalar('valid/Loss', L_s_val, epoch) 182 | 183 | 184 | if L_s_val < best_val_loss: 185 | best_val_loss = L_s_val 186 | best_model_state = copy.deepcopy(model_s.state_dict()) 187 | best_epoch = epoch 188 | 189 | # test 190 | if args.test_phase: 191 | model_s.eval() 192 | metric.reset() 193 | L_t = 0 194 | acc_c = 0 195 | Loss_s, Loss_kl = 0, 0 196 | gt_list, pred_list = [], [] 197 | with torch.no_grad(): 198 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 199 | total=len(test_loader), dynamic_ncols=True, disable=False, 200 | file=sys.stdout): 201 | data, data2, label = data.to(device), data2.to(device), label.to(device) 202 | if args.mode == 'm1': 203 | data_t, data_s = data2, data 204 | else: 205 | data_t, data_s = data, data2 206 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 207 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 208 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 209 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 210 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 211 | L_t = L_t + loss.item() 212 | Loss_kl += loss_kl.item() 213 | Loss_s += loss_s.item() 214 | metric.update(outputs_s, label) 215 | L_t = L_t / (i + 1) 216 | Loss_kl_avg = Loss_kl / (i + 1) 217 | Loss_s_avg = Loss_s / (i + 1) 218 | # res = metric.compute() 219 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_t:.4f} | " 220 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 221 | # writer.add_scalar('test/Loss', L_t, epoch) 222 | # writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 223 | 224 | # For NYU-Depth-V2 225 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t:.4f}") 226 | writer.add_scalar('test/Loss', L_t, epoch) 227 | 228 | # if (epoch + 1) % 10 == 0: 229 | # print('\n===============Metrics==================') 230 | # for e in res.keys(): 231 | # print(e) 232 | # print(res[e]) 233 | # print('----------------------------') 234 | # print('=======================================\n') 235 | 236 | # scheduler_t.step() # 学习率衰减(当训练SAFN模型时,需加入监听指标作为参数) 237 | args.alpha *= 0.5 if (epoch + 1) % 30 == 0 else 1.0 238 | start_time2 = time.time() 239 | time_cost = start_time2 - start_time1 240 | if time_cost > 100: 241 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 242 | else: 243 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 244 | 245 | writer.close() 246 | 247 | if args.save_model: 248 | if not os.path.exists('./checkpoints/stu/kd'): 249 | os.makedirs('./checkpoints/stu/kd') 250 | names = args.ckpt_name.split('_') 251 | Tmodel_mode = names[3] 252 | torch.save(best_model_state, 253 | f'./checkpoints/stu/kd/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 254 | -------------------------------------------------------------------------------- /main-MLLD-UU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import get_data, get_dataset, get_Tmodules, get_Smodules, seed_all 4 | import copy 5 | import time 6 | from argparse import ArgumentParser 7 | from tensorboardX import SummaryWriter 8 | from tqdm import tqdm 9 | from Dataset import MultiModalX 10 | from KD_methods.MLLD import MultiLevelLogitDistillation 11 | import sys 12 | import os 13 | 14 | 15 | if __name__ == '__main__': 16 | ''' 17 | Args Setting for CML. 18 | ''' 19 | parser = ArgumentParser(description='CML-MLLD') 20 | parser.add_argument('--database', type=str, default='AV-MNIST', 21 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 22 | parser.add_argument('--Tmodel', type=str, default='ThreeLayerCNN-A', 23 | help='Teacher model name') 24 | 25 | parser.add_argument('--Smodel', type=str, default='LeNet5', 26 | help='Student model name') 27 | parser.add_argument('--mode', type=str, default='m1', 28 | help='modality mode: m1 or m2') 29 | parser.add_argument('--ckpt_name', type=str, 30 | default='AV-MNIST_ThreeLayerCNN-A_seed0_m2_ep24-100.pth', 31 | help='The name of the weight to be loaded in ./checkpoints/stu') 32 | parser.add_argument('--seed', type=int, default=0, 33 | help='Random seed') 34 | 35 | parser.add_argument('--num_workers', type=int, default=0, 36 | help='num_workers for DataLoader') 37 | parser.add_argument('--alpha', type=float, default=0.5, 38 | help='weight for loss') 39 | parser.add_argument('--batch_size', type=int, default=512, 40 | help='batch size for training') 41 | parser.add_argument('--lr', type=float, default=0.0001, 42 | help='learning rate for training') 43 | parser.add_argument('--record', type=bool, default=True, 44 | help='flag whether to record the learning log') 45 | parser.add_argument('--cuda_id', type=int, default=0, 46 | help='cuda id') 47 | parser.add_argument('--epochs', type=int, default=100, 48 | help='epochs for training') 49 | parser.add_argument('--save_model', type=bool, default=True, 50 | help='flag whether to save best model') 51 | parser.add_argument('--test_phase', type=bool, default=False, 52 | help='flag whether to conduct the test phase') 53 | parser.add_argument('--commit', type=str, default='UU-MLLD', 54 | help='Commit for logs') 55 | args = parser.parse_args() 56 | 57 | seed_all(args.seed) 58 | 59 | 60 | log_dir = f'./logs/mlld' 61 | if not os.path.exists(log_dir): 62 | os.makedirs(log_dir) 63 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 64 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 65 | 66 | data = get_data(args.database) 67 | data_train = get_dataset(args.database, data, 'train', args.seed) 68 | data_val = get_dataset(args.database, data, 'val', args.seed) 69 | data_test = get_dataset(args.database, data, 'test', args.seed) 70 | 71 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 72 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 73 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 74 | 75 | train_loader = DataLoader( 76 | train_dataset, 77 | batch_size=args.batch_size, 78 | pin_memory=True, 79 | num_workers=args.num_workers, 80 | shuffle=True 81 | ) 82 | 83 | valid_loader = DataLoader( 84 | valid_dataset, 85 | pin_memory=True, 86 | batch_size=args.batch_size 87 | ) 88 | 89 | test_loader = DataLoader( 90 | test_dataset, 91 | pin_memory=True, 92 | batch_size=args.batch_size 93 | ) 94 | 95 | # ===========GPU Setting==================== 96 | device = torch.device(f"cuda:{args.cuda_id}") 97 | # ==========Initialization=========== 98 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 99 | model_s, optimizer_s, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 100 | mlld_loss = MultiLevelLogitDistillation() 101 | 102 | model_t = model_t.to(device) 103 | model_s = model_s.to(device) 104 | model_t.load_state_dict(torch.load(f'checkpoints/stu/wo_kd/{args.ckpt_name}', weights_only=True, map_location=f'cuda:{args.cuda_id}')) 105 | model_t.eval() 106 | 107 | best_model_state = None 108 | best_val_loss = float('inf') 109 | # Cloud-only Learning 110 | print('================= CML-KD =================') 111 | for epoch in range(args.epochs): 112 | start_time1 = time.time() 113 | # train 114 | model_s.train() 115 | 116 | tra_LOSS_s, tra_LOSS_task, tra_LOSS_kl = 0, 0 ,0 117 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 118 | total=len(train_loader), dynamic_ncols=True, 119 | disable=False, file=sys.stdout): 120 | # print('Train Iter {}'.format(i)) 121 | data, data2, label = data.to(device), data2.to(device), label.to(device) 122 | if args.mode == 'm1': 123 | data_t, data_s = data2, data 124 | else: 125 | data_t, data_s = data, data2 126 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 127 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 128 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 129 | loss_kl = mlld_loss(outputs_s, outputs_t) 130 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 131 | optimizer_s.zero_grad() 132 | loss.backward() 133 | optimizer_s.step() 134 | tra_LOSS_s += loss.item() 135 | tra_LOSS_task += loss_s.item() 136 | tra_LOSS_kl += loss_kl.item() 137 | 138 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 139 | tra_LOSS_task_avg = tra_LOSS_task / (i + 1) 140 | tra_LOSS_kl_avg = tra_LOSS_kl / (i + 1) 141 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | ' 142 | f'loss_task = {tra_LOSS_task_avg:.4f} | loss_mlld = {tra_LOSS_kl_avg:.4f}') 143 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 144 | 145 | # validation 146 | model_s.eval() 147 | metric.reset() 148 | L_s_val = 0 149 | acc_c = 0 150 | Loss_s, Loss_kl = 0, 0 151 | gt_list, pred_list = [], [] 152 | with torch.no_grad(): 153 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 154 | total=len(valid_loader), dynamic_ncols=True, disable=False, 155 | file=sys.stdout): 156 | # print('Val Iter {}'.format(i)) 157 | data, data2, label = data.to(device), data2.to(device), label.to(device) 158 | if args.mode == 'm1': 159 | data_t, data_s = data2, data 160 | else: 161 | data_t, data_s = data, data2 162 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 163 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 164 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 165 | loss_kl = mlld_loss(outputs_s, outputs_t) 166 | loss = loss_s 167 | L_s_val += loss.item() 168 | Loss_kl += loss_kl.item() 169 | Loss_s += loss_s.item() 170 | metric.update(outputs_s, label) 171 | L_s_val = L_s_val / (i + 1) 172 | Loss_kl_avg = Loss_kl / (i + 1) 173 | Loss_s_avg = Loss_s / (i + 1) 174 | # res = metric.compute() 175 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_s_val:.4f} | " 176 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 177 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 178 | # writer.add_scalar('valid/Acc', float(res['Accuracy']), epoch) 179 | 180 | # For NYU-Depth-V2 181 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s_val:.4f} | loss_task = {Loss_s_avg:.4f} | loss_mlld = {Loss_kl_avg:.4f}") 182 | writer.add_scalar('valid/Loss', L_s_val, epoch) 183 | 184 | # 保存验证集上表现最好的模型 185 | if L_s_val < best_val_loss: 186 | best_val_loss = L_s_val 187 | best_model_state = copy.deepcopy(model_s.state_dict()) 188 | best_epoch = epoch 189 | 190 | # test 191 | if args.test_phase: 192 | model_s.eval() 193 | metric.reset() 194 | L_t = 0 195 | acc_c = 0 196 | Loss_s, Loss_kl = 0, 0 197 | gt_list, pred_list = [], [] 198 | with torch.no_grad(): 199 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 200 | total=len(test_loader), dynamic_ncols=True, disable=False, 201 | file=sys.stdout): 202 | data, data2, label = data.to(device), data2.to(device), label.to(device) 203 | if args.mode == 'm1': 204 | data_t, data_s = data2, data 205 | else: 206 | data_t, data_s = data, data2 207 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 208 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 209 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 210 | loss_kl = mlld_loss(outputs_s, outputs_t) 211 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 212 | L_t = L_t + loss.item() 213 | Loss_kl += loss_kl.item() 214 | Loss_s += loss_s.item() 215 | metric.update(outputs_s, label) 216 | L_t = L_t / (i + 1) 217 | Loss_kl_avg = Loss_kl / (i + 1) 218 | Loss_s_avg = Loss_s / (i + 1) 219 | # res = metric.compute() 220 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_t:.4f} | " 221 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 222 | # writer.add_scalar('test/Loss', L_t, epoch) 223 | # writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 224 | 225 | # For NYU-Depth-V2 226 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t:.4f}") 227 | writer.add_scalar('test/Loss', L_t, epoch) 228 | 229 | # if (epoch + 1) % 10 == 0: 230 | # print('\n===============Metrics==================') 231 | # for e in res.keys(): 232 | # print(e) 233 | # print(res[e]) 234 | # print('----------------------------') 235 | # print('=======================================\n') 236 | 237 | # scheduler_t.step() 238 | args.alpha *= 0.5 if (epoch + 1) % 30 == 0 else 1.0 239 | start_time2 = time.time() 240 | time_cost = start_time2 - start_time1 241 | if time_cost > 100: 242 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 243 | else: 244 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 245 | 246 | writer.close() 247 | 248 | if args.save_model: 249 | if not os.path.exists('./checkpoints/stu/mlld'): 250 | os.makedirs('./checkpoints/stu/mlld') 251 | names = args.ckpt_name.split('_') 252 | Tmodel_mode = names[3] 253 | torch.save(best_model_state, 254 | f'./checkpoints/stu/mlld/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 255 | -------------------------------------------------------------------------------- /main-RKD-UU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import os 4 | from utils import get_data, get_dataset, get_Smodules, seed_all, get_Tmodules, hooks_builder 5 | import copy 6 | import time 7 | from argparse import ArgumentParser 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm 10 | from Dataset import MultiModalX 11 | from KD_methods.RKD import get_RKDmodules, penultimate_feature_extractor 12 | import sys 13 | 14 | 15 | if __name__ == '__main__': 16 | ''' 17 | Args Setting for CML. 18 | ''' 19 | parser = ArgumentParser(description='CML-RKD') 20 | parser.add_argument('--database', type=str, default='AV-MNIST', 21 | help="database name must be one of " 22 | "['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 23 | parser.add_argument('--Tmodel', type=str, default='ThreeLayerCNN-A', 24 | help='Teacher model name') 25 | parser.add_argument('--Smodel', type=str, default='LeNet5', 26 | help='Student model name') 27 | parser.add_argument('--mode', type=str, default='m1', 28 | help='modality mode: m1 or m2') 29 | 30 | parser.add_argument('--ckpt_name', type=str, 31 | default='', 32 | help='The name of the weight to be loaded in ./checkpoints/stu') 33 | parser.add_argument('--seed', type=int, default=0, 34 | help='Random seed') 35 | parser.add_argument('--num_workers', type=int, default=0, 36 | help='num_workers for DataLoader') 37 | parser.add_argument('--alpha', type=float, default=0.5, 38 | help='weight for loss') 39 | parser.add_argument('--batch_size', type=int, default=32, 40 | help='batch size for training') 41 | parser.add_argument('--lr', type=float, default=0.0001, 42 | help='learning rate for training') 43 | parser.add_argument('--record', type=bool, default=True, 44 | help='flag whether to record the learning log') 45 | parser.add_argument('--cuda_id', type=int, default=0, 46 | help='cuda id') 47 | parser.add_argument('--epochs', type=int, default=100, 48 | help='epochs for training') 49 | parser.add_argument('--save_model', type=bool, default=True, 50 | help='flag whether to save best model') 51 | parser.add_argument('--test_phase', type=bool, default=False, 52 | help='flag whether to conduct the test phase') 53 | parser.add_argument('--commit', type=str, default='RKD-baseline', 54 | help='Commit for logs') 55 | args = parser.parse_args() 56 | 57 | seed_all(args.seed) 58 | 59 | log_dir = f'./logs/rkd' 60 | if not os.path.exists(log_dir): 61 | os.makedirs(log_dir) 62 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 63 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 64 | 65 | data = get_data(args.database) 66 | data_train = get_dataset(args.database, data, 'train', args.seed) 67 | data_val = get_dataset(args.database, data, 'val', args.seed) 68 | data_test = get_dataset(args.database, data, 'test', args.seed) 69 | 70 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 71 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 72 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 73 | 74 | train_loader = DataLoader( 75 | train_dataset, 76 | batch_size=args.batch_size, 77 | pin_memory=True, 78 | num_workers=args.num_workers, 79 | shuffle=True 80 | ) 81 | 82 | valid_loader = DataLoader( 83 | valid_dataset, 84 | pin_memory=True, 85 | batch_size=args.batch_size 86 | ) 87 | 88 | test_loader = DataLoader( 89 | test_dataset, 90 | pin_memory=True, 91 | batch_size=args.batch_size 92 | ) 93 | 94 | # ===========GPU Setting==================== 95 | device = torch.device(f"cuda:{args.cuda_id}") 96 | # ==========Initialization=========== 97 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 98 | model_s, _, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 99 | feat_names, criterion_rkd, optimizer_RKD = get_RKDmodules(args, model_s, n_data=len(train_dataset)) 100 | 101 | model_t = model_t.to(device) 102 | model_s = model_s.to(device) 103 | model_t.load_state_dict(torch.load(f'checkpoints/stu/wo_kd/{args.ckpt_name}', map_location=f'cuda:{args.cuda_id}', weights_only=True)) 104 | model_t.eval() 105 | 106 | best_model_state = None 107 | best_val_loss = float('inf') 108 | # Cloud-only Learning 109 | print('================= CMKD-RKD =================') 110 | for epoch in range(args.epochs): 111 | start_time1 = time.time() 112 | # train 113 | model_s.train() 114 | criterion_rkd.embed_t.train() 115 | criterion_rkd.embed_s.train() 116 | tra_LOSS_s, tra_LOSS_task, tra_LOSS_rkd = 0, 0, 0 117 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 118 | total=len(train_loader), dynamic_ncols=True, 119 | disable=False, file=sys.stdout): 120 | # print('Train Iter {}'.format(i)) 121 | data, data2, label = data.to(device), data2.to(device), label.to(device) 122 | if args.mode == 'm1': 123 | data_t, data_s = data2, data 124 | else: 125 | data_t, data_s = data, data2 126 | 127 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 128 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 129 | 130 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 131 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 132 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 133 | f_s = penultimate_feature_extractor(feat_names[1], features_s, args) 134 | f_t = penultimate_feature_extractor(feat_names[0], features_t, args) 135 | f_t = f_t.detach() 136 | loss_rkd = criterion_rkd(f_s, f_t) 137 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_rkd 138 | optimizer_RKD.zero_grad() 139 | loss.backward() 140 | optimizer_RKD.step() 141 | tra_LOSS_s += loss.item() 142 | tra_LOSS_task += loss_s.item() 143 | tra_LOSS_rkd += loss_rkd.item() 144 | 145 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 146 | tra_LOSS_task_avg = tra_LOSS_task / (i + 1) 147 | tra_LOSS_rkd_avg = tra_LOSS_rkd / (i + 1) 148 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | ' 149 | f'task_loss = {tra_LOSS_task_avg:.4f} | rkd_loss = {tra_LOSS_rkd_avg:.4f}') 150 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 151 | 152 | # validation 153 | model_s.eval() 154 | criterion_rkd.embed_t.eval() 155 | criterion_rkd.embed_s.eval() 156 | metric.reset() 157 | L_s_val = 0 158 | acc_c = 0 159 | gt_list, pred_list = [], [] 160 | with torch.no_grad(): 161 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 162 | total=len(valid_loader), dynamic_ncols=True, disable=False, 163 | file=sys.stdout): 164 | # print('Val Iter {}'.format(i)) 165 | data, data2, label = data.to(device), data2.to(device), label.to(device) 166 | if args.mode == 'm1': 167 | data_t, data_s = data2, data 168 | else: 169 | data_t, data_s = data, data2 170 | 171 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 172 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 173 | loss = loss_s 174 | L_s_val += loss.item() 175 | metric.update(outputs_s, label) 176 | L_s_val = L_s_val / (i + 1) 177 | # res = metric.compute() 178 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_CRD = {L_s_val:.4f}, OA_s = {float(res['Accuracy'])}") 179 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 180 | # writer.add_scalar('valid/Acc', float(res['Accuracy']), epoch) 181 | 182 | # For NYU-Depth-V2 183 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_s = {L_s_val:.4f}") 184 | writer.add_scalar('valid/Loss', L_s_val, epoch) 185 | 186 | if L_s_val < best_val_loss: 187 | best_val_loss = L_s_val 188 | best_model_state = copy.deepcopy(model_s.state_dict()) 189 | best_epoch = epoch 190 | 191 | # test 192 | if args.test_phase: 193 | model_s.eval() 194 | criterion_rkd.embed_t.eval() 195 | criterion_rkd.embed_s.eval() 196 | metric.reset() 197 | L_t = 0 198 | acc_c = 0 199 | gt_list, pred_list = [], [] 200 | with torch.no_grad(): 201 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 202 | total=len(test_loader), dynamic_ncols=True, disable=False, 203 | file=sys.stdout): 204 | data, data2, label = data.to(device), data2.to(device), label.to(device) 205 | if args.mode == 'm1': 206 | data_t, data_s = data2, data 207 | else: 208 | data_t, data_s = data, data2 209 | 210 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 211 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 212 | 213 | outputs_t = model_t(data_t) if preprocessing_t is None else preprocessing_t(model_t, data_t) 214 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 215 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 216 | f_s = penultimate_feature_extractor(feat_names[1], features_s, args) 217 | f_t = penultimate_feature_extractor(feat_names[0], features_t, args) 218 | f_t = f_t.detach() 219 | loss_rkd = criterion_rkd(f_s, f_t) 220 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_rkd 221 | L_t = L_t + loss.item() 222 | metric.update(outputs_s, label) 223 | L_t = L_t / (i + 1) 224 | res = metric.compute() 225 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_RKD = {L_t:.4f}, OA_s = {float(res['Accuracy'])}") 226 | writer.add_scalar('test/Loss', L_t, epoch) 227 | writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 228 | 229 | # # For NYU-Depth-V2 230 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_CRD = {L_t:.4f}") 231 | # writer.add_scalar('test/Loss', L_t, epoch) 232 | 233 | # if (epoch + 1) % 10 == 0: 234 | # print('\n===============Metrics==================') 235 | # for e in res.keys(): 236 | # print(e) 237 | # print(res[e]) 238 | # print('----------------------------') 239 | # print('=======================================\n') 240 | 241 | # scheduler_t.step() 242 | args.alpha *= 0.5 if (epoch + 1) % 30 == 0 else 1.0 243 | start_time2 = time.time() 244 | time_cost = start_time2 - start_time1 245 | if time_cost > 100: 246 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 247 | else: 248 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 249 | 250 | writer.close() 251 | 252 | if args.save_model: 253 | if not os.path.exists('./checkpoints/stu/rkd'): 254 | os.makedirs('./checkpoints/stu/rkd') 255 | names = args.ckpt_name.split('_') 256 | Tmodel_mode = names[3] 257 | torch.save(best_model_state, 258 | f'./checkpoints/stu/rkd/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 259 | -------------------------------------------------------------------------------- /KD_methods/CRD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | eps = 1e-7 7 | 8 | 9 | def get_CRDmodules(args, model_s, n_data): 10 | if args.database == 'AV-MNIST': 11 | if args.Tmodel == 'CNN-I': 12 | t_dim = 298 13 | if args.Smodel == 'LeNet5': 14 | feat_names = ['fc1', 'fc1'] 15 | s_dim = 84 16 | elif args.Smodel == 'ThreeLayerCNN-A': 17 | feat_names = ['fc1', 'conv3'] 18 | s_dim = 128 * 14 * 14 19 | elif args.Tmodel == 'LeNet5': 20 | t_dim = 84 21 | if args.Smodel == 'ThreeLayerCNN-A': 22 | feat_names = ['fc1', 'conv3'] 23 | s_dim = 128 * 14 * 14 24 | elif args.Tmodel == 'ThreeLayerCNN-A': 25 | t_dim = 128 * 14 * 14 26 | if args.Smodel == 'LeNet5': 27 | feat_names = ['conv3', 'fc1'] 28 | s_dim = 84 29 | 30 | elif args.database == 'RAVDESS': 31 | if args.Tmodel == 'DSCNN-I': 32 | t_dim = 160 33 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 34 | feat_names = ['fc2', 'fc2'] 35 | s_dim = 160 36 | elif args.Tmodel in ['AudioBranchNet', 'VisualBranchNet']: 37 | t_dim = 160 38 | if args.Smodel in ['AudioBranchNet', 'VisualBranchNet']: 39 | feat_names = ['fc2', 'fc2'] 40 | s_dim = 160 41 | 42 | elif args.database == 'VGGSound-50k': 43 | if args.Tmodel == 'DSCNN-VGGS-I': 44 | t_dim = 160 45 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 46 | feat_names = ['fc2', 'fc2'] 47 | s_dim = 160 48 | elif args.Tmodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 49 | t_dim = 160 50 | if args.Smodel in ['VisualBranchNet-VGGS', 'AudioBranchNet-VGGS']: 51 | feat_names = ['fc2', 'fc2'] 52 | s_dim = 160 53 | 54 | elif args.database == 'CMMD-V2': 55 | if args.Tmodel == 'MLP-I': 56 | t_dim = 256 57 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 58 | feat_names = ['fc2', 'fc2'] 59 | s_dim = 512 60 | elif args.Tmodel in ['MLP-Vb', 'MLP-Tb']: 61 | t_dim = 512 62 | if args.Smodel in ['MLP-Vb', 'MLP-Tb']: 63 | feat_names = ['fc2', 'fc2'] 64 | s_dim = 512 65 | 66 | else: 67 | raise ValueError(f"Invalid database name {args.database}.") 68 | criterion = CRDLoss(s_dim, t_dim, 128, 4096, args.nce_t, 0.5, n_data, args).cuda(args.cuda_id) 69 | proj_s, proj_t = criterion.embed_s, criterion.embed_t 70 | proj_s.cuda(args.cuda_id) 71 | proj_t.cuda(args.cuda_id) 72 | params = list(model_s.parameters()) + list(proj_s.parameters()) + list(proj_t.parameters()) 73 | optim = torch.optim.Adam(params, lr=args.lr) 74 | return feat_names, criterion, optim 75 | 76 | 77 | class CRDLoss(nn.Module): 78 | """CRD Loss function 79 | includes two symmetric parts: 80 | (a) using teacher as anchor, choose positive and negatives over the student side 81 | (b) using student as anchor, choose positive and negatives over the teacher side 82 | 83 | Args: 84 | opt.s_dim: the dimension of student's feature 85 | opt.t_dim: the dimension of teacher's feature 86 | opt.feat_dim: the dimension of the projection space 87 | opt.nce_k: number of negatives paired with each positive 88 | opt.nce_t: the temperature 89 | opt.nce_m: the momentum for updating the memory buffer 90 | opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim 91 | """ 92 | def __init__(self, s_dim, t_dim, feat_dim, nce_k, nce_t, nce_m, n_data, args): 93 | super(CRDLoss, self).__init__() 94 | self.embed_s = Embed(s_dim, feat_dim) 95 | self.embed_t = Embed(t_dim, feat_dim) 96 | self.contrast = ContrastMemory(feat_dim, n_data, nce_k, nce_t, nce_m, args.cuda_id) 97 | self.criterion_t = ContrastLoss(n_data) 98 | self.criterion_s = ContrastLoss(n_data) 99 | 100 | def forward(self, f_s, f_t, idx, contrast_idx=None): 101 | """ 102 | Args: 103 | f_s: the feature of student network, size [batch_size, s_dim] 104 | f_t: the feature of teacher network, size [batch_size, t_dim] 105 | idx: the indices of these positive samples in the dataset, size [batch_size] 106 | contrast_idx: the indices of negative samples, size [batch_size, nce_k] 107 | 108 | Returns: 109 | The contrastive loss 110 | """ 111 | 112 | f_s = self.embed_s(f_s) 113 | f_t = self.embed_t(f_t) 114 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) 115 | s_loss = self.criterion_s(out_s) 116 | t_loss = self.criterion_t(out_t) 117 | loss = s_loss + t_loss 118 | return loss 119 | 120 | 121 | def penultimate_feature_extractor(feat_name, features, args): 122 | if args.database == 'AV-MNIST': 123 | penultimate_feature = features[feat_name] 124 | if penultimate_feature.dim() == 4: 125 | penultimate_feature = F.max_pool2d(penultimate_feature, 2).view(-1, 128 * 14 * 14) 126 | elif args.database == 'VGGSound-50k': 127 | if feat_name == 'psp': 128 | penultimate_feature = features[feat_name][0].mean(dim=1) 129 | else: 130 | penultimate_feature = features[feat_name] 131 | else: 132 | penultimate_feature = features[feat_name] 133 | return penultimate_feature 134 | 135 | 136 | class ContrastLoss(nn.Module): 137 | """ 138 | contrastive loss, corresponding to Eq (18) 139 | """ 140 | def __init__(self, n_data): 141 | super(ContrastLoss, self).__init__() 142 | self.n_data = n_data 143 | 144 | def forward(self, x): 145 | bsz = x.shape[0] 146 | m = x.size(1) - 1 147 | 148 | # noise distribution 149 | Pn = 1 / float(self.n_data) 150 | 151 | # loss for positive pair 152 | P_pos = x.select(1, 0) 153 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 154 | 155 | # loss for K negative pair 156 | P_neg = x.narrow(1, 1, m) 157 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 158 | 159 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 160 | 161 | return loss 162 | 163 | 164 | class Embed(nn.Module): 165 | """Embedding module""" 166 | def __init__(self, dim_in=1024, dim_out=128): 167 | super(Embed, self).__init__() 168 | self.linear = nn.Linear(dim_in, dim_out) 169 | self.l2norm = Normalize(2) 170 | 171 | def forward(self, x): 172 | x = x.view(x.shape[0], -1) 173 | x = self.linear(x) 174 | x = self.l2norm(x) 175 | return x 176 | 177 | 178 | class Normalize(nn.Module): 179 | """normalization layer""" 180 | def __init__(self, power=2): 181 | super(Normalize, self).__init__() 182 | self.power = power 183 | 184 | def forward(self, x): 185 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 186 | out = x.div(norm) 187 | return out 188 | 189 | 190 | class ContrastMemory(nn.Module): 191 | """ 192 | memory buffer that supplies large amount of negative samples. 193 | """ 194 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, cuda_id=0): 195 | super(ContrastMemory, self).__init__() 196 | self.nLem = outputSize 197 | self.unigrams = torch.ones(self.nLem) 198 | self.multinomial = AliasMethod(self.unigrams, cuda_id) 199 | self.multinomial.cuda() 200 | self.K = K 201 | 202 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 203 | stdv = 1. / math.sqrt(inputSize / 3) 204 | self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 205 | self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 206 | 207 | def forward(self, v1, v2, y, idx=None): 208 | K = int(self.params[0].item()) 209 | T = self.params[1].item() 210 | Z_v1 = self.params[2].item() 211 | Z_v2 = self.params[3].item() 212 | 213 | momentum = self.params[4].item() 214 | batchSize = v1.size(0) 215 | outputSize = self.memory_v1.size(0) 216 | inputSize = self.memory_v1.size(1) 217 | 218 | # original score computation 219 | if idx is None: 220 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 221 | idx.select(1, 0).copy_(y.data) 222 | # sample 223 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach() 224 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize) 225 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1)) 226 | out_v2 = torch.exp(torch.div(out_v2, T)) 227 | # sample 228 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach() 229 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize) 230 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1)) 231 | out_v1 = torch.exp(torch.div(out_v1, T)) 232 | 233 | # set Z if haven't been set yet 234 | if Z_v1 < 0: 235 | self.params[2] = out_v1.mean() * outputSize 236 | Z_v1 = self.params[2].clone().detach().item() 237 | # print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1)) 238 | if Z_v2 < 0: 239 | self.params[3] = out_v2.mean() * outputSize 240 | Z_v2 = self.params[3].clone().detach().item() 241 | # print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2)) 242 | 243 | # compute out_v1, out_v2 244 | out_v1 = torch.div(out_v1, Z_v1).contiguous() 245 | out_v2 = torch.div(out_v2, Z_v2).contiguous() 246 | 247 | # update memory 248 | with torch.no_grad(): 249 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1)) 250 | l_pos.mul_(momentum) 251 | l_pos.add_(torch.mul(v1, 1 - momentum)) 252 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 253 | updated_v1 = l_pos.div(l_norm) 254 | self.memory_v1.index_copy_(0, y, updated_v1) 255 | 256 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1)) 257 | ab_pos.mul_(momentum) 258 | ab_pos.add_(torch.mul(v2, 1 - momentum)) 259 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 260 | updated_v2 = ab_pos.div(ab_norm) 261 | self.memory_v2.index_copy_(0, y, updated_v2) 262 | 263 | return out_v1, out_v2 264 | 265 | 266 | class AliasMethod(object): 267 | """ 268 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 269 | """ 270 | def __init__(self, probs, cuda_id): 271 | 272 | if probs.sum() > 1: 273 | probs.div_(probs.sum()) 274 | K = len(probs) 275 | self.prob = torch.zeros(K) 276 | self.alias = torch.LongTensor([0]*K) 277 | self.cuda_id = cuda_id 278 | 279 | # Sort the data into the outcomes with probabilities 280 | # that are larger and smaller than 1/K. 281 | smaller = [] 282 | larger = [] 283 | for kk, prob in enumerate(probs): 284 | self.prob[kk] = K*prob 285 | if self.prob[kk] < 1.0: 286 | smaller.append(kk) 287 | else: 288 | larger.append(kk) 289 | 290 | # Loop though and create little binary mixtures that 291 | # appropriately allocate the larger outcomes over the 292 | # overall uniform mixture. 293 | while len(smaller) > 0 and len(larger) > 0: 294 | small = smaller.pop() 295 | large = larger.pop() 296 | 297 | self.alias[small] = large 298 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 299 | 300 | if self.prob[large] < 1.0: 301 | smaller.append(large) 302 | else: 303 | larger.append(large) 304 | 305 | for last_one in smaller+larger: 306 | self.prob[last_one] = 1 307 | 308 | def cuda(self): 309 | self.prob = self.prob.cuda(self.cuda_id) 310 | self.alias = self.alias.cuda(self.cuda_id) 311 | 312 | def draw(self, N): 313 | """ Draw N samples from multinomial """ 314 | K = self.alias.size(0) 315 | 316 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 317 | prob = self.prob.index_select(0, kk) 318 | alias = self.alias.index_select(0, kk) 319 | # b is whether a random number is greater than q 320 | b = torch.bernoulli(prob) 321 | oq = kk.mul(b.long()) 322 | oj = alias.mul((1-b).long()) 323 | 324 | return oq + oj -------------------------------------------------------------------------------- /main-RKD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from torch.optim import Adam 5 | from utils import * 6 | import copy 7 | import numpy as np 8 | import torch.optim as optim 9 | import time 10 | from argparse import ArgumentParser 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | from Dataset import MultiModalX 14 | from KD_methods.RKD import get_RKDmodules, penultimate_feature_extractor 15 | import sys 16 | 17 | if __name__ == '__main__': 18 | ''' 19 | Args Setting for CML. 20 | ''' 21 | parser = ArgumentParser(description='CML-RKD') 22 | parser.add_argument('--database', type=str, default='AV-MNIST', 23 | help="database name must be one of " 24 | "['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 25 | parser.add_argument('--Tmodel', type=str, default='CNN-I', 26 | help='Teacher model name') 27 | parser.add_argument('--Smodel', type=str, default='LeNet5', 28 | help='Student model name') 29 | parser.add_argument('--mode', type=str, default='m1', 30 | help='modality mode: m1 or m2') 31 | 32 | parser.add_argument('--ckpt_name', type=str, 33 | default='AV-MNIST_CNN-I_seed0_ORG_ep5-5.pth', 34 | help='The name of the weight to be loaded in ./checkpoints/stu') 35 | parser.add_argument('--seed', type=int, default=0, 36 | help='Random seed') 37 | parser.add_argument('--num_workers', type=int, default=0, 38 | help='num_workers for DataLoader') 39 | parser.add_argument('--alpha', type=float, default=0.5, 40 | help='weight for loss') 41 | parser.add_argument('--batch_size', type=int, default=512, 42 | help='batch size for training') 43 | parser.add_argument('--lr', type=float, default=0.0001, 44 | help='learning rate for training') 45 | parser.add_argument('--record', type=bool, default=True, 46 | help='flag whether to record the learning log') 47 | parser.add_argument('--cuda_id', type=int, default=0, 48 | help='cuda id') 49 | parser.add_argument('--epochs', type=int, default=100, 50 | help='epochs for training') 51 | parser.add_argument('--save_model', type=bool, default=True, 52 | help='flag whether to save best model') 53 | parser.add_argument('--test_phase', type=bool, default=False, 54 | help='flag whether to conduct the test phase') 55 | parser.add_argument('--commit', type=str, default='RKD-baseline', 56 | help='Commit for logs') 57 | args = parser.parse_args() 58 | 59 | seed_all(args.seed) 60 | 61 | # 保存log 62 | log_dir = f'./logs/rkd' 63 | if not os.path.exists(log_dir): 64 | os.makedirs(log_dir) 65 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 66 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 67 | 68 | data = get_data(args.database) 69 | data_train = get_dataset(args.database, data, 'train', args.seed) 70 | data_val = get_dataset(args.database, data, 'val', args.seed) 71 | data_test = get_dataset(args.database, data, 'test', args.seed) 72 | 73 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 74 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 75 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 76 | 77 | train_loader = DataLoader( 78 | train_dataset, 79 | batch_size=args.batch_size, 80 | pin_memory=True, 81 | num_workers=args.num_workers, 82 | shuffle=True 83 | ) 84 | 85 | valid_loader = DataLoader( 86 | valid_dataset, 87 | pin_memory=True, 88 | batch_size=args.batch_size 89 | ) 90 | 91 | test_loader = DataLoader( 92 | test_dataset, 93 | pin_memory=True, 94 | batch_size=args.batch_size 95 | ) 96 | 97 | # ===========GPU Setting==================== 98 | device = torch.device(f"cuda:{args.cuda_id}") 99 | # ==========Initialization=========== 100 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 101 | model_s, _, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 102 | feat_names, criterion_rkd, optimizer_RKD = get_RKDmodules(args, model_s, n_data=len(train_dataset)) 103 | 104 | model_t = model_t.to(device) 105 | model_s = model_s.to(device) 106 | model_t.load_state_dict(torch.load(f'checkpoints/tea/{args.ckpt_name}', map_location=f'cuda:{args.cuda_id}', weights_only=True)) 107 | model_t.eval() 108 | 109 | best_model_state = None 110 | best_val_loss = float('inf') 111 | # Cloud-only Learning 112 | print('================= CMKD-RKD =================') 113 | for epoch in range(args.epochs): 114 | start_time1 = time.time() 115 | # train 116 | model_s.train() 117 | criterion_rkd.embed_t.train() 118 | criterion_rkd.embed_s.train() 119 | tra_LOSS_s, tra_LOSS_task, tra_LOSS_rkd = 0, 0, 0 120 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 121 | total=len(train_loader), dynamic_ncols=True, 122 | disable=False, file=sys.stdout): 123 | # print('Train Iter {}'.format(i)) 124 | data, data2, label = data.to(device), data2.to(device), label.to(device) 125 | data_s = data if args.mode == 'm1' else data2 126 | 127 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 128 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 129 | 130 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 131 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 132 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 133 | f_s = penultimate_feature_extractor(feat_names[1], features_s, args) 134 | f_t = penultimate_feature_extractor(feat_names[0], features_t, args) 135 | f_t = f_t.detach() 136 | loss_rkd = criterion_rkd(f_s, f_t) 137 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_rkd 138 | optimizer_RKD.zero_grad() 139 | loss.backward() 140 | optimizer_RKD.step() 141 | tra_LOSS_s += loss.item() 142 | tra_LOSS_task += loss_s.item() 143 | tra_LOSS_rkd += loss_rkd.item() 144 | 145 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 146 | tra_LOSS_task_avg = tra_LOSS_task / (i + 1) 147 | tra_LOSS_rkd_avg = tra_LOSS_rkd / (i + 1) 148 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | ' 149 | f'task_loss = {tra_LOSS_task_avg:.4f} | rkd_loss = {tra_LOSS_rkd_avg:.4f}') 150 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 151 | 152 | # validation 153 | model_s.eval() 154 | criterion_rkd.embed_t.eval() 155 | criterion_rkd.embed_s.eval() 156 | metric.reset() 157 | L_s_val = 0 158 | acc_c = 0 159 | gt_list, pred_list = [], [] 160 | with torch.no_grad(): 161 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 162 | total=len(valid_loader), dynamic_ncols=True, disable=False, 163 | file=sys.stdout): 164 | # print('Val Iter {}'.format(i)) 165 | data, data2, label = data.to(device), data2.to(device), label.to(device) 166 | data_s = data if args.mode == 'm1' else data2 167 | 168 | # hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 169 | # hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 170 | 171 | # outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 172 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 173 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 174 | # f_s = penultimate_feature_extractor(feat_names[1], features_s, args) 175 | # f_t = penultimate_feature_extractor(feat_names[0], features_t, args) 176 | # f_t = f_t.detach() 177 | # loss_crd = criterion_crd(f_s, f_t, idx, contrast_idx) 178 | loss = loss_s 179 | L_s_val += loss.item() 180 | metric.update(outputs_s, label) 181 | L_s_val = L_s_val / (i + 1) 182 | # res = metric.compute() 183 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_CRD = {L_s_val:.4f}, OA_s = {float(res['Accuracy'])}") 184 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 185 | # writer.add_scalar('valid/Acc', float(res['Accuracy']), epoch) 186 | 187 | # For NYU-Depth-V2 188 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_s = {L_s_val:.4f}") 189 | writer.add_scalar('valid/Loss', L_s_val, epoch) 190 | 191 | # 保存验证集上表现最好的模型 192 | if L_s_val < best_val_loss: 193 | best_val_loss = L_s_val 194 | best_model_state = copy.deepcopy(model_s.state_dict()) 195 | best_epoch = epoch 196 | 197 | # test 198 | if args.test_phase: 199 | model_s.eval() 200 | criterion_rkd.embed_t.eval() 201 | criterion_rkd.embed_s.eval() 202 | metric.reset() 203 | L_t = 0 204 | acc_c = 0 205 | gt_list, pred_list = [], [] 206 | with torch.no_grad(): 207 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 208 | total=len(test_loader), dynamic_ncols=True, disable=False, 209 | file=sys.stdout): 210 | data, data2, label = data.to(device), data2.to(device), label.to(device) 211 | data_s = data if args.mode == 'm1' else data2 212 | 213 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 214 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 215 | 216 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 217 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 218 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 219 | f_s = penultimate_feature_extractor(feat_names[1], features_s, args) 220 | f_t = penultimate_feature_extractor(feat_names[0], features_t, args) 221 | f_t = f_t.detach() 222 | loss_rkd = criterion_rkd(f_s, f_t) 223 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_rkd 224 | L_t = L_t + loss.item() 225 | metric.update(outputs_s, label) 226 | L_t = L_t / (i + 1) 227 | res = metric.compute() 228 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_RKD = {L_t:.4f}, OA_s = {float(res['Accuracy'])}") 229 | writer.add_scalar('test/Loss', L_t, epoch) 230 | writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 231 | 232 | # # For NYU-Depth-V2 233 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_CRD = {L_t:.4f}") 234 | # writer.add_scalar('test/Loss', L_t, epoch) 235 | 236 | # if (epoch + 1) % 10 == 0: 237 | # print('\n===============Metrics==================') 238 | # for e in res.keys(): 239 | # print(e) 240 | # print(res[e]) 241 | # print('----------------------------') 242 | # print('=======================================\n') 243 | 244 | # scheduler_t.step() # 学习率衰减(当训练SAFN模型时,需加入监听指标作为参数) 245 | args.alpha *= 0.5 if (epoch + 1) % 30 == 0 else 1.0 246 | start_time2 = time.time() 247 | time_cost = start_time2 - start_time1 248 | if time_cost > 100: 249 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 250 | else: 251 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 252 | 253 | writer.close() 254 | 255 | if args.save_model: 256 | if not os.path.exists('./checkpoints/stu/rkd'): 257 | os.makedirs('./checkpoints/stu/rkd') 258 | names = args.ckpt_name.split('_') 259 | Tmodel_mode = names[3] 260 | torch.save(best_model_state, 261 | f'./checkpoints/stu/rkd/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 262 | -------------------------------------------------------------------------------- /main-OFA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch import nn 4 | from utils import seed_all, hooks_builder, hooks_remover, get_data, get_dataset, get_Smodules, get_Tmodules 5 | import copy 6 | import os 7 | import time 8 | from argparse import ArgumentParser 9 | from tensorboardX import SummaryWriter 10 | from tqdm import tqdm 11 | from Dataset import MultiModalX 12 | from KD_methods.OFA import get_OFAmodules, projectors_train, projectors_eval 13 | import sys 14 | 15 | 16 | if __name__ == '__main__': 17 | ''' 18 | Args Setting for CML. 19 | ''' 20 | parser = ArgumentParser(description='CML-OFA') 21 | parser.add_argument('--database', type=str, default='AV-MNIST', 22 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 23 | parser.add_argument('--Tmodel', type=str, default='CNN-I', 24 | help='Teacher model name') 25 | parser.add_argument('--Smodel', type=str, default='LeNet5', 26 | help='Student model name') 27 | parser.add_argument('--mode', type=str, default='m1', 28 | help='modality mode: m1 or m2') 29 | 30 | parser.add_argument('--ckpt_name', type=str, 31 | default='AV-MNIST_CNN-I_seed0_ORG_ep5-5.pth', 32 | help='The name of the weight to be loaded in ./checkpoints/stu') 33 | parser.add_argument('--seed', type=int, default=0, 34 | help='Random seed') 35 | parser.add_argument('--num_workers', type=int, default=0, 36 | help='num_workers for DataLoader') 37 | parser.add_argument('--alpha', type=float, default=0.5, 38 | help='weight for loss') 39 | parser.add_argument('--ofa_eps', type=float, default=1.1, 40 | help='Modulating parameter') 41 | parser.add_argument('--ofa_temperature', type=float, default=1., 42 | help='The temperature fot OFA loss') 43 | parser.add_argument('--batch_size', type=int, default=32, 44 | help='batch size for training') 45 | parser.add_argument('--lr', type=float, default=0.0001, 46 | help='learning rate for training') 47 | parser.add_argument('--record', type=bool, default=True, 48 | help='flag whether to record the learning log') 49 | parser.add_argument('--cuda_id', type=int, default=0, 50 | help='cuda id') 51 | parser.add_argument('--epochs', type=int, default=100, 52 | help='epochs for training') 53 | parser.add_argument('--save_model', type=bool, default=True, 54 | help='flag whether to save best model') 55 | parser.add_argument('--test_phase', type=bool, default=False, 56 | help='flag whether to conduct the test phase') 57 | parser.add_argument('--commit', type=str, default='CMKD-OFA', 58 | help='Commit for logs') 59 | args = parser.parse_args() 60 | 61 | seed_all(args.seed) 62 | 63 | 64 | log_dir = f'./logs/ofa' 65 | if not os.path.exists(log_dir): 66 | os.makedirs(log_dir) 67 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 68 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 69 | 70 | data = get_data(args.database) 71 | data_train = get_dataset(args.database, data, 'train', args.seed) 72 | data_val = get_dataset(args.database, data, 'val', args.seed) 73 | data_test = get_dataset(args.database, data, 'test', args.seed) 74 | 75 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 76 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 77 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 78 | 79 | train_loader = DataLoader( 80 | train_dataset, 81 | batch_size=args.batch_size, 82 | pin_memory=True, 83 | num_workers=args.num_workers, 84 | shuffle=True 85 | ) 86 | 87 | valid_loader = DataLoader( 88 | valid_dataset, 89 | pin_memory=True, 90 | batch_size=args.batch_size 91 | ) 92 | 93 | test_loader = DataLoader( 94 | test_dataset, 95 | pin_memory=True, 96 | batch_size=args.batch_size 97 | ) 98 | 99 | # ===========GPU Setting==================== 100 | device = torch.device(f"cuda:{args.cuda_id}") 101 | # ==========Initialization=========== 102 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 103 | model_s, _, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 104 | projectors, criterion_ofa, optimizer_ofa = get_OFAmodules(args, model_s) 105 | 106 | model_t = model_t.to(device) 107 | model_s = model_s.to(device) 108 | projectors = [projector.to(device) for projector in projectors] 109 | model_t.load_state_dict(torch.load(f'checkpoints/tea/{args.ckpt_name}', map_location=f'cuda:{args.cuda_id}', weights_only=True)) 110 | model_t.eval() 111 | 112 | best_model_state = None 113 | best_val_loss = float('inf') 114 | # Cloud-only Learning 115 | print('================= CMKD-OFA =================') 116 | for epoch in range(args.epochs): 117 | start_time1 = time.time() 118 | # train 119 | model_s.train() 120 | projectors_train(projectors) 121 | if args.Tmodel == 'CEN': 122 | for module in model_t.modules(): 123 | if isinstance(module, nn.BatchNorm2d): 124 | module.eval() 125 | if args.Smodel in ['CEN_RGB-branch', 'CEN_D-branch']: 126 | for module in model_s.modules(): 127 | if isinstance(module, nn.BatchNorm2d): 128 | module.eval() 129 | tra_LOSS_s, tra_LOSS_ofa = 0, 0 130 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 131 | total=len(train_loader), dynamic_ncols=True, 132 | disable=False, file=sys.stdout): 133 | # print('Train Iter {}'.format(i)) 134 | data, data2, label = data.to(device), data2.to(device), label.to(device) 135 | data_s = data if args.mode == 'm1' else data2 136 | 137 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 138 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 139 | 140 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 141 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 142 | 143 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 144 | loss_ofa = criterion_ofa(features_t, features_s, label) 145 | 146 | hooks_remover(hooks_t) 147 | hooks_remover(hooks_s) 148 | 149 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_ofa 150 | optimizer_ofa.zero_grad() 151 | loss.backward() 152 | optimizer_ofa.step() 153 | tra_LOSS_s += loss.item() 154 | tra_LOSS_ofa += loss_ofa.item() 155 | 156 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 157 | tra_LOSS_ofa_avg = tra_LOSS_ofa / (i + 1) 158 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | loss_ofa = {tra_LOSS_ofa_avg:.4f}') 159 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 160 | 161 | # validation 162 | model_s.eval() 163 | projectors_eval(projectors) 164 | metric.reset() 165 | L_s_val = 0 166 | val_LOSS_ofa = 0 167 | acc_c = 0 168 | gt_list, pred_list = [], [] 169 | with torch.no_grad(): 170 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 171 | total=len(valid_loader), dynamic_ncols=True, disable=False, 172 | file=sys.stdout): 173 | # print('Val Iter {}'.format(i)) 174 | data, data2, label = data.to(device), data2.to(device), label.to(device) 175 | data_s = data if args.mode == 'm1' else data2 176 | 177 | hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 178 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 179 | 180 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 181 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 182 | 183 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 184 | loss_ofa = criterion_ofa(features_t, features_s, label) 185 | 186 | hooks_remover(hooks_t) 187 | hooks_remover(hooks_s) 188 | 189 | loss = loss_s 190 | L_s_val += loss.item() 191 | val_LOSS_ofa += loss_ofa.item() 192 | metric.update(outputs_s, label) 193 | L_s_val = L_s_val / (i + 1) 194 | val_LOSS_ofa_avg = val_LOSS_ofa / (i + 1) 195 | # res = metric.compute() 196 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss_c = {L_s_val:.4f}, OA_c = {res['Accuracy']:.2f}%") 197 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 198 | # writer.add_scalar('valid/Acc', res['Accuracy'], epoch) 199 | 200 | # For NYU-Depth-V2 201 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s_val:.4f} | loss_ofa = {val_LOSS_ofa_avg:.4f}") 202 | writer.add_scalar('valid/Loss', L_s_val, epoch) 203 | 204 | # 保存验证集上表现最好的模型 205 | if L_s_val < best_val_loss: 206 | best_val_loss = L_s_val 207 | best_model_state = copy.deepcopy(model_s.state_dict()) 208 | best_epoch = epoch 209 | 210 | # test 211 | if args.test_phase: 212 | model_s.eval() 213 | projectors_eval(projectors) 214 | metric.reset() 215 | L_t = 0 216 | acc_c = 0 217 | gt_list, pred_list = [], [] 218 | with torch.no_grad(): 219 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 220 | total=len(test_loader), dynamic_ncols=True, disable=False, 221 | file=sys.stdout): 222 | data, data2, label = data.to(device), data2.to(device), label.to(device) 223 | data_s = data if args.mode == 'm1' else data2 224 | 225 | # hooks_t, features_t = hooks_builder(model_t, model_t.hook_names) 226 | hooks_s, features_s = hooks_builder(model_s, model_s.hook_names) 227 | 228 | # outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 229 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 230 | 231 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 232 | # loss_ofa = criterion_ofa(features_t, features_s, label) 233 | 234 | # hooks_remover(hooks_t) 235 | hooks_remover(hooks_s) 236 | 237 | loss = loss_s 238 | L_t = L_t + loss.item() 239 | metric.update(outputs_s, label) 240 | L_t = L_t / (i + 1) 241 | # res = metric.compute() 242 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss_c = {L_t:.4f}, OA_c = {res['Accuracy']:.2f}%") 243 | # writer.add_scalar('test/Loss', L_t, epoch) 244 | # writer.add_scalar('test/Acc', res['Accuracy'], epoch) 245 | 246 | # For NYU-Depth-V2 247 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t:.4f}") 248 | writer.add_scalar('test/Loss', L_t, epoch) 249 | 250 | # if (epoch + 1) % 10 == 0: 251 | # print('\n===============Metrics==================') 252 | # for e in res.keys(): 253 | # print(e) 254 | # print(res[e]) 255 | # print('----------------------------') 256 | # print('=======================================\n') 257 | 258 | # scheduler_t.step() 259 | args.alpha *= 0.5 if (epoch + 1) % 30 == 0 else 1.0 260 | start_time2 = time.time() 261 | time_cost = start_time2 - start_time1 262 | if time_cost > 100: 263 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 264 | else: 265 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 266 | 267 | writer.close() 268 | 269 | if args.save_model: 270 | if not os.path.exists('./checkpoints/stu/ofa'): 271 | os.makedirs('./checkpoints/stu/ofa') 272 | names = args.ckpt_name.split('_') 273 | Tmodel_mode = names[3] 274 | torch.save(best_model_state, 275 | f'./checkpoints/stu/ofa/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 276 | -------------------------------------------------------------------------------- /main-KD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from utils import get_data, get_dataset, get_Tmodules, get_Smodules, seed_all 4 | import copy 5 | import time 6 | from argparse import ArgumentParser 7 | from tensorboardX import SummaryWriter 8 | from tqdm import tqdm 9 | from Dataset import MultiModalX 10 | from KD_methods.KD import distillation_loss 11 | import sys 12 | import os 13 | 14 | 15 | if __name__ == '__main__': 16 | ''' 17 | Args Setting for CML. 18 | ''' 19 | parser = ArgumentParser(description='CML-KD') 20 | parser.add_argument('--database', type=str, default='AV-MNIST', 21 | help="database name must be one of ['NYU-Depth-V2', 'RAVDESS', 'AV-MNIST', 'VGGSound-50k', 'MM-IMDb']") 22 | parser.add_argument('--Tmodel', type=str, default='CNN-I', 23 | help='Teacher model name') 24 | parser.add_argument('--Smodel', type=str, default='LeNet5', 25 | help='Student model name') 26 | parser.add_argument('--mode', type=str, default='m1', 27 | help='modality mode: m1 or m2') 28 | 29 | parser.add_argument('--ckpt_name', type=str, 30 | default='AV-MNIST_CNN-I_seed0_ORG_ep5-5.pth', 31 | help='The name of the weight to be loaded in ./checkpoints/stu') 32 | parser.add_argument('--seed', type=int, default=0, 33 | help='Random seed') 34 | parser.add_argument('--num_workers', type=int, default=0, 35 | help='num_workers for DataLoader') 36 | parser.add_argument('--alpha', type=float, default=0.5, 37 | help='weight for loss') 38 | parser.add_argument('--batch_size', type=int, default=512, 39 | help='batch size for training') 40 | parser.add_argument('--lr', type=float, default=0.0001, 41 | help='learning rate for training') 42 | parser.add_argument('--record', type=bool, default=True, 43 | help='flag whether to record the learning log') 44 | parser.add_argument('--cuda_id', type=int, default=0, 45 | help='cuda id') 46 | parser.add_argument('--epochs', type=int, default=100, 47 | help='epochs for training') 48 | parser.add_argument('--save_model', type=bool, default=True, 49 | help='flag whether to save best model') 50 | parser.add_argument('--test_phase', type=bool, default=False, 51 | help='flag whether to conduct the test phase') 52 | parser.add_argument('--final_test', type=bool, default=False, 53 | help='flag whether to conduct the test phase') 54 | parser.add_argument('--commit', type=str, default='KD-baseline', 55 | help='Commit for logs') 56 | args = parser.parse_args() 57 | 58 | seed_all(args.seed) 59 | 60 | 61 | log_dir = f'./logs/kd' 62 | if not os.path.exists(log_dir): 63 | os.makedirs(log_dir) 64 | log_dir = log_dir + f'/{args.database}_{args.lr}_{str(time.time()).split(".")[0]}_{args.commit}' 65 | writer = SummaryWriter(log_dir, write_to_disk=args.record) 66 | 67 | data = get_data(args.database) 68 | data_train = get_dataset(args.database, data, 'train', args.seed) 69 | data_val = get_dataset(args.database, data, 'val', args.seed) 70 | data_test = get_dataset(args.database, data, 'test', args.seed) 71 | 72 | train_dataset = MultiModalX(data_train, args.database, mode=args.mode) 73 | valid_dataset = MultiModalX(data_val, args.database, mode=args.mode) 74 | test_dataset = MultiModalX(data_test, args.database, mode=args.mode) 75 | 76 | train_loader = DataLoader( 77 | train_dataset, 78 | batch_size=args.batch_size, 79 | pin_memory=True, 80 | num_workers=args.num_workers, 81 | shuffle=True 82 | ) 83 | 84 | valid_loader = DataLoader( 85 | valid_dataset, 86 | pin_memory=True, 87 | batch_size=args.batch_size 88 | ) 89 | 90 | test_loader = DataLoader( 91 | test_dataset, 92 | pin_memory=True, 93 | batch_size=args.batch_size 94 | ) 95 | 96 | # ===========GPU Setting==================== 97 | device = torch.device(f"cuda:{args.cuda_id}") 98 | # ==========Initialization=========== 99 | model_t, _, _, _, preprocessing_t, postprocessing_t, _ = get_Tmodules(args, device) 100 | model_s, optimizer_s, scheduler_s, criterion_s, preprocessing_s, postprocessing_s, metric = get_Smodules(args) 101 | 102 | model_t = model_t.to(device) 103 | model_s = model_s.to(device) 104 | model_t.load_state_dict(torch.load(f'checkpoints/tea/{args.ckpt_name}', weights_only=True, map_location=f'cuda:{args.cuda_id}')) 105 | model_t.eval() 106 | 107 | best_model_state = None 108 | best_val_loss = float('inf') 109 | # Cloud-only Learning 110 | print('================= CML-KD =================') 111 | for epoch in range(args.epochs): 112 | start_time1 = time.time() 113 | # train 114 | model_s.train() 115 | 116 | tra_LOSS_s, tra_LOSS_task, tra_LOSS_kl = 0, 0 ,0 117 | for i, (data, data2, label) in tqdm(enumerate(train_loader), desc="Model Training ...", 118 | total=len(train_loader), dynamic_ncols=True, 119 | disable=False, file=sys.stdout): 120 | # print('Train Iter {}'.format(i)) 121 | data, data2, label = data.to(device), data2.to(device), label.to(device) 122 | data_s = data if args.mode == 'm1' else data2 123 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 124 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 125 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 126 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 127 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 128 | optimizer_s.zero_grad() 129 | loss.backward() 130 | optimizer_s.step() 131 | tra_LOSS_s += loss.item() 132 | tra_LOSS_task += loss_s.item() 133 | tra_LOSS_kl += loss_kl.item() 134 | 135 | tra_LOSS_s_avg = tra_LOSS_s / (i + 1) 136 | tra_LOSS_task_avg = tra_LOSS_task / (i + 1) 137 | tra_LOSS_kl_avg = tra_LOSS_kl / (i + 1) 138 | print(f'Train =====> Epoch {epoch + 1}/{args.epochs}: loss = {tra_LOSS_s_avg:.4f} | ' 139 | f'loss_task = {tra_LOSS_task_avg:.4f} | loss_kl = {tra_LOSS_kl_avg:.4f}') 140 | writer.add_scalar('train/Loss', tra_LOSS_s_avg, epoch) 141 | 142 | # validation 143 | model_s.eval() 144 | metric.reset() 145 | L_s_val = 0 146 | acc_c = 0 147 | Loss_s, Loss_kl = 0, 0 148 | gt_list, pred_list = [], [] 149 | with torch.no_grad(): 150 | for i, (data, data2, label) in tqdm(enumerate(valid_loader), desc="Model Validating ...", 151 | total=len(valid_loader), dynamic_ncols=True, disable=False, 152 | file=sys.stdout): 153 | # print('Val Iter {}'.format(i)) 154 | data, data2, label = data.to(device), data2.to(device), label.to(device) 155 | data_s = data if args.mode == 'm1' else data2 156 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 157 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 158 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 159 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 160 | loss = loss_s 161 | L_s_val += loss.item() 162 | Loss_kl += loss_kl.item() 163 | Loss_s += loss_s.item() 164 | metric.update(outputs_s, label) 165 | L_s_val = L_s_val / (i + 1) 166 | Loss_kl_avg = Loss_kl / (i + 1) 167 | Loss_s_avg = Loss_s / (i + 1) 168 | # res = metric.compute() 169 | # print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_s_val:.4f} | " 170 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 171 | # writer.add_scalar('valid/Loss', L_s_val, epoch) 172 | # writer.add_scalar('valid/Acc', float(res['Accuracy']), epoch) 173 | 174 | # For NYU-Depth-V2 175 | print(f"Valid =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_s_val:.4f} | loss_task = {Loss_s_avg:.4f} | loss_kl = {Loss_kl_avg:.4f}") 176 | writer.add_scalar('valid/Loss', L_s_val, epoch) 177 | 178 | if L_s_val < best_val_loss: 179 | best_val_loss = L_s_val 180 | best_model_state = copy.deepcopy(model_s.state_dict()) 181 | best_epoch = epoch 182 | 183 | # test 184 | if args.test_phase: 185 | model_s.eval() 186 | metric.reset() 187 | L_t = 0 188 | acc_c = 0 189 | Loss_s, Loss_kl = 0, 0 190 | gt_list, pred_list = [], [] 191 | with torch.no_grad(): 192 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 193 | total=len(test_loader), dynamic_ncols=True, disable=False, 194 | file=sys.stdout): 195 | data, data2, label = data.to(device), data2.to(device), label.to(device) 196 | data_s = data if args.mode == 'm1' else data2 197 | outputs_t = model_t(data, data2) if preprocessing_t is None else preprocessing_t(model_t, data, data2) 198 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 199 | loss_s = criterion_s(outputs_s, label) if postprocessing_s is None else postprocessing_s(outputs_s, label) 200 | loss_kl = distillation_loss(args, outputs_s, outputs_t) 201 | loss = (1.0 - args.alpha) * loss_s + args.alpha * loss_kl 202 | L_t = L_t + loss.item() 203 | Loss_kl += loss_kl.item() 204 | Loss_s += loss_s.item() 205 | metric.update(outputs_s, label) 206 | L_t = L_t / (i + 1) 207 | Loss_kl_avg = Loss_kl / (i + 1) 208 | Loss_s_avg = Loss_s / (i + 1) 209 | # res = metric.compute() 210 | # print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: OA_s = {float(res['Accuracy'])} | loss = {L_t:.4f} | " 211 | # f"Loss_s = {Loss_s_avg:.4f} | Loss_kl = {Loss_kl_avg:.4f}") 212 | # writer.add_scalar('test/Loss', L_t, epoch) 213 | # writer.add_scalar('test/Acc', float(res['Accuracy']), epoch) 214 | 215 | # For NYU-Depth-V2 216 | print(f"Test =====> Epoch {epoch + 1}/{args.epochs}: loss = {L_t:.4f}") 217 | writer.add_scalar('test/Loss', L_t, epoch) 218 | 219 | # if (epoch + 1) % 10 == 0: 220 | # print('\n===============Metrics==================') 221 | # for e in res.keys(): 222 | # print(e) 223 | # print(res[e]) 224 | # print('----------------------------') 225 | # print('=======================================\n') 226 | 227 | # scheduler_t.step() 228 | args.alpha *= 0.5 if (epoch+1) % 30 == 0 else 1.0 229 | start_time2 = time.time() 230 | time_cost = start_time2 - start_time1 231 | if time_cost > 100: 232 | print(f"Epoch {epoch + 1} time cost: {time_cost / 60:.2f} minutes.\n") 233 | else: 234 | print(f"Epoch {epoch + 1} time cost: {time_cost:.2f} seconds.\n") 235 | 236 | writer.close() 237 | 238 | if args.final_test: 239 | model_s.load_state_dict(best_model_state) 240 | print('================= Final Test for This Model =================') 241 | # test 242 | model_s.eval() 243 | metric.reset() 244 | gt_list, pred_list = [], [] 245 | with torch.no_grad(): 246 | for i, (data, data2, label) in tqdm(enumerate(test_loader), desc="Model Testing ...", 247 | total=len(test_loader), dynamic_ncols=True, disable=True, file=sys.stdout): 248 | data, data2, label = data.to(device), data2.to(device), label.to(device) 249 | data_s = data if args.mode == 'm1' else data2 250 | outputs_s = model_s(data_s) if preprocessing_s is None else preprocessing_s(model_s, data_s) 251 | metric.update(outputs_s, label) 252 | res = metric.compute() 253 | 254 | print('\n===============Metrics==================') 255 | for e in res.keys(): 256 | print(e) 257 | print(res[e]) 258 | print('----------------------------') 259 | print('=======================================\n') 260 | 261 | if args.save_model: 262 | if not os.path.exists('./checkpoints/stu/kd'): 263 | os.makedirs('./checkpoints/stu/kd') 264 | names = args.ckpt_name.split('_') 265 | Tmodel_mode = names[3] 266 | torch.save(best_model_state, 267 | f'./checkpoints/stu/kd/{args.database}_{args.Tmodel}_{Tmodel_mode}--{args.Smodel}_seed{args.seed}_{args.mode}_ep{best_epoch + 1}-{args.epochs}.pth') 268 | --------------------------------------------------------------------------------