├── .DS_Store ├── OGM.jpg ├── OPM.jpg ├── README.md └── code ├── config.py ├── dataset ├── KS.json ├── KS.py ├── KS_train_val.json ├── loader.py └── spatial_transforms.py ├── main.py ├── models ├── Audio_Classifier.py ├── BasicModule.py ├── Classifier.py ├── Resnet_18.py ├── Visual_Classifier.py └── fusion_model.py ├── requirements.txt └── scripts ├── inference.sh ├── train_ogm.sh └── train_opm.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/.DS_Store -------------------------------------------------------------------------------- /OGM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/OGM.jpg -------------------------------------------------------------------------------- /OPM.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BML_TPAMI2024/4b3aa4fd841856c4acbebf48549e6c59fbf7635e/OPM.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code of On-the-fly Modulation for Balanced Multimodal Learning 2 | The repo for "On-the-fly Modulation for Balanced Multimodal Learning", T-PAMI 2024 3 | 4 | Here is the official PyTorch implementation of ''*On-the-fly Modulation for Balanced Multimodal Learning*'', which analyze and alleviate the imbalanced multimodal learning problem from both the feed-forward and the back-propagation stages during optimization Please refer to our [T-PAMI 2024 paper](https://ieeexplore.ieee.org/abstract/document/10694738) for more details. This journal paper is extension of our previous CVPR 2022 paper [\[Balanced Multimodal Learning via On-the-fly Gradient Modulation\]](https://arxiv.org/abs/2203.15332). 5 | 6 | **Paper Title: "On-the-fly Modulation for Balanced Multimodal Learning"** 7 | 8 | **Authors: [Yake Wei](https://echo0409.github.io/), [Di Hu](https://dtaoo.github.io/index.html), Henghui Du and Ji-Rong Wen** 9 | 10 | 11 | ## On-the-fly Modulation for Balanced Multimodal Learning 12 | Multimodal learning is expected to boost model performance by integrating information from different modalities. However, its potential is not fully exploited because the widely-used joint training strategy, which has a uniform objective for all modalities, leads to imbalanced and under-optimized uni-modal representations. Specifically, we point out that there often exists modality with more discriminative information, e.g., vision of playing football and sound of blowing wind. They could dominate the joint training process, resulting in other modalities being significantly under-optimized. 13 | 14 | To alleviate this problem, we first analyze the under-optimized phenomenon from both the feed-forward and the back-propagation stages during optimization. Then, **On-the-fly Prediction Modulation (OPM)** and **On-the-fly Gradient Modulation (OGM)** strategies are proposed to modulate the optimization of each modality, by monitoring the discriminative discrepancy between modalities during training. Concretely, OPM weakens the influence of the dominant modality by dropping its feature with dynamical probability in the feed-forward stage, while OGM mitigates its gradient in the back-propagation stage. In experiments, our methods demonstrate considerable improvement across a variety of multimodal tasks. These simple yet effective strategies not only enhance performance in vanilla and task-oriented multimodal models, but also in more complex multimodal tasks, showcasing their effectiveness and flexibility. 15 | 16 | 17 |
Pipeline of OPM method.
20 |Pipeline of OGM method.
26 |
74 | @article{wei2024on,
75 | title={On-the-fly modulation for balanced multimodal learning},
76 | author={Wei, Yake and Hu, Di and Du, Henghui and Wen, Ji-Rong},
77 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
78 | year={2024}
79 | }
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/code/config.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 |
4 |
5 | class Config():
6 |
7 | def __init__(self) -> None:
8 |
9 | # dataset setting
10 | self.dataset='KineticSound'
11 | self.num_classes={'VGGSound':309,'KineticSound':31,'CREMAD':6,'AVE':28}
12 | self.modality=['audio','visual']
13 | self.fps = 1
14 | self.use_video_frames = 3
15 |
16 | # backbone setting
17 | self.in_c=3
18 | self.out_c=64
19 |
20 | # train setting
21 | self.train = False
22 | self.batch_size = 32
23 | self.epochs=100
24 | self.optimizer='Adamw'
25 |
26 | self.learning_rate=5e-5
27 | self.lr_decay_ratio=0.1
28 | # self.lr_decay_step=[30,50,70]
29 | self.lr_decay_step=40
30 |
31 | # modulation setting
32 | self.use_modulation=False
33 | self.modulation = 'OGM_GE'
34 | self.modulation_starts = 0
35 | self.modulation_ends = 80
36 |
37 | self.alpha = 1
38 |
39 | # fusion setting
40 | self.fusion_method = 'concat'
41 | self.d = [512, 512]
42 | # gated_fusion
43 | self.mid_c=512
44 | self.x_gated=False
45 |
46 | # adam-drop lambda setting
47 | self.use_adam_drop = False
48 | self.key=50
49 | self.sigma=2
50 |
51 | self.p_exe=0.7
52 | self.q_base=0.4
53 | self.lam=0.5
54 |
55 | # other setting
56 | self.checkpoint_path = 'result'
57 |
58 | self.resume_model=False
59 | self.resume_model_path=None
60 |
61 | self.use_tensorboard = True
62 |
63 | self.random_seed = 0
64 | self.gpu_ids = [0,1]
65 |
66 | self.func='tanh'
67 | self.form='/'
68 |
69 | self.device=0
70 |
71 |
72 | # transforms setting
73 | self.decrease_epoch=10
74 | self.sample_size=112
75 | self.sample_t_stride=1
76 | self.train_crop='random'
77 | self.value_scale=1
78 | self.scale_h=128
79 | self.scale_w=171
80 | self.train_crop_min_scale=0.25
81 | self.train_crop_min_ratio=0.75
82 | self.no_hflip=False
83 | self.colorjitter=False
84 | self.train_t_crop='random'
85 |
86 | self.audio_drop=0.0
87 | self.visual_drop=0.0
88 |
89 |
90 | def parse(self,kwargs):
91 | for k,v in kwargs.items():
92 | if not hasattr(self,k):
93 | warnings.warn('has not attribute %s'%k)
94 | setattr(self,k,v)
95 |
96 | # print('config info:')
97 | # for k,v in self.__dict__.items():
98 | # if not k.startswith('__'):
99 | # print(k,getattr(self,k))
100 |
101 | if __name__ == "__main__":
102 | import argparse
103 | cfg=Config()
104 |
105 | parser = argparse.ArgumentParser()
106 |
107 | parser.add_argument('--use_modulation',action='store_true',help='use gradient modulation')
108 | parser.add_argument('--use_adam_drop',action='store_true',help='use adam-drop')
109 | parser.add_argument('--modulation', default='OGM_GE', type=str,choices=['Normal', 'OGM', 'OGM_GE'])
110 | parser.add_argument('--fusion_method', default='concat', type=str,choices=['sum', 'concat', 'gated'])
111 | parser.add_argument('--train', action='store_true', help='turn on train mode')
112 | parser.add_argument('--resume_model',action='store_true',help='whether to resume model')
113 | parser.add_argument('--checkpoint_path',type=str,help='load checkpoints')
114 |
115 | args=parser.parse_args()
116 | cfg.parse(vars(args))
117 |
118 |
--------------------------------------------------------------------------------
/code/dataset/KS.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import torch.utils.data as data
5 | from pathlib import Path
6 | from random import randrange
7 | import numpy as np
8 | import h5py
9 | import pickle
10 |
11 |
12 | from .loader import VideoLoaderHDF5
13 | from .loader import AudioFeatureLoader
14 | from .spatial_transforms import get_spatial_transform,get_val_spatial_transforms
15 | HDF5_DIR=''
16 | PKL_DIR=''
17 | PROJECT_DIR=''
18 |
19 | def get_dataset(annotation_data,mode):
20 | video_names = []
21 | video_labels = []
22 |
23 | for key in annotation_data.keys():
24 | if annotation_data[key]['subset'] == mode:
25 | video_names.append(key)
26 | video_labels.append(annotation_data[key]['label'])
27 | return video_names,video_labels
28 |
29 |
30 | class KSDataset(data.Dataset):
31 | def __init__(self,
32 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'),
33 | mode='training',
34 | spatial_transform=None,
35 | video_loader = None,
36 | audio_drop=0.0,
37 | visual_drop=0.0
38 | ):
39 |
40 | self.video_dir = HDF5_DIR
41 | self.audio_dir = PKL_DIR
42 |
43 | self.audio_drop=audio_drop
44 | self.visual_drop=visual_drop
45 |
46 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode)
47 |
48 | self.spatial_transform = spatial_transform
49 |
50 | self.loader = video_loader
51 |
52 |
53 | def __make_dataset(self,video_dir,annotation_path,subset):
54 | with open(annotation_path) as f:
55 | annotation_data = json.load(f)
56 | class_labels = annotation_data['labels']
57 | annotation_data = annotation_data['database']
58 |
59 | video_names , video_labels = get_dataset(annotation_data,subset)
60 |
61 | class_to_idx = {label : i for i,label in enumerate(class_labels)}
62 | idx_to_class = {i : label for i,label in enumerate(class_labels)}
63 |
64 | n_videos = len(video_names)
65 |
66 | dataset = []
67 | max_len = 0
68 |
69 | for i in range(n_videos):
70 |
71 | label = video_labels[i]
72 | label_id = class_to_idx[label]
73 |
74 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5")
75 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl")
76 | if not os.path.exists(video_path) or not os.path.exists(audio_path):
77 | continue
78 |
79 | sample = {
80 | 'video': video_names[i],
81 | 'label': label_id,
82 | }
83 |
84 | dataset.append(sample)
85 | return dataset,idx_to_class,n_videos
86 |
87 | def add_mask_visual(self, image, ratio):
88 | patch_w = 10
89 | patch_l = 10
90 | w_num = int(224 / patch_w)
91 | l_num = int(224 / patch_l)
92 | total_num = w_num * l_num
93 | patch_num = int(total_num * ratio)
94 | # print(total_num, patch_num)
95 | patch_list = np.random.choice(total_num, patch_num, replace=False)
96 | for index in patch_list:
97 | patch_x = index % w_num * patch_w
98 | patch_y = int(index / w_num) * patch_l
99 |
100 | image[:, patch_x:patch_x+patch_w, patch_y:patch_y+patch_l] = 0.0
101 |
102 | return image
103 |
104 | def add_mask_audio(self, image, ratio):
105 | patch_w = 10
106 | patch_l = 10
107 | w_num = int(224 / patch_w)
108 | l_num = int(224 / patch_l)
109 | total_num = w_num * l_num
110 | patch_num = int(total_num * ratio)
111 | # print(total_num, patch_num)
112 | patch_list = np.random.choice(total_num, patch_num, replace=False)
113 | for index in patch_list:
114 | patch_x = index % w_num * patch_w
115 | patch_y = int(index / w_num) * patch_l
116 |
117 | image[patch_x:patch_x+patch_w, patch_y:patch_y+patch_l] = 0.0
118 |
119 | return image
120 |
121 | def __len__(self):
122 | return len(self.dataset)
123 |
124 |
125 | def __loading(self, path, video_name):
126 |
127 | clip=None
128 | try:
129 | clip = self.loader(path)
130 | except Exception as e:
131 | print("path {} has error".format(path))
132 |
133 | len_clip = len(clip)
134 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]]
135 |
136 | if self.spatial_transform is not None:
137 | self.spatial_transform.randomize_parameters()
138 | clip = [self.spatial_transform(img) for img in clip]
139 | if self.visual_drop>0.0:
140 | clip=[self.add_mask_visual(img,self.visual_drop) for img in clip]
141 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) # c t h w
142 | return clip
143 |
144 | def __load_audio(self,audio_path):
145 | with open(audio_path,"rb") as f:
146 | audio = pickle.load(f)
147 |
148 | if self.audio_drop>0.0:
149 | audio=self.add_mask_audio(audio,self.audio_drop)
150 | return audio
151 |
152 | def __getitem__(self, index):
153 |
154 | video_name = self.dataset[index]['video']
155 |
156 | video_path = os.path.join(self.video_dir,video_name + ".hdf5")
157 | label = self.dataset[index]['label']
158 |
159 | clip = self.__loading(video_path, video_name)
160 |
161 | audio_path = os.path.join(self.audio_dir,video_name + ".pkl")
162 |
163 | audio = self.__load_audio(audio_path)
164 |
165 | return audio,clip,label
166 |
167 |
168 |
169 | class VisualDataset(data.Dataset):
170 | def __init__(self,
171 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'),
172 | mode='training',
173 | spatial_transform=None,
174 | video_loader = None
175 | ):
176 |
177 | self.video_dir = HDF5_DIR
178 | self.audio_dir = PKL_DIR
179 |
180 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode)
181 |
182 | self.spatial_transform = spatial_transform
183 |
184 | self.loader = video_loader
185 |
186 |
187 | def __make_dataset(self,video_dir,annotation_path,subset):
188 | with open(annotation_path) as f:
189 | annotation_data = json.load(f)
190 | class_labels = annotation_data['labels']
191 | annotation_data = annotation_data['database']
192 |
193 | video_names , video_labels = get_dataset(annotation_data,subset)
194 |
195 | class_to_idx = {label : i for i,label in enumerate(class_labels)}
196 | idx_to_class = {i : label for i,label in enumerate(class_labels)}
197 |
198 | n_videos = len(video_names)
199 |
200 | dataset = []
201 | max_len = 0
202 |
203 | for i in range(n_videos):
204 |
205 | label = video_labels[i]
206 | label_id = class_to_idx[label]
207 |
208 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5")
209 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl")
210 | if not os.path.exists(video_path) or not os.path.exists(audio_path):
211 | continue
212 |
213 | sample = {
214 | 'video': video_names[i],
215 | 'label': label_id,
216 | }
217 |
218 | dataset.append(sample)
219 | return dataset,idx_to_class,n_videos
220 |
221 |
222 |
223 | def __len__(self):
224 | return len(self.dataset)
225 |
226 |
227 | def __loading(self, path, video_name):
228 |
229 | clip=None
230 | try:
231 | clip = self.loader(path)
232 | except Exception as e:
233 | print("path {} has error".format(path))
234 |
235 | len_clip = len(clip)
236 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]]
237 |
238 | if self.spatial_transform is not None:
239 | self.spatial_transform.randomize_parameters()
240 | clip = [self.spatial_transform(img) for img in clip]
241 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
242 | return clip
243 |
244 | def __load_audio(self,audio_path):
245 | with open(audio_path,"rb") as f:
246 | audio = pickle.load(f)
247 |
248 | return audio
249 |
250 | def __getitem__(self, index):
251 |
252 | video_name = self.dataset[index]['video']
253 |
254 | video_path = os.path.join(self.video_dir,video_name + ".hdf5")
255 | label = self.dataset[index]['label']
256 |
257 | clip = self.__loading(video_path, video_name)
258 |
259 | # audio_path = os.path.join(self.audio_dir,video_name + ".pkl")
260 |
261 | # audio = self.__load_audio(audio_path)
262 |
263 | return clip,label
264 |
265 |
266 | class AudioDataset(data.Dataset):
267 | def __init__(self,
268 | annotation_path=os.path.join(PROJECT_DIR,'dataset/KS_train_val.json'),
269 | mode='training',
270 | spatial_transform=None,
271 | video_loader = None
272 | ):
273 |
274 | self.video_dir = HDF5_DIR
275 | self.audio_dir = PKL_DIR
276 |
277 | self.dataset,self.idx_to_class,self.n_videos = self.__make_dataset(self.video_dir,annotation_path,mode)
278 |
279 | self.spatial_transform = spatial_transform
280 |
281 | self.loader = video_loader
282 |
283 |
284 | def __make_dataset(self,video_dir,annotation_path,subset):
285 | with open(annotation_path) as f:
286 | annotation_data = json.load(f)
287 | class_labels = annotation_data['labels']
288 | annotation_data = annotation_data['database']
289 |
290 | video_names , video_labels = get_dataset(annotation_data,subset)
291 |
292 | class_to_idx = {label : i for i,label in enumerate(class_labels)}
293 | idx_to_class = {i : label for i,label in enumerate(class_labels)}
294 |
295 | n_videos = len(video_names)
296 |
297 | dataset = []
298 | max_len = 0
299 |
300 | for i in range(n_videos):
301 |
302 | label = video_labels[i]
303 | label_id = class_to_idx[label]
304 |
305 | video_path = os.path.join(video_dir,video_names[i] + ".hdf5")
306 | audio_path = os.path.join(self.audio_dir,video_names[i] + ".pkl")
307 | if not os.path.exists(video_path) or not os.path.exists(audio_path):
308 | continue
309 |
310 | sample = {
311 | 'video': video_names[i],
312 | 'label': label_id,
313 | }
314 |
315 | dataset.append(sample)
316 | return dataset,idx_to_class,n_videos
317 |
318 |
319 |
320 | def __len__(self):
321 | return len(self.dataset)
322 |
323 |
324 | def __loading(self, path, video_name):
325 |
326 | clip=None
327 | try:
328 | clip = self.loader(path)
329 | except Exception as e:
330 | print("path {} has error".format(path))
331 |
332 | len_clip = len(clip)
333 | clip = [clip[0],clip[int((len_clip-1)/2)],clip[len_clip-1]]
334 |
335 | if self.spatial_transform is not None:
336 | self.spatial_transform.randomize_parameters()
337 | clip = [self.spatial_transform(img) for img in clip]
338 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3)
339 | return clip
340 |
341 | def __load_audio(self,audio_path):
342 | with open(audio_path,"rb") as f:
343 | audio = pickle.load(f)
344 |
345 | return audio
346 |
347 | def __getitem__(self, index):
348 |
349 | video_name = self.dataset[index]['video']
350 |
351 | # video_path = os.path.join(self.video_dir,video_name + ".hdf5")
352 | label = self.dataset[index]['label']
353 |
354 | # clip = self.__loading(video_path, video_name)
355 |
356 | audio_path = os.path.join(self.audio_dir,video_name + ".pkl")
357 |
358 | audio = self.__load_audio(audio_path)
359 |
360 | return audio,label
361 |
362 |
363 |
364 | if __name__=='__main__':
365 |
366 | import argparse
367 | parser=argparse.ArgumentParser()
368 |
369 | parser.add_argument("--decrease_epoch",type = int,default = 10)
370 | parser.add_argument('--sample_size',type = int,default = 112)
371 | parser.add_argument('--sample_t_stride',type = int,default = 1)
372 | parser.add_argument('--train_crop',
373 | default='random',
374 | type=str,
375 | help=('Spatial cropping method in training. '
376 | 'random is uniform. '
377 | 'corner is selection from 4 corners and 1 center. '
378 | '(random | corner | center)'))
379 | parser.add_argument('--value_scale',
380 | default=1,
381 | type=int,
382 | help=
383 | 'If 1, range of inputs is [0-1]. If 255, range of inputs is [0-255].')
384 | parser.add_argument("--scale_h", type=int, default=128,
385 | help="Scale image height to")
386 | parser.add_argument("--scale_w", type=int, default=171,
387 | help="Scale image width to")
388 | parser.add_argument('--train_crop_min_scale',
389 | default=0.25,
390 | type=float,
391 | help='Min scale for random cropping in training')
392 | parser.add_argument('--train_crop_min_ratio',
393 | default=0.75,
394 | type=float,
395 | help='Min aspect ratio for random cropping in training')
396 | parser.add_argument('--no_hflip',
397 | action='store_true',
398 | help='If true holizontal flipping is not performed.')
399 | parser.add_argument('--colorjitter',
400 | action='store_true',
401 | help='If true colorjitter is performed.')
402 | parser.add_argument('--train_t_crop',
403 | default='random',
404 | type=str,
405 | help=('Temporal cropping method in training. '
406 | 'random is uniform. '
407 | '(random | center)'))
408 |
409 | args=parser.parse_args()
410 |
411 | spatial_transforms=get_spatial_transform(opt=args)
412 |
413 | dataset=KSDataset(video_loader=VideoLoaderHDF5(),spatial_transform=spatial_transforms)
414 |
415 |
416 |
--------------------------------------------------------------------------------
/code/dataset/loader.py:
--------------------------------------------------------------------------------
1 | import io
2 | import h5py
3 | import numpy as np
4 | from os import path
5 | from PIL import Image
6 |
7 |
8 | class ImageLoaderPIL(object):
9 |
10 | def __call__(self, path):
11 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
12 | with path.open('rb') as f:
13 | with Image.open(f) as img:
14 | return img.convert('RGB')
15 |
16 |
17 | # class ImageLoaderAccImage(object):
18 |
19 | # def __call__(self, path):
20 | # import accimage
21 | # return accimage.Image(str(path))
22 |
23 |
24 | class NumpyLoader(object):
25 |
26 | def __call__(self, path):
27 | return np.load(path)
28 |
29 |
30 | class AudioFeatureLoader(object):
31 | # load an audio feature stored as numpy file ('.npy)
32 | def __init__(self):
33 | self.npyloader = NumpyLoader()
34 |
35 | def __call__(self, filename):
36 | if path.isfile(filename):
37 | features = self.npyloader(filename)
38 | else:
39 | features = None
40 | return features
41 |
42 |
43 | class VideoLoader(object):
44 |
45 | def __init__(self, image_name_formatter, image_loader=None):
46 | self.image_name_formatter = image_name_formatter
47 | if image_loader is None:
48 | self.image_loader = ImageLoaderPIL()
49 | else:
50 | self.image_loader = image_loader
51 |
52 | def __call__(self, video_path, frame_indices):
53 | video = []
54 | for i in frame_indices:
55 | image_path = video_path / self.image_name_formatter(i)
56 | if image_path.exists():
57 | video.append(self.image_loader(image_path))
58 | return video
59 |
60 |
61 |
62 | class VideoLoaderHDF5(object):
63 |
64 | def __call__(self, video_path):
65 | with h5py.File(video_path, 'r') as f:
66 | video_data = f['video']
67 | video = []
68 | for i in range(len(video_data)):
69 | video.append(Image.open(io.BytesIO(video_data[i - 1])))
70 | return video
71 |
72 |
73 | class VideoLoaderFlowHDF5(object):
74 |
75 | def __init__(self):
76 | self.flows = ['u', 'v']
77 |
78 | def __call__(self, video_path, frame_indices):
79 | with h5py.File(video_path, 'r') as f:
80 |
81 | flow_data = []
82 | for flow in self.flows:
83 | flow_data.append(f[f'video_{flow}'])
84 |
85 | video = []
86 | for i in frame_indices:
87 | if i < len(flow_data[0]):
88 | frame = [
89 | Image.open(io.BytesIO(video_data[i]))
90 | for video_data in flow_data
91 | ]
92 | frame.append(frame[-1]) # add dummy data into third channel
93 | video.append(Image.merge('RGB', frame))
94 | return video
95 |
--------------------------------------------------------------------------------
/code/dataset/spatial_transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from torchvision.transforms import transforms
4 | from torchvision.transforms import functional as F
5 | from PIL import Image
6 |
7 |
8 | class Compose(transforms.Compose):
9 |
10 | def randomize_parameters(self):
11 | for t in self.transforms:
12 | t.randomize_parameters()
13 |
14 |
15 | class ToTensor(transforms.ToTensor):
16 |
17 | def randomize_parameters(self):
18 | pass
19 |
20 |
21 | class Normalize(transforms.Normalize):
22 |
23 | def randomize_parameters(self):
24 | pass
25 |
26 |
27 | class ScaleValue(object):
28 |
29 | def __init__(self, s):
30 | self.s = s
31 |
32 | def __call__(self, tensor):
33 | tensor *= self.s
34 | return tensor
35 |
36 | def randomize_parameters(self):
37 | pass
38 |
39 |
40 | class Resize(transforms.Resize):
41 |
42 | def randomize_parameters(self):
43 | pass
44 |
45 |
46 | class RandomCrop(transforms.RandomCrop):
47 |
48 | def randomize_parameters(self):
49 | pass
50 |
51 |
52 | # class Scale(transforms.Scale):
53 |
54 | # def randomize_parameters(self):
55 | # pass
56 |
57 |
58 | class CenterCrop(transforms.CenterCrop):
59 |
60 | def randomize_parameters(self):
61 | pass
62 |
63 |
64 | class CornerCrop(object):
65 |
66 | def __init__(self,
67 | size,
68 | crop_position=None,
69 | crop_positions=['c', 'tl', 'tr', 'bl', 'br']):
70 | self.size = size
71 | self.crop_position = crop_position
72 | self.crop_positions = crop_positions
73 |
74 | if crop_position is None:
75 | self.randomize = True
76 | else:
77 | self.randomize = False
78 | self.randomize_parameters()
79 |
80 | def __call__(self, img):
81 | image_width = img.size[0]
82 | image_height = img.size[1]
83 |
84 | h, w = (self.size, self.size)
85 | if self.crop_position == 'c':
86 | i = int(round((image_height - h) / 2.))
87 | j = int(round((image_width - w) / 2.))
88 | elif self.crop_position == 'tl':
89 | i = 0
90 | j = 0
91 | elif self.crop_position == 'tr':
92 | i = 0
93 | j = image_width - self.size
94 | elif self.crop_position == 'bl':
95 | i = image_height - self.size
96 | j = 0
97 | elif self.crop_position == 'br':
98 | i = image_height - self.size
99 | j = image_width - self.size
100 |
101 | img = F.crop(img, i, j, h, w)
102 |
103 | return img
104 |
105 | def randomize_parameters(self):
106 | if self.randomize:
107 | self.crop_position = self.crop_positions[random.randint(
108 | 0,
109 | len(self.crop_positions) - 1)]
110 |
111 | def __repr__(self):
112 | return self.__class__.__name__ + '(size={0}, crop_position={1}, randomize={2})'.format(
113 | self.size, self.crop_position, self.randomize)
114 |
115 |
116 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
117 |
118 | def __init__(self, p=0.5):
119 | super().__init__(p)
120 | self.randomize_parameters()
121 |
122 | def __call__(self, img):
123 | """
124 | Args:
125 | img (PIL.Image): Image to be flipped.
126 | Returns:
127 | PIL.Image: Randomly flipped image.
128 | """
129 | if self.random_p < self.p:
130 | return F.hflip(img)
131 | return img
132 |
133 | def randomize_parameters(self):
134 | self.random_p = random.random()
135 |
136 |
137 | class MultiScaleCornerCrop(object):
138 |
139 | def __init__(self,
140 | size,
141 | scales,
142 | crop_positions=['c', 'tl', 'tr', 'bl', 'br'],
143 | interpolation=Image.BILINEAR):
144 | self.size = size
145 | self.scales = scales
146 | self.interpolation = interpolation
147 | self.crop_positions = crop_positions
148 |
149 | self.randomize_parameters()
150 |
151 | def __call__(self, img):
152 | short_side = min(img.size[0], img.size[1])
153 | crop_size = int(short_side * self.scale)
154 | self.corner_crop.size = crop_size
155 |
156 | img = self.corner_crop(img)
157 | return img.resize((self.size, self.size), self.interpolation)
158 |
159 | def randomize_parameters(self):
160 | self.scale = self.scales[random.randint(0, len(self.scales) - 1)]
161 | crop_position = self.crop_positions[random.randint(
162 | 0,
163 | len(self.crop_positions) - 1)]
164 |
165 | self.corner_crop = CornerCrop(None, crop_position)
166 |
167 | def __repr__(self):
168 | return self.__class__.__name__ + '(size={0}, scales={1}, interpolation={2})'.format(
169 | self.size, self.scales, self.interpolation)
170 |
171 |
172 | class RandomResizedCrop(transforms.RandomResizedCrop):
173 |
174 | def __init__(self,
175 | size,
176 | scale=(0.08, 1.0),
177 | ratio=(3. / 4., 4. / 3.),
178 | interpolation=Image.BILINEAR):
179 | super().__init__(size, scale, ratio, interpolation)
180 | self.randomize_parameters()
181 |
182 | def __call__(self, img):
183 | if self.randomize:
184 | self.random_crop = self.get_params(img, self.scale, self.ratio)
185 | self.randomize = False
186 |
187 | i, j, h, w = self.random_crop
188 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
189 |
190 | def randomize_parameters(self):
191 | self.randomize = True
192 |
193 |
194 | class ColorJitter(transforms.ColorJitter):
195 |
196 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
197 | super().__init__(brightness, contrast, saturation, hue)
198 | self.randomize_parameters()
199 |
200 | def __call__(self, img):
201 | if self.randomize:
202 | self.transform = self.get_params(self.brightness, self.contrast,
203 | self.saturation, self.hue)
204 | self.randomize = False
205 |
206 | return self.transform(img)
207 |
208 | def randomize_parameters(self):
209 | self.randomize = True
210 |
211 |
212 | class PickFirstChannels(object):
213 |
214 | def __init__(self, n):
215 | self.n = n
216 |
217 | def __call__(self, tensor):
218 | return tensor[:self.n, :, :]
219 |
220 | def randomize_parameters(self):
221 | pass
222 |
223 |
224 | def get_normalize_method():
225 | mean=[0.485, 0.456, 0.406]
226 | std=[0.229, 0.224, 0.225]
227 |
228 | return Normalize(mean,std)
229 |
230 | def get_spatial_transform(opt):
231 | assert opt.train_crop in ['random', 'corner', 'center', 'other']
232 | spatial_transform = []
233 | if opt.train_crop == 'random':
234 | spatial_transform.append(
235 | RandomResizedCrop(
236 | opt.sample_size, (opt.train_crop_min_scale, 1.0),
237 | (opt.train_crop_min_ratio, 1.0 / opt.train_crop_min_ratio)))
238 | elif opt.train_crop == 'corner':
239 | scales = [1.0]
240 | scale_step = 1 / (2**(1 / 4))
241 | for _ in range(1, 5):
242 | scales.append(scales[-1] * scale_step)
243 | spatial_transform.append(MultiScaleCornerCrop(opt.sample_size, scales))
244 | elif opt.train_crop == 'center':
245 | spatial_transform.append(Resize(opt.sample_size))
246 | spatial_transform.append(CenterCrop(opt.sample_size))
247 | elif opt.train_crop == 'other':
248 | print('other')
249 | spatial_transform.append(Resize((opt.scale_h, opt.scale_w)))
250 | spatial_transform.append(RandomCrop(opt.sample_size))
251 |
252 | normalize = get_normalize_method()
253 | if not opt.no_hflip:
254 | spatial_transform.append(RandomHorizontalFlip())
255 | if opt.colorjitter:
256 | spatial_transform.append(ColorJitter())
257 | spatial_transform.append(ToTensor())
258 |
259 | spatial_transform.append(ScaleValue(opt.value_scale))
260 | spatial_transform.append(normalize)
261 | spatial_transform = Compose(spatial_transform)
262 |
263 | return spatial_transform
264 |
265 |
266 | def get_val_spatial_transforms(opt):
267 | normalize=get_normalize_method()
268 | if opt.train_crop=='other':
269 | spatial_transforms=[
270 | Resize((opt.scale_h,opt.scale_w)),
271 | RandomCrop(opt.sample_size),
272 | ToTensor()
273 | ]
274 | else:
275 | spatial_transforms=[
276 | Resize(opt.sample_size),
277 | CenterCrop(opt.sample_size),
278 | ToTensor()
279 | ]
280 | spatial_transforms.extend([ScaleValue(opt.value_scale),normalize])
281 | spatial_transforms=Compose(spatial_transforms)
282 |
283 | return spatial_transforms
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import torch
4 | import numpy as np
5 | import random
6 | import json
7 | import os
8 | from os.path import join
9 | import sys
10 | from tqdm import tqdm
11 | from torch import nn
12 | from torch.utils.data import DataLoader
13 | from torch.utils.tensorboard import SummaryWriter
14 |
15 | from models.Classifier import Classifier
16 | from config import Config
17 | from dataset.KS import KSDataset
18 | from utils.log_file import Logger
19 | from datetime import datetime
20 |
21 | from dataset.spatial_transforms import get_spatial_transform,get_val_spatial_transforms
22 | from dataset.loader import VideoLoaderHDF5
23 |
24 | from sklearn.metrics import average_precision_score
25 | import torch.nn.functional as F
26 |
27 |
28 | TIMESTAMP = "{0:%Y-%m-%d-%H-%M-%S/}".format(datetime.now())
29 |
30 |
31 | def get_arguments():
32 | parser = argparse.ArgumentParser()
33 |
34 | parser.add_argument('--use_modulation',action='store_true',help='use gradient modulation')
35 | parser.add_argument('--use_adam_drop',action='store_true',help='use adam-drop')
36 | parser.add_argument('--modulation', default='OGM_GE', type=str,choices=['Normal', 'OGM', 'OGM_GE'])
37 | parser.add_argument('--use_OGM_plus',action='store_true')
38 | parser.add_argument('--fusion_method', default='concat', type=str,choices=['sum', 'concat', 'gated'])
39 | parser.add_argument('--train', action='store_true', help='turn on train mode')
40 | parser.add_argument('--resume_model',action='store_true',help='whether to resume model')
41 | parser.add_argument('--resume_model_path')
42 | parser.add_argument('--q_base',type=float,default=0.5)
43 | parser.add_argument('--lam',type=float,default=0.5)
44 | parser.add_argument('--p_exe',type=float,default=0.7)
45 | parser.add_argument('--alpha',type=float,default=1.0)
46 | parser.add_argument('--modulation_starts',type=int,default=0)
47 | parser.add_argument('--modulation_ends',type=int,default=80)
48 | parser.add_argument('--audio_drop',type=float,default=0.0)
49 | parser.add_argument('--visual_drop',type=float,default=0.0)
50 | parser.add_argument('--exp_name',type=str,default='exp')
51 |
52 | return parser.parse_args()
53 |
54 |
55 | def setup_seed(seed):
56 | torch.manual_seed(seed)
57 | torch.cuda.manual_seed_all(seed)
58 | np.random.seed(seed)
59 | random.seed(seed)
60 | torch.backends.cudnn.deterministic = True
61 |
62 |
63 | def weight_init(m):
64 | if isinstance(m, nn.Linear):
65 | nn.init.xavier_normal_(m.weight)
66 | nn.init.constant_(m.bias, 0)
67 | elif isinstance(m, nn.Conv2d):
68 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
69 | elif isinstance(m, nn.BatchNorm2d):
70 | nn.init.constant_(m.weight, 1)
71 | nn.init.constant_(m.bias, 0)
72 |
73 | weight_a=0.36
74 | weight_v=0.27
75 | weight_av=0.37
76 |
77 | def train(cfg,epoch,model,device,dataloader,optimizer,scheduler,tb=None):
78 | loss_fn=nn.CrossEntropyLoss().to(device)
79 | relu=nn.ReLU(inplace=True)
80 | tanh=nn.Tanh()
81 | model.train()
82 | total_loss=0
83 | total_loss_1=0
84 | total_loss_2=0
85 | with tqdm(total=len(dataloader), desc=f"Train-epoch-{epoch}") as pbar:
86 | for step, (spec,image,label) in enumerate(dataloader):
87 | spec=spec.to(device) # b,h,w
88 | image=image.to(device) # b,c,t,h,w
89 | label=label.to(device)
90 | optimizer.zero_grad()
91 | warm_up=1 if epoch<=5 else 0
92 | # warm_up=0
93 | out_1,out_2,out,update_flag,performance_1,performance_2=model(spec.unsqueeze(1).float(),image.float(),label,warm_up)
94 |
95 | if warm_up==0 and cfg.use_adam_drop:
96 | if torch.sum(update_flag,dim=0)==0:
97 | continue
98 | select_mask=update_flag!=0
99 | label=label[select_mask]
100 | out_1=out_1[select_mask]
101 | out_2=out_2[select_mask]
102 |
103 |
104 | loss=loss_fn(out,label)
105 | loss_1=loss_fn(out_1,label)
106 | loss_2=loss_fn(out_2,label)
107 | total_loss+=loss.item()
108 | total_loss_1+=loss_1.item()
109 | total_loss_2+=loss_2.item()
110 |
111 | # if warm_up==0:
112 | # loss=loss*weight_av+loss_1*weight_a+loss_2*weight_v
113 |
114 | loss.backward()
115 |
116 | if warm_up==0 and cfg.use_modulation:
117 | # log.logger.info('per_1:{} per_2:{} '.format(performance_1,performance_2))
118 | coeff_1,coeff_2=None,None
119 | radio_1=performance_1/performance_2
120 | radio_2=performance_2/performance_1
121 | # if cfg.form=='/':
122 | # radio_1=performance_1/performance_2
123 | # radio_2=performance_2/performance_1
124 | # else:
125 | # radio_1=performance_1-performance_2
126 | # radio_2=performance_2-performance_1
127 |
128 | if cfg.use_OGM_plus:
129 | if radio_1>1:
130 | # coeff_2=1+tanh(cfg.alpha*relu(radio_1))
131 | coeff_2=4
132 | coeff_1=1
133 | else:
134 | coeff_2=1
135 | # coeff_1=1+tanh(cfg.alpha*relu(radio_2))
136 | coeff_1=4
137 | else:
138 | if radio_1>1:
139 | coeff_1=1-tanh(cfg.alpha*relu(radio_1))
140 | # if cfg.func=='tanh':
141 | # coeff_1=1-tanh(cfg.alpha*relu(radio_1))
142 | # else:
143 | # coeff_1=1-sigmoid(cfg.alpha*relu(radio_1))
144 |
145 | coeff_2=1
146 | else:
147 | coeff_1=1
148 | coeff_2=1-tanh(cfg.alpha*relu(radio_2))
149 | # if cfg.func=='tanh':
150 | # coeff_2=1-tanh(cfg.alpha*relu(radio_2))
151 | # else:
152 | # coeff_2=1-sigmoid(cfg.alpha*relu(radio_2))
153 |
154 | if cfg.modulation_starts<=epoch<=cfg.modulation_ends:
155 | for name,parms in model.named_parameters():
156 | layer_name=str(name).split('.')[0]
157 | if 'encoder_1' in layer_name and parms.grad is not None and len(parms.grad.size()) == 4:
158 | if cfg.modulation == 'OGM_GE':
159 | parms.grad = parms.grad * coeff_1 + \
160 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
161 | elif cfg.modulation == 'OGM':
162 | parms.grad *= coeff_1
163 |
164 | if 'encoder_2' in layer_name and parms.grad is not None and len(parms.grad.size()) == 4:
165 | if cfg.modulation == 'OGM_GE':
166 | parms.grad = parms.grad * coeff_2 + \
167 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
168 | elif cfg.modulation == 'OGM':
169 | parms.grad *= coeff_2
170 |
171 | optimizer.step()
172 | pbar.update(1)
173 |
174 | scheduler.step()
175 |
176 | return total_loss/len(dataloader),total_loss_1/len(dataloader),total_loss_2/len(dataloader)
177 |
178 |
179 | def val(model,device,dataloader):
180 | softmax=nn.Softmax(dim=1)
181 | sum_all=0
182 | sum_1=0
183 | sum_2=0
184 | tot=0
185 | all_out=[]
186 | all_label=[]
187 | with torch.no_grad():
188 | model.eval()
189 | for step,(spec,img,label) in enumerate(dataloader):
190 | spec=spec.to(device)
191 | img=img.to(device)
192 | label=label.to(device)
193 | out_1,out_2,out,update_flag,performance_1,performance_2=model(spec.unsqueeze(1).float(),img.float(),label,warm_up=1)
194 | prediction=softmax(out)
195 | pred_1=softmax(out_1)
196 | pred_2=softmax(out_2)
197 | tot+=img.shape[0]
198 | sum_all+=torch.sum(torch.argmax(prediction,dim=1)==label).item()
199 | sum_1+=torch.sum(torch.argmax(pred_1,dim=1)==label).item()
200 | sum_2+=torch.sum(torch.argmax(pred_2,dim=1)==label).item()
201 |
202 | for i in range(label.shape[0]):
203 | all_out.append(prediction[i].cpu().data.numpy())
204 | ss=torch.zeros(31)
205 | ss[label[i]]=1
206 | all_label.append(ss.numpy())
207 |
208 | all_out=np.array(all_out)
209 | all_label=np.array(all_label)
210 | mAP=average_precision_score(all_label,all_out)
211 |
212 |
213 | return mAP,sum_all/tot,sum_1/tot,sum_2/tot
214 |
215 |
216 | def write2txt(fp,info,mode='a'):
217 | with open(fp,mode=mode) as f:
218 | f.write(info)
219 | f.write('\n')
220 |
221 |
222 | def main():
223 | # job_id=datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
224 | cfg = Config()
225 | args=get_arguments()
226 | cfg.parse(vars(args))
227 | setup_seed(cfg.random_seed)
228 |
229 | job_name=args.exp_name
230 | cur_dir=os.path.join('results',job_name)
231 | os.makedirs(cur_dir,exist_ok=True)
232 |
233 | # log=Logger(os.path.join(cur_dir,'log.log'),level='info')
234 | writer=None
235 | if cfg.use_tensorboard:
236 | writer_path=os.path.join(cur_dir,'tensorboard')
237 | os.makedirs(writer_path,exist_ok=True)
238 | writer=SummaryWriter(writer_path)
239 |
240 | saved_data=vars(cfg)
241 | cmd=' '.join(sys.argv)
242 | saved_data.update({'cmd':cmd})
243 | saved_data=json.dumps(saved_data,indent=4)
244 | with open(os.path.join(cur_dir,'config.json'),'w') as f:
245 | f.write(saved_data)
246 |
247 | device=torch.device('cuda')
248 |
249 | spatial_transforms=get_spatial_transform(opt=cfg)
250 | val_spatial_transforms=get_val_spatial_transforms(opt=cfg)
251 | train_dataset=KSDataset(mode='training',spatial_transform=spatial_transforms,video_loader=VideoLoaderHDF5())
252 | test_dataset=KSDataset(mode='testing',spatial_transform=val_spatial_transforms,video_loader=VideoLoaderHDF5(),audio_drop=cfg.audio_drop,visual_drop=cfg.visual_drop)
253 |
254 | train_loader=DataLoader(train_dataset,batch_size=cfg.batch_size,shuffle=True,num_workers=32,pin_memory=True)
255 | test_loader=DataLoader(test_dataset,batch_size=cfg.batch_size,shuffle=False,num_workers=32,pin_memory=True)
256 |
257 | model=Classifier(cfg,device=device)
258 |
259 | if cfg.resume_model:
260 | state_dict=torch.load(cfg.resume_model_path,map_location='cuda')
261 | model.load_state_dict(state_dict=state_dict)
262 | else:
263 | model.apply(weight_init)
264 |
265 | model.to(device)
266 |
267 | optimizer=torch.optim.AdamW(model.parameters(),lr=cfg.learning_rate,weight_decay=0.01)
268 | scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=cfg.lr_decay_step,gamma=cfg.lr_decay_ratio)
269 |
270 | start_epoch=-1
271 | best_acc=0.0
272 | logger_path=join(cur_dir,'log.txt')
273 |
274 | if cfg.train:
275 | for epoch in range(start_epoch+1,cfg.epochs):
276 | loss,loss_1,loss_2=train(cfg,epoch,model,device,train_loader,optimizer,scheduler,tb=writer)
277 | mAP,acc,acc_1,acc_2=val(model,device,test_loader)
278 | # log.logger.info('epoch:{} acc:{:.4f} acc_1:{:.4f} acc_2:{:.4f} mAP:{:.4f}'.format(epoch,acc,acc_1,acc_2,mAP))
279 | write2txt(fp=logger_path,info=f'epoch:{epoch} acc:{acc:.4f} acc_1:{acc_1:.4f} acc_2:{acc_2:.4f} mAP:{mAP:.4f}')
280 | if writer is not None:
281 | writer.add_scalars(main_tag='Loss',tag_scalar_dict={'loss':loss,'loss_1':loss_1,'loss_2':loss_2},global_step=epoch)
282 | writer.add_scalars(main_tag='Acc',tag_scalar_dict={'acc':acc,'acc_1':acc_1,'acc_2':acc_2},global_step=epoch)
283 |
284 | if acc>best_acc:
285 | best_acc=acc
286 | saved_data={}
287 | saved_data['epoch']=epoch
288 | saved_data['acc']=acc
289 | saved_data['mAP']=mAP
290 | saved_data['acc_1']=acc_1
291 | saved_data['acc_2']=acc_2
292 | saved_data=json.dumps(saved_data,indent=4)
293 |
294 | with open(os.path.join(cur_dir,'best_model.json'),'w') as f:
295 | f.write(saved_data)
296 |
297 | torch.save(model.state_dict(),os.path.join(cur_dir,'best_model.pth'))
298 | else:
299 | mAP,acc,acc_1,acc_2=val(model,device,test_loader)
300 | print('mAP:{} Acc:{}'.format(mAP,acc))
301 |
302 |
303 | if __name__ == "__main__":
304 | main()
305 |
--------------------------------------------------------------------------------
/code/models/Audio_Classifier.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 | from .Resnet_18 import resnet18
5 |
6 | class Classifier(nn.Module):
7 |
8 | def __init__(self,cfg,device='cuda:0'):
9 | super().__init__()
10 |
11 | self.encoder_1=resnet18(modality=cfg.modality[0])
12 |
13 | self.cfg=cfg
14 | self.device=device
15 |
16 | self.linear=nn.Linear(512,31)
17 |
18 | def forward(self,mod_1):
19 | out_1=self.encoder_1(mod_1)
20 | out_1=F.adaptive_avg_pool2d(out_1,1)
21 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048]
22 |
23 | out_1=self.linear(out_1)
24 | return out_1
25 |
26 |
27 |
28 | if __name__ == "__main__":
29 | pass
30 |
--------------------------------------------------------------------------------
/code/models/BasicModule.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 |
5 | import time
6 |
7 | class BasicModule(nn.Module):
8 |
9 | def __init__(self) -> None:
10 | super().__init__()
11 | self.model_name=str(type(self))
12 |
13 | def load(self,path):
14 | self.load_state_dict(torch.load(path))
15 |
16 | def save(self,name=None):
17 | if name is None:
18 | name=time.strftime('checkpoints/'+self.model_name+'_'+'%m%d_%H:%M:%S.pth')
19 | torch.save(self.state_dict(),name)
20 |
21 | return name
22 |
23 | if __name__ == "__main__":
24 | pass
25 |
--------------------------------------------------------------------------------
/code/models/Classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from .fusion_model import ConcatFusion,SumFusion,GatedFusion,LMF
7 | from .Resnet_18 import resnet18
8 |
9 | class custom_autograd(torch.autograd.Function):
10 |
11 | @staticmethod
12 | def forward(ctx,input,theta):
13 | ctx.save_for_backward(input,theta)
14 | return input/(1-theta.item())
15 |
16 | @staticmethod
17 | def backward(ctx,grad_output):
18 | input,theta=ctx.saved_tensors
19 | input_grad=1/(1-theta.item())*grad_output.clone()
20 |
21 | return input_grad,theta
22 |
23 |
24 | class Modality_drop():
25 |
26 | def __init__(self,dim_list,p_exe=0.7,device='cuda'):
27 | self.dim_list=dim_list
28 | self.p_exe=p_exe
29 | self.device=device
30 |
31 | def execute_drop(self,fead_list,q):
32 | B = fead_list[0].shape[0]
33 | D = fead_list[0].shape[1]
34 | exe_drop = torch.tensor(np.random.rand(1)).to(device=self.device) >= 1-self.p_exe
35 | if not exe_drop:
36 | return fead_list, torch.ones([B],dtype=torch.int32,device=self.device)
37 |
38 | num_mod=len(fead_list)
39 | d_sum=sum(self.dim_list)
40 | q_sum=sum(self.dim_list*q)
41 | theta=q_sum/d_sum
42 | # p_sum=sum(self.dim_list*(1-q))
43 | # theta=p_sum/d_sum
44 |
45 | mask=torch.distributions.Bernoulli(1-q).sample([B,1]).permute(2,1,0).contiguous().reshape(num_mod,B,-1).to(device=self.device) # [2,B,1]
46 | # print(f'mask:{mask}')
47 | concat_list=torch.stack(fead_list,dim=0) # [2,B,D]
48 | concat_list=torch.mul(concat_list,mask)
49 | concat_list=custom_autograd.apply(concat_list,theta)
50 | mask=torch.transpose(mask,0,1).squeeze(-1) # [B,2]
51 | update_flag=torch.sum(mask,dim=1)>0
52 | cleaned_fea=torch.masked_select(concat_list,update_flag.unsqueeze(-1)).reshape(num_mod,-1,D)
53 | cleaned_fea=torch.chunk(cleaned_fea,num_mod,dim=0) ]
54 | cleaned_fea=[_.squeeze(0) for _ in cleaned_fea] # [B,D]
55 | return cleaned_fea,update_flag
56 |
57 |
58 | def calcu_q(performance_1,performance_2,q_base,fix_lambda):
59 | q=torch.tensor([0.0,0.0])
60 | relu = nn.ReLU(inplace=True)
61 | ratio_1=torch.tanh(relu(performance_1/performance_2-1))
62 | ratio_2=torch.tanh(relu(performance_2/performance_1-1))
63 |
64 | lamda = fix_lambda
65 |
66 |
67 | q[0] = q_base * (1 + lamda * ratio_1) if ratio_1>0 else 0
68 | q[1] = q_base * (1 + lamda * ratio_2) if ratio_2>0 else 0
69 |
70 | q=torch.clip(q,0.0,1.0)
71 |
72 | return q
73 |
74 |
75 | class Classifier(nn.Module):
76 |
77 | def __init__(self,cfg,device='cuda'):
78 | super().__init__()
79 |
80 | self.encoder_1=resnet18(modality='audio')
81 | self.encoder_2=resnet18(modality='visual')
82 |
83 | self.cfg=cfg
84 | self.device=device
85 |
86 | self.softmax=nn.Softmax(dim=1)
87 | self.fusion_model=ConcatFusion(in_c_x=512,in_c_y=512,out_c=31)
88 |
89 | if self.cfg.use_adam_drop:
90 | self.modality_drop=Modality_drop(dim_list=torch.tensor(self.cfg.d),p_exe=self.cfg.p_exe,device=self.device)
91 |
92 |
93 | def forward(self,mod_1,mod_2,label,warm_up=1):
94 | out_1=self.encoder_1(mod_1)
95 | out_2=self.encoder_2(mod_2) # [B,T,C,H,W]--> [B,2048,2,2]
96 |
97 | _,C,H,W=out_2.shape
98 | B=out_1.shape[0]
99 |
100 | out_2=out_2.reshape(B,-1,C,H,W).permute(0,2,1,3,4)
101 |
102 | out_1=F.adaptive_avg_pool2d(out_1,1)
103 | out_2=F.adaptive_avg_pool3d(out_2,1)
104 |
105 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048]
106 | out_2=out_2.squeeze(2).squeeze(2).squeeze(2) # [B,2048]
107 |
108 | performance_1=None
109 | performance_2=None
110 | t1,t2=None,None
111 |
112 | w=self.fusion_model.fxy.weight.clone().detach()
113 | b=self.fusion_model.fxy.bias.clone().detach()
114 |
115 | # if self.cfg.t1_bias==0.5:
116 | # t1_bias=b/2
117 | # elif self.cfg.t1_bias==0.0:
118 | # t1_bias=0.0
119 | # elif self.cfg.t1_bias==0.3:
120 | # t1_bias=b/3
121 | # elif self.cfg.t1_bias==0.6:
122 | # t1_bias=2*b/3
123 | # elif self.cfg.t1_bias==1.0:
124 | # t1_bias=b
125 | t1_bias=b/2
126 |
127 | # if self.cfg.t2_bias==0.5:
128 | # t2_bias=b/2
129 | # elif self.cfg.t2_bias==0.0:
130 | # t2_bias=0.0
131 | # elif self.cfg.t2_bias==0.3:
132 | # t2_bias=b/3
133 | # elif self.cfg.t2_bias==0.6:
134 | # t2_bias=2*b/3
135 | # elif self.cfg.t2_bias==1.0:
136 | # t2_bias=b
137 | t2_bias=b/2
138 |
139 | t1=torch.mm(out_1,torch.transpose(w[:,:512],0,1))+t1_bias
140 | t2=torch.mm(out_2,torch.transpose(w[:,512:],0,1))+t2_bias
141 |
142 | performance_1=sum([self.softmax(t1)[i][int(label[i].item())] for i in range(t1.shape[0])])
143 | performance_2=sum([self.softmax(t2)[i][int(label[i].item())] for i in range(t2.shape[0])])
144 |
145 | if warm_up==0 and self.cfg.use_adam_drop:
146 | self.q=calcu_q(performance_1,performance_2,self.cfg.q_base,fix_lambda=self.cfg.lam)
147 | cleaned_fea,update_flag=self.modality_drop.execute_drop([out_1,out_2],self.q)
148 | cleaned_fae_1,cleaned_fea_2,out=self.fusion_model(cleaned_fea[0],cleaned_fea[1])
149 | return t1,t2,out,update_flag,performance_1,performance_2
150 |
151 | else:
152 | x,y,out=self.fusion_model(out_1,out_2)
153 | return t1,t2,out,torch.ones([B],dtype=torch.int32,device=self.device),performance_1,performance_2
154 |
155 |
156 |
157 | class AVClassifier_gb(nn.Module):
158 | def __init__(self, n_classes):
159 | super(AVClassifier_gb, self).__init__()
160 | self.n_classes = n_classes
161 |
162 | self.encoder_1=resnet18(modality='audio')
163 | self.encoder_2=resnet18(modality='visual')
164 |
165 | self.fusion_model = ConcatFusion(512,512,31)
166 |
167 | self.audio_head = nn.Linear(512, n_classes)
168 | self.visual_head = nn.Linear(512, n_classes)
169 |
170 |
171 | def forward(self, audio, visual):
172 | out_1=self.encoder_1(audio)
173 | out_2=self.encoder_2(visual) # [B,T,C,H,W]--> [B,2048,2,2]
174 |
175 | _,C,H,W=out_2.shape
176 | B=out_1.shape[0]
177 |
178 | out_2=out_2.reshape(B,-1,C,H,W).permute(0,2,1,3,4)
179 |
180 | out_1=F.adaptive_avg_pool2d(out_1,1)
181 | out_2=F.adaptive_avg_pool3d(out_2,1)
182 |
183 | out_1=out_1.squeeze(2).squeeze(2) # [B,2048]
184 | out_2=out_2.squeeze(2).squeeze(2).squeeze(2)
185 |
186 | x,y,out=self.fusion_model(out_1,out_2)
187 | return x,y,out
188 |
189 |
--------------------------------------------------------------------------------
/code/models/Resnet_18.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
5 | """3x3 convolution with padding"""
6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
7 | padding=dilation, groups=groups, bias=False, dilation=dilation)
8 |
9 |
10 | def conv1x1(in_planes, out_planes, stride=1):
11 | """1x1 convolution"""
12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
19 | base_width=64, dilation=1, norm_layer=None):
20 | super(BasicBlock, self).__init__()
21 | if norm_layer is None:
22 | norm_layer = nn.BatchNorm2d
23 | if groups != 1 or base_width != 64:
24 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
25 | if dilation > 1:
26 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
27 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
28 | self.conv1 = conv3x3(inplanes, planes, stride)
29 | self.bn1 = norm_layer(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.bn2 = norm_layer(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | identity = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | identity = self.downsample(x)
48 |
49 | out += identity
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class ResNet(nn.Module):
56 |
57 | def __init__(self, block, layers, modality, num_classes=1000, pool='avgpool', zero_init_residual=False,
58 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
59 | norm_layer=None):
60 | super(ResNet, self).__init__()
61 | self.modality = modality
62 | self.pool = pool
63 | if norm_layer is None:
64 | norm_layer = nn.BatchNorm2d
65 | self._norm_layer = norm_layer
66 |
67 | self.inplanes = 64
68 | self.dilation = 1
69 | if replace_stride_with_dilation is None:
70 | # each element in the tuple indicates if we should replace
71 | # the 2x2 stride with a dilated convolution instead
72 | replace_stride_with_dilation = [False, False, False]
73 | if len(replace_stride_with_dilation) != 3:
74 | raise ValueError("replace_stride_with_dilation should be None "
75 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
76 | self.groups = groups
77 | self.base_width = width_per_group
78 | if modality == 'audio':
79 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
80 | bias=False)
81 | elif modality == 'visual':
82 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
83 | bias=False)
84 | elif modality=='optical':
85 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3,
86 | bias=False)
87 | else:
88 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality))
89 | self.bn1 = norm_layer(self.inplanes)
90 | self.relu = nn.ReLU(inplace=True)
91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
92 | self.layer1 = self._make_layer(block, 64, layers[0])
93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
94 | dilate=replace_stride_with_dilation[0])
95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
96 | dilate=replace_stride_with_dilation[1])
97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
98 | dilate=replace_stride_with_dilation[2])
99 | # if self.pool == 'avgpool':
100 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
101 | #
102 | # self.fc = nn.Linear(512 * block.expansion, num_classes) # 8192
103 |
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
107 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
108 | nn.init.normal_(m.weight, mean=1, std=0.02)
109 | nn.init.constant_(m.bias, 0)
110 |
111 | # Zero-initialize the last BN in each residual branch,
112 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
113 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
114 | if zero_init_residual:
115 | for m in self.modules():
116 | if isinstance(m, Bottleneck):
117 | nn.init.constant_(m.bn3.weight, 0)
118 | elif isinstance(m, BasicBlock):
119 | nn.init.constant_(m.bn2.weight, 0)
120 |
121 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
122 | norm_layer = self._norm_layer
123 | downsample = None
124 | previous_dilation = self.dilation
125 | if dilate:
126 | self.dilation *= stride
127 | stride = 1
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | conv1x1(self.inplanes, planes * block.expansion, stride),
131 | norm_layer(planes * block.expansion),
132 | )
133 |
134 | layers = []
135 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
136 | self.base_width, previous_dilation, norm_layer))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(block(self.inplanes, planes, groups=self.groups,
140 | base_width=self.base_width, dilation=self.dilation,
141 | norm_layer=norm_layer))
142 |
143 | return nn.Sequential(*layers)
144 |
145 | def forward(self, x):
146 |
147 | if self.modality == 'visual':
148 | (B, C, T, H, W) = x.size()
149 | x = x.permute(0, 2, 1, 3, 4).contiguous()
150 | x = x.view(B * T, C, H, W)
151 |
152 | x = self.conv1(x)
153 | x = self.bn1(x)
154 | x = self.relu(x)
155 | x = self.maxpool(x)
156 |
157 | x = self.layer1(x)
158 | x = self.layer2(x)
159 | x = self.layer3(x)
160 | x = self.layer4(x)
161 | out = x
162 |
163 | return out
164 |
165 |
166 | class Bottleneck(nn.Module):
167 | expansion = 4
168 |
169 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
170 | base_width=64, dilation=1, norm_layer=None):
171 | super(Bottleneck, self).__init__()
172 | if norm_layer is None:
173 | norm_layer = nn.BatchNorm2d
174 | width = int(planes * (base_width / 64.)) * groups
175 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
176 | self.conv1 = conv1x1(inplanes, width)
177 | self.bn1 = norm_layer(width)
178 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
179 | self.bn2 = norm_layer(width)
180 | self.conv3 = conv1x1(width, planes * self.expansion)
181 | self.bn3 = norm_layer(planes * self.expansion)
182 | self.relu = nn.ReLU(inplace=True)
183 | self.downsample = downsample
184 | self.stride = stride
185 |
186 | def forward(self, x):
187 | identity = x
188 |
189 | out = self.conv1(x)
190 | out = self.bn1(out)
191 | out = self.relu(out)
192 |
193 | out = self.conv2(out)
194 | out = self.bn2(out)
195 | out = self.relu(out)
196 |
197 | out = self.conv3(out)
198 | out = self.bn3(out)
199 |
200 | if self.downsample is not None:
201 | identity = self.downsample(x)
202 |
203 | out += identity
204 | out = self.relu(out)
205 |
206 | return out
207 |
208 |
209 | def _resnet(arch, block, layers, modality, progress, **kwargs):
210 | model = ResNet(block, layers, modality, **kwargs)
211 | return model
212 |
213 |
214 | def resnet18(modality, progress=True, **kwargs):
215 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], modality, progress,
216 | **kwargs)
217 |
--------------------------------------------------------------------------------
/code/models/Visual_Classifier.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | from .Resnet_18 import resnet18
6 |
7 | class Classifier(nn.Module):
8 |
9 | def __init__(self,cfg,device='cuda:0'):
10 | super().__init__()
11 |
12 | self.encoder_2=resnet18(modality=cfg.modality[1])
13 |
14 | self.cfg=cfg
15 | self.device=device
16 |
17 | self.linear=nn.Linear(512,31)
18 |
19 | def forward(self,mod_1):
20 | B=mod_1.shape[0]
21 | out_1=self.encoder_2(mod_1)
22 |
23 | _,C,H,W=out_1.shape
24 | out_1=out_1.reshape(B,-1,C,H,W).permute(0,2,1,3,4)
25 | out_1=F.adaptive_avg_pool3d(out_1,1)
26 | out_1=torch.flatten(out_1,1)
27 |
28 | out_1=self.linear(out_1)
29 | return out_1
30 |
31 |
32 |
33 | if __name__ == "__main__":
34 | pass
35 |
--------------------------------------------------------------------------------
/code/models/fusion_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class SumFusion(nn.Module):
5 |
6 | def __init__(self,in_c_x,out_c_x,in_c_y,out_c_y) -> None:
7 | super().__init__()
8 |
9 | self.fx=nn.Linear(in_c_x,out_c_x)
10 | self.fy=nn.Linear(in_c_y,out_c_y)
11 |
12 | def forward(self,x,y):
13 | out=self.fx(x)+self.fy(y)
14 | return x,y,out
15 |
16 | class ConcatFusion(nn.Module):
17 |
18 | def __init__(self,in_c_x,in_c_y,out_c) -> None:
19 | super().__init__()
20 | self.fxy=nn.Linear(in_c_x+in_c_y,out_c)
21 |
22 | def forward(self,x,y):
23 | out=torch.cat([x,y],dim=1)
24 | out=self.fxy(out)
25 | return x,y,out
26 |
27 | class GatedFusion(nn.Module):
28 |
29 | def __init__(self,in_c_x,in_c_y,mid_c,out_c,x_gate=True) -> None:
30 | super().__init__()
31 |
32 | self.fx=nn.Linear(in_c_x,mid_c)
33 | self.fy=nn.Linear(in_c_y,mid_c)
34 | self.f_out=nn.Linear(mid_c,out_c)
35 |
36 | self.x_gate=x_gate
37 | self.sigmoid=nn.Sigmoid()
38 |
39 | def forward(self,x,y):
40 | out_x=self.fx(x)
41 | out_y=self.fy(y)
42 |
43 | if self.x_gate:
44 | gate=self.sigmoid(out_x)
45 | out=self.f_out(torch.mul(gate,out_y))
46 | else:
47 | gate=self.sigmoid(out_y)
48 | out=self.f_out(torch.mul(out_x,gate))
49 |
50 | return out_x,out_y,out
51 |
52 | from torch.autograd import Variable
53 | from torch.nn.parameter import Parameter
54 | class LMF(nn.Module):
55 |
56 | def __init__(self,rank=4,hidden_dim=512,out_dim=31,device='cuda:0'):
57 | super().__init__()
58 | self.device=device
59 | self.rank=rank
60 | self.hidden_dim=hidden_dim
61 | self.out_dim=out_dim
62 | self.x_factor=Parameter(torch.Tensor(self.rank,self.hidden_dim+1,self.out_dim)).to(device) # r,d+1,cls
63 | self.y_factor=Parameter(torch.Tensor(self.rank,self.hidden_dim+1,self.out_dim)).to(device)
64 | self.fusion_weights=Parameter(torch.Tensor(1,self.rank)).to(device) # 1,r
65 | self.fusion_bias=Parameter(torch.Tensor(1,self.out_dim)).to(device)
66 |
67 | torch.nn.init.xavier_normal_(self.x_factor)
68 | torch.nn.init.xavier_normal_(self.y_factor)
69 | torch.nn.init.xavier_normal_(self.fusion_weights)
70 | self.fusion_bias.data.fill_(0)
71 |
72 | def forward(self,x,y):
73 | b=x.shape[0]
74 | _x=torch.cat((Variable(torch.ones(b,1).to(self.device),requires_grad=False),x),dim=1) # b,d+1
75 | _y=torch.cat((Variable(torch.ones(b,1).to(self.device),requires_grad=False),y),dim=1)
76 |
77 | fusion_x=torch.matmul(_x,self.x_factor) # r,b,cls
78 | fusion_y=torch.matmul(_y,self.y_factor)
79 | fusion_zy=fusion_x*fusion_y
80 |
81 | output=torch.matmul(self.fusion_weights,fusion_zy.permute(1,0,2)).squeeze()+self.fusion_bias # b,cls
82 | # output=output.view(-1,self.out_dim)
83 |
84 | return output,x,y
85 |
86 | if __name__ == "__main__":
87 | net=GatedFusion(10,10,10,20)
88 | x=torch.zeros([1,10])
89 | y=torch.zeros([1,10])
90 | x_out,y_out,z=net(x,y)
91 | print(x_out.shape,y_out.shape) # torch.Size([1, 10]) torch.Size([1, 10])
92 | print(z.shape) # torch.Size([1, 20])
93 |
94 | print(net.weight)
--------------------------------------------------------------------------------
/code/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.31.0
3 | aiohttp==3.9.1
4 | aiosignal==1.3.1
5 | alembic==1.13.1
6 | annotated-types==0.6.0
7 | antlr4-python3-runtime==4.9.3
8 | anyio==4.2.0
9 | asteroid-filterbanks==0.4.0
10 | async-timeout==4.0.3
11 | attrs==23.2.0
12 | audioread==3.0.1
13 | certifi==2023.11.17
14 | cffi==1.16.0
15 | charset-normalizer==3.3.2
16 | click==8.1.7
17 | colorama==0.4.6
18 | colorlog==6.8.0
19 | contourpy==1.2.0
20 | cycler==0.12.1
21 | decorator==4.4.2
22 | decord==0.6.0
23 | deepspeed==0.14.3
24 | distro==1.9.0
25 | docopt==0.6.2
26 | einops==0.7.0
27 | exceptiongroup==1.2.0
28 | filelock==3.13.1
29 | fonttools==4.47.2
30 | frozenlist==1.4.1
31 | fsspec==2023.12.2
32 | greenlet==3.0.3
33 | grpcio==1.63.0
34 | h11==0.14.0
35 | h5py==3.11.0
36 | hjson==3.1.0
37 | httpcore==1.0.2
38 | httpx==0.26.0
39 | huggingface-hub==0.23.4
40 | HyperPyYAML==1.2.2
41 | idna==3.6
42 | imageio==2.34.1
43 | imageio-ffmpeg==0.5.1
44 | importlib-resources==6.1.1
45 | importlib_metadata==7.1.0
46 | Jinja2==3.1.3
47 | joblib==1.3.2
48 | julius==0.2.7
49 | kiwisolver==1.4.5
50 | lazy_loader==0.3
51 | librosa==0.10.1
52 | lightning==2.1.3
53 | lightning-utilities==0.10.1
54 | llvmlite==0.41.1
55 | Mako==1.3.0
56 | Markdown==3.6
57 | markdown-it-py==3.0.0
58 | MarkupSafe==2.1.4
59 | matplotlib==3.8.2
60 | mdurl==0.1.2
61 | more-itertools==10.2.0
62 | moviepy==1.0.3
63 | mpmath==1.3.0
64 | msgpack==1.0.7
65 | multidict==6.0.4
66 | networkx==3.2.1
67 | ninja==1.11.1.1
68 | numba==0.58.1
69 | numpy==1.26.3
70 | nvidia-cublas-cu12==12.1.3.1
71 | nvidia-cuda-cupti-cu12==12.1.105
72 | nvidia-cuda-nvrtc-cu12==12.1.105
73 | nvidia-cuda-runtime-cu12==12.1.105
74 | nvidia-cudnn-cu12==8.9.2.26
75 | nvidia-cufft-cu12==11.0.2.54
76 | nvidia-curand-cu12==10.3.2.106
77 | nvidia-cusolver-cu12==11.4.5.107
78 | nvidia-cusparse-cu12==12.1.0.106
79 | nvidia-ml-py==12.555.43
80 | nvidia-nccl-cu12==2.18.1
81 | nvidia-nvjitlink-cu12==12.3.101
82 | nvidia-nvtx-cu12==12.1.105
83 | omegaconf==2.3.0
84 | openai==1.9.0
85 | openai-whisper==20231117
86 | opencv-python==4.9.0.80
87 | optuna==3.5.0
88 | packaging==23.2
89 | pandas==2.2.0
90 | peft==0.3.0
91 | pillow==10.2.0
92 | platformdirs==4.1.0
93 | pooch==1.8.0
94 | primePy==1.3
95 | proglog==0.1.10
96 | protobuf==4.25.2
97 | psutil==6.0.0
98 | py-cpuinfo==9.0.0
99 | pyannote.core==5.0.0
100 | pyannote.database==5.0.1
101 | pyannote.metrics==3.2.1
102 | pyannote.pipeline==3.0.1
103 | pycparser==2.21
104 | pydantic==2.5.3
105 | pydantic_core==2.14.6
106 | Pygments==2.17.2
107 | pyparsing==3.1.1
108 | python-dateutil==2.8.2
109 | pytorch-lightning==2.1.3
110 | pytorch-metric-learning==2.4.1
111 | pytube==15.0.0
112 | pytz==2023.3.post1
113 | PyYAML==6.0.1
114 | regex==2023.12.25
115 | requests==2.31.0
116 | rich==13.7.0
117 | ruamel.yaml==0.18.5
118 | ruamel.yaml.clib==0.2.8
119 | safetensors==0.4.3
120 | scikit-learn==1.4.0
121 | scipy==1.11.4
122 | semver==3.0.2
123 | sentencepiece==0.1.99
124 | shellingham==1.5.4
125 | six==1.16.0
126 | sniffio==1.3.0
127 | sortedcontainers==2.4.0
128 | soundfile==0.12.1
129 | soxr==0.3.7
130 | speechbrain==0.5.16
131 | SQLAlchemy==2.0.25
132 | sympy==1.12
133 | tabulate==0.9.0
134 | tensorboard==2.16.2
135 | tensorboard-data-server==0.7.2
136 | tensorboardX==2.6.2.2
137 | threadpoolctl==3.2.0
138 | tiktoken==0.5.2
139 | timm==0.9.12
140 | tokenizers==0.19.1
141 | torch==1.13.1+cu116
142 | torch-audiomentations==0.11.0
143 | torch-pitch-shift==1.2.4
144 | torchaudio==0.13.1+cu116
145 | torchmetrics==1.3.0.post0
146 | torchvision==0.14.1+cu116
147 | tqdm==4.66.1
148 | transformers==4.41.2
149 | triton==2.1.0
150 | typer==0.9.0
151 | typing_extensions==4.9.0
152 | tzdata==2023.4
153 | urllib3==2.1.0
154 | Werkzeug==3.0.3
155 | yarl==1.9.4
156 | youtube-dl==2021.12.17
157 | zipp==3.17.0
158 |
--------------------------------------------------------------------------------
/code/scripts/inference.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --resume_model \
3 | --resume_model_path 'ckpt_path'
4 |
5 |
--------------------------------------------------------------------------------
/code/scripts/train_ogm.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --train \
3 | --use_modulation \
4 | --fusion_method concat \
5 | --alpha 0.8 \
6 | --modulation_starts 0 \
7 | --modulation_ends 60
8 |
9 |
--------------------------------------------------------------------------------
/code/scripts/train_opm.sh:
--------------------------------------------------------------------------------
1 | python main.py \
2 | --train \
3 | --use_adam_drop \
4 | --fusion_method concat \
5 | --q_base 0.5 \
6 | --lam 0.5 \
7 | --p_exe 0.7
8 |
--------------------------------------------------------------------------------