├── .DS_Store ├── OGM.jpg ├── OPM.jpg ├── README.md └── code ├── config.py ├── dataset ├── KS.json ├── KS.py ├── KS_train_val.json ├── loader.py └── spatial_transforms.py ├── main.py ├── models ├── Audio_Classifier.py ├── BasicModule.py ├── Classifier.py ├── Resnet_18.py ├── Visual_Classifier.py └── fusion_model.py ├── requirements.txt └── scripts ├── inference.sh ├── train_ogm.sh └── train_opm.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/.DS_Store -------------------------------------------------------------------------------- /OGM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/OGM.jpg -------------------------------------------------------------------------------- /OPM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/OPM.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code of On-the-fly Modulation for Balanced Multimodal Learning 2 | The repo for "On-the-fly Modulation for Balanced Multimodal Learning", T-PAMI 2024 3 | 4 | Here is the official PyTorch implementation of ''*On-the-fly Modulation for Balanced Multimodal Learning*'', which analyze and alleviate the imbalanced multimodal learning problem from both the feed-forward and the back-propagation stages during optimization Please refer to our [T-PAMI 2024 paper](https://ieeexplore.ieee.org/abstract/document/10694738) for more details. This journal paper is extension of our previous CVPR 2022 paper [\[Balanced Multimodal Learning via On-the-fly Gradient Modulation\]](https://arxiv.org/abs/2203.15332). 5 | 6 | **Paper Title: "On-the-fly Modulation for Balanced Multimodal Learning"** 7 | 8 | **Authors: [Yake Wei](https://echo0409.github.io/), [Di Hu](https://dtaoo.github.io/index.html), Henghui Du and Ji-Rong Wen** 9 | 10 | 11 | ## On-the-fly Modulation for Balanced Multimodal Learning 12 | Multimodal learning is expected to boost model performance by integrating information from different modalities. However, its potential is not fully exploited because the widely-used joint training strategy, which has a uniform objective for all modalities, leads to imbalanced and under-optimized uni-modal representations. Specifically, we point out that there often exists modality with more discriminative information, e.g., vision of playing football and sound of blowing wind. They could dominate the joint training process, resulting in other modalities being significantly under-optimized. 13 | 14 | To alleviate this problem, we first analyze the under-optimized phenomenon from both the feed-forward and the back-propagation stages during optimization. Then, **On-the-fly Prediction Modulation (OPM)** and **On-the-fly Gradient Modulation (OGM)** strategies are proposed to modulate the optimization of each modality, by monitoring the discriminative discrepancy between modalities during training. Concretely, OPM weakens the influence of the dominant modality by dropping its feature with dynamical probability in the feed-forward stage, while OGM mitigates its gradient in the back-propagation stage. In experiments, our methods demonstrate considerable improvement across a variety of multimodal tasks. These simple yet effective strategies not only enhance performance in vanilla and task-oriented multimodal models, but also in more complex multimodal tasks, showcasing their effectiveness and flexibility. 15 | 16 | 17 |
18 | 19 |

Pipeline of OPM method.

20 |
21 | 22 | 23 |
24 | 25 |

Pipeline of OGM method.

26 |
27 | 28 | 29 | ## Code instruction 30 | 31 | ### Data Preparation 32 | The original datasets can be found: 33 | [CREMA-D](https://github.com/CheyneyComputerScience/CREMA-D), 34 | [Kinetics-Sounds](https://github.com/cvdfoundation/kinetics-dataset), 35 | [UCF101](https://www.crcv.ucf.edu/data/UCF101.php). 36 | 37 | The data preprocessing follows [OGM-GE](https://github.com/GeWu-Lab/OGM-GE_CVPR2022). 38 | 39 | 40 | 41 | ### Install 42 | 43 | ```python 44 | pip install -r requirements.txt 45 | ``` 46 | 47 | ### Prepare dataset 48 | 1. Get the prepeocessed data of the KS dataset to "YOU_PATH". In our case, video data are HDF5 files and audio data are PKL files. 49 | 2. set HDF5_DIR = "YOU_PATH_VIDEO", PKL_DIR = "YOU_PATH_AUDIO" in dataset/KS.py 50 | 51 | 52 | ### Training 53 | 54 | ```shell 55 | cd code 56 | # OGM 57 | bash scripts/train_ogm.sh 58 | 59 | # OPM 60 | bash scripts/train_opm.sh 61 | ``` 62 | 63 | ### Inference 64 | 65 | ```shell 66 | bash scripts/inference.sh 67 | ``` 68 | 69 | 70 | ## Citation 71 | If you find this work useful, please consider citing it. 72 | 73 |

74 | @article{wei2024on,
75 |   title={On-the-fly modulation for balanced multimodal learning},
76 |   author={Wei, Yake and Hu, Di and Du, Henghui and Wen, Ji-Rong},
77 |   journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
78 |   year={2024}
79 | }
80 | 
81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | 4 | 5 | class Config(): 6 | 7 | def __init__(self) -> None: 8 | 9 | # dataset setting 10 | self.dataset='KineticSound' 11 | self.num_classes={'VGGSound':309,'KineticSound':31,'CREMAD':6,'AVE':28} 12 | self.modality=['audio','visual'] 13 | self.fps = 1 14 | self.use_video_frames = 3 15 | 16 | # backbone setting 17 | self.in_c=3 18 | self.out_c=64 19 | 20 | # train setting 21 | self.train = False 22 | self.batch_size = 32 23 | self.epochs=100 24 | self.optimizer='Adamw' 25 | 26 | self.learning_rate=5e-5 27 | self.lr_decay_ratio=0.1 28 | # self.lr_decay_step=[30,50,70] 29 | self.lr_decay_step=40 30 | 31 | # modulation setting 32 | self.use_modulation=False 33 | self.modulation = 'OGM_GE' 34 | self.modulation_starts = 0 35 | self.modulation_ends = 80 36 | 37 | self.alpha = 1 38 | 39 | # fusion setting 40 | self.fusion_method = 'concat' 41 | self.d = [512, 512] 42 | # gated_fusion 43 | self.mid_c=512 44 | self.x_gated=False 45 | 46 | # adam-drop lambda setting 47 | self.use_adam_drop = False 48 | self.key=50 49 | self.sigma=2 50 | 51 | self.p_exe=0.7 52 | self.q_base=0.4 53 | self.lam=0.5 54 | 55 | # other setting 56 | self.checkpoint_path = 'result' 57 | 58 | self.resume_model=False 59 | self.resume_model_path=None 60 | 61 | self.use_tensorboard = True 62 | 63 | self.random_seed = 0 64 | self.gpu_ids = [0,1] 65 | 66 | self.func='tanh' 67 | self.form='/' 68 | 69 | self.device=0 70 | 71 | 72 | # transforms setting 73 | self.decrease_epoch=10 74 | self.sample_size=112 75 | self.sample_t_stride=1 76 | self.train_crop='random' 77 | self.value_scale=1 78 | self.scale_h=128 79 | self.scale_w=171 80 | self.train_crop_min_scale=0.25 81 | self.train_crop_min_ratio=0.75 82 | self.no_hflip=False 83 | self.colorjitter=False 84 | self.train_t_crop='random' 85 | 86 | self.audio_drop=0.0 87 | self.visual_drop=0.0 88 | 89 | 90 | def parse(self,kwargs): 91 | for k,v in kwargs.items(): 92 | if not hasattr(self,k): 93 | warnings.warn('has not attribute %s'%k) 94 | setattr(self,k,v) 95 | 96 | # print('config info:') 97 | # for k,v in self.__dict__.items(): 98 | # if not k.startswith('__'): 99 | # print(k,getattr(self,k)) 100 | 101 | if __name__ == "__main__": 102 | import argparse 103 | cfg=Config() 104 | 105 | parser = argparse.ArgumentParser() 106 | 107 | parser.add_argument('--use_modulation',action='store_true',help='use gradient modulation') 108 | parser.add_argument('--use_adam_drop',action='store_true',help='use adam-drop') 109 | parser.add_argument('--modulation', default='OGM_GE', type=str,choices=['Normal', 'OGM', 'OGM_GE']) 110 | parser.add_argument('--fusion_method', default='concat', type=str,choices=['sum', 'concat', 'gated']) 111 | parser.add_argument('--train', action='store_true', help='turn on train mode') 112 | parser.add_argument('--resume_model',action='store_true',help='whether to resume model') 113 | parser.add_argument('--checkpoint_path',type=str,help='load checkpoints') 114 | 115 | args=parser.parse_args() 116 | cfg.parse(vars(args)) 117 | 118 | -------------------------------------------------------------------------------- /code/dataset/KS.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import torch.utils.data as data 5 | from pathlib import Path 6 | from random import randrange 7 | import numpy as np 8 | import h5py 9 | import pickle 10 | 11 | 12 | from .loader import VideoLoaderHDF5 13 | from .loader import AudioFeatureLoader 14 | from .spatial_transforms import get_spatial_transform,get_val_spatial_transforms 15 | HDF5_DIR='' 16 | PKL_DIR='' 17 | PROJECT_DIR='' 18 | 19 | def get_dataset(annotation_data,mode): 20 | video_names = [] 21 | video_labels = [] 22 | 23 | for key in annotation_data.keys(): 24 | if annotation_data[key]['subset'] == mode: 25 | video_names.append(key) 26 | video_labels.append(annotation_data[key]['label']) 27 | return video_names,video_labels 28 | 29 | 30 | class KSDataset(data.Dataset): 31 | def __init__(self, 32 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'), 33 | mode='training', 34 | spatial_transform=None, 35 | video_loader = None, 36 | audio_drop=0.0, 37 | visual_drop=0.0 38 | ): 39 | 40 | self.video_dir = HDF5_DIR 41 | self.audio_dir = PKL_DIR 42 | 43 | self.audio_drop=audio_drop 44 | self.visual_drop=visual_drop 45 | 46 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode) 47 | 48 | self.spatial_transform = spatial_transform 49 | 50 | self.loader = video_loader 51 | 52 | 53 | def __make_dataset(self,video_dir,annotation_path,subset): 54 | with open(annotation_path) as f: 55 | annotation_data = json.load(f) 56 | class_labels = annotation_data['labels'] 57 | annotation_data = annotation_data['database'] 58 | 59 | video_names , video_labels = get_dataset(annotation_data,subset) 60 | 61 | class_to_idx = {label : i for i,label in enumerate(class_labels)} 62 | idx_to_class = {i : label for i,label in enumerate(class_labels)} 63 | 64 | n_videos = len(video_names) 65 | 66 | dataset = [] 67 | max_len = 0 68 | 69 | for i in range(n_videos): 70 | 71 | label = video_labels[i] 72 | label_id = class_to_idx[label] 73 | 74 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5") 75 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl") 76 | if not os.path.exists(video_path) or not os.path.exists(audio_path): 77 | continue 78 | 79 | sample = { 80 | 'video': video_names[i], 81 | 'label': label_id, 82 | } 83 | 84 | dataset.append(sample) 85 | return dataset,idx_to_class,n_videos 86 | 87 | def add_mask_visual(self, image, ratio): 88 | patch_w = 10 89 | patch_l = 10 90 | w_num = int(224 / patch_w) 91 | l_num = int(224 / patch_l) 92 | total_num = w_num * l_num 93 | patch_num = int(total_num * ratio) 94 | # print(total_num, patch_num) 95 | patch_list = np.random.choice(total_num, patch_num, replace=False) 96 | for index in patch_list: 97 | patch_x = index % w_num * patch_w 98 | patch_y = int(index / w_num) * patch_l 99 | 100 | image[:, patch_x:patch_x+patch_w, patch_y:patch_y+patch_l] = 0.0 101 | 102 | return image 103 | 104 | def add_mask_audio(self, image, ratio): 105 | patch_w = 10 106 | patch_l = 10 107 | w_num = int(224 / patch_w) 108 | l_num = int(224 / patch_l) 109 | total_num = w_num * l_num 110 | patch_num = int(total_num * ratio) 111 | # print(total_num, patch_num) 112 | patch_list = np.random.choice(total_num, patch_num, replace=False) 113 | for index in patch_list: 114 | patch_x = index % w_num * patch_w 115 | patch_y = int(index / w_num) * patch_l 116 | 117 | image[patch_x:patch_x+patch_w, patch_y:patch_y+patch_l] = 0.0 118 | 119 | return image 120 | 121 | def __len__(self): 122 | return len(self.dataset) 123 | 124 | 125 | def __loading(self, path, video_name): 126 | 127 | clip=None 128 | try: 129 | clip = self.loader(path) 130 | except Exception as e: 131 | print("path {} has error".format(path)) 132 | 133 | len_clip = len(clip) 134 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]] 135 | 136 | if self.spatial_transform is not None: 137 | self.spatial_transform.randomize_parameters() 138 | clip = [self.spatial_transform(img) for img in clip] 139 | if self.visual_drop>0.0: 140 | clip=[self.add_mask_visual(img,self.visual_drop) for img in clip] 141 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) # c t h w 142 | return clip 143 | 144 | def __load_audio(self,audio_path): 145 | with open(audio_path,"rb") as f: 146 | audio = pickle.load(f) 147 | 148 | if self.audio_drop>0.0: 149 | audio=self.add_mask_audio(audio,self.audio_drop) 150 | return audio 151 | 152 | def __getitem__(self, index): 153 | 154 | video_name = self.dataset[index]['video'] 155 | 156 | video_path = os.path.join(self.video_dir,video_name + ".hdf5") 157 | label = self.dataset[index]['label'] 158 | 159 | clip = self.__loading(video_path, video_name) 160 | 161 | audio_path = os.path.join(self.audio_dir,video_name + ".pkl") 162 | 163 | audio = self.__load_audio(audio_path) 164 | 165 | return audio,clip,label 166 | 167 | 168 | 169 | class VisualDataset(data.Dataset): 170 | def __init__(self, 171 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'), 172 | mode='training', 173 | spatial_transform=None, 174 | video_loader = None 175 | ): 176 | 177 | self.video_dir = HDF5_DIR 178 | self.audio_dir = PKL_DIR 179 | 180 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode) 181 | 182 | self.spatial_transform = spatial_transform 183 | 184 | self.loader = video_loader 185 | 186 | 187 | def __make_dataset(self,video_dir,annotation_path,subset): 188 | with open(annotation_path) as f: 189 | annotation_data = json.load(f) 190 | class_labels = annotation_data['labels'] 191 | annotation_data = annotation_data['database'] 192 | 193 | video_names , video_labels = get_dataset(annotation_data,subset) 194 | 195 | class_to_idx = {label : i for i,label in enumerate(class_labels)} 196 | idx_to_class = {i : label for i,label in enumerate(class_labels)} 197 | 198 | n_videos = len(video_names) 199 | 200 | dataset = [] 201 | max_len = 0 202 | 203 | for i in range(n_videos): 204 | 205 | label = video_labels[i] 206 | label_id = class_to_idx[label] 207 | 208 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5") 209 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl") 210 | if not os.path.exists(video_path) or not os.path.exists(audio_path): 211 | continue 212 | 213 | sample = { 214 | 'video': video_names[i], 215 | 'label': label_id, 216 | } 217 | 218 | dataset.append(sample) 219 | return dataset,idx_to_class,n_videos 220 | 221 | 222 | 223 | def __len__(self): 224 | return len(self.dataset) 225 | 226 | 227 | def __loading(self, path, video_name): 228 | 229 | clip=None 230 | try: 231 | clip = self.loader(path) 232 | except Exception as e: 233 | print("path {} has error".format(path)) 234 | 235 | len_clip = len(clip) 236 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]] 237 | 238 | if self.spatial_transform is not None: 239 | self.spatial_transform.randomize_parameters() 240 | clip = [self.spatial_transform(img) for img in clip] 241 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 242 | return clip 243 | 244 | def __load_audio(self,audio_path): 245 | with open(audio_path,"rb") as f: 246 | audio = pickle.load(f) 247 | 248 | return audio 249 | 250 | def __getitem__(self, index): 251 | 252 | video_name = self.dataset[index]['video'] 253 | 254 | video_path = os.path.join(self.video_dir,video_name + ".hdf5") 255 | label = self.dataset[index]['label'] 256 | 257 | clip = self.__loading(video_path, video_name) 258 | 259 | # audio_path = os.path.join(self.audio_dir,video_name + ".pkl") 260 | 261 | # audio = self.__load_audio(audio_path) 262 | 263 | return clip,label 264 | 265 | 266 | class AudioDataset(data.Dataset): 267 | def __init__(self, 268 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'), 269 | mode='training', 270 | spatial_transform=None, 271 | video_loader = None 272 | ): 273 | 274 | self.video_dir = HDF5_DIR 275 | self.audio_dir = PKL_DIR 276 | 277 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode) 278 | 279 | self.spatial_transform = spatial_transform 280 | 281 | self.loader = video_loader 282 | 283 | 284 | def __make_dataset(self,video_dir,annotation_path,subset): 285 | with open(annotation_path) as f: 286 | annotation_data = json.load(f) 287 | class_labels = annotation_data['labels'] 288 | annotation_data = annotation_data['database'] 289 | 290 | video_names , video_labels = get_dataset(annotation_data,subset) 291 | 292 | class_to_idx = {label : i for i,label in enumerate(class_labels)} 293 | idx_to_class = {i : label for i,label in enumerate(class_labels)} 294 | 295 | n_videos = len(video_names) 296 | 297 | dataset = [] 298 | max_len = 0 299 | 300 | for i in range(n_videos): 301 | 302 | label = video_labels[i] 303 | label_id = class_to_idx[label] 304 | 305 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5") 306 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl") 307 | if not os.path.exists(video_path) or not os.path.exists(audio_path): 308 | continue 309 | 310 | sample = { 311 | 'video': video_names[i], 312 | 'label': label_id, 313 | } 314 | 315 | dataset.append(sample) 316 | return dataset,idx_to_class,n_videos 317 | 318 | 319 | 320 | def __len__(self): 321 | return len(self.dataset) 322 | 323 | 324 | def __loading(self, path, video_name): 325 | 326 | clip=None 327 | try: 328 | clip = self.loader(path) 329 | except Exception as e: 330 | print("path {} has error".format(path)) 331 | 332 | len_clip = len(clip) 333 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]] 334 | 335 | if self.spatial_transform is not None: 336 | self.spatial_transform.randomize_parameters() 337 | clip = [self.spatial_transform(img) for img in clip] 338 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 339 | return clip 340 | 341 | def __load_audio(self,audio_path): 342 | with open(audio_path,"rb") as f: 343 | audio = pickle.load(f) 344 | 345 | return audio 346 | 347 | def __getitem__(self, index): 348 | 349 | video_name = self.dataset[index]['video'] 350 | 351 | # video_path = os.path.join(self.video_dir,video_name + ".hdf5") 352 | label = self.dataset[index]['label'] 353 | 354 | # clip = self.__loading(video_path, video_name) 355 | 356 | audio_path = os.path.join(self.audio_dir,video_name + ".pkl") 357 | 358 | audio = self.__load_audio(audio_path) 359 | 360 | return audio,label 361 | 362 | 363 | 364 | if __name__=='__main__': 365 | 366 | import argparse 367 | parser=argparse.ArgumentParser() 368 | 369 | parser.add_argument("--decrease_epoch",type = int,default = 10) 370 | parser.add_argument('--sample_size',type = int,default = 112) 371 | parser.add_argument('--sample_t_stride',type = int,default = 1) 372 | parser.add_argument('--train_crop', 373 | default='random', 374 | type=str, 375 | help=('Spatial cropping method in training. ' 376 | 'random is uniform. ' 377 | 'corner is selection from 4 corners and 1 center. ' 378 | '(random | corner | center)')) 379 | parser.add_argument('--value_scale', 380 | default=1, 381 | type=int, 382 | help= 383 | 'If 1, range of inputs is [0-1]. If 255, range of inputs is [0-255].') 384 | parser.add_argument("--scale_h", type=int, default=128, 385 | help="Scale image height to") 386 | parser.add_argument("--scale_w", type=int, default=171, 387 | help="Scale image width to") 388 | parser.add_argument('--train_crop_min_scale', 389 | default=0.25, 390 | type=float, 391 | help='Min scale for random cropping in training') 392 | parser.add_argument('--train_crop_min_ratio', 393 | default=0.75, 394 | type=float, 395 | help='Min aspect ratio for random cropping in training') 396 | parser.add_argument('--no_hflip', 397 | action='store_true', 398 | help='If true holizontal flipping is not performed.') 399 | parser.add_argument('--colorjitter', 400 | action='store_true', 401 | help='If true colorjitter is performed.') 402 | parser.add_argument('--train_t_crop', 403 | default='random', 404 | type=str, 405 | help=('Temporal cropping method in training. ' 406 | 'random is uniform. ' 407 | '(random | center)')) 408 | 409 | args=parser.parse_args() 410 | 411 | spatial_transforms=get_spatial_transform(opt=args) 412 | 413 | dataset=KSDataset(video_loader=VideoLoaderHDF5(),spatial_transform=spatial_transforms) 414 | 415 | 416 | -------------------------------------------------------------------------------- /code/dataset/loader.py: -------------------------------------------------------------------------------- 1 | import io 2 | import h5py 3 | import numpy as np 4 | from os import path 5 | from PIL import Image 6 | 7 | 8 | class ImageLoaderPIL(object): 9 | 10 | def __call__(self, path): 11 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 12 | with path.open('rb') as f: 13 | with Image.open(f) as img: 14 | return img.convert('RGB') 15 | 16 | 17 | # class ImageLoaderAccImage(object): 18 | 19 | # def __call__(self, path): 20 | # import accimage 21 | # return accimage.Image(str(path)) 22 | 23 | 24 | class NumpyLoader(object): 25 | 26 | def __call__(self, path): 27 | return np.load(path) 28 | 29 | 30 | class AudioFeatureLoader(object): 31 | # load an audio feature stored as numpy file ('.npy) 32 | def __init__(self): 33 | self.npyloader = NumpyLoader() 34 | 35 | def __call__(self, filename): 36 | if path.isfile(filename): 37 | features = self.npyloader(filename) 38 | else: 39 | features = None 40 | return features 41 | 42 | 43 | class VideoLoader(object): 44 | 45 | def __init__(self, image_name_formatter, image_loader=None): 46 | self.image_name_formatter = image_name_formatter 47 | if image_loader is None: 48 | self.image_loader = ImageLoaderPIL() 49 | else: 50 | self.image_loader = image_loader 51 | 52 | def __call__(self, video_path, frame_indices): 53 | video = [] 54 | for i in frame_indices: 55 | image_path = video_path / self.image_name_formatter(i) 56 | if image_path.exists(): 57 | video.append(self.image_loader(image_path)) 58 | return video 59 | 60 | 61 | 62 | class VideoLoaderHDF5(object): 63 | 64 | def __call__(self, video_path): 65 | with h5py.File(video_path, 'r') as f: 66 | video_data = f['video'] 67 | video = [] 68 | for i in range(len(video_data)): 69 | video.append(Image.open(io.BytesIO(video_data[i - 1]))) 70 | return video 71 | 72 | 73 | class VideoLoaderFlowHDF5(object): 74 | 75 | def __init__(self): 76 | self.flows = ['u', 'v'] 77 | 78 | def __call__(self, video_path, frame_indices): 79 | with h5py.File(video_path, 'r') as f: 80 | 81 | flow_data = [] 82 | for flow in self.flows: 83 | flow_data.append(f[f'video_{flow}']) 84 | 85 | video = [] 86 | for i in frame_indices: 87 | if i < len(flow_data[0]): 88 | frame = [ 89 | Image.open(io.BytesIO(video_data[i])) 90 | for video_data in flow_data 91 | ] 92 | frame.append(frame[-1]) # add dummy data into third channel 93 | video.append(Image.merge('RGB', frame)) 94 | return video 95 | -------------------------------------------------------------------------------- /code/dataset/spatial_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torchvision.transforms import transforms 4 | from torchvision.transforms import functional as F 5 | from PIL import Image 6 | 7 | 8 | class Compose(transforms.Compose): 9 | 10 | def randomize_parameters(self): 11 | for t in self.transforms: 12 | t.randomize_parameters() 13 | 14 | 15 | class ToTensor(transforms.ToTensor): 16 | 17 | def randomize_parameters(self): 18 | pass 19 | 20 | 21 | class Normalize(transforms.Normalize): 22 | 23 | def randomize_parameters(self): 24 | pass 25 | 26 | 27 | class ScaleValue(object): 28 | 29 | def __init__(self, s): 30 | self.s = s 31 | 32 | def __call__(self, tensor): 33 | tensor *= self.s 34 | return tensor 35 | 36 | def randomize_parameters(self): 37 | pass 38 | 39 | 40 | class Resize(transforms.Resize): 41 | 42 | def randomize_parameters(self): 43 | pass 44 | 45 | 46 | class RandomCrop(transforms.RandomCrop): 47 | 48 | def randomize_parameters(self): 49 | pass 50 | 51 | 52 | # class Scale(transforms.Scale): 53 | 54 | # def randomize_parameters(self): 55 | # pass 56 | 57 | 58 | class CenterCrop(transforms.CenterCrop): 59 | 60 | def randomize_parameters(self): 61 | pass 62 | 63 | 64 | class CornerCrop(object): 65 | 66 | def __init__(self, 67 | size, 68 | crop_position=None, 69 | crop_positions=['c', 'tl', 'tr', 'bl', 'br']): 70 | self.size = size 71 | self.crop_position = crop_position 72 | self.crop_positions = crop_positions 73 | 74 | if crop_position is None: 75 | self.randomize = True 76 | else: 77 | self.randomize = False 78 | self.randomize_parameters() 79 | 80 | def __call__(self, img): 81 | image_width = img.size[0] 82 | image_height = img.size[1] 83 | 84 | h, w = (self.size, self.size) 85 | if self.crop_position == 'c': 86 | i = int(round((image_height - h) / 2.)) 87 | j = int(round((image_width - w) / 2.)) 88 | elif self.crop_position == 'tl': 89 | i = 0 90 | j = 0 91 | elif self.crop_position == 'tr': 92 | i = 0 93 | j = image_width - self.size 94 | elif self.crop_position == 'bl': 95 | i = image_height - self.size 96 | j = 0 97 | elif self.crop_position == 'br': 98 | i = image_height - self.size 99 | j = image_width - self.size 100 | 101 | img = F.crop(img, i, j, h, w) 102 | 103 | return img 104 | 105 | def randomize_parameters(self): 106 | if self.randomize: 107 | self.crop_position = self.crop_positions[random.randint( 108 | 0, 109 | len(self.crop_positions) - 1)] 110 | 111 | def __repr__(self): 112 | return self.__class__.__name__ + '(size={0}, crop_position={1}, randomize={2})'.format( 113 | self.size, self.crop_position, self.randomize) 114 | 115 | 116 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 117 | 118 | def __init__(self, p=0.5): 119 | super().__init__(p) 120 | self.randomize_parameters() 121 | 122 | def __call__(self, img): 123 | """ 124 | Args: 125 | img (PIL.Image): Image to be flipped. 126 | Returns: 127 | PIL.Image: Randomly flipped image. 128 | """ 129 | if self.random_p < self.p: 130 | return F.hflip(img) 131 | return img 132 | 133 | def randomize_parameters(self): 134 | self.random_p = random.random() 135 | 136 | 137 | class MultiScaleCornerCrop(object): 138 | 139 | def __init__(self, 140 | size, 141 | scales, 142 | crop_positions=['c', 'tl', 'tr', 'bl', 'br'], 143 | interpolation=Image.BILINEAR): 144 | self.size = size 145 | self.scales = scales 146 | self.interpolation = interpolation 147 | self.crop_positions = crop_positions 148 | 149 | self.randomize_parameters() 150 | 151 | def __call__(self, img): 152 | short_side = min(img.size[0], img.size[1]) 153 | crop_size = int(short_side * self.scale) 154 | self.corner_crop.size = crop_size 155 | 156 | img = self.corner_crop(img) 157 | return img.resize((self.size, self.size), self.interpolation) 158 | 159 | def randomize_parameters(self): 160 | self.scale = self.scales[random.randint(0, len(self.scales) - 1)] 161 | crop_position = self.crop_positions[random.randint( 162 | 0, 163 | len(self.crop_positions) - 1)] 164 | 165 | self.corner_crop = CornerCrop(None, crop_position) 166 | 167 | def __repr__(self): 168 | return self.__class__.__name__ + '(size={0}, scales={1}, interpolation={2})'.format( 169 | self.size, self.scales, self.interpolation) 170 | 171 | 172 | class RandomResizedCrop(transforms.RandomResizedCrop): 173 | 174 | def __init__(self, 175 | size, 176 | scale=(0.08, 1.0), 177 | ratio=(3. / 4., 4. / 3.), 178 | interpolation=Image.BILINEAR): 179 | super().__init__(size, scale, ratio, interpolation) 180 | self.randomize_parameters() 181 | 182 | def __call__(self, img): 183 | if self.randomize: 184 | self.random_crop = self.get_params(img, self.scale, self.ratio) 185 | self.randomize = False 186 | 187 | i, j, h, w = self.random_crop 188 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 189 | 190 | def randomize_parameters(self): 191 | self.randomize = True 192 | 193 | 194 | class ColorJitter(transforms.ColorJitter): 195 | 196 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 197 | super().__init__(brightness, contrast, saturation, hue) 198 | self.randomize_parameters() 199 | 200 | def __call__(self, img): 201 | if self.randomize: 202 | self.transform = self.get_params(self.brightness, self.contrast, 203 | self.saturation, self.hue) 204 | self.randomize = False 205 | 206 | return self.transform(img) 207 | 208 | def randomize_parameters(self): 209 | self.randomize = True 210 | 211 | 212 | class PickFirstChannels(object): 213 | 214 | def __init__(self, n): 215 | self.n = n 216 | 217 | def __call__(self, tensor): 218 | return tensor[:self.n, :, :] 219 | 220 | def randomize_parameters(self): 221 | pass 222 | 223 | 224 | def get_normalize_method(): 225 | mean=[0.485, 0.456, 0.406] 226 | std=[0.229, 0.224, 0.225] 227 | 228 | return Normalize(mean,std) 229 | 230 | def get_spatial_transform(opt): 231 | assert opt.train_crop in ['random', 'corner', 'center', 'other'] 232 | spatial_transform = [] 233 | if opt.train_crop == 'random': 234 | spatial_transform.append( 235 | RandomResizedCrop( 236 | opt.sample_size, (opt.train_crop_min_scale, 1.0), 237 | (opt.train_crop_min_ratio, 1.0 / opt.train_crop_min_ratio))) 238 | elif opt.train_crop == 'corner': 239 | scales = [1.0] 240 | scale_step = 1 / (2**(1 / 4)) 241 | for _ in range(1, 5): 242 | scales.append(scales[-1] * scale_step) 243 | spatial_transform.append(MultiScaleCornerCrop(opt.sample_size, scales)) 244 | elif opt.train_crop == 'center': 245 | spatial_transform.append(Resize(opt.sample_size)) 246 | spatial_transform.append(CenterCrop(opt.sample_size)) 247 | elif opt.train_crop == 'other': 248 | print('other') 249 | spatial_transform.append(Resize((opt.scale_h, opt.scale_w))) 250 | spatial_transform.append(RandomCrop(opt.sample_size)) 251 | 252 | normalize = get_normalize_method() 253 | if not opt.no_hflip: 254 | spatial_transform.append(RandomHorizontalFlip()) 255 | if opt.colorjitter: 256 | spatial_transform.append(ColorJitter()) 257 | spatial_transform.append(ToTensor()) 258 | 259 | spatial_transform.append(ScaleValue(opt.value_scale)) 260 | spatial_transform.append(normalize) 261 | spatial_transform = Compose(spatial_transform) 262 | 263 | return spatial_transform 264 | 265 | 266 | def get_val_spatial_transforms(opt): 267 | normalize=get_normalize_method() 268 | if opt.train_crop=='other': 269 | spatial_transforms=[ 270 | Resize((opt.scale_h,opt.scale_w)), 271 | RandomCrop(opt.sample_size), 272 | ToTensor() 273 | ] 274 | else: 275 | spatial_transforms=[ 276 | Resize(opt.sample_size), 277 | CenterCrop(opt.sample_size), 278 | ToTensor() 279 | ] 280 | spatial_transforms.extend([ScaleValue(opt.value_scale),normalize]) 281 | spatial_transforms=Compose(spatial_transforms) 282 | 283 | return spatial_transforms -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import torch 4 | import numpy as np 5 | import random 6 | import json 7 | import os 8 | from os.path import join 9 | import sys 10 | from tqdm import tqdm 11 | from torch import nn 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from models.Classifier import Classifier 16 | from config import Config 17 | from dataset.KS import KSDataset 18 | from utils.log_file import Logger 19 | from datetime import datetime 20 | 21 | from dataset.spatial_transforms import get_spatial_transform,get_val_spatial_transforms 22 | from dataset.loader import VideoLoaderHDF5 23 | 24 | from sklearn.metrics import average_precision_score 25 | import torch.nn.functional as F 26 | 27 | 28 | TIMESTAMP = "{0:%Y-%m-%d-%H-%M-%S/}".format(datetime.now()) 29 | 30 | 31 | def get_arguments(): 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument('--use_modulation',action='store_true',help='use gradient modulation') 35 | parser.add_argument('--use_adam_drop',action='store_true',help='use adam-drop') 36 | parser.add_argument('--modulation', default='OGM_GE', type=str,choices=['Normal', 'OGM', 'OGM_GE']) 37 | parser.add_argument('--use_OGM_plus',action='store_true') 38 | parser.add_argument('--fusion_method', default='concat', type=str,choices=['sum', 'concat', 'gated']) 39 | parser.add_argument('--train', action='store_true', help='turn on train mode') 40 | parser.add_argument('--resume_model',action='store_true',help='whether to resume model') 41 | parser.add_argument('--resume_model_path') 42 | parser.add_argument('--q_base',type=float,default=0.5) 43 | parser.add_argument('--lam',type=float,default=0.5) 44 | parser.add_argument('--p_exe',type=float,default=0.7) 45 | parser.add_argument('--alpha',type=float,default=1.0) 46 | parser.add_argument('--modulation_starts',type=int,default=0) 47 | parser.add_argument('--modulation_ends',type=int,default=80) 48 | parser.add_argument('--audio_drop',type=float,default=0.0) 49 | parser.add_argument('--visual_drop',type=float,default=0.0) 50 | parser.add_argument('--exp_name',type=str,default='exp') 51 | 52 | return parser.parse_args() 53 | 54 | 55 | def setup_seed(seed): 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed_all(seed) 58 | np.random.seed(seed) 59 | random.seed(seed) 60 | torch.backends.cudnn.deterministic = True 61 | 62 | 63 | def weight_init(m): 64 | if isinstance(m, nn.Linear): 65 | nn.init.xavier_normal_(m.weight) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 69 | elif isinstance(m, nn.BatchNorm2d): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | weight_a=0.36 74 | weight_v=0.27 75 | weight_av=0.37 76 | 77 | def train(cfg,epoch,model,device,dataloader,optimizer,scheduler,tb=None): 78 | loss_fn=nn.CrossEntropyLoss().to(device) 79 | relu=nn.ReLU(inplace=True) 80 | tanh=nn.Tanh() 81 | model.train() 82 | total_loss=0 83 | total_loss_1=0 84 | total_loss_2=0 85 | with tqdm(total=len(dataloader), desc=f"Train-epoch-{epoch}") as pbar: 86 | for step, (spec,image,label) in enumerate(dataloader): 87 | spec=spec.to(device) # b,h,w 88 | image=image.to(device) # b,c,t,h,w 89 | label=label.to(device) 90 | optimizer.zero_grad() 91 | warm_up=1 if epoch<=5 else 0 92 | # warm_up=0 93 | out_1,out_2,out,update_flag,performance_1,performance_2=model(spec.unsqueeze(1).float(),image.float(),label,warm_up) 94 | 95 | if warm_up==0 and cfg.use_adam_drop: 96 | if torch.sum(update_flag,dim=0)==0: 97 | continue 98 | select_mask=update_flag!=0 99 | label=label[select_mask] 100 | out_1=out_1[select_mask] 101 | out_2=out_2[select_mask] 102 | 103 | 104 | loss=loss_fn(out,label) 105 | loss_1=loss_fn(out_1,label) 106 | loss_2=loss_fn(out_2,label) 107 | total_loss+=loss.item() 108 | total_loss_1+=loss_1.item() 109 | total_loss_2+=loss_2.item() 110 | 111 | # if warm_up==0: 112 | # loss=loss*weight_av+loss_1*weight_a+loss_2*weight_v 113 | 114 | loss.backward() 115 | 116 | if warm_up==0 and cfg.use_modulation: 117 | # log.logger.info('per_1:{} per_2:{} '.format(performance_1,performance_2)) 118 | coeff_1,coeff_2=None,None 119 | radio_1=performance_1/performance_2 120 | radio_2=performance_2/performance_1 121 | # if cfg.form=='/': 122 | # radio_1=performance_1/performance_2 123 | # radio_2=performance_2/performance_1 124 | # else: 125 | # radio_1=performance_1-performance_2 126 | # radio_2=performance_2-performance_1 127 | 128 | if cfg.use_OGM_plus: 129 | if radio_1>1: 130 | # coeff_2=1+tanh(cfg.alpha*relu(radio_1)) 131 | coeff_2=4 132 | coeff_1=1 133 | else: 134 | coeff_2=1 135 | # coeff_1=1+tanh(cfg.alpha*relu(radio_2)) 136 | coeff_1=4 137 | else: 138 | if radio_1>1: 139 | coeff_1=1-tanh(cfg.alpha*relu(radio_1)) 140 | # if cfg.func=='tanh': 141 | # coeff_1=1-tanh(cfg.alpha*relu(radio_1)) 142 | # else: 143 | # coeff_1=1-sigmoid(cfg.alpha*relu(radio_1)) 144 | 145 | coeff_2=1 146 | else: 147 | coeff_1=1 148 | coeff_2=1-tanh(cfg.alpha*relu(radio_2)) 149 | # if cfg.func=='tanh': 150 | # coeff_2=1-tanh(cfg.alpha*relu(radio_2)) 151 | # else: 152 | # coeff_2=1-sigmoid(cfg.alpha*relu(radio_2)) 153 | 154 | if cfg.modulation_starts<=epoch<=cfg.modulation_ends: 155 | for name,parms in model.named_parameters(): 156 | layer_name=str(name).split('.')[0] 157 | if 'encoder_1' in layer_name and parms.grad is not None and len(parms.grad.size()) == 4: 158 | if cfg.modulation == 'OGM_GE': 159 | parms.grad = parms.grad * coeff_1 + \ 160 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8) 161 | elif cfg.modulation == 'OGM': 162 | parms.grad *= coeff_1 163 | 164 | if 'encoder_2' in layer_name and parms.grad is not None and len(parms.grad.size()) == 4: 165 | if cfg.modulation == 'OGM_GE': 166 | parms.grad = parms.grad * coeff_2 + \ 167 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8) 168 | elif cfg.modulation == 'OGM': 169 | parms.grad *= coeff_2 170 | 171 | optimizer.step() 172 | pbar.update(1) 173 | 174 | scheduler.step() 175 | 176 | return total_loss/len(dataloader),total_loss_1/len(dataloader),total_loss_2/len(dataloader) 177 | 178 | 179 | def val(model,device,dataloader): 180 | softmax=nn.Softmax(dim=1) 181 | sum_all=0 182 | sum_1=0 183 | sum_2=0 184 | tot=0 185 | all_out=[] 186 | all_label=[] 187 | with torch.no_grad(): 188 | model.eval() 189 | for step,(spec,img,label) in enumerate(dataloader): 190 | spec=spec.to(device) 191 | img=img.to(device) 192 | label=label.to(device) 193 | out_1,out_2,out,update_flag,performance_1,performance_2=model(spec.unsqueeze(1).float(),img.float(),label,warm_up=1) 194 | prediction=softmax(out) 195 | pred_1=softmax(out_1) 196 | pred_2=softmax(out_2) 197 | tot+=img.shape[0] 198 | sum_all+=torch.sum(torch.argmax(prediction,dim=1)==label).item() 199 | sum_1+=torch.sum(torch.argmax(pred_1,dim=1)==label).item() 200 | sum_2+=torch.sum(torch.argmax(pred_2,dim=1)==label).item() 201 | 202 | for i in range(label.shape[0]): 203 | all_out.append(prediction[i].cpu().data.numpy()) 204 | ss=torch.zeros(31) 205 | ss[label[i]]=1 206 | all_label.append(ss.numpy()) 207 | 208 | all_out=np.array(all_out) 209 | all_label=np.array(all_label) 210 | mAP=average_precision_score(all_label,all_out) 211 | 212 | 213 | return mAP,sum_all/tot,sum_1/tot,sum_2/tot 214 | 215 | 216 | def write2txt(fp,info,mode='a'): 217 | with open(fp,mode=mode) as f: 218 | f.write(info) 219 | f.write('\n') 220 | 221 | 222 | def main(): 223 | # job_id=datetime.now().strftime("%Y-%m-%d-%H-%M-%S") 224 | cfg = Config() 225 | args=get_arguments() 226 | cfg.parse(vars(args)) 227 | setup_seed(cfg.random_seed) 228 | 229 | job_name=args.exp_name 230 | cur_dir=os.path.join('results',job_name) 231 | os.makedirs(cur_dir,exist_ok=True) 232 | 233 | # log=Logger(os.path.join(cur_dir,'log.log'),level='info') 234 | writer=None 235 | if cfg.use_tensorboard: 236 | writer_path=os.path.join(cur_dir,'tensorboard') 237 | os.makedirs(writer_path,exist_ok=True) 238 | writer=SummaryWriter(writer_path) 239 | 240 | saved_data=vars(cfg) 241 | cmd=' '.join(sys.argv) 242 | saved_data.update({'cmd':cmd}) 243 | saved_data=json.dumps(saved_data,indent=4) 244 | with open(os.path.join(cur_dir,'config.json'),'w') as f: 245 | f.write(saved_data) 246 | 247 | device=torch.device('cuda') 248 | 249 | spatial_transforms=get_spatial_transform(opt=cfg) 250 | val_spatial_transforms=get_val_spatial_transforms(opt=cfg) 251 | train_dataset=KSDataset(mode='training',spatial_transform=spatial_transforms,video_loader=VideoLoaderHDF5()) 252 | test_dataset=KSDataset(mode='testing',spatial_transform=val_spatial_transforms,video_loader=VideoLoaderHDF5(),audio_drop=cfg.audio_drop,visual_drop=cfg.visual_drop) 253 | 254 | train_loader=DataLoader(train_dataset,batch_size=cfg.batch_size,shuffle=True,num_workers=32,pin_memory=True) 255 | test_loader=DataLoader(test_dataset,batch_size=cfg.batch_size,shuffle=False,num_workers=32,pin_memory=True) 256 | 257 | model=Classifier(cfg,device=device) 258 | 259 | if cfg.resume_model: 260 | state_dict=torch.load(cfg.resume_model_path,map_location='cuda') 261 | model.load_state_dict(state_dict=state_dict) 262 | else: 263 | model.apply(weight_init) 264 | 265 | model.to(device) 266 | 267 | optimizer=torch.optim.AdamW(model.parameters(),lr=cfg.learning_rate,weight_decay=0.01) 268 | scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=cfg.lr_decay_step,gamma=cfg.lr_decay_ratio) 269 | 270 | start_epoch=-1 271 | best_acc=0.0 272 | logger_path=join(cur_dir,'log.txt') 273 | 274 | if cfg.train: 275 | for epoch in range(start_epoch+1,cfg.epochs): 276 | loss,loss_1,loss_2=train(cfg,epoch,model,device,train_loader,optimizer,scheduler,tb=writer) 277 | mAP,acc,acc_1,acc_2=val(model,device,test_loader) 278 | # log.logger.info('epoch:{} acc:{:.4f} acc_1:{:.4f} acc_2:{:.4f} mAP:{:.4f}'.format(epoch,acc,acc_1,acc_2,mAP)) 279 | write2txt(fp=logger_path,info=f'epoch:{epoch} acc:{acc:.4f} acc_1:{acc_1:.4f} acc_2:{acc_2:.4f} mAP:{mAP:.4f}') 280 | if writer is not None: 281 | writer.add_scalars(main_tag='Loss',tag_scalar_dict={'loss':loss,'loss_1':loss_1,'loss_2':loss_2},global_step=epoch) 282 | writer.add_scalars(main_tag='Acc',tag_scalar_dict={'acc':acc,'acc_1':acc_1,'acc_2':acc_2},global_step=epoch) 283 | 284 | if acc>best_acc: 285 | best_acc=acc 286 | saved_data={} 287 | saved_data['epoch']=epoch 288 | saved_data['acc']=acc 289 | saved_data['mAP']=mAP 290 | saved_data['acc_1']=acc_1 291 | saved_data['acc_2']=acc_2 292 | saved_data=json.dumps(saved_data,indent=4) 293 | 294 | with open(os.path.join(cur_dir,'best_model.json'),'w') as f: 295 | f.write(saved_data) 296 | 297 | torch.save(model.state_dict(),os.path.join(cur_dir,'best_model.pth')) 298 | else: 299 | mAP,acc,acc_1,acc_2=val(model,device,test_loader) 300 | print('mAP:{} Acc:{}'.format(mAP,acc)) 301 | 302 | 303 | if __name__ == "__main__": 304 | main() 305 | -------------------------------------------------------------------------------- /code/models/Audio_Classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from .Resnet_18 import resnet18 5 | 6 | class Classifier(nn.Module): 7 | 8 | def __init__(self,cfg,device='cuda:0'): 9 | super().__init__() 10 | 11 | self.encoder_1=resnet18(modality=cfg.modality[0]) 12 | 13 | self.cfg=cfg 14 | self.device=device 15 | 16 | self.linear=nn.Linear(512,31) 17 | 18 | def forward(self,mod_1): 19 | out_1=self.encoder_1(mod_1) 20 | out_1=F.adaptive_avg_pool2d(out_1,1) 21 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048] 22 | 23 | out_1=self.linear(out_1) 24 | return out_1 25 | 26 | 27 | 28 | if __name__ == "__main__": 29 | pass 30 | -------------------------------------------------------------------------------- /code/models/BasicModule.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | import time 6 | 7 | class BasicModule(nn.Module): 8 | 9 | def __init__(self) -> None: 10 | super().__init__() 11 | self.model_name=str(type(self)) 12 | 13 | def load(self,path): 14 | self.load_state_dict(torch.load(path)) 15 | 16 | def save(self,name=None): 17 | if name is None: 18 | name=time.strftime('checkpoints/'+self.model_name+'_'+'%m%d_%H:%M:%S.pth') 19 | torch.save(self.state_dict(),name) 20 | 21 | return name 22 | 23 | if __name__ == "__main__": 24 | pass 25 | -------------------------------------------------------------------------------- /code/models/Classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from .fusion_model import ConcatFusion,SumFusion,GatedFusion,LMF 7 | from .Resnet_18 import resnet18 8 | 9 | class custom_autograd(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx,input,theta): 13 | ctx.save_for_backward(input,theta) 14 | return input/(1-theta.item()) 15 | 16 | @staticmethod 17 | def backward(ctx,grad_output): 18 | input,theta=ctx.saved_tensors 19 | input_grad=1/(1-theta.item())*grad_output.clone() 20 | 21 | return input_grad,theta 22 | 23 | 24 | class Modality_drop(): 25 | 26 | def __init__(self,dim_list,p_exe=0.7,device='cuda'): 27 | self.dim_list=dim_list 28 | self.p_exe=p_exe 29 | self.device=device 30 | 31 | def execute_drop(self,fead_list,q): 32 | B = fead_list[0].shape[0] 33 | D = fead_list[0].shape[1] 34 | exe_drop = torch.tensor(np.random.rand(1)).to(device=self.device) >= 1-self.p_exe 35 | if not exe_drop: 36 | return fead_list, torch.ones([B],dtype=torch.int32,device=self.device) 37 | 38 | num_mod=len(fead_list) 39 | d_sum=sum(self.dim_list) 40 | q_sum=sum(self.dim_list*q) 41 | theta=q_sum/d_sum 42 | # p_sum=sum(self.dim_list*(1-q)) 43 | # theta=p_sum/d_sum 44 | 45 | mask=torch.distributions.Bernoulli(1-q).sample([B,1]).permute(2,1,0).contiguous().reshape(num_mod,B,-1).to(device=self.device) # [2,B,1] 46 | # print(f'mask:{mask}') 47 | concat_list=torch.stack(fead_list,dim=0) # [2,B,D] 48 | concat_list=torch.mul(concat_list,mask) 49 | concat_list=custom_autograd.apply(concat_list,theta) 50 | mask=torch.transpose(mask,0,1).squeeze(-1) # [B,2] 51 | update_flag=torch.sum(mask,dim=1)>0 52 | cleaned_fea=torch.masked_select(concat_list,update_flag.unsqueeze(-1)).reshape(num_mod,-1,D) 53 | cleaned_fea=torch.chunk(cleaned_fea,num_mod,dim=0) ] 54 | cleaned_fea=[_.squeeze(0) for _ in cleaned_fea] # [B,D] 55 | return cleaned_fea,update_flag 56 | 57 | 58 | def calcu_q(performance_1,performance_2,q_base,fix_lambda): 59 | q=torch.tensor([0.0,0.0]) 60 | relu = nn.ReLU(inplace=True) 61 | ratio_1=torch.tanh(relu(performance_1/performance_2-1)) 62 | ratio_2=torch.tanh(relu(performance_2/performance_1-1)) 63 | 64 | lamda = fix_lambda 65 | 66 | 67 | q[0] = q_base * (1 + lamda * ratio_1) if ratio_1>0 else 0 68 | q[1] = q_base * (1 + lamda * ratio_2) if ratio_2>0 else 0 69 | 70 | q=torch.clip(q,0.0,1.0) 71 | 72 | return q 73 | 74 | 75 | class Classifier(nn.Module): 76 | 77 | def __init__(self,cfg,device='cuda'): 78 | super().__init__() 79 | 80 | self.encoder_1=resnet18(modality='audio') 81 | self.encoder_2=resnet18(modality='visual') 82 | 83 | self.cfg=cfg 84 | self.device=device 85 | 86 | self.softmax=nn.Softmax(dim=1) 87 | self.fusion_model=ConcatFusion(in_c_x=512,in_c_y=512,out_c=31) 88 | 89 | if self.cfg.use_adam_drop: 90 | self.modality_drop=Modality_drop(dim_list=torch.tensor(self.cfg.d),p_exe=self.cfg.p_exe,device=self.device) 91 | 92 | 93 | def forward(self,mod_1,mod_2,label,warm_up=1): 94 | out_1=self.encoder_1(mod_1) 95 | out_2=self.encoder_2(mod_2) # [B,T,C,H,W]--> [B,2048,2,2] 96 | 97 | _,C,H,W=out_2.shape 98 | B=out_1.shape[0] 99 | 100 | out_2=out_2.reshape(B,-1,C,H,W).permute(0,2,1,3,4) 101 | 102 | out_1=F.adaptive_avg_pool2d(out_1,1) 103 | out_2=F.adaptive_avg_pool3d(out_2,1) 104 | 105 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048] 106 | out_2=out_2.squeeze(2).squeeze(2).squeeze(2) # [B,2048] 107 | 108 | performance_1=None 109 | performance_2=None 110 | t1,t2=None,None 111 | 112 | w=self.fusion_model.fxy.weight.clone().detach() 113 | b=self.fusion_model.fxy.bias.clone().detach() 114 | 115 | # if self.cfg.t1_bias==0.5: 116 | # t1_bias=b/2 117 | # elif self.cfg.t1_bias==0.0: 118 | # t1_bias=0.0 119 | # elif self.cfg.t1_bias==0.3: 120 | # t1_bias=b/3 121 | # elif self.cfg.t1_bias==0.6: 122 | # t1_bias=2*b/3 123 | # elif self.cfg.t1_bias==1.0: 124 | # t1_bias=b 125 | t1_bias=b/2 126 | 127 | # if self.cfg.t2_bias==0.5: 128 | # t2_bias=b/2 129 | # elif self.cfg.t2_bias==0.0: 130 | # t2_bias=0.0 131 | # elif self.cfg.t2_bias==0.3: 132 | # t2_bias=b/3 133 | # elif self.cfg.t2_bias==0.6: 134 | # t2_bias=2*b/3 135 | # elif self.cfg.t2_bias==1.0: 136 | # t2_bias=b 137 | t2_bias=b/2 138 | 139 | t1=torch.mm(out_1,torch.transpose(w[:,:512],0,1))+t1_bias 140 | t2=torch.mm(out_2,torch.transpose(w[:,512:],0,1))+t2_bias 141 | 142 | performance_1=sum([self.softmax(t1)[i][int(label[i].item())] for i in range(t1.shape[0])]) 143 | performance_2=sum([self.softmax(t2)[i][int(label[i].item())] for i in range(t2.shape[0])]) 144 | 145 | if warm_up==0 and self.cfg.use_adam_drop: 146 | self.q=calcu_q(performance_1,performance_2,self.cfg.q_base,fix_lambda=self.cfg.lam) 147 | cleaned_fea,update_flag=self.modality_drop.execute_drop([out_1,out_2],self.q) 148 | cleaned_fae_1,cleaned_fea_2,out=self.fusion_model(cleaned_fea[0],cleaned_fea[1]) 149 | return t1,t2,out,update_flag,performance_1,performance_2 150 | 151 | else: 152 | x,y,out=self.fusion_model(out_1,out_2) 153 | return t1,t2,out,torch.ones([B],dtype=torch.int32,device=self.device),performance_1,performance_2 154 | 155 | 156 | 157 | class AVClassifier_gb(nn.Module): 158 | def __init__(self, n_classes): 159 | super(AVClassifier_gb, self).__init__() 160 | self.n_classes = n_classes 161 | 162 | self.encoder_1=resnet18(modality='audio') 163 | self.encoder_2=resnet18(modality='visual') 164 | 165 | self.fusion_model = ConcatFusion(512,512,31) 166 | 167 | self.audio_head = nn.Linear(512, n_classes) 168 | self.visual_head = nn.Linear(512, n_classes) 169 | 170 | 171 | def forward(self, audio, visual): 172 | out_1=self.encoder_1(audio) 173 | out_2=self.encoder_2(visual) # [B,T,C,H,W]--> [B,2048,2,2] 174 | 175 | _,C,H,W=out_2.shape 176 | B=out_1.shape[0] 177 | 178 | out_2=out_2.reshape(B,-1,C,H,W).permute(0,2,1,3,4) 179 | 180 | out_1=F.adaptive_avg_pool2d(out_1,1) 181 | out_2=F.adaptive_avg_pool3d(out_2,1) 182 | 183 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048] 184 | out_2=out_2.squeeze(2).squeeze(2).squeeze(2) 185 | 186 | x,y,out=self.fusion_model(out_1,out_2) 187 | return x,y,out 188 | 189 | -------------------------------------------------------------------------------- /code/models/Resnet_18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 19 | base_width=64, dilation=1, norm_layer=None): 20 | super(BasicBlock, self).__init__() 21 | if norm_layer is None: 22 | norm_layer = nn.BatchNorm2d 23 | if groups != 1 or base_width != 64: 24 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 25 | if dilation > 1: 26 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 27 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = norm_layer(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = norm_layer(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | identity = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class ResNet(nn.Module): 56 | 57 | def __init__(self, block, layers, modality, num_classes=1000, pool='avgpool', zero_init_residual=False, 58 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 59 | norm_layer=None): 60 | super(ResNet, self).__init__() 61 | self.modality = modality 62 | self.pool = pool 63 | if norm_layer is None: 64 | norm_layer = nn.BatchNorm2d 65 | self._norm_layer = norm_layer 66 | 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | # each element in the tuple indicates if we should replace 71 | # the 2x2 stride with a dilated convolution instead 72 | replace_stride_with_dilation = [False, False, False] 73 | if len(replace_stride_with_dilation) != 3: 74 | raise ValueError("replace_stride_with_dilation should be None " 75 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 76 | self.groups = groups 77 | self.base_width = width_per_group 78 | if modality == 'audio': 79 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 80 | bias=False) 81 | elif modality == 'visual': 82 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 83 | bias=False) 84 | elif modality=='optical': 85 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3, 86 | bias=False) 87 | else: 88 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality)) 89 | self.bn1 = norm_layer(self.inplanes) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(block, 64, layers[0]) 93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 94 | dilate=replace_stride_with_dilation[0]) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 96 | dilate=replace_stride_with_dilation[1]) 97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 98 | dilate=replace_stride_with_dilation[2]) 99 | # if self.pool == 'avgpool': 100 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 101 | # 102 | # self.fc = nn.Linear(512 * block.expansion, num_classes) # 8192 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 108 | nn.init.normal_(m.weight, mean=1, std=0.02) 109 | nn.init.constant_(m.bias, 0) 110 | 111 | # Zero-initialize the last BN in each residual branch, 112 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 113 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, Bottleneck): 117 | nn.init.constant_(m.bn3.weight, 0) 118 | elif isinstance(m, BasicBlock): 119 | nn.init.constant_(m.bn2.weight, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 122 | norm_layer = self._norm_layer 123 | downsample = None 124 | previous_dilation = self.dilation 125 | if dilate: 126 | self.dilation *= stride 127 | stride = 1 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | conv1x1(self.inplanes, planes * block.expansion, stride), 131 | norm_layer(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 136 | self.base_width, previous_dilation, norm_layer)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append(block(self.inplanes, planes, groups=self.groups, 140 | base_width=self.base_width, dilation=self.dilation, 141 | norm_layer=norm_layer)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def forward(self, x): 146 | 147 | if self.modality == 'visual': 148 | (B, C, T, H, W) = x.size() 149 | x = x.permute(0, 2, 1, 3, 4).contiguous() 150 | x = x.view(B * T, C, H, W) 151 | 152 | x = self.conv1(x) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | x = self.layer4(x) 161 | out = x 162 | 163 | return out 164 | 165 | 166 | class Bottleneck(nn.Module): 167 | expansion = 4 168 | 169 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 170 | base_width=64, dilation=1, norm_layer=None): 171 | super(Bottleneck, self).__init__() 172 | if norm_layer is None: 173 | norm_layer = nn.BatchNorm2d 174 | width = int(planes * (base_width / 64.)) * groups 175 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 176 | self.conv1 = conv1x1(inplanes, width) 177 | self.bn1 = norm_layer(width) 178 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 179 | self.bn2 = norm_layer(width) 180 | self.conv3 = conv1x1(width, planes * self.expansion) 181 | self.bn3 = norm_layer(planes * self.expansion) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.downsample = downsample 184 | self.stride = stride 185 | 186 | def forward(self, x): 187 | identity = x 188 | 189 | out = self.conv1(x) 190 | out = self.bn1(out) 191 | out = self.relu(out) 192 | 193 | out = self.conv2(out) 194 | out = self.bn2(out) 195 | out = self.relu(out) 196 | 197 | out = self.conv3(out) 198 | out = self.bn3(out) 199 | 200 | if self.downsample is not None: 201 | identity = self.downsample(x) 202 | 203 | out += identity 204 | out = self.relu(out) 205 | 206 | return out 207 | 208 | 209 | def _resnet(arch, block, layers, modality, progress, **kwargs): 210 | model = ResNet(block, layers, modality, **kwargs) 211 | return model 212 | 213 | 214 | def resnet18(modality, progress=True, **kwargs): 215 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], modality, progress, 216 | **kwargs) 217 | -------------------------------------------------------------------------------- /code/models/Visual_Classifier.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | from .Resnet_18 import resnet18 6 | 7 | class Classifier(nn.Module): 8 | 9 | def __init__(self,cfg,device='cuda:0'): 10 | super().__init__() 11 | 12 | self.encoder_2=resnet18(modality=cfg.modality[1]) 13 | 14 | self.cfg=cfg 15 | self.device=device 16 | 17 | self.linear=nn.Linear(512,31) 18 | 19 | def forward(self,mod_1): 20 | B=mod_1.shape[0] 21 | out_1=self.encoder_2(mod_1) 22 | 23 | _,C,H,W=out_1.shape 24 | out_1=out_1.reshape(B,-1,C,H,W).permute(0,2,1,3,4) 25 | out_1=F.adaptive_avg_pool3d(out_1,1) 26 | out_1=torch.flatten(out_1,1) 27 | 28 | out_1=self.linear(out_1) 29 | return out_1 30 | 31 | 32 | 33 | if __name__ == "__main__": 34 | pass 35 | -------------------------------------------------------------------------------- /code/models/fusion_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class SumFusion(nn.Module): 5 | 6 | def __init__(self,in_c_x,out_c_x,in_c_y,out_c_y) -> None: 7 | super().__init__() 8 | 9 | self.fx=nn.Linear(in_c_x,out_c_x) 10 | self.fy=nn.Linear(in_c_y,out_c_y) 11 | 12 | def forward(self,x,y): 13 | out=self.fx(x)+self.fy(y) 14 | return x,y,out 15 | 16 | class ConcatFusion(nn.Module): 17 | 18 | def __init__(self,in_c_x,in_c_y,out_c) -> None: 19 | super().__init__() 20 | self.fxy=nn.Linear(in_c_x+in_c_y,out_c) 21 | 22 | def forward(self,x,y): 23 | out=torch.cat([x,y],dim=1) 24 | out=self.fxy(out) 25 | return x,y,out 26 | 27 | class GatedFusion(nn.Module): 28 | 29 | def __init__(self,in_c_x,in_c_y,mid_c,out_c,x_gate=True) -> None: 30 | super().__init__() 31 | 32 | self.fx=nn.Linear(in_c_x,mid_c) 33 | self.fy=nn.Linear(in_c_y,mid_c) 34 | self.f_out=nn.Linear(mid_c,out_c) 35 | 36 | self.x_gate=x_gate 37 | self.sigmoid=nn.Sigmoid() 38 | 39 | def forward(self,x,y): 40 | out_x=self.fx(x) 41 | out_y=self.fy(y) 42 | 43 | if self.x_gate: 44 | gate=self.sigmoid(out_x) 45 | out=self.f_out(torch.mul(gate,out_y)) 46 | else: 47 | gate=self.sigmoid(out_y) 48 | out=self.f_out(torch.mul(out_x,gate)) 49 | 50 | return out_x,out_y,out 51 | 52 | from torch.autograd import Variable 53 | from torch.nn.parameter import Parameter 54 | class LMF(nn.Module): 55 | 56 | def __init__(self,rank=4,hidden_dim=512,out_dim=31,device='cuda:0'): 57 | super().__init__() 58 | self.device=device 59 | self.rank=rank 60 | self.hidden_dim=hidden_dim 61 | self.out_dim=out_dim 62 | self.x_factor=Parameter(torch.Tensor(self.rank,self.hidden_dim+1,self.out_dim)).to(device) # r,d+1,cls 63 | self.y_factor=Parameter(torch.Tensor(self.rank,self.hidden_dim+1,self.out_dim)).to(device) 64 | self.fusion_weights=Parameter(torch.Tensor(1,self.rank)).to(device) # 1,r 65 | self.fusion_bias=Parameter(torch.Tensor(1,self.out_dim)).to(device) 66 | 67 | torch.nn.init.xavier_normal_(self.x_factor) 68 | torch.nn.init.xavier_normal_(self.y_factor) 69 | torch.nn.init.xavier_normal_(self.fusion_weights) 70 | self.fusion_bias.data.fill_(0) 71 | 72 | def forward(self,x,y): 73 | b=x.shape[0] 74 | _x=torch.cat((Variable(torch.ones(b,1).to(self.device),requires_grad=False),x),dim=1) # b,d+1 75 | _y=torch.cat((Variable(torch.ones(b,1).to(self.device),requires_grad=False),y),dim=1) 76 | 77 | fusion_x=torch.matmul(_x,self.x_factor) # r,b,cls 78 | fusion_y=torch.matmul(_y,self.y_factor) 79 | fusion_zy=fusion_x*fusion_y 80 | 81 | output=torch.matmul(self.fusion_weights,fusion_zy.permute(1,0,2)).squeeze()+self.fusion_bias # b,cls 82 | # output=output.view(-1,self.out_dim) 83 | 84 | return output,x,y 85 | 86 | if __name__ == "__main__": 87 | net=GatedFusion(10,10,10,20) 88 | x=torch.zeros([1,10]) 89 | y=torch.zeros([1,10]) 90 | x_out,y_out,z=net(x,y) 91 | print(x_out.shape,y_out.shape) # torch.Size([1, 10]) torch.Size([1, 10]) 92 | print(z.shape) # torch.Size([1, 20]) 93 | 94 | print(net.weight) -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.31.0 3 | aiohttp==3.9.1 4 | aiosignal==1.3.1 5 | alembic==1.13.1 6 | annotated-types==0.6.0 7 | antlr4-python3-runtime==4.9.3 8 | anyio==4.2.0 9 | asteroid-filterbanks==0.4.0 10 | async-timeout==4.0.3 11 | attrs==23.2.0 12 | audioread==3.0.1 13 | certifi==2023.11.17 14 | cffi==1.16.0 15 | charset-normalizer==3.3.2 16 | click==8.1.7 17 | colorama==0.4.6 18 | colorlog==6.8.0 19 | contourpy==1.2.0 20 | cycler==0.12.1 21 | decorator==4.4.2 22 | decord==0.6.0 23 | deepspeed==0.14.3 24 | distro==1.9.0 25 | docopt==0.6.2 26 | einops==0.7.0 27 | exceptiongroup==1.2.0 28 | filelock==3.13.1 29 | fonttools==4.47.2 30 | frozenlist==1.4.1 31 | fsspec==2023.12.2 32 | greenlet==3.0.3 33 | grpcio==1.63.0 34 | h11==0.14.0 35 | h5py==3.11.0 36 | hjson==3.1.0 37 | httpcore==1.0.2 38 | httpx==0.26.0 39 | huggingface-hub==0.23.4 40 | HyperPyYAML==1.2.2 41 | idna==3.6 42 | imageio==2.34.1 43 | imageio-ffmpeg==0.5.1 44 | importlib-resources==6.1.1 45 | importlib_metadata==7.1.0 46 | Jinja2==3.1.3 47 | joblib==1.3.2 48 | julius==0.2.7 49 | kiwisolver==1.4.5 50 | lazy_loader==0.3 51 | librosa==0.10.1 52 | lightning==2.1.3 53 | lightning-utilities==0.10.1 54 | llvmlite==0.41.1 55 | Mako==1.3.0 56 | Markdown==3.6 57 | markdown-it-py==3.0.0 58 | MarkupSafe==2.1.4 59 | matplotlib==3.8.2 60 | mdurl==0.1.2 61 | more-itertools==10.2.0 62 | moviepy==1.0.3 63 | mpmath==1.3.0 64 | msgpack==1.0.7 65 | multidict==6.0.4 66 | networkx==3.2.1 67 | ninja==1.11.1.1 68 | numba==0.58.1 69 | numpy==1.26.3 70 | nvidia-cublas-cu12==12.1.3.1 71 | nvidia-cuda-cupti-cu12==12.1.105 72 | nvidia-cuda-nvrtc-cu12==12.1.105 73 | nvidia-cuda-runtime-cu12==12.1.105 74 | nvidia-cudnn-cu12==8.9.2.26 75 | nvidia-cufft-cu12==11.0.2.54 76 | nvidia-curand-cu12==10.3.2.106 77 | nvidia-cusolver-cu12==11.4.5.107 78 | nvidia-cusparse-cu12==12.1.0.106 79 | nvidia-ml-py==12.555.43 80 | nvidia-nccl-cu12==2.18.1 81 | nvidia-nvjitlink-cu12==12.3.101 82 | nvidia-nvtx-cu12==12.1.105 83 | omegaconf==2.3.0 84 | openai==1.9.0 85 | openai-whisper==20231117 86 | opencv-python==4.9.0.80 87 | optuna==3.5.0 88 | packaging==23.2 89 | pandas==2.2.0 90 | peft==0.3.0 91 | pillow==10.2.0 92 | platformdirs==4.1.0 93 | pooch==1.8.0 94 | primePy==1.3 95 | proglog==0.1.10 96 | protobuf==4.25.2 97 | psutil==6.0.0 98 | py-cpuinfo==9.0.0 99 | pyannote.core==5.0.0 100 | pyannote.database==5.0.1 101 | pyannote.metrics==3.2.1 102 | pyannote.pipeline==3.0.1 103 | pycparser==2.21 104 | pydantic==2.5.3 105 | pydantic_core==2.14.6 106 | Pygments==2.17.2 107 | pyparsing==3.1.1 108 | python-dateutil==2.8.2 109 | pytorch-lightning==2.1.3 110 | pytorch-metric-learning==2.4.1 111 | pytube==15.0.0 112 | pytz==2023.3.post1 113 | PyYAML==6.0.1 114 | regex==2023.12.25 115 | requests==2.31.0 116 | rich==13.7.0 117 | ruamel.yaml==0.18.5 118 | ruamel.yaml.clib==0.2.8 119 | safetensors==0.4.3 120 | scikit-learn==1.4.0 121 | scipy==1.11.4 122 | semver==3.0.2 123 | sentencepiece==0.1.99 124 | shellingham==1.5.4 125 | six==1.16.0 126 | sniffio==1.3.0 127 | sortedcontainers==2.4.0 128 | soundfile==0.12.1 129 | soxr==0.3.7 130 | speechbrain==0.5.16 131 | SQLAlchemy==2.0.25 132 | sympy==1.12 133 | tabulate==0.9.0 134 | tensorboard==2.16.2 135 | tensorboard-data-server==0.7.2 136 | tensorboardX==2.6.2.2 137 | threadpoolctl==3.2.0 138 | tiktoken==0.5.2 139 | timm==0.9.12 140 | tokenizers==0.19.1 141 | torch==1.13.1+cu116 142 | torch-audiomentations==0.11.0 143 | torch-pitch-shift==1.2.4 144 | torchaudio==0.13.1+cu116 145 | torchmetrics==1.3.0.post0 146 | torchvision==0.14.1+cu116 147 | tqdm==4.66.1 148 | transformers==4.41.2 149 | triton==2.1.0 150 | typer==0.9.0 151 | typing_extensions==4.9.0 152 | tzdata==2023.4 153 | urllib3==2.1.0 154 | Werkzeug==3.0.3 155 | yarl==1.9.4 156 | youtube-dl==2021.12.17 157 | zipp==3.17.0 158 | -------------------------------------------------------------------------------- /code/scripts/inference.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --resume_model \ 3 | --resume_model_path 'ckpt_path' 4 | 5 | -------------------------------------------------------------------------------- /code/scripts/train_ogm.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --train \ 3 | --use_modulation \ 4 | --fusion_method concat \ 5 | --alpha 0.8 \ 6 | --modulation_starts 0 \ 7 | --modulation_ends 60 8 | 9 | -------------------------------------------------------------------------------- /code/scripts/train_opm.sh: -------------------------------------------------------------------------------- 1 | python main.py \ 2 | --train \ 3 | --use_adam_drop \ 4 | --fusion_method concat \ 5 | --q_base 0.5 \ 6 | --lam 0.5 \ 7 | --p_exe 0.7 8 | --------------------------------------------------------------------------------