├── .gitignore ├── README.md ├── balancemm ├── __main__.py ├── analysis │ └── t_sne.py ├── datasets │ ├── KS_dataset.py │ ├── Mosei_dataset.py │ ├── VGG_dataset.py │ ├── __init__.py │ ├── balance_dataset.py │ ├── cremad_dataset.py │ ├── food_dataset.py │ └── ucf101_dataset.py ├── encoders │ ├── VisionTransformer_encoder.py │ ├── __init__.py │ ├── pretrained_encoder.py │ └── resnet18_encoder.py ├── evaluation │ ├── __init__.py │ ├── complex.py │ ├── modalitys.py │ └── precisions.py ├── models │ ├── __init__.py │ ├── avclassify_model.py │ ├── encoders.py │ ├── fusion_arch.py │ └── resnet_arch.py ├── test.py ├── train.py ├── trainer │ ├── AGM_trainer.py │ ├── AMCo_trainer.py │ ├── CML_trainer.py │ ├── GBlending_trainer.py │ ├── Greedy_trainer.py │ ├── LFM_trainer.py │ ├── LinearProbe_trainer.py │ ├── MBSD_trainer.py │ ├── MLA_trainer.py │ ├── MMCosine_trainer.py │ ├── MMPareto_trainer.py │ ├── OGM_trainer.py │ ├── OPM_trainer.py │ ├── PMR_trainer.py │ ├── ReLearning_trainer.py │ ├── ReconBoost_trainer.py │ ├── Sample_trainer.py │ ├── UMT_trainer.py │ ├── __init__.py │ ├── base_trainer.py │ ├── baseline_trainer.py │ └── unimodal_trainer.py └── utils │ ├── data_utils.py │ ├── encoder_module.py │ ├── logger.py │ ├── optimizer.py │ ├── parser_utils.py │ ├── scheduler.py │ └── train_utils.py ├── configs ├── dataset_config.yaml ├── encoder_config.yaml ├── global_config.yaml ├── model_config.yaml ├── trainer_config.yaml └── user_default.yaml ├── environment ├── images ├── Algorithms.jpeg ├── Results.jpeg └── frame6_00.png ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore experiments 2 | experiments/ 3 | # ignore python cache 4 | __pycache__/ 5 | *.out 6 | # ignore temps 7 | temps/ 8 | training_* 9 | result* 10 | # ignore setup -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BalanceBenchmark: A Survey for Multimodal Imbalance Learning 2 | 3 | ## Paper 4 | [**BalanceBenchmark: A Survey for Multimodal Imbalance Learning**](http://arxiv.org/abs/2502.10816)
5 | Shaoxuan Xu, Menglu Cui, Chengxiang Huang, Hongfa Wang and Di Hu 6 | 7 | If you find this repository useful, please cite our paper and corresponding toolkit: 8 | ```bibtex 9 | @misc{xu2025balancebenchmarksurveymultimodalimbalancelearning, 10 | title={BalanceBenchmark: A Survey for Multimodal Imbalance Learning}, 11 | author={Shaoxuan Xu and Menglu Cui and Chengxiang Huang and Hongfa Wang and Di Hu}, 12 | year={2025}, 13 | eprint={2502.10816}, 14 | archivePrefix={arXiv}, 15 | primaryClass={cs.LG}, 16 | url={https://arxiv.org/abs/2502.10816}, 17 | } 18 | ``` 19 | 20 | ## Overview 21 | ![](images/frame6_00.png) 22 | 23 | Multimodal learning has gained attention for its capacity to integrate information from different modalities. However, it is often hindered by the multimodal imbalance problem, where certain modalities disproportionately dominate while others remain underutilized. Although recent studies have proposed various methods to alleviate this problem, they lack comprehensive and fair comparisons. 24 | To facilitate this field, we introduce BalanceBenchmark, a systematic and unified benchmark for evaluating multimodal imbalance learning methods. BalanceBenchmark spans 17 algorithms and 7 datasets, providing a comprehensive framework for method evaluation and comparison. 25 | 26 | To accompany BalanceBenchmark, we release **BalanceMM**, a standardized toolkit that implements 17 state-of-the-art approaches spanning four research directions: data-level adjustments, feed-forward modifications, objective adaptations, and optimization-based methods. The toolkit provides a standardized pipeline that unifies innovations in fusion paradigms, optimization objectives, and training approaches. 27 | Our toolkit simplifies the research workflow through: 28 | 29 | + Standardized data loading for 7 multimodal datasets 30 | + Unified implementation of various imbalance learning methods 31 | + Automated experimental pipeline from training to evaluation 32 | + Comprehensive metrics for assessing performance, imbalance degree, and complexity 33 | 34 | BalanceMM is designed with modularity and extensibility in mind, enabling easy integration of new methods and datasets. It provides researchers with the necessary tools to reproduce experiments, conduct fair comparisons, and develop new approaches for addressing the multimodal imbalance problem. 35 | ## Datasets currently supported 36 | + Audio-Visual: KineticsSounds, CREMA-D, BalancedAV, VGGSound 37 | + RGB-Optical Flow: UCF-101 38 | + Image-Text: FOOD-101 39 | + Audio-Visual-Text: CMU-MOSEI 40 | 41 | To add a new dataset: 42 | 43 | 1. Go to `balancemm/datasets/` 44 | 2. Create a new Python file and a new dataset class 45 | 3. Implement the required data loading and preprocessing methods in the `corresponding_dataset.py` file 46 | 4. Add configuration file in `balancemm/configs/dataset_config.yaml` 47 | 48 | ## Algorithms currently supported 49 | + Data-level methods: Modality-valuation 50 | + Feed-forward methods: MLA, OPM, Greedy, AMCo 51 | + Objective methods: MMCosine, UMT, MBSD, CML, MMPareto, GBlending, LFM 52 | + Optimization methods: OGM, AGM, PMR, Relearning, ReconBoost 53 | 54 | See Section 3 in our paper for detailed descriptions of each method. 55 | 56 | ![](images/Algorithms.jpeg) 57 | 58 | To add a new method: 59 | 60 | 1. Determine which category your method belongs to: 61 | + "Data" : methods that adjust data processing 62 | + "Feed-forward" : methods that modify network architecture 63 | + "Objective" : methods that adapt learning objectives 64 | + "Optimization" : methods that adjust optimization process 65 | 2. Go to `balancemm/trainer/` 66 | 3. Create a new Python file implementing your method 67 | 4. Implement the `corresponding_trainer.py` file based on `base_trainer.py`, you should rewrite `trainer.training_step` usually. 68 | 5. Other implementation by your method's category: 69 | + If your method belongs to "Data", go to `balancemm/datasets/__init.py` and modify properly. 70 | + If your method belongs to "Feed-forward", go to `balancemm/models/avclassify_model.py`, create a new model class and rewrite specific functions. 71 | + If your method belongs to "Objective", you mostly don't have to do other modification except traienr. 72 | + If your method belongs to "Optimization", you may need to modify trainer or any parts mentioned above. 73 | + You can also modify any combination of the parts metioned above according to your method. 74 | 6. Add configuration file in `balancemm/configs/trainer_config.yaml` 75 | ## Installation 76 | ``` 77 | git clone https://github.com/GeWu-Lab/BalanceBenchmark.git 78 | cd BalanceBenchmark 79 | conda create -n balancemm python=3.10 80 | conda activate balancemm 81 | pip install torch==1.12.1+cu113 82 | pip install -r requirements.txt 83 | pip install lightning==2.0.0 84 | pip install lightning-cloud==0.5.68 85 | pip install lightning-utilities==0.11.2 86 | ``` 87 | ## Experiment 88 | To run experiments, you'll need to download the datasets from their open-sourced links. After downloading, place the datasets in your preferred directory and update the dataset path in your configuration file. 89 | 90 | You can run any experiment using a single command line: 91 | ``` 92 | python -m balancemm \ 93 | --trainer [trainer_name] \ 94 | --dataset [dataset_name] \ 95 | --model [model_name] \ 96 | --hyper-params [param_file.yaml] \ 97 | --device [0/cpu] 98 | ``` 99 | For example, to run OGM on CREMA-D dataset: 100 | ``` 101 | python -m balancemm \ 102 | --trainer OGM \ 103 | --dataset CREMAD \ 104 | --model BaseClassifier \ 105 | --alpha 0.5 \ 106 | --device 0 107 | ``` 108 | ## Results 109 | We have conducted comprehensive experiments using the proposed BanlenceBenchmark on 7 datasets. The results indicate that almost all related methods outperform the Baseline in terms of accuracy and F1 score, demonstrating that the multimodal imbalance problem is prevalent across various scenarios. 110 | 111 | ![](images/Results.jpeg) 112 | 113 | -------------------------------------------------------------------------------- /balancemm/analysis/t_sne.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.manifold import TSNE 3 | import matplotlib.pyplot as plt 4 | from matplotlib.colors import ListedColormap 5 | 6 | def visualize_tsne_2d(datas, perplexity=30, n_components=2, random_state=42, save_dir = 't-sne', modality = '', n_class = 10): 7 | """ 8 | use t-SNE to do dimensionality reduction visualization 9 | 10 | """ 11 | data = [] 12 | labels = [] 13 | for i in range(len(datas)): 14 | data.append(datas[i][0].cpu().detach().numpy()) 15 | labels.append(datas[i][1].cpu().numpy()) 16 | data = np.array(data) 17 | labels = np.array(labels) 18 | colors = plt.cm.rainbow(np.linspace(0, 1, n_class)) 19 | custom_cmap = ListedColormap(colors) 20 | 21 | tsne = TSNE( 22 | n_components=n_components, 23 | perplexity=perplexity, 24 | random_state=random_state 25 | ) 26 | tsne_results = tsne.fit_transform(data) 27 | 28 | plt.figure(figsize=(10, 8)) 29 | 30 | if labels is not None: 31 | scatter = plt.scatter( 32 | tsne_results[:, 0], 33 | tsne_results[:, 1], 34 | c=labels, 35 | cmap=custom_cmap 36 | ) 37 | plt.colorbar(scatter) 38 | else: 39 | plt.scatter( 40 | tsne_results[:, 0], 41 | tsne_results[:, 1], 42 | alpha=0.5 43 | ) 44 | 45 | plt.title(f't-SNE Visualization {modality} {n_components}d') 46 | plt.xlabel('t-SNE 1') 47 | plt.ylabel('t-SNE 2') 48 | plt.savefig(save_dir) 49 | return plt.gcf() 50 | 51 | def visualize_tsne_3d(datas, perplexity=30, n_components=3, random_state=42, save_dir='t-sne', modality='',n_class = 10): 52 | """ 53 | use t-SNE to do 3D visualization 54 | 55 | """ 56 | # prepare data 57 | data = [] 58 | labels = [] 59 | for i in range(len(datas)): 60 | data.append(datas[i][0].cpu().detach().numpy()) 61 | labels.append(datas[i][1].cpu().numpy()) 62 | data = np.array(data) 63 | labels = np.array(labels) 64 | colors = plt.cm.rainbow(np.linspace(0, 1, n_class)) 65 | custom_cmap = ListedColormap(colors) 66 | # t-SNE dimension reduction 67 | tsne = TSNE( 68 | n_components=n_components, 69 | perplexity=perplexity, 70 | random_state=random_state 71 | ) 72 | tsne_results = tsne.fit_transform(data) 73 | 74 | fig = plt.figure(figsize=(12, 10)) 75 | ax = fig.add_subplot(111, projection='3d') 76 | 77 | scatter = ax.scatter( 78 | tsne_results[:, 0], 79 | tsne_results[:, 1], 80 | tsne_results[:, 2], 81 | c=labels, 82 | cmap=custom_cmap, 83 | alpha=0.6 84 | ) 85 | 86 | plt.colorbar(scatter) 87 | ax.set_title(f't-SNE Visualization {modality} {n_components}d') 88 | ax.set_xlabel('t-SNE 1') 89 | ax.set_ylabel('t-SNE 2') 90 | ax.set_zlabel('t-SNE 3') 91 | 92 | ax.grid(True) 93 | 94 | plt.savefig(save_dir) 95 | 96 | for angle in range(0, 360, 45): 97 | ax.view_init(30, angle) 98 | plt.savefig(f'{save_dir}_angle_{angle}_{modality}.jpg') 99 | 100 | return fig 101 | -------------------------------------------------------------------------------- /balancemm/datasets/KS_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import math 3 | import os 4 | import random 5 | import copy 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional 9 | import torchaudio 10 | from PIL import Image 11 | from scipy import signal 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | 15 | class KineticsSoundsDataset(Dataset): 16 | def __init__(self, args:dict, transforms=None): 17 | self.data = [] 18 | self.label = [] 19 | self.mode = args['mode'] 20 | if self.mode == "train": 21 | self.csv_path = args['csv_path_train'] 22 | self.audio_path = args['audio_path_train'] 23 | self.visual_path = args['visual_path_train'] 24 | else: 25 | self.csv_path = args['csv_path_test'] 26 | self.audio_path = args['audio_path_test'] 27 | self.visual_path = args['visual_path_test'] 28 | 29 | 30 | with open(self.csv_path) as f: 31 | for line in f: 32 | item = line.split("\n")[0].split(" ") 33 | name = item[0] 34 | 35 | if os.path.exists(self.audio_path + '/' + name + '.npy'): 36 | path = self.visual_path + '/' + name 37 | files_list=[lists for lists in os.listdir(path)] 38 | if(len(files_list)>3): 39 | self.data.append(name) 40 | self.label.append(int(item[-1])) 41 | 42 | print('data load finish') 43 | self.transforms = transforms 44 | 45 | self._init_atransform() 46 | 47 | print('# of files = %d ' % len(self.data)) 48 | 49 | def _init_atransform(self): 50 | self.aid_transform = transforms.Compose([transforms.ToTensor()]) 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | def __getitem__(self, idx): 56 | av_file = self.data[idx] 57 | 58 | spectrogram = np.load(self.audio_path + '/' + av_file + '.npy') 59 | spectrogram = np.expand_dims(spectrogram, axis=0) 60 | 61 | # Visual 62 | path = self.visual_path + '/' + av_file 63 | files_list=[lists for lists in os.listdir(path)] 64 | file_num = len([fn for fn in files_list if fn.endswith("jpg")]) 65 | if self.mode == 'train': 66 | transf = transforms.Compose([ 67 | transforms.RandomResizedCrop(224), 68 | transforms.RandomHorizontalFlip(), 69 | transforms.ToTensor(), 70 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 71 | ]) 72 | else: 73 | transf = transforms.Compose([ 74 | transforms.Resize(size=(224, 224)), 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 77 | ]) 78 | 79 | pick_num = 3 80 | seg = int(file_num / pick_num) 81 | path1 = [] 82 | image = [] 83 | image_arr = [] 84 | t = [0] * pick_num 85 | 86 | for i in range(pick_num): 87 | if self.mode == 'train': 88 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1 89 | if t[i] >= 10: 90 | t[i] = 9 91 | else: 92 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1 93 | 94 | path1.append('frame_0000' + str(t[i]) + '.jpg') 95 | image.append(Image.open(path + "/" + path1[i]).convert('RGB')) 96 | 97 | image_arr.append(transf(image[i])) 98 | image_arr[i] = image_arr[i].unsqueeze(1).float() 99 | 100 | if i == 0: 101 | image_n = copy.copy(image_arr[i]) 102 | else: 103 | image_n = torch.cat((image_n, image_arr[i]), 1) 104 | 105 | 106 | label = self.label[idx] 107 | 108 | return {'visual':image_n, 'audio':spectrogram, 'label': label,'idx': idx} -------------------------------------------------------------------------------- /balancemm/datasets/Mosei_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import Dataset 3 | import pickle 4 | import os 5 | import torch 6 | 7 | 8 | def acc3(i): 9 | if i<-0.5: 10 | return 0 11 | if i>0.5: 12 | return 1 13 | return 2 14 | 15 | def acc7(i): 16 | if i < -2: 17 | res = 0 18 | if -2 <= i and i < -1: 19 | res = 1 20 | if -1 <= i and i < 0: 21 | res = 2 22 | if 0 <= i and i <= 0: 23 | res = 3 24 | if 0 < i and i <= 1: 25 | res = 4 26 | if 1 < i and i <= 2: 27 | res = 5 28 | if i > 2: 29 | res = 6 30 | return res 31 | 32 | def acc2(i): 33 | if i<0: 34 | return 0 35 | else : 36 | return 1 37 | 38 | class CMUMOSEIDataset(Dataset): 39 | def __init__(self, args: dict, transforms = None): 40 | super(CMUMOSEIDataset, self).__init__() 41 | dataset_path = args['dataset_path'] 42 | data= args['data'] 43 | split_type= args['mode'] 44 | if_align = args['if_align'] 45 | dataset_path = os.path.join(dataset_path, data+'_data.pkl' if if_align else data+'_data_noalign.pkl' ) 46 | dataset = pickle.load(open(dataset_path, 'rb')) 47 | 48 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach() 49 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach() 50 | self.audio = dataset[split_type]['audio'].astype(np.float32) 51 | self.audio[self.audio == -np.inf] = 0 52 | self.audio = torch.tensor(self.audio).cpu().detach() 53 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach() 54 | 55 | self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None 56 | 57 | self.data = data 58 | 59 | self.n_modalities = 3 # vision/ text/ audio 60 | def get_n_modalities(self): 61 | return self.n_modalities 62 | def get_seq_len(self): 63 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1] 64 | def get_dim(self): 65 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] 66 | def get_lbl_info(self): 67 | # return number_of_labels, label_dim 68 | return self.labels.shape[1], self.labels.shape[2] 69 | def __len__(self): 70 | return len(self.labels) 71 | 72 | def __getitem__(self, index): 73 | X = [index, self.text[index], self.audio[index], self.vision[index]] 74 | Y = self.labels[index] 75 | 76 | Y = acc2(Y[0,0]) 77 | META = (0,0,0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2]) 78 | if self.data == 'mosi': 79 | META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'), self.meta[index][2].decode('UTF-8')) 80 | if self.data == 'iemocap': 81 | Y = torch.argmax(Y, dim=-1) 82 | 83 | return {'text' : X[1], 'visual': X[3], 'audio' : X[2], 'label':Y, 'idx': index} 84 | -------------------------------------------------------------------------------- /balancemm/datasets/VGG_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import csv 3 | import os 4 | import pickle 5 | import librosa 6 | import numpy as np 7 | from scipy import signal 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | import pdb 13 | import random 14 | 15 | class VGGSoundDataset(Dataset): 16 | 17 | def __init__(self, args, mode='train'): 18 | 19 | self.args = args 20 | self.mode = args['mode'] 21 | train_video_data = [] 22 | train_audio_data = [] 23 | test_video_data = [] 24 | test_audio_data = [] 25 | train_label = [] 26 | test_label = [] 27 | train_class = [] 28 | test_class = [] 29 | csv_root = args['csv_root'] 30 | video_train_root = args['video_train_root'] 31 | video_test_root = args['video_test_root'] 32 | audio_train_root = args['audio_train_root'] 33 | audio_test_root = args['audio_test_root'] 34 | 35 | print(video_train_root) 36 | print(video_test_root) 37 | print(audio_train_root) 38 | print(audio_test_root) 39 | train_valid = 0 40 | test_valid = 0 41 | with open(csv_root) as f: 42 | csv_reader = csv.reader(f) 43 | for item in csv_reader: 44 | if item[3] == 'train': 45 | 46 | video_dir = os.path.join(video_train_root, item[0]+'_'+item[1]) 47 | audio_dir = os.path.join(audio_train_root, item[0]+'_'+item[1] + '.npy') 48 | 49 | if os.path.exists(video_dir) and os.path.exists(audio_dir) and len(os.listdir(video_dir))>3 : 50 | train_video_data.append(video_dir) 51 | train_audio_data.append(audio_dir) 52 | if item[2] not in train_class: train_class.append(item[2]) 53 | train_label.append(item[2]) 54 | train_valid += 1 55 | 56 | if item[3] == 'test': 57 | 58 | video_dir = os.path.join(video_test_root, item[0]+'_'+item[1]) 59 | audio_dir = os.path.join(audio_test_root, item[0]+'_'+item[1] + '.npy') 60 | 61 | if os.path.exists(video_dir) and os.path.exists(audio_dir) and len(os.listdir(video_dir))>3: 62 | test_video_data.append(video_dir) 63 | test_audio_data.append(audio_dir) 64 | if item[2] not in test_class: test_class.append(item[2]) 65 | test_label.append(item[2]) 66 | test_valid += 1 67 | 68 | print("Get Valid Train Sample: " + str(train_valid)) 69 | print("Get Valid Test Sample: " + str(test_valid)) 70 | 71 | assert len(train_class) == len(test_class) 72 | 73 | if len(train_class) == 0: 74 | raise ValueError("If you see this, it means you have problem in reading dataset") 75 | 76 | self.classes = train_class 77 | 78 | class_dict = dict(zip(self.classes, range(len(self.classes)))) 79 | 80 | if self.mode == 'train': 81 | self.video = train_video_data 82 | self.audio = train_audio_data 83 | self.label = [class_dict[train_label[idx]] for idx in range(len(train_label))] 84 | else: 85 | self.video = test_video_data 86 | self.audio = test_audio_data 87 | self.label = [class_dict[test_label[idx]] for idx in range(len(test_label))] 88 | 89 | 90 | def __len__(self): 91 | return len(self.video) 92 | 93 | def __getitem__(self, idx): 94 | 95 | spectrogram = np.load(self.audio[idx]) 96 | 97 | # Def Image Transform 98 | if self.mode == 'train': 99 | transform = transforms.Compose([ 100 | transforms.RandomResizedCrop(224), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 104 | ]) 105 | else: 106 | transform = transforms.Compose([ 107 | transforms.Resize(size=(224, 224)), 108 | transforms.ToTensor(), 109 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | ]) 111 | 112 | pick_num = self.args['use_video_frames'] 113 | image_samples = os.listdir(self.video[idx]) 114 | image_samples = sorted(image_samples) 115 | file_num = len(image_samples) 116 | select_index = np.random.choice(len(image_samples), size=pick_num, replace=False) 117 | select_index.sort() 118 | images = torch.zeros((pick_num, 3, 224, 224)) 119 | t = [0] * pick_num 120 | seg = (file_num//pick_num) 121 | for i in range(pick_num): 122 | if self.mode == 'train': 123 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1 124 | if t[i] >= 10: 125 | t[i] = 9 126 | else: 127 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1 128 | for i, idx_frame in enumerate(select_index): 129 | img_path = os.path.join(self.video[idx], image_samples[t[i]-1]) 130 | # img_path = os.path.join(self.video[idx], image_samples[min(i * seg + max(int(seg/2), 1), len(image_samples)-1)]) 131 | img = Image.open(img_path).convert('RGB') 132 | img = transform(img) 133 | images[i] = img 134 | 135 | spectrogram = torch.tensor(spectrogram).unsqueeze(0).float() 136 | images = images.permute(1,0,2,3) 137 | 138 | # label 139 | label = self.label[idx] 140 | 141 | return { 142 | 'audio': spectrogram, 143 | 'visual': images, 144 | 'label': label, 145 | 'idx': idx 146 | } 147 | 148 | -------------------------------------------------------------------------------- /balancemm/datasets/balance_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import h5py 3 | import os 4 | import pickle 5 | from PIL import Image 6 | import io 7 | import torch 8 | import csv 9 | from torchvision import transforms 10 | from torch.utils.data import Dataset 11 | import numpy as np 12 | import random 13 | import copy 14 | 15 | class BalanceDataset(Dataset): 16 | 17 | def __init__(self, args: dict, transforms=None): 18 | self.data = [] 19 | self.label = [] 20 | self.mode = None 21 | csv_path = args['csv_path'] 22 | self.visual_path = args['visual_path'] 23 | self.audio_path = args['audio_path'] 24 | 25 | self.mode = args['mode'] 26 | 27 | 28 | with open(csv_path) as f: 29 | annotation_data = json.load(f) 30 | all_data = annotation_data['database'] 31 | # choose = ['playing piano', 'playing cello', 'lawn mowing', 'singing', 'cleaning floor', 'bowling', 'swimming', 'whistling', 'motorcycling', 'playing flute', 'writing on blackboard', 'beat boxing'] 32 | class_labels = annotation_data['labels'] 33 | 34 | self.class_to_idx = {label : i for i,label in enumerate(class_labels)} 35 | print(len(class_labels)) 36 | 37 | for key in all_data.keys(): 38 | if all_data[key]['subset'] == (self.mode + 'ing'): 39 | if os.path.exists(self.visual_path + key + '.hdf5') and os.path.exists(self.audio_path + key + '.pkl'): 40 | self.data.append(key) 41 | self.label.append(self.class_to_idx[all_data[key]['label']]) 42 | 43 | 44 | print('data load finish') 45 | 46 | self.transforms = transforms 47 | 48 | self._init_atransform() 49 | 50 | print('# of files = %d ' % len(self.data)) 51 | 52 | def _init_atransform(self): 53 | self.aid_transform = transforms.Compose([transforms.ToTensor()]) 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, idx): 59 | av_file = self.data[idx] 60 | 61 | with open(self.audio_path + av_file + '.pkl',"rb") as f: 62 | spectrogram = pickle.load(f) 63 | spectrogram = np.expand_dims(spectrogram, axis=0) 64 | 65 | # Visual 66 | path = self.visual_path + av_file + '.hdf5' 67 | with h5py.File(path, 'r') as f: 68 | video_data = f['video'] 69 | file_num = len(video_data) 70 | 71 | if self.mode == 'training': 72 | 73 | transf = transforms.Compose([ 74 | transforms.RandomResizedCrop(224), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 78 | ]) 79 | else: 80 | transf = transforms.Compose([ 81 | transforms.Resize(size=(224, 224)), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 84 | ]) 85 | 86 | pick_num = 3 87 | seg = int(file_num / pick_num) 88 | image = [] 89 | image_arr = [] 90 | t = [0] * pick_num 91 | 92 | for i in range(pick_num): 93 | if self.mode == 'train': 94 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1 95 | if t[i] >= 9: 96 | t[i] = 8 97 | else: 98 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1 99 | 100 | image.append(Image.open(io.BytesIO(video_data[t[i]])).convert('RGB')) 101 | 102 | image_arr.append(transf(image[i])) 103 | image_arr[i] = image_arr[i].unsqueeze(1).float() 104 | if i == 0: 105 | image_n = copy.copy(image_arr[i]) 106 | else: 107 | image_n = torch.cat((image_n, image_arr[i]), 1) 108 | 109 | 110 | label = self.label[idx] 111 | 112 | return {'visual':image_n, 'audio':spectrogram, 'label': label,'idx': idx} 113 | # return image_n,spectrogram,label,idx -------------------------------------------------------------------------------- /balancemm/datasets/cremad_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import librosa 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | import os.path as osp 10 | 11 | class CREMADDataset(Dataset): 12 | def __init__(self, args: dict): 13 | 14 | self.image = [] 15 | self.audio = [] 16 | self.label = [] 17 | class_dict = {'NEU':0, 'HAP':1, 'SAD':2, 'FEA':3, 'DIS':4, 'ANG':5} 18 | 19 | self.mode = args['mode'] 20 | self.fps = args['fps'] 21 | self.visual_path = args['visual_path'] 22 | self.audio_path = args['audio_path'] 23 | self.train_txt = args['train_txt'] 24 | self.test_txt = args['test_txt'] 25 | self.aid_transform = transforms.Compose([transforms.ToTensor()]) 26 | if self.mode == 'train': 27 | self.csv_file = self.train_txt 28 | else: 29 | self.csv_file = self.test_txt 30 | self.data = [] 31 | self.label = [] 32 | with open(self.csv_file) as f: 33 | csv_reader = csv.reader(f) 34 | for item in csv_reader: 35 | if item[1] in class_dict and os.path.exists(osp.join(self.audio_path, item[0] + '.pt')) and os.path.exists(osp.join(self.visual_path, item[0])) and len(os.listdir(osp.join(self.visual_path, item[0]))) >= self.fps: 36 | self.data.append(item[0]) 37 | self.label.append(class_dict[item[1]]) 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, idx): 43 | datum = self.data[idx] 44 | 45 | # Audio 46 | fbank = torch.load(self.audio_path + datum + '.pt').unsqueeze(0) 47 | 48 | # Visual 49 | if self.mode == 'train': 50 | transf = transforms.Compose([ 51 | transforms.RandomResizedCrop(224), 52 | transforms.RandomHorizontalFlip(), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 55 | ]) 56 | else: 57 | transf = transforms.Compose([ 58 | transforms.Resize(size=(224, 224)), 59 | transforms.ToTensor(), 60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 61 | ]) 62 | 63 | folder_path = self.visual_path + datum 64 | file_num = len(os.listdir(folder_path)) 65 | file_list = os.listdir(folder_path) 66 | pick_num = self.fps 67 | seg = int(file_num/pick_num) 68 | image_arr = [] 69 | 70 | for i in range(pick_num): 71 | if self.mode == 'train': 72 | index = i*seg + np.random.randint(seg) 73 | else: 74 | index = i*seg + seg//2 75 | path = os.path.join(folder_path, file_list[index]) 76 | image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0)) 77 | 78 | images = torch.cat(image_arr) 79 | 80 | label = self.label[idx] 81 | images = images.permute(1,0,2,3) 82 | return {'audio': fbank, 'visual': images, 'label': label, 'idx': idx} 83 | 84 | if __name__ == '__main__': 85 | print('start') 86 | a = CREMADDataset({'mode':'train','fps':2}) 87 | a.__getitem__(0) -------------------------------------------------------------------------------- /balancemm/datasets/food_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import torch.utils.data as data 5 | from pathlib import Path 6 | from random import randrange 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from transformers import BertTokenizer 10 | import torch 11 | 12 | import os 13 | import pandas as pd 14 | from PIL import Image 15 | import re 16 | import random 17 | from torchvision import transforms 18 | 19 | def find_classes(directory) : 20 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 21 | if not classes: 22 | raise FileNotFoundError(f"Couldn't find any classes in {directory}.") 23 | 24 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 25 | idx_to_class = {i: cls_name for i, cls_name in enumerate(classes)} 26 | return classes, class_to_idx,idx_to_class 27 | 28 | # 1. Subclass torch.utils.data.Dataset 29 | class FOOD101Dataset(Dataset): 30 | 31 | # 2. Initialize with a targ_dir and transform (optional) parameter 32 | def __init__(self, args:dict,transform=None): 33 | # 3. Create class attributes 34 | # targ_dir = args['dataset_path'] 35 | targ_dir = args['targ_dir'] 36 | phase = args['mode'] 37 | mode = 'all' 38 | #resize = 384 39 | resize = 224 40 | train_transforms = transforms.Compose([transforms.RandomRotation(30), 41 | transforms.Resize((resize,resize)), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 45 | 46 | test_transforms = transforms.Compose([transforms.Resize((resize,resize)), 47 | #transforms.CenterCrop(resize), 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 50 | self.dataset_root = targ_dir 51 | self.csv_file_path = '%s/texts/%s_titles.csv' % (self.dataset_root, phase) 52 | self.img_dir='%s/images/%s' % (self.dataset_root, phase) 53 | 54 | self.data = pd.read_csv(self.csv_file_path) 55 | self.tokenizer = BertTokenizer.from_pretrained('') 56 | # Setup transforms 57 | if phase == 'train': 58 | self.transform = train_transforms 59 | else: 60 | self.transform = test_transforms 61 | if transform != None: self.transform = transform 62 | # Create classes and class_to_idx attributes 63 | self.classes, self.class_to_idx,self.idx_to_class = find_classes(self.img_dir) 64 | self.mode=mode 65 | 66 | # 4. Make function to load images 67 | def load_image(self, index,img_path): #-> Image.Image: 68 | "Opens an image via a path and returns it." 69 | return Image.open(img_path).convert('RGB') 70 | 71 | def clean_text(self,raw_text): 72 | t = re.sub(r'^RT[\s]+', '', raw_text)# remove old style retweet text "RT" 73 | t = re.sub(r'https?:\/\/.*[\r\n]*', '', t)# remove hyperlinks 74 | t = re.sub(r'#', '', t) # remove hashtags 75 | return t 76 | 77 | def tokenize(self, sentence): 78 | ids = self.tokenizer(sentence , 79 | padding='max_length', max_length=40, truncation=True).items() 80 | return {k: torch.tensor(v) for k, v in ids} 81 | 82 | # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset) 83 | def __len__(self):# -> int: 84 | "Returns the total number of samples." 85 | return len(self.data) 86 | 87 | # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset) 88 | def __getitem__(self, index): #returns Tuple[torch.Tensor, int]: 89 | "Returns one sample of data, data and label (X, y)." 90 | sample=self.data.iloc[index] 91 | txt=sample['text'] 92 | txt = self.clean_text(txt) 93 | text_tokens= self.tokenize(txt) 94 | class_name = sample['label'] 95 | class_idx = self.class_to_idx[class_name] 96 | if self.mode =="all": 97 | img_path = os.path.join(self.img_dir,sample['label'] ,sample["Image_path"] ) 98 | img = self.load_image(index,img_path) 99 | # Transform if necessary 100 | if self.transform: 101 | x = {'visual':self.transform(img).unsqueeze(1),'text':text_tokens['input_ids'].unsqueeze(0), 'label':class_idx,'idx': index} ##text:40 102 | return x 103 | else: 104 | return img, text_tokens, txt, class_idx ,index 105 | elif self.mode =="Text_only": 106 | return text_tokens, txt, class_idx ,index -------------------------------------------------------------------------------- /balancemm/datasets/ucf101_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from genericpath import isdir 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | 13 | import csv 14 | from genericpath import isdir 15 | import os 16 | import random 17 | import numpy as np 18 | import torch 19 | import torchvision.datasets as datasets 20 | import torchvision.transforms as transforms 21 | from PIL import Image 22 | from torch.utils.data import Dataset 23 | 24 | class UCF101Dataset(Dataset): 25 | 26 | def __init__(self, args:dict, v_norm = True, a_norm = False, name = "UCF101"): 27 | self.data = [] 28 | classes = [] 29 | data2class = {} 30 | self.mode= args['mode'] 31 | self.v_norm = v_norm 32 | self.a_norm = a_norm 33 | self.stat_path = args['stat_path'] 34 | self.train_txt = args['train_txt'] 35 | self.test_txt = args['test_txt'] 36 | self.visual_path = args['visual_path'] 37 | self.flow_path_v = args['flow_path_v'] 38 | self.flow_path_u = args['flow_path_u'] 39 | 40 | if self.mode == 'train': 41 | csv_file = self.train_txt 42 | else: 43 | csv_file = self.test_txt 44 | 45 | with open(self.stat_path) as f: 46 | for line in f: 47 | item = line.split("\n")[0].split(" ")[1] 48 | classes.append(item) 49 | with open(csv_file) as f: 50 | for line in f: 51 | class_name = line.split('/')[0] 52 | name = line.split('/')[1].split('.')[0] 53 | if os.path.isdir(self.visual_path + name) and os.path.isdir(self.flow_path_u + name) and os.path.isdir(self.flow_path_v + name): 54 | self.data.append(name) 55 | data2class[name] = class_name 56 | self.classes = sorted(classes) 57 | self.data2class = data2class 58 | self.class_num = len(self.classes) 59 | print(self.class_num) 60 | print('# of files = %d ' % len(self.data)) 61 | 62 | def _init_atransform(self): 63 | self.aid_transform = transforms.Compose([transforms.ToTensor()]) 64 | 65 | def __len__(self): 66 | return len(self.data) 67 | 68 | def __getitem__(self, idx): 69 | datum = self.data[idx] 70 | # crop = transforms.RandomResizedCrop(112, (1/4, 1.0), (3/4, 4/3)) 71 | if self.mode == 'train': 72 | rgb_transf = [ 73 | transforms.Resize(size=(224, 224)), 74 | transforms.RandomHorizontalFlip(), 75 | transforms.ToTensor(), 76 | ] 77 | diff_transf = [transforms.ToTensor()] 78 | 79 | flow_transf = [ 80 | transforms.Resize(size=(224, 224)), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.ToTensor(), 83 | ] 84 | else: 85 | rgb_transf = [ 86 | transforms.Resize(size=(224, 224)), 87 | transforms.ToTensor() 88 | ] 89 | diff_transf = [transforms.ToTensor()] 90 | flow_transf = [ 91 | transforms.Resize(size=(224, 224)), 92 | transforms.ToTensor(), 93 | ] 94 | 95 | if self.v_norm: 96 | rgb_transf.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 97 | diff_transf.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 98 | if self.a_norm : 99 | flow_transf.append(transforms.Normalize([0.1307], [0.3081])) 100 | rgb_transf = transforms.Compose(rgb_transf) 101 | diff_transf = transforms.Compose(diff_transf) 102 | flow_transf = transforms.Compose(flow_transf) 103 | folder_path = self.visual_path + datum 104 | 105 | ####### RGB 106 | file_num = 6 107 | 108 | pick_num = 3 109 | seg = int(file_num/pick_num) 110 | image_arr = [] 111 | 112 | for i in range(pick_num): 113 | if self.mode == 'train': 114 | chosen_index = random.randint(i*seg + 1, i*seg + seg) 115 | else: 116 | chosen_index = i*seg + max(int(seg/2), 1) 117 | path = folder_path + '/frame_0000' + str(chosen_index) + '.jpg' 118 | tranf_image = rgb_transf(Image.open(path).convert('RGB')) 119 | image_arr.append(tranf_image.unsqueeze(0)) 120 | 121 | images = torch.cat(image_arr) 122 | 123 | num_u = len(os.listdir(self.flow_path_u + datum)) 124 | pick_num = 3 125 | flow_arr = [] 126 | seg = int(num_u/pick_num) 127 | 128 | for i in range(pick_num): 129 | if self.mode == 'train': 130 | chosen_index = random.randint(i*seg + 1, i*seg + seg) 131 | else: 132 | chosen_index = i*seg + max(int(seg/2), 1) 133 | 134 | flow_u = self.flow_path_u + datum + '/frame00' + str(chosen_index).zfill(4) + '.jpg' 135 | flow_v = self.flow_path_v + datum + '/frame00' + str(chosen_index).zfill(4) + '.jpg' 136 | u = flow_transf(Image.open(flow_u)) 137 | v = flow_transf(Image.open(flow_v)) 138 | flow = torch.cat((u,v),0) 139 | flow_arr.append(flow.unsqueeze(0)) 140 | 141 | flow_n = torch.cat(flow_arr) 142 | images = images.permute(1,0,2,3) 143 | sample = { 144 | 'flow':flow_n, 145 | 'visual':images, 146 | 'label': self.classes.index(self.data2class[datum]), 147 | 'raw':datum, 148 | 'idx':idx 149 | } 150 | 151 | 152 | return sample 153 | 154 | -------------------------------------------------------------------------------- /balancemm/encoders/VisionTransformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | ## waiting for update 4 | class Transformer_Encoder(nn.Module): 5 | """Extends nn.Transformer.""" 6 | 7 | def __init__(self, input_dim, dim, n_head, n_layers): 8 | """Initialize Transformer object. 9 | 10 | Args: 11 | n_features (int): Number of features in the input. 12 | dim (int): Dimension which to embed upon / Hidden dimension size. 13 | """ 14 | super().__init__() 15 | self.embed_dim = dim 16 | self.conv = nn.Conv1d(input_dim, self.embed_dim, 17 | kernel_size=1, padding=0, bias=False) 18 | layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=n_head) 19 | self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers) 20 | 21 | def forward(self, x): 22 | """Apply Transformer to Input. 23 | 24 | Args: 25 | x (torch.Tensor): Layer Input 26 | 27 | Returns: 28 | torch.Tensor: Layer Output 29 | """ 30 | if type(x) is list: 31 | x = x[0] 32 | x = self.conv(x.permute([0, 2, 1])) 33 | x = x.permute([2, 0, 1]) 34 | x = self.transformer(x)[-1] 35 | return x -------------------------------------------------------------------------------- /balancemm/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from os import path as osp 4 | from ..utils.parser_utils import find_module 5 | import torch.nn as nn 6 | import torch 7 | __all__ = ['find_encoder', 'create_encoder'] 8 | from .pretrained_encoder import text_encoder 9 | from torchvision.models import vit_b_16, vit_h_14 10 | 11 | 12 | # import encoder modules from those with '_encoder' in file names 13 | _encoder_folder = osp.dirname(osp.abspath(__file__)) 14 | _encoder_filenames = [ 15 | osp.splitext(v)[0] for v in os.listdir(_encoder_folder) 16 | if v.endswith('_encoder.py') 17 | ] 18 | _encoder_modules = [ 19 | importlib.import_module(f'.{file_name}', package="balancemm.encoders") 20 | for file_name in _encoder_filenames 21 | ] 22 | 23 | # find encoder from encoder_opt 24 | def find_encoder(encoder_name: str) -> object: 25 | encoder_cls = find_module(_encoder_modules, encoder_name, 'Encoder') 26 | return encoder_cls 27 | 28 | def create_encoders(encoder_opt: dict[str, dict])->dict[str, nn.Module]: 29 | modalitys = encoder_opt.keys() 30 | encoders = {} 31 | for modality in modalitys: 32 | pre_train = encoder_opt[modality]['if_pretrain'] 33 | path = encoder_opt[modality]['pretrain_path'] 34 | name = encoder_opt[modality]['name'] 35 | if name == "ViT_B": 36 | if pre_train: 37 | encoders[modality] = vit_b_16(path) 38 | else: 39 | encoders[modality] = vit_b_16() 40 | continue 41 | del encoder_opt[modality]['pretrain_path'] 42 | del encoder_opt[modality]['if_pretrain'] 43 | encoder = find_encoder(encoder_opt[modality]['name']) 44 | del encoder_opt[modality]['name'] 45 | encoders[modality] = encoder(**encoder_opt[modality]) 46 | encoder_opt[modality]['name'] = name 47 | if pre_train: 48 | if modality == 'text': 49 | encoders[modality] = text_encoder() 50 | else: 51 | state = torch.load(path) 52 | if modality == 'flow': 53 | del state['conv1.weight'] 54 | encoders[modality].load_state_dict(state, strict=False) 55 | print('pretrain load finish') 56 | encoder_opt[modality]['if_pretrain'] = pre_train 57 | encoder_opt[modality]['pretrain_path'] = path 58 | print (f'Encoder {name} - {modality} is created.') 59 | return encoders -------------------------------------------------------------------------------- /balancemm/encoders/pretrained_encoder.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | from functools import partialmethod 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import precision_recall_fscore_support 8 | import os 9 | 10 | from transformers import BertModel, BertConfig 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | class text_encoder(nn.Module): 15 | def __init__(self, output_dim=1024): 16 | super().__init__() 17 | config = BertConfig() 18 | self.textEncoder= BertModel(config).from_pretrained('') 19 | self.linear = nn.Linear(config.hidden_size, output_dim) 20 | 21 | def forward(self, x): 22 | text = x.squeeze(1) 23 | hidden_states = self.textEncoder(text) 24 | e_i = self.linear(hidden_states[1]) 25 | e_i = F.dropout(e_i) 26 | return e_i 27 | 28 | class image_encoder(nn.Module): 29 | def __init__(self,model_arch,num_classes=10,weights="IMAGENET1K_V1",device="cuda"): 30 | super().__init__() 31 | self.model_arch=model_arch 32 | self.device=device 33 | self.num_classes=num_classes 34 | print(model_arch) 35 | self.model=torch.hub.load("pytorch/vision", self.model_arch, weights=weights).to(self.device) 36 | # print(self.model.parameters) 37 | self.model.heads = nn.Sequential() 38 | 39 | def forward(self,x): 40 | y=self.model(x) 41 | return y 42 | -------------------------------------------------------------------------------- /balancemm/encoders/resnet18_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 4 | """3x3 convolution with padding""" 5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 6 | padding=dilation, groups=groups, bias=False, dilation=dilation) 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | """1x1 convolution""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 18 | base_width=64, dilation=1, norm_layer=None): 19 | super(BasicBlock, self).__init__() 20 | if norm_layer is None: 21 | norm_layer = nn.BatchNorm2d 22 | if groups != 1 or base_width != 64: 23 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 24 | if dilation > 1: 25 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = norm_layer(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = norm_layer(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | identity = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class ResNet(nn.Module): 55 | def __init__(self, block, layers, modality, zero_init_residual=False, 56 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 57 | norm_layer=None, input_dim = None): 58 | super(ResNet, self).__init__() 59 | self.modality = modality 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | self._norm_layer = norm_layer 63 | 64 | self.inplanes = 64 65 | self.dilation = 1 66 | if replace_stride_with_dilation is None: 67 | # each element in the tuple indicates if we should replace 68 | # the 2x2 stride with a dilated convolution instead 69 | replace_stride_with_dilation = [False, False, False] 70 | if len(replace_stride_with_dilation) != 3: 71 | raise ValueError("replace_stride_with_dilation should be None " 72 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 73 | self.groups = groups 74 | self.base_width = width_per_group 75 | if modality == 'audio': 76 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 77 | bias=False) 78 | elif modality == 'visual' or modality == 'front_view' or modality == 'back_view': 79 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 80 | bias=False) 81 | elif modality == 'flow': 82 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3, 83 | bias=False) 84 | else: 85 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality)) 86 | self.bn1 = norm_layer(self.inplanes) 87 | self.relu = nn.ReLU(inplace=True) 88 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 89 | self.layer1 = self._make_layer(block, 64, layers[0]) 90 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 91 | dilate=replace_stride_with_dilation[0]) 92 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 93 | dilate=replace_stride_with_dilation[1]) 94 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 95 | dilate=replace_stride_with_dilation[2]) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 101 | nn.init.normal_(m.weight, mean=1, std=0.02) 102 | nn.init.constant_(m.bias, 0) 103 | 104 | if zero_init_residual: 105 | for m in self.modules(): 106 | if isinstance(m, Bottleneck): 107 | nn.init.constant_(m.bn3.weight, 0) 108 | elif isinstance(m, BasicBlock): 109 | nn.init.constant_(m.bn2.weight, 0) 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 112 | norm_layer = self._norm_layer 113 | downsample = None 114 | previous_dilation = self.dilation 115 | if dilate: 116 | self.dilation *= stride 117 | stride = 1 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | conv1x1(self.inplanes, planes * block.expansion, stride), 121 | norm_layer(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 126 | self.base_width, previous_dilation, norm_layer)) 127 | self.inplanes = planes * block.expansion 128 | for _ in range(1, blocks): 129 | layers.append(block(self.inplanes, planes, groups=self.groups, 130 | base_width=self.base_width, dilation=self.dilation, 131 | norm_layer=norm_layer)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | 137 | if self.modality == 'visual' or self.modality == 'flow': 138 | (B, T, C, H, W) = x.size() 139 | x = x.view(B * T, C, H, W) 140 | 141 | x = self.conv1(x) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | x = self.layer4(x) 150 | out = x 151 | 152 | return out 153 | 154 | class ResNet18Encoder(ResNet): 155 | def __init__(self, **kwargs): 156 | super(ResNet18Encoder, self).__init__(BasicBlock, [2, 2, 2, 2], **kwargs) 157 | 158 | class Bottleneck(nn.Module): 159 | expansion = 4 160 | 161 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 162 | base_width=64, dilation=1, norm_layer=None): 163 | super(Bottleneck, self).__init__() 164 | if norm_layer is None: 165 | norm_layer = nn.BatchNorm2d 166 | width = int(planes * (base_width / 64.)) * groups 167 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 168 | self.conv1 = conv1x1(inplanes, width) 169 | self.bn1 = norm_layer(width) 170 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 171 | self.bn2 = norm_layer(width) 172 | self.conv3 = conv1x1(width, planes * self.expansion) 173 | self.bn3 = norm_layer(planes * self.expansion) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.downsample = downsample 176 | self.stride = stride 177 | 178 | def forward(self, x): 179 | identity = x 180 | 181 | out = self.conv1(x) 182 | out = self.bn1(out) 183 | out = self.relu(out) 184 | 185 | out = self.conv2(out) 186 | out = self.bn2(out) 187 | out = self.relu(out) 188 | 189 | out = self.conv3(out) 190 | out = self.bn3(out) 191 | 192 | if self.downsample is not None: 193 | identity = self.downsample(x) 194 | 195 | out += identity 196 | out = self.relu(out) 197 | 198 | return out -------------------------------------------------------------------------------- /balancemm/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .precisions import BatchMetricsCalculator 2 | def Evaluation(trainer, model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, logger): 3 | if temp_model is None: 4 | trainer(model, train_dataloader, val_dataloader, optimizer, scheduler, logger) 5 | class ComprehensiveModelEvaluator: 6 | def __init__(self, args): 7 | self.Metrics = BatchMetricsCalculator(args['Metrics']) 8 | self.Complex = {} 9 | self.Modalitys = {} 10 | 11 | -------------------------------------------------------------------------------- /balancemm/evaluation/complex.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from thop.vision.basic_hooks import * 3 | from thop.rnn_hooks import * 4 | from thop.utils import prRed 5 | import torch.nn as nn 6 | ## thop 7 | register_hooks = { 8 | nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication. 9 | nn.Conv1d: count_convNd, 10 | nn.Conv2d: count_convNd, 11 | nn.Conv3d: count_convNd, 12 | nn.ConvTranspose1d: count_convNd, 13 | nn.ConvTranspose2d: count_convNd, 14 | nn.ConvTranspose3d: count_convNd, 15 | nn.BatchNorm1d: count_normalization, 16 | nn.BatchNorm2d: count_normalization, 17 | nn.BatchNorm3d: count_normalization, 18 | nn.LayerNorm: count_normalization, 19 | nn.InstanceNorm1d: count_normalization, 20 | nn.InstanceNorm2d: count_normalization, 21 | nn.InstanceNorm3d: count_normalization, 22 | nn.PReLU: count_prelu, 23 | nn.Softmax: count_softmax, 24 | nn.ReLU: zero_ops, 25 | nn.ReLU6: zero_ops, 26 | nn.LeakyReLU: count_relu, 27 | nn.MaxPool1d: zero_ops, 28 | nn.MaxPool2d: zero_ops, 29 | nn.MaxPool3d: zero_ops, 30 | nn.AdaptiveMaxPool1d: zero_ops, 31 | nn.AdaptiveMaxPool2d: zero_ops, 32 | nn.AdaptiveMaxPool3d: zero_ops, 33 | nn.AvgPool1d: count_avgpool, 34 | nn.AvgPool2d: count_avgpool, 35 | nn.AvgPool3d: count_avgpool, 36 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 37 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 38 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 39 | nn.Linear: count_linear, 40 | nn.Dropout: zero_ops, 41 | nn.Upsample: count_upsample, 42 | nn.UpsamplingBilinear2d: count_upsample, 43 | nn.UpsamplingNearest2d: count_upsample, 44 | nn.RNNCell: count_rnn_cell, 45 | nn.GRUCell: count_gru_cell, 46 | nn.LSTMCell: count_lstm_cell, 47 | nn.RNN: count_rnn, 48 | nn.GRU: count_gru, 49 | nn.LSTM: count_lstm, 50 | nn.Sequential: zero_ops, 51 | nn.PixelShuffle: zero_ops, 52 | } 53 | 54 | def profile( 55 | model: nn.Module, 56 | inputs, 57 | method_name = 'forward', 58 | custom_ops=None, 59 | verbose=True, 60 | ret_layer_info=False, 61 | report_missing=False, 62 | ): 63 | ##change from thop 64 | handler_collection = {} 65 | types_collection = set() 66 | if custom_ops is None: 67 | custom_ops = {} 68 | if report_missing: 69 | # overwrite `verbose` option when enable report_missing 70 | verbose = True 71 | 72 | def add_hooks(m: nn.Module): 73 | m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64)) 74 | m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64)) 75 | 76 | # for p in m.parameters(): 77 | # m.total_params += torch.DoubleTensor([p.numel()]) 78 | 79 | m_type = type(m) 80 | 81 | fn = None 82 | if m_type in custom_ops: 83 | # if defined both op maps, use custom_ops to overwrite. 84 | fn = custom_ops[m_type] 85 | if m_type not in types_collection and verbose: 86 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) 87 | elif m_type in register_hooks: 88 | fn = register_hooks[m_type] 89 | if m_type not in types_collection and verbose: 90 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) 91 | else: 92 | if m_type not in types_collection and report_missing: 93 | prRed( 94 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." 95 | % m_type 96 | ) 97 | 98 | if fn is not None: 99 | handler_collection[m] = ( 100 | m.register_forward_hook(fn), 101 | m.register_forward_hook(count_parameters), 102 | ) 103 | types_collection.add(m_type) 104 | 105 | prev_training_status = model.training 106 | 107 | model.eval() 108 | model.apply(add_hooks) 109 | 110 | with torch.no_grad(): 111 | method = getattr(model, method_name) 112 | method(*inputs) 113 | 114 | def dfs_count(module: nn.Module, prefix="\t") -> (int, int): 115 | total_ops, total_params = module.total_ops.item(), 0 116 | ret_dict = {} 117 | for n, m in module.named_children(): 118 | # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0: 119 | # m_ops, m_params = dfs_count(m, prefix=prefix + "\t") 120 | # else: 121 | # m_ops, m_params = m.total_ops, m.total_params 122 | next_dict = {} 123 | if m in handler_collection and not isinstance( 124 | m, (nn.Sequential, nn.ModuleList) 125 | ): 126 | m_ops, m_params = m.total_ops.item(), m.total_params.item() 127 | else: 128 | m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") 129 | ret_dict[n] = (m_ops, m_params, next_dict) 130 | total_ops += m_ops 131 | total_params += m_params 132 | # print(prefix, module._get_name(), (total_ops, total_params)) 133 | return total_ops, total_params, ret_dict 134 | 135 | total_ops, total_params, ret_dict = dfs_count(model) 136 | 137 | # reset model to original status 138 | model.train(prev_training_status) 139 | for m, (op_handler, params_handler) in handler_collection.items(): 140 | op_handler.remove() 141 | params_handler.remove() 142 | m._buffers.pop("total_ops") 143 | m._buffers.pop("total_params") 144 | 145 | if ret_layer_info: 146 | return total_ops, total_params, ret_dict 147 | return total_ops, total_params 148 | 149 | class FLOPsMonitor: 150 | def __init__(self): 151 | self.total_flops = 0 152 | self.forward_flops = 0 153 | self.backward_flops = 0 154 | 155 | def update(self, flops, operation='total'): 156 | if operation == 'forward': 157 | self.forward_flops += flops 158 | self.total_flops += flops 159 | 160 | def report(self, logger): 161 | print(f"Total FLOPs: {self.total_flops}") 162 | print(f"Forward FLOPs: {self.forward_flops}") 163 | logger.info(f"Total FLOPs: {self.total_flops}") 164 | logger.info(f"Forward FLOPs: {self.forward_flops}") 165 | 166 | def get_flops(model, input_sample): 167 | with torch.no_grad(): 168 | flops, params = profile(model, inputs= input_sample) 169 | return flops, params -------------------------------------------------------------------------------- /balancemm/evaluation/modalitys.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | from collections import defaultdict 3 | from torch.utils.data.dataset import Dataset 4 | import logging 5 | from copy import deepcopy 6 | from tqdm import tqdm 7 | from math import factorial 8 | import torch 9 | def generate_all_combinations(input_list: list[str], include_empty: bool = True): 10 | all_combinations = [] 11 | 12 | # generate all combinations of length from 0 to len(input_list) 13 | start_range = 0 if include_empty else 1 14 | for r in range(start_range, len(input_list) + 1): 15 | all_combinations.extend(combinations(input_list, r)) 16 | # converts a combination to a list 17 | return [list(combo) for combo in all_combinations] 18 | 19 | def Calculate_Shapley(trainer, model, CalcuLoader: Dataset, logger: logging.Logger, conduct: bool = True) -> dict[str: float]: 20 | 21 | if conduct: 22 | modalitys = model.modalitys 23 | n = len(modalitys) 24 | Shapley = defaultdict(float) ##default is 0 25 | res_cahce = defaultdict(lambda:float('inf')) ## store the middle results 26 | for modality in modalitys: 27 | temp_modalitys = list(modalitys) 28 | temp_modalitys.remove(modality) 29 | combinations = generate_all_combinations(temp_modalitys, include_empty = True) 30 | for combo in combinations: 31 | S_size = len(combo) 32 | indentifer = tuple(sorted(combo)) 33 | if res_cahce[indentifer] == float('inf'): 34 | with torch.no_grad(): 35 | _, v_combo = trainer.val_loop(model = model, val_loader= CalcuLoader, limit_modalitys= combo.copy()) 36 | res_cahce[indentifer] = v_combo 37 | else: 38 | v_combo = res_cahce[indentifer] 39 | if modality not in combo: 40 | add_combo = combo.copy() 41 | add_combo.append(modality) 42 | add_combo = sorted(add_combo) 43 | indentifer = tuple(add_combo) 44 | if res_cahce[indentifer] == float('inf'): 45 | with torch.no_grad(): 46 | _, v_add = trainer.val_loop(model = model, val_loader= CalcuLoader, limit_modalitys= add_combo) 47 | res_cahce[indentifer] = v_add 48 | else: 49 | v_add = res_cahce[indentifer] 50 | else: 51 | v_add = v_combo 52 | Shapley[modality] += (factorial(S_size) * factorial(n - S_size - 1)) / factorial(n)*(v_add['acc']['output'] - v_combo['acc']['output']) 53 | logger.info(Shapley) 54 | return Shapley 55 | else: 56 | return 57 | 58 | def Calculate_Shapley_Sample(trainer, model, CalcuLoader: Dataset, logger: logging.Logger,conduct: bool = True,is_print: bool = False) -> dict[str: float]: 59 | if not conduct: 60 | return None 61 | modalitys = model.modalitys 62 | n = len(modalitys) 63 | Shapley = {modality: {} for modality in modalitys} 64 | res_cache = defaultdict(lambda: None) 65 | 66 | for batch_idx, batch in tqdm(enumerate(CalcuLoader)): 67 | label = batch['label'].to(model.device) 68 | batch_size = len(label) 69 | 70 | all_combinations = generate_all_combinations(modalitys, include_empty=True) 71 | for combo in all_combinations: 72 | identifier = tuple(sorted(combo)) 73 | if res_cache[identifier] is None: 74 | if not combo: 75 | res_cache[identifier] = torch.zeros(batch_size, dtype=torch.bool) 76 | else: 77 | with torch.no_grad(): 78 | model.validation_step(batch, batch_idx,limit_modality=combo) 79 | res_cache[identifier] = (model.pridiction['output'] == label) 80 | 81 | for i in range(batch_size): 82 | sample_idx = int(batch['idx'][i]) 83 | 84 | for modality in modalitys: 85 | shapley_value = 0.0 86 | temp_modalitys = [m for m in modalitys if m != modality] 87 | combinations = generate_all_combinations(temp_modalitys, include_empty=True) 88 | 89 | for combo in combinations: 90 | S_size = len(combo) 91 | v_combo = res_cache[tuple(sorted(combo))][i] 92 | 93 | add_combo = sorted(combo + [modality]) 94 | v_add = res_cache[tuple(add_combo)][i] 95 | 96 | weight = (factorial(S_size) * factorial(n - S_size - 1)) / factorial(n) 97 | marginal_contribution = float(v_add) - float(v_combo) 98 | shapley_value += weight * marginal_contribution 99 | 100 | Shapley[modality][sample_idx] = shapley_value 101 | 102 | 103 | return Shapley -------------------------------------------------------------------------------- /balancemm/evaluation/precisions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import confusion_matrix 3 | from collections import defaultdict 4 | 5 | class BatchMetricsCalculator: 6 | def __init__(self, num_classes: int, modalitys: list): 7 | self.num_classes = num_classes 8 | self.confusion_matrix = {} 9 | self.modalitys = modalitys 10 | for modality in modalitys: 11 | self.confusion_matrix[modality] = np.zeros((num_classes, num_classes)) 12 | self.total_samples = 0 13 | 14 | def update(self, y_true, y_pred): 15 | for modality in self.confusion_matrix.keys(): 16 | batch_cm = confusion_matrix(y_true, y_pred[modality].cpu(), labels=range(self.num_classes)) 17 | self.confusion_matrix[modality] += batch_cm 18 | self.total_samples += len(y_true) 19 | 20 | def compute_metrics(self): 21 | # calculate accuracy 22 | Metrics_res = defaultdict(dict) 23 | for modality in self.confusion_matrix.keys(): 24 | accuracy = np.sum(np.diag(self.confusion_matrix[modality])) / self.total_samples 25 | 26 | # calculate f1 score of each class 27 | fps = self.confusion_matrix[modality].sum(axis=0) - np.diag(self.confusion_matrix[modality]) 28 | fns = self.confusion_matrix[modality].sum(axis=1) - np.diag(self.confusion_matrix[modality]) 29 | tps = np.diag(self.confusion_matrix[modality]) 30 | precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps, dtype=float), where=(tps + fps) != 0) 31 | recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps, dtype=float), where=(tps + fns) != 0) 32 | f1_scores = np.divide(2 * (precisions * recalls), precisions + recalls, 33 | out=np.zeros_like(precisions, dtype=float), 34 | where=(precisions + recalls) != 0) 35 | Metrics_res['f1'][modality] = np.mean(f1_scores) 36 | Metrics_res['acc'][modality] = accuracy 37 | return Metrics_res 38 | def ClearAll(self): 39 | for modality in self.modalitys: 40 | self.confusion_matrix[modality] = np.zeros((self.num_classes, self.num_classes)) 41 | self.total_samples = 0 42 | -------------------------------------------------------------------------------- /balancemm/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from os import path as osp 4 | from types import SimpleNamespace 5 | from ..utils.parser_utils import find_module 6 | __all__ = ['find_model', 'create_model'] 7 | 8 | 9 | # import model modules from those with '_model' in file names 10 | _model_folder = osp.dirname(osp.abspath(__file__)) 11 | _model_filenames = [ 12 | osp.splitext(v)[0] for v in os.listdir(_model_folder) 13 | if v.endswith('_model.py') 14 | ] 15 | _model_modules = [ 16 | importlib.import_module(f'.{file_name}', package="balancemm.models") 17 | for file_name in _model_filenames 18 | ] 19 | 20 | # find model from model_opt 21 | def find_model(model_name: str) -> object: 22 | model_cls = find_module(_model_modules, model_name, 'Model') 23 | return model_cls 24 | 25 | # create model from model_opt 26 | def create_model(model_opt: dict): 27 | if 'type' not in model_opt: 28 | raise ValueError('Model type is required.') 29 | model_cls = find_model(model_opt['type']) 30 | model = model_cls(model_opt) 31 | 32 | print( 33 | f'Model {model.__class__.__name__} - {model_opt["type"]} is created.') 34 | return model -------------------------------------------------------------------------------- /balancemm/models/encoders.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | from functools import partialmethod 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.metrics import precision_recall_fscore_support 8 | import os 9 | 10 | 11 | 12 | from transformers import BertModel, BertConfig 13 | import torch.nn.functional as F 14 | from torch import nn 15 | 16 | class text_encoder(nn.Module): 17 | def __init__(self, dim_text_repr=768): 18 | super().__init__() 19 | config = BertConfig() 20 | self.textEncoder= BertModel(config).from_pretrained('') 21 | 22 | def forward(self, x): 23 | text = x 24 | hidden_states = self.textEncoder(**text) # B, T, dim_text_repr 25 | e_i = F.dropout(hidden_states[1]) 26 | return e_i 27 | 28 | class image_encoder(nn.Module): 29 | def __init__(self,model_arch,num_classes=10,weights="IMAGENET1K_V1",device="cuda"): 30 | super().__init__() 31 | self.model_arch=model_arch 32 | self.device=device 33 | self.num_classes=num_classes 34 | print(model_arch) 35 | self.model=torch.hub.load("pytorch/vision", self.model_arch, weights=weights).to(self.device) 36 | # print(self.model.parameters) 37 | self.model.heads = nn.Sequential() 38 | 39 | def forward(self,x): 40 | y=self.model(x) 41 | return y 42 | -------------------------------------------------------------------------------- /balancemm/models/fusion_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SumFusion(nn.Module): 6 | def __init__(self, input_dim=512, output_dim=100): 7 | super(SumFusion, self).__init__() 8 | self.fc_x = nn.Linear(input_dim, output_dim) 9 | self.fc_y = nn.Linear(input_dim, output_dim) 10 | 11 | def forward(self, x, y): 12 | output = self.fc_x(x) + self.fc_y(y) 13 | return x, y, output 14 | 15 | class SharedHead(nn.Module): 16 | def __init__(self, input_dim=512, output_dim=100): 17 | super(SharedHead, self).__init__() 18 | self.fc_out = nn.Linear(input_dim, output_dim) 19 | def forward(self, x): 20 | output = self.fc_out(x) 21 | return output 22 | 23 | class ConcatFusion(nn.Module): 24 | def __init__(self, input_dim=1024, output_dim=100): 25 | super(ConcatFusion, self).__init__() 26 | self.fc_out = nn.Linear(input_dim, output_dim) 27 | def forward(self, x, y): 28 | output = torch.cat((x, y), dim=1) 29 | output = self.fc_out(output) 30 | return x, y, output 31 | 32 | 33 | class ConcatFusion_N(nn.Module): 34 | def __init__(self, input_dim=3072, output_dim=100): 35 | super(ConcatFusion_N, self).__init__() 36 | self.fc_out = nn.Linear(input_dim, output_dim) 37 | def forward(self, encoder_res): 38 | output = torch.cat(list(encoder_res.values()),dim = 1) 39 | output = self.fc_out(output) 40 | return output 41 | 42 | class ConcatFusion_Mask(nn.Module): 43 | def __init__(self, input_dim=3072, output_dim=100): 44 | super(ConcatFusion_N, self).__init__() 45 | self.fc_out = nn.Linear(input_dim, output_dim) 46 | def forward(self, encoder_res): 47 | output = torch.cat(list(encoder_res.values()),dim = 1) 48 | output = self.fc_out(output) 49 | return output 50 | 51 | class ConcatFusion_3(nn.Module): 52 | def __init__(self, input_dim=3072, output_dim=100): 53 | super(ConcatFusion_3, self).__init__() 54 | self.fc_out = nn.Linear(input_dim, output_dim) 55 | self.input_dim = input_dim 56 | 57 | def forward(self, x, y, z): 58 | output = torch.cat((x, y), dim=1) 59 | output = torch.cat((output, z),dim = 1) 60 | output = self.fc_out(output) 61 | # x = (torch.mm(x, torch.transpose(self.fc_out.weight[:, self.input_dim // 3: 2 * self.input_dim // 3], 0, 1)) 62 | # + self.fc_out.bias / 2) 63 | 64 | # y = (torch.mm(y, torch.transpose(self.fc_out.weight[:, 2* self.input_dim // 3: self.input_dim ], 0, 1)) 65 | # + self.fc_out.bias / 2) 66 | 67 | return x, y, z, output 68 | 69 | class FiLM(nn.Module): 70 | """ 71 | FiLM: Visual Reasoning with a General Conditioning Layer, 72 | https://arxiv.org/pdf/1709.07871.pdf. 73 | """ 74 | 75 | def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True): 76 | super(FiLM, self).__init__() 77 | 78 | self.dim = input_dim 79 | self.fc = nn.Linear(input_dim, 2 * dim) 80 | self.fc_out = nn.Linear(dim, output_dim) 81 | 82 | self.x_film = x_film 83 | 84 | def forward(self, x, y): 85 | 86 | if self.x_film: 87 | film = x 88 | to_be_film = y 89 | else: 90 | film = y 91 | to_be_film = x 92 | 93 | gamma, beta = torch.split(self.fc(film), self.dim, 1) 94 | 95 | output = gamma * to_be_film + beta 96 | output = self.fc_out(output) 97 | 98 | return x, y, output 99 | 100 | 101 | class GatedFusion(nn.Module): 102 | """ 103 | Efficient Large-Scale Multi-Modal Classification, 104 | https://arxiv.org/pdf/1802.02892.pdf. 105 | """ 106 | 107 | def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True): 108 | super(GatedFusion, self).__init__() 109 | 110 | self.fc_x = nn.Linear(input_dim, dim) 111 | self.fc_y = nn.Linear(input_dim, dim) 112 | self.fc_out = nn.Linear(dim, output_dim) 113 | 114 | self.x_gate = x_gate # whether to choose the x to obtain the gate 115 | 116 | self.sigmoid = nn.Sigmoid() 117 | 118 | def forward(self, x, y): 119 | out_x = self.fc_x(x) 120 | out_y = self.fc_y(y) 121 | 122 | if self.x_gate: 123 | gate = self.sigmoid(out_x) 124 | output = self.fc_out(torch.mul(gate, out_y)) 125 | else: 126 | gate = self.sigmoid(out_y) 127 | output = self.fc_out(torch.mul(out_x, gate)) 128 | 129 | return out_x, out_y, output 130 | 131 | -------------------------------------------------------------------------------- /balancemm/models/resnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 4 | """3x3 convolution with padding""" 5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 6 | padding=dilation, groups=groups, bias=False, dilation=dilation) 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | """1x1 convolution""" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 18 | base_width=64, dilation=1, norm_layer=None): 19 | super(BasicBlock, self).__init__() 20 | if norm_layer is None: 21 | norm_layer = nn.BatchNorm2d 22 | if groups != 1 or base_width != 64: 23 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 24 | if dilation > 1: 25 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = norm_layer(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = norm_layer(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | identity = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class ResNet(nn.Module): 55 | 56 | def __init__(self, block, layers, modality, zero_init_residual=False, 57 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 58 | norm_layer=None): 59 | super(ResNet, self).__init__() 60 | self.modality = modality 61 | if norm_layer is None: 62 | norm_layer = nn.BatchNorm2d 63 | self._norm_layer = norm_layer 64 | 65 | self.inplanes = 64 66 | self.dilation = 1 67 | if replace_stride_with_dilation is None: 68 | # each element in the tuple indicates if we should replace 69 | # the 2x2 stride with a dilated convolution instead 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | if modality == 'audio': 77 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 78 | bias=False) 79 | elif modality == 'visual': 80 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 81 | bias=False) 82 | elif modality == 'flow': 83 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3, 84 | bias=False) 85 | else: 86 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality)) 87 | self.bn1 = norm_layer(self.inplanes) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 90 | self.layer1 = self._make_layer(block, 64, layers[0]) 91 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 92 | dilate=replace_stride_with_dilation[0]) 93 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 96 | dilate=replace_stride_with_dilation[2]) 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 102 | nn.init.normal_(m.weight, mean=1, std=0.02) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | # Zero-initialize the last BN in each residual branch, 106 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 107 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 108 | if zero_init_residual: 109 | for m in self.modules(): 110 | if isinstance(m, Bottleneck): 111 | nn.init.constant_(m.bn3.weight, 0) 112 | elif isinstance(m, BasicBlock): 113 | nn.init.constant_(m.bn2.weight, 0) 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 116 | norm_layer = self._norm_layer 117 | downsample = None 118 | previous_dilation = self.dilation 119 | if dilate: 120 | self.dilation *= stride 121 | stride = 1 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | conv1x1(self.inplanes, planes * block.expansion, stride), 125 | norm_layer(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 130 | self.base_width, previous_dilation, norm_layer)) 131 | self.inplanes = planes * block.expansion 132 | for _ in range(1, blocks): 133 | layers.append(block(self.inplanes, planes, groups=self.groups, 134 | base_width=self.base_width, dilation=self.dilation, 135 | norm_layer=norm_layer)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | 141 | if self.modality == 'visual' or self.modality == 'flow': 142 | (B, T, C, H, W) = x.size() 143 | x = x.view(B * T, C, H, W) 144 | 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | out = x 155 | 156 | return out 157 | 158 | class ResNet18(ResNet): 159 | def __init__(self, **kwargs): 160 | super(ResNet18, self).__init__(BasicBlock, [2, 2, 2, 2], **kwargs) 161 | 162 | 163 | class Bottleneck(nn.Module): 164 | expansion = 4 165 | 166 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 167 | base_width=64, dilation=1, norm_layer=None): 168 | super(Bottleneck, self).__init__() 169 | if norm_layer is None: 170 | norm_layer = nn.BatchNorm2d 171 | width = int(planes * (base_width / 64.)) * groups 172 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 173 | self.conv1 = conv1x1(inplanes, width) 174 | self.bn1 = norm_layer(width) 175 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 176 | self.bn2 = norm_layer(width) 177 | self.conv3 = conv1x1(width, planes * self.expansion) 178 | self.bn3 = norm_layer(planes * self.expansion) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.downsample = downsample 181 | self.stride = stride 182 | 183 | def forward(self, x): 184 | identity = x 185 | 186 | out = self.conv1(x) 187 | out = self.bn1(out) 188 | out = self.relu(out) 189 | 190 | out = self.conv2(out) 191 | out = self.bn2(out) 192 | out = self.relu(out) 193 | 194 | out = self.conv3(out) 195 | out = self.bn3(out) 196 | 197 | if self.downsample is not None: 198 | identity = self.downsample(x) 199 | 200 | out += identity 201 | out = self.relu(out) 202 | 203 | return out -------------------------------------------------------------------------------- /balancemm/train.py: -------------------------------------------------------------------------------- 1 | from .utils.logger import setup_logger 2 | from .models import create_model 3 | from .trainer import create_trainer 4 | from .utils.train_utils import choose_logger 5 | import subprocess 6 | from .utils.data_utils import create_train_val_dataloader 7 | from types import SimpleNamespace 8 | from .utils.optimizer import create_optimizer 9 | from .utils.scheduler import create_scheduler 10 | from lightning.fabric import Fabric 11 | import lightning as L 12 | from os import path as osp 13 | import os 14 | import torch 15 | import logging 16 | from datetime import datetime 17 | from .evaluation.modalitys import Calculate_Shapley 18 | from .trainer.LinearProbe_trainer import NewLinearHead 19 | from .models.avclassify_model import MultiModalParallel 20 | 21 | import copy 22 | def train_and_test(args: dict): 23 | dict_args = args 24 | args = SimpleNamespace(**args) 25 | 26 | if args.trainer['name'] == 'unimodal': 27 | args.out_dir = args.out_dir.replace('unimodalTrainer','unimodalTrainer_' + list(args.model['encoders'].keys())[0]) 28 | 29 | log_dir = osp.join(args.out_dir, "logs") 30 | print("logg:{}".format(log_dir)) 31 | loggers_online = [choose_logger(logger_name, log_dir = log_dir, project = args.name, comment = args.log['comment']) for logger_name in args.log['logger_name']] 32 | logger = logging.getLogger(__name__) 33 | args.checkpoint_dir = osp.join(args.out_dir, 'checkpoints') 34 | os.makedirs(args.out_dir, exist_ok=True) 35 | os.makedirs(args.checkpoint_dir, exist_ok=True) 36 | file_handler = logging.FileHandler(args.out_dir + '/training.log') 37 | file_handler.setLevel(logging.INFO) 38 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 39 | file_handler.setFormatter(formatter) 40 | logger.addHandler(file_handler) 41 | loggers = [logger] 42 | logger.setLevel(logging.DEBUG) 43 | logger.info(dict_args) 44 | for tb_logger in loggers_online: 45 | if isinstance(tb_logger, L.fabric.loggers.CSVLogger): 46 | os.makedirs(os.path.join(log_dir, 'csv'), exist_ok=True) 47 | elif isinstance(tb_logger, L.fabric.loggers.TensorBoardLogger): 48 | os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True) 49 | elif isinstance(tb_logger, L.pytorch.loggers.wandb.WandbLogger): 50 | os.makedirs(os.path.join(log_dir, 'wandb'), exist_ok=True) 51 | if tb_logger: 52 | tb_logger.log_hyperparams(dict_args) 53 | fabric = Fabric(**(args.fabric), 54 | loggers = loggers, 55 | ) 56 | if isinstance(fabric.accelerator, L.fabric.accelerators.CUDAAccelerator): 57 | fabric.print('set float32 matmul precision to high') 58 | torch.set_float32_matmul_precision('high') 59 | device = args.Main_config['device'] 60 | # if device == '': 61 | # device = torch.device('cpu') 62 | # else: 63 | # device = torch.device('cuda:' + args.Main_config['device']) 64 | if device == '': 65 | model = create_model(args.model) 66 | else: 67 | model = create_model(args.model) 68 | print(list(range(torch.cuda.device_count()))) 69 | model = MultiModalParallel(model, device_ids = list(range(torch.cuda.device_count()))) 70 | model = model.cuda() 71 | # args.model['device'] = device 72 | if args.trainer['name'] != 'Sample': 73 | train_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args) 74 | else: 75 | train_dataloader, train_val_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args) 76 | args.trainer['checkpoint_dir'] = args.checkpoint_dir ## 77 | optimizer = create_optimizer(model, args.train['optimizer'], args.train['parameter']) 78 | scheduler = create_scheduler(optimizer, args.train['scheduler']) 79 | trainer = create_trainer(fabric, args.Main_config, args.trainer, args, logger,tb_logger) 80 | 81 | start_time = datetime.now() 82 | if args.trainer['name'] == 'GBlendingTrainer': 83 | temp_model = create_model(args.model) 84 | temp_model = MultiModalParallel(temp_model,device_ids = list(range(torch.cuda.device_count()))) 85 | temp_model.cuda() 86 | temp_optimizer = create_optimizer(temp_model, args.train['optimizer'], args.train['parameter']) 87 | temp_optimizer_origin = copy.deepcopy(temp_optimizer.state_dict()) 88 | trainer.fit(model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, temp_optimizer,temp_optimizer_origin,logger,tb_logger) 89 | 90 | elif args.trainer['name'] == 'SampleTrainer': 91 | trainer.fit(model, train_dataloader, train_val_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger) 92 | else : 93 | trainer.fit(model, train_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger) 94 | end_time = datetime.now() 95 | total_time = end_time - start_time 96 | total_time = total_time.total_seconds() / 3600 97 | 98 | logger.info("Training time :{:.2f}".format(total_time)) 99 | #val best 100 | print(f'The best val acc is : {trainer.best_acc}') 101 | logger.info(f'The best val acc is : {trainer.best_acc}') 102 | logger.info('Use the best model to Test') 103 | #load best 104 | model.eval() 105 | best_state = torch.load(args.checkpoint_dir+ '/epoch_normal.ckpt') 106 | model.load_state_dict(best_state['model']) 107 | #test sharply 108 | logger.info('Calculate the shapley value of best model') 109 | Calculate_Shapley(trainer = trainer, model = model, CalcuLoader = test_dataloader, logger= logger) 110 | #best test 111 | _, Metrics_res = trainer.val_loop(model, test_dataloader) 112 | test_acc = Metrics_res['acc']['output'] 113 | info = '' 114 | output_info = '' 115 | for metircs in sorted(Metrics_res.keys()): 116 | if metircs == 'acc': 117 | valid_acc = Metrics_res[metircs] 118 | for modality in sorted(valid_acc.keys()): 119 | tag = "valid_acc" 120 | if modality == 'output': 121 | output_info += f"test_acc: {valid_acc[modality]}" 122 | 123 | else: 124 | info += f", acc_{modality}: {valid_acc[modality]}" 125 | 126 | 127 | if metircs == 'f1': 128 | valid_f1 = Metrics_res[metircs] 129 | for modality in sorted(valid_f1.keys()): 130 | tag = "valid_f1" 131 | if modality == 'output': 132 | output_info += f", test_f1: {valid_f1[modality]}" 133 | 134 | else: 135 | info += f", f1_{modality}: {valid_f1[modality]}" 136 | 137 | info = output_info+ ', ' + info 138 | 139 | logger.info(info) 140 | for handler in logger.handlers: 141 | handler.flush() 142 | logger.info(f'The best test acc is : {test_acc}') 143 | print(f'The best test acc is : {test_acc}') 144 | # ======= 145 | # _, Metrics_res = trainer.val_loop(model, test_dataloader) 146 | # logger.info('Calculate the shapley value of best model') 147 | # Calculate_Shapley(trainer = trainer, model = model, CalcuLoader = test_dataloader, logger= logger) 148 | # logger.info(f'The best val acc is : {trainer.best_acc}') 149 | # print(f'The best val acc is : {trainer.best_acc}') 150 | # # for metircs in sorted(Metrics_res.keys()): 151 | # # if metircs == 'acc': 152 | # # valid_acc = Metrics_res[metircs] 153 | # # for modality in sorted(valid_acc.keys()): 154 | # # if modality == 'output': 155 | # # print(f'The test acc is : {Metrics_res[metrics][modality]}') 156 | # # print(f'The imbalance is :{imbalance}') 157 | # >>>>>>> main 158 | 159 | 160 | 161 | def linear_probe_eval(args: dict): 162 | dict_args = args 163 | args = SimpleNamespace(**args) 164 | args.out_dir = osp.join(args.out_dir,args.trainer['trainer_probed']) 165 | log_dir = osp.join(args.out_dir, "logs") 166 | print("logg:{}".format(log_dir)) 167 | loggers_online = [choose_logger(logger_name, log_dir = log_dir, project = args.name, comment = args.log['comment']) for logger_name in args.log['logger_name']] 168 | logger = logging.getLogger(__name__) 169 | args.checkpoint_dir = osp.join(args.out_dir, 'checkpoints') 170 | os.makedirs(args.out_dir, exist_ok=True) 171 | os.makedirs(args.checkpoint_dir, exist_ok=True) 172 | file_handler = logging.FileHandler(args.out_dir + '/training.log') 173 | file_handler.setLevel(logging.INFO) 174 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 175 | file_handler.setFormatter(formatter) 176 | logger.addHandler(file_handler) 177 | loggers = [logger] 178 | logger.setLevel(logging.DEBUG) 179 | logger.info(dict_args) 180 | for tb_logger in loggers_online: 181 | if isinstance(tb_logger, L.fabric.loggers.CSVLogger): 182 | os.makedirs(os.path.join(log_dir, 'csv'), exist_ok=True) 183 | elif isinstance(tb_logger, L.fabric.loggers.TensorBoardLogger): 184 | os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True) 185 | elif isinstance(tb_logger, L.pytorch.loggers.wandb.WandbLogger): 186 | os.makedirs(os.path.join(log_dir, 'wandb'), exist_ok=True) 187 | if tb_logger: 188 | tb_logger.log_hyperparams(dict_args) 189 | fabric = Fabric(**(args.fabric), 190 | loggers = loggers, 191 | ) 192 | if isinstance(fabric.accelerator, L.fabric.accelerators.CUDAAccelerator): 193 | fabric.print('set float32 matmul precision to high') 194 | torch.set_float32_matmul_precision('high') 195 | device = args.Main_config['device'] 196 | if device == '': 197 | device = torch.device('cpu') 198 | else: 199 | device = torch.device('cuda:' + args.Main_config['device']) 200 | args.model['device'] = device 201 | train_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args) 202 | args.trainer['checkpoint_dir'] = args.checkpoint_dir ## 203 | model = create_model(args.model) 204 | model.to(device) 205 | model.device = device 206 | input_dim = sum(model.modality_size.values()) 207 | # Create a linear-classifier-head 208 | new_head = NewLinearHead(input_dim, model.n_classes).to(model.device) 209 | optimizer = create_optimizer(new_head, args.train['optimizer'], args.train['parameter']) 210 | scheduler = create_scheduler(optimizer, args.train['scheduler']) 211 | trainer = create_trainer(fabric, args.Main_config, args.trainer, args, logger,tb_logger) 212 | 213 | start_time = datetime.now() 214 | if args.trainer['name'] == 'GBlendingTrainer': 215 | temp_model = create_model(args.model) 216 | temp_model.to(device) 217 | temp_model.device = device 218 | temp_optimizer = create_optimizer(temp_model, args.train['optimizer'], args.train['parameter']) 219 | trainer.fit(model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, temp_optimizer,logger,tb_logger) 220 | else : 221 | trainer.fit(model,new_head, train_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger) 222 | end_time = datetime.now() 223 | total_time = end_time - start_time 224 | total_time = total_time.total_seconds() / 3600 225 | logger.info("Training time :{:.2f}".format(total_time)) 226 | logger.info('Use the best model to Test') 227 | model.eval() 228 | new_head.eval() 229 | best_state = torch.load(args.checkpoint_dir+ '/epoch_normal.ckpt') 230 | model.load_state_dict(best_state['model']) 231 | new_head.load_state_dict(best_state['new_head']) 232 | trainer.val_loop(model, new_head,test_dataloader) 233 | logger.info(f'The best val acc is : {trainer.best_acc}') 234 | print(f'The best val acc is : {trainer.best_acc}') 235 | -------------------------------------------------------------------------------- /balancemm/trainer/AMCo_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | import numpy as np 18 | from ..models.avclassify_model import BaseClassifierModel 19 | class AMCoTrainer(BaseTrainer): 20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 21 | super(AMCoTrainer,self).__init__(fabric,**para_dict) 22 | self.alpha = method_dict['alpha'] 23 | self.method = method_dict['method'] 24 | self.modulation_starts = method_dict['modulation_starts'] 25 | self.modulation_ends = method_dict['modulation_ends'] 26 | 27 | self.sigma = method_dict['sigma'] 28 | self.U = method_dict['U'] 29 | self.eps = method_dict['eps'] 30 | self.modality = method_dict['modality'] 31 | 32 | def train_loop( 33 | self, 34 | model: BaseClassifierModel, 35 | optimizer: torch.optim.Optimizer, 36 | train_loader: torch.utils.data.DataLoader, 37 | limit_batches: Union[int, float] = float("inf"), 38 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 39 | ): 40 | """The training loop running a single training epoch. 41 | 42 | Args: 43 | model: the LightningModule to train 44 | optimizer: the optimizer, optimizing the LightningModule. 45 | train_loader: The dataloader yielding the training batches. 46 | limit_batches: Limits the batches during this training epoch. 47 | If greater than the number of batches in the ``train_loader``, this has no effect. 48 | scheduler_cfg: The learning rate scheduler configuration. 49 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 50 | for supported values. 51 | 52 | """ 53 | self.fabric.call("on_train_epoch_start") 54 | all_modalitys = list(model.modalitys) 55 | all_modalitys.append('output') 56 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 57 | iterable = self.progbar_wrapper( 58 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 59 | ) 60 | 61 | dependent_modality = {} 62 | for modality in model.modalitys: 63 | dependent_modality[modality] = False 64 | l_t = 0 65 | for batch_idx, batch in enumerate(iterable): 66 | # end epoch if stopping training completely or max batches for this epoch reached 67 | if self.should_stop or batch_idx >= limit_batches: 68 | break 69 | 70 | self.fabric.call("on_train_batch_start", batch, batch_idx) 71 | 72 | # prepare the mask 73 | pt = np.sin(np.pi/2*(min(self.eps,l_t)/self.eps)) 74 | N = int(pt * model.n_classes) 75 | mask_t = np.ones(model.n_classes-N) 76 | mask_t = np.pad(mask_t,(0,N)) 77 | np.random.shuffle(mask_t) 78 | mask_t = torch.from_numpy(mask_t) 79 | mask_t = mask_t.to(model.device) 80 | mask_t = mask_t.float() 81 | mask_t = mask_t.unsqueeze(0) 82 | mask_t = mask_t.expand(2, -1) 83 | l_t += self.current_epoch/10 84 | 85 | # check if optimizer should step in gradient accumulation 86 | should_optim_step = self.global_step % self.grad_accum_steps == 0 87 | if should_optim_step: 88 | # currently only supports a single optimizer 89 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 90 | 91 | # optimizer step runs train step internally through closure 92 | loss, dependent_modality = self.training_step( model=model, batch=batch, 93 | batch_idx=batch_idx, mask= mask_t, 94 | dependent_modality= dependent_modality, 95 | pt = pt) 96 | optimizer.step() 97 | self.fabric.call("on_before_zero_grad", optimizer) 98 | 99 | optimizer.zero_grad() 100 | 101 | else: 102 | # gradient accumulation -> no optimizer step 103 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 104 | 105 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 106 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 107 | 108 | # this guard ensures, we only step the scheduler once per global step 109 | # if should_optim_step: 110 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 111 | 112 | # add output values to progress bar 113 | 114 | self._format_iterable(iterable, self._current_train_return, "train") 115 | 116 | # only increase global step if optimizer stepped 117 | self.global_step += int(should_optim_step) 118 | 119 | self._current_metrics = self.precision_calculator.compute_metrics() 120 | self.fabric.call("on_train_epoch_end") 121 | 122 | def training_step(self, model: BaseClassifierModel, batch, batch_idx, dependent_modality, mask ,pt): 123 | 124 | # TODO: make it simpler and easier to extend 125 | criterion = nn.CrossEntropyLoss() 126 | softmax = nn.Softmax(dim=1) 127 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 128 | model(batch,dependent_modality = dependent_modality, mask = mask,\ 129 | pt = pt) 130 | else: 131 | model(batch) 132 | # model.Unimodality_Calculate(mask, dependent_modality) 133 | 134 | label = batch['label'] 135 | label = label.to(model.device) 136 | # print(a.shape, v.shape, model.head.weight.shape) 137 | 138 | ## our modality-wise normalization on weight and feature 139 | out = model.unimodal_result['output'] 140 | loss = criterion(out, label) 141 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 142 | loss.backward(retain_graph = True) 143 | for modality in model.modalitys: 144 | loss_uni = criterion(model.unimodal_result[modality],label) 145 | loss_uni.backward() 146 | out_combine = torch.cat([value for key,value in model.unimodal_result.items() if key != 'output'],1) 147 | sft_out = softmax(out_combine) 148 | now_dim = 0 149 | for modality in model.modalitys: 150 | if now_dim < sft_out.shape[1] - model.n_classes: 151 | sft_uni = torch.sum(sft_out[:, now_dim: now_dim + model.n_classes])/(len(label)) 152 | else: 153 | sft_uni = torch.sum(sft_out[:, now_dim: ])/(len(label)) 154 | dependent_modality[modality] = bool(sft_uni > self.sigma) 155 | now_dim += model.n_classes 156 | else: 157 | loss.backward() 158 | 159 | return loss, dependent_modality -------------------------------------------------------------------------------- /balancemm/trainer/CML_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | from lightning_utilities import apply_to_collection 15 | import lightning as L 16 | import torch 17 | import random 18 | 19 | def conf_loss(conf, pred, conf_x, pred_x, label): 20 | sign = (~((pred == label) & (pred_x != label))).long() # trick 1 21 | #print(sign) 22 | return (max(0, torch.sub(conf_x, conf).sum())), sign.sum() 23 | 24 | class CMLTrainer(BaseTrainer): 25 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 26 | super(CMLTrainer,self).__init__(fabric,**para_dict) 27 | 28 | self.modulation_starts = method_dict['modulation_starts'] 29 | self.modulation_ends = method_dict['modulation_ends'] 30 | 31 | self.lam = method_dict['lam'] 32 | 33 | def train_loop( 34 | self, 35 | model: L.LightningModule, 36 | optimizer: torch.optim.Optimizer, 37 | train_loader: torch.utils.data.DataLoader, 38 | limit_batches: Union[int, float] = float("inf"), 39 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 40 | ): 41 | """The training loop running a single training epoch. 42 | 43 | Args: 44 | model: the LightningModule to train 45 | optimizer: the optimizer, optimizing the LightningModule. 46 | train_loader: The dataloader yielding the training batches. 47 | limit_batches: Limits the batches during this training epoch. 48 | If greater than the number of batches in the ``train_loader``, this has no effect. 49 | scheduler_cfg: The learning rate scheduler configuration. 50 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 51 | for supported values. 52 | 53 | """ 54 | self.fabric.call("on_train_epoch_start") 55 | all_modalitys = list(model.modalitys) 56 | all_modalitys.append('output') 57 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 58 | iterable = self.progbar_wrapper( 59 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 60 | ) 61 | 62 | random_dict = list(model.modalitys) 63 | # if self.modality == 3: 64 | # random_dict = ["audio", "visual", "text"] 65 | # else: 66 | # random_dict = ['audio', "visual" ] 67 | random.shuffle(random_dict) 68 | for batch_idx, batch in enumerate(iterable): 69 | # end epoch if stopping training completely or max batches for this epoch reached 70 | if self.should_stop or batch_idx >= limit_batches: 71 | break 72 | 73 | self.fabric.call("on_train_batch_start", batch, batch_idx) 74 | 75 | # check if optimizer should step in gradient accumulation 76 | should_optim_step = self.global_step % self.grad_accum_steps == 0 77 | if should_optim_step: 78 | # currently only supports a single optimizer 79 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 80 | 81 | # optimizer step runs train step internally through closure 82 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx, random_dict_ = random_dict)) 83 | self.fabric.call("on_before_zero_grad", optimizer) 84 | 85 | optimizer.zero_grad() 86 | 87 | else: 88 | # gradient accumulation -> no optimizer step 89 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 90 | 91 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 92 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 93 | 94 | # this guard ensures, we only step the scheduler once per global step 95 | # if should_optim_step: 96 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 97 | 98 | # add output values to progress bar 99 | self._format_iterable(iterable, self._current_train_return, "train") 100 | 101 | # only increase global step if optimizer stepped 102 | self.global_step += int(should_optim_step) 103 | 104 | self._current_metrics = self.precision_calculator.compute_metrics() 105 | self.fabric.call("on_train_epoch_end") 106 | 107 | def training_step(self, model, batch, batch_idx, random_dict_ ): 108 | 109 | # TODO: make it simpler and easier to extend 110 | criterion = nn.CrossEntropyLoss() 111 | softmax = nn.Softmax(dim=1) 112 | label = batch['label'] 113 | label = label.to(model.device) 114 | 115 | _loss_c = 0 116 | modality_num = len(model.modalitys) 117 | modality_list = model.modalitys 118 | key = list(modality_list) 119 | m = {} 120 | if modality_num == 3: 121 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: ###### 122 | pad_audio = False 123 | pad_visual = False 124 | pad_text = False 125 | loss_mm = 0 126 | model(batch) 127 | for modality in modality_list: 128 | m[modality] = model.encoder_result[modality] 129 | m['out'] = model.encoder_result['output'] 130 | # a, v, t, out = model(batch) 131 | unimodal_result = model.Unimodality_Calculate() 132 | out_s = unimodal_result['output'] 133 | # out_a, out_v, out_t = model.AVTCalculate(a, v, t, out) 134 | # out_s = out 135 | random_dict = random_dict_.copy() 136 | for i in range(modality_num - 1): 137 | removed_mm = random_dict.pop() 138 | 139 | out_p = out_s - unimodal_result[removed_mm] +model.fusion_module.fc_out.bias/3 140 | 141 | prediction_s = softmax(out_s) 142 | conf_s, pred_s = torch.max(prediction_s, dim=1) 143 | 144 | prediction_p = softmax(out_p) 145 | conf_p, pred_p = torch.max(prediction_p, dim=1) 146 | 147 | if i ==0 : loss = criterion(out_s, label) 148 | 149 | loss_p = criterion(out_p, label) 150 | loss_pc ,_ = conf_loss(conf_s, pred_s, conf_p, pred_p, label) 151 | loss = loss + loss_p 152 | _loss_c = _loss_c + loss_pc 153 | 154 | out_s = out_p 155 | 156 | loss = (loss) / 3 +self.lam * _loss_c 157 | else: 158 | model(batch) 159 | for modality in modality_list: 160 | m[modality] = model.encoder_result[modality] 161 | m['out'] = model.encoder_result['output'] 162 | # a, v, t, out = model(batch) 163 | unimodal_result = model.Unimodality_Calculate() 164 | out_s = unimodal_result['output'] 165 | 166 | 167 | loss = criterion(m['out'], label) 168 | 169 | else: 170 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: ###### 171 | pad_audio = False 172 | pad_visual = False 173 | pad_text = False 174 | loss_mm = 0 175 | model(batch) 176 | for modality in modality_list: 177 | m[modality] = model.encoder_result[modality] 178 | m['out'] = model.encoder_result['output'] 179 | unimodal_result = model.Unimodality_Calculate() 180 | out_s = unimodal_result['output'] 181 | random_dict = random_dict_.copy() 182 | for i in range(modality_num - 1): 183 | removed_mm = random_dict.pop() 184 | 185 | out_p = out_s - unimodal_result[removed_mm] +model.fusion_module.fc_out.bias/2 186 | 187 | prediction_s = softmax(out_s) 188 | conf_s, pred_s = torch.max(prediction_s, dim=1) 189 | 190 | prediction_p = softmax(out_p) 191 | conf_p, pred_p = torch.max(prediction_p, dim=1) 192 | 193 | if i ==0 : loss = criterion(out_s, label) 194 | 195 | loss_p = criterion(out_p, label) 196 | loss_pc ,_ = conf_loss(conf_s, pred_s, conf_p, pred_p, label) 197 | loss += loss_p 198 | _loss_c += loss_pc 199 | 200 | out_s = out_p 201 | loss = (loss) / 2 +self.lam * _loss_c 202 | else: 203 | model(batch) 204 | for modality in modality_list: 205 | m[modality] = model.encoder_result[modality] 206 | m['out'] = model.encoder_result['output'] 207 | # out_a, out_v = model.AVCalculate(a, v, out) 208 | 209 | loss = criterion(m['out'], label) 210 | loss.backward() 211 | 212 | # # avoid gradients in stored/accumulated values -> prevents potential OOM 213 | # self._current_train_return = apply_to_collection(outputs, dtype=torch.Tensor, function=lambda x: x.detach()) 214 | 215 | 216 | return loss -------------------------------------------------------------------------------- /balancemm/trainer/Greedy_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | import numpy as np 18 | from ..models.avclassify_model import BaseClassifierGreedyModel 19 | class GreedyTrainer(BaseTrainer): 20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 21 | super(GreedyTrainer,self).__init__(fabric,**para_dict) 22 | self.alpha = method_dict['alpha'] 23 | self.modulation_starts = method_dict['modulation_starts'] 24 | self.modulation_ends = method_dict['modulation_ends'] 25 | 26 | self.modality = method_dict['modality'] 27 | self.window_size = method_dict['window_size'] 28 | self.M_bypass_modal_0 = 0 29 | self.M_bypass_modal_1 = 0 30 | self.M_main_modal_0 = 0 31 | self.M_main_modal_1 = 0 32 | self.curation_mode = False 33 | self.caring_modality = 0 34 | self.curation_step = self.window_size 35 | self.speed = 0 36 | 37 | def train_loop( 38 | self, 39 | model: BaseClassifierGreedyModel, 40 | optimizer: torch.optim.Optimizer, 41 | train_loader: torch.utils.data.DataLoader, 42 | limit_batches: Union[int, float] = float("inf"), 43 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 44 | ): 45 | """The training loop running a single training epoch. 46 | 47 | Args: 48 | model: the LightningModule to train 49 | optimizer: the optimizer, optimizing the LightningModule. 50 | train_loader: The dataloader yielding the training batches. 51 | limit_batches: Limits the batches during this training epoch. 52 | If greater than the number of batches in the ``train_loader``, this has no effect. 53 | scheduler_cfg: The learning rate scheduler configuration. 54 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 55 | for supported values. 56 | 57 | """ 58 | self.fabric.call("on_train_epoch_start") 59 | all_modalitys = list(model.modalitys) 60 | all_modalitys.append('output') 61 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 62 | iterable = self.progbar_wrapper( 63 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 64 | ) 65 | 66 | for batch_idx, batch in enumerate(iterable): 67 | # end epoch if stopping training completely or max batches for this epoch reached 68 | if self.should_stop or batch_idx >= limit_batches: 69 | break 70 | self.fabric.call("on_train_batch_start", batch, batch_idx) 71 | 72 | 73 | # check if optimizer should step in gradient accumulation 74 | should_optim_step = self.global_step % self.grad_accum_steps == 0 75 | if should_optim_step: 76 | # currently only supports a single optimizer 77 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 78 | 79 | # optimizer step runs train step internally through closure 80 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 81 | if not self.curation_mode: 82 | self.speed = self.compute_learning_speed(model) 83 | if abs(self.speed) > self.alpha: 84 | biased_direction=np.sign(self.speed) 85 | self.curation_mode = True 86 | self.curation_step = 0 87 | 88 | if biased_direction==-1: #BDR0BDR1 91 | self.caring_modality = 0 92 | else: 93 | self.curation_mode = False 94 | self.caring_modality = 0 95 | else: 96 | self.curation_step +=1 97 | if self.curation_step==self.window_size: 98 | self.curation_mode=False 99 | optimizer.step() 100 | self.fabric.call("on_before_zero_grad", optimizer) 101 | 102 | optimizer.zero_grad() 103 | 104 | else: 105 | # gradient accumulation -> no optimizer step 106 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 107 | 108 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 109 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 110 | 111 | # this guard ensures, we only step the scheduler once per global step 112 | # if should_optim_step: 113 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 114 | 115 | # add output values to progress bar 116 | 117 | self._format_iterable(iterable, self._current_train_return, "train") 118 | 119 | # only increase global step if optimizer stepped 120 | self.global_step += int(should_optim_step) 121 | 122 | self._current_metrics = self.precision_calculator.compute_metrics() 123 | self.fabric.call("on_train_epoch_end") 124 | 125 | def training_step(self, model: BaseClassifierGreedyModel, batch, batch_idx): 126 | 127 | # TODO: make it simpler and easier to extend 128 | modality_list = model.modalitys 129 | criterion = nn.CrossEntropyLoss() 130 | softmax = nn.Softmax(dim=1) 131 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 132 | model(batch,self.curation_mode,self.caring_modality) 133 | else: 134 | model(batch,self.curation_mode,self.caring_modality) 135 | model.Unimodality_Calculate() 136 | 137 | label = batch['label'] 138 | label = label.to(model.device) 139 | # print(a.shape, v.shape, model.head.weight.shape) 140 | 141 | ## our modality-wise normalization on weight and feature 142 | out = model.encoder_result['output'] 143 | loss = criterion(out, label) 144 | loss.backward() 145 | return loss 146 | 147 | def compute_learning_speed(self,model:BaseClassifierGreedyModel): 148 | modality_list = model.modalitys 149 | wn_main, wn_bypass = [0]*len(modality_list), [0]*len(modality_list) 150 | gn_main, gn_bypass = [0]*len(modality_list), [0]*len(modality_list) 151 | for name, parameter in model.named_parameters(): 152 | wn = (parameter ** 2).sum().item() 153 | gn = (parameter.grad.data ** 2).sum().item()#(grad ** 2).sum().item() 154 | if 'mmtm_layers' in name: 155 | shared=True 156 | for ind, modal in enumerate(modality_list): 157 | if modal in name: 158 | wn_bypass[ind]+=wn 159 | gn_bypass[ind]+=gn 160 | shared = False 161 | if shared: 162 | for ind, modal in enumerate(modality_list): 163 | wn_bypass[ind]+=wn 164 | gn_bypass[ind]+=gn 165 | 166 | else: 167 | for ind, modal in enumerate(modality_list): 168 | if modal in name: 169 | wn_main[ind]+=wn 170 | gn_main[ind]+=gn 171 | 172 | self.M_bypass_modal_0 += gn_bypass[0]/wn_bypass[0] 173 | self.M_bypass_modal_1 += gn_bypass[1]/wn_bypass[1] 174 | self.M_main_modal_0 += gn_main[0]/wn_main[0] 175 | self.M_main_modal_1 += gn_main[1]/wn_main[1] 176 | 177 | BDR_0 = np.log10(self.M_bypass_modal_0/self.M_main_modal_0) 178 | BDR_1 = np.log10(self.M_bypass_modal_1/self.M_main_modal_1) 179 | 180 | return BDR_0 - BDR_1 -------------------------------------------------------------------------------- /balancemm/trainer/MBSD_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | 18 | 19 | class MBSDTrainer(BaseTrainer): 20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 21 | super(MBSDTrainer,self).__init__(fabric,**para_dict) 22 | self.modulation_starts = method_dict['modulation_starts'] 23 | self.modulation_ends = method_dict['modulation_ends'] 24 | 25 | def train_loop( 26 | self, 27 | model: L.LightningModule, 28 | optimizer: torch.optim.Optimizer, 29 | train_loader: torch.utils.data.DataLoader, 30 | limit_batches: Union[int, float] = float("inf"), 31 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 32 | ): 33 | """The training loop running a single training epoch. 34 | 35 | Args: 36 | model: the LightningModule to train 37 | optimizer: the optimizer, optimizing the LightningModule. 38 | train_loader: The dataloader yielding the training batches. 39 | limit_batches: Limits the batches during this training epoch. 40 | If greater than the number of batches in the ``train_loader``, this has no effect. 41 | scheduler_cfg: The learning rate scheduler configuration. 42 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 43 | for supported values. 44 | 45 | """ 46 | self.fabric.call("on_train_epoch_start") 47 | all_modalitys = list(model.modalitys) 48 | all_modalitys.append('output') 49 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 50 | iterable = self.progbar_wrapper( 51 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 52 | ) 53 | 54 | 55 | for batch_idx, batch in enumerate(iterable): 56 | # end epoch if stopping training completely or max batches for this epoch reached 57 | if self.should_stop or batch_idx >= limit_batches: 58 | break 59 | 60 | self.fabric.call("on_train_batch_start", batch, batch_idx) 61 | 62 | # check if optimizer should step in gradient accumulation 63 | should_optim_step = self.global_step % self.grad_accum_steps == 0 64 | if should_optim_step: 65 | # currently only supports a single optimizer 66 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 67 | 68 | # optimizer step runs train step internally through closure 69 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 70 | self.fabric.call("on_before_zero_grad", optimizer) 71 | 72 | optimizer.zero_grad() 73 | 74 | else: 75 | # gradient accumulation -> no optimizer step 76 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 77 | 78 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 79 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 80 | 81 | # this guard ensures, we only step the scheduler once per global step 82 | # if should_optim_step: 83 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 84 | 85 | # add output values to progress bar 86 | 87 | self._format_iterable(iterable, self._current_train_return, "train") 88 | 89 | # only increase global step if optimizer stepped 90 | self.global_step += int(should_optim_step) 91 | 92 | self._current_metrics = self.precision_calculator.compute_metrics() 93 | self.fabric.call("on_train_epoch_end") 94 | 95 | def training_step(self, model, batch, batch_idx , dependent_modality : str = 'none'): 96 | 97 | # TODO: make it simpler and easier to extend 98 | criterion = nn.CrossEntropyLoss() 99 | softmax = nn.Softmax(dim=1) 100 | modality_list = model.modalitys 101 | key = list(modality_list) 102 | m = {} 103 | loss = {} 104 | loss_modality = {} 105 | prediction = {} 106 | y_pred = {} 107 | model(batch) 108 | for modality in modality_list: 109 | m[modality] = model.encoder_result[modality] 110 | m['out'] = model.encoder_result['output'] 111 | model.Unimodality_Calculate() 112 | # out_a, out_v = model.AVCalculate(a, v, out) 113 | label = batch['label'] 114 | device = model.device 115 | # print(a.shape, v.shape, model.head.weight.shape) 116 | 117 | ## our modality-wise normalization on weight and feature 118 | 119 | loss['out'] = criterion(m['out'], label) 120 | for modality in modality_list: 121 | loss_modality[modality] = criterion(model.unimodal_result[modality], label) 122 | # loss_v = criterion(unimodal_result[], label) 123 | # loss_a = criterion(out_a, label) 124 | 125 | for modality in modality_list: 126 | prediction[modality] = softmax(model.unimodal_result[modality]) 127 | # prediction_a = softmax(out_a) 128 | # prediction_v = softmax(out_v) 129 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 130 | if len(modality_list) == 2: 131 | 132 | loss_RS = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((model.unimodal_result[key[0]] - model.unimodal_result[key[1]])**2, dim = 1) 133 | 134 | w = torch.tensor([0.0 for _ in range(len(m['out']))]) 135 | w = w.to(device) 136 | for modality in modality_list: 137 | y_pred[modality] = prediction[modality] 138 | y_pred[modality] = y_pred[modality].argmax(dim=-1) 139 | # y_pred_a = prediction_a 140 | # y_pred_a = y_pred_a.argmax(dim = -1) 141 | # y_pred_v = prediction_v 142 | # y_pred_v = y_pred_v.argmax(dim = -1) 143 | ps = torch.tensor([0.0 for _ in range(len(m['out']))]) 144 | ps = ps.to(device) 145 | pw = torch.tensor([0.0 for _ in range(len(m['out']))]) 146 | pw = pw.to(device) 147 | for i in range(len(m['out'])): 148 | if y_pred[key[0]][i] == label[i] or y_pred[key[1]][i] == label[i]: 149 | w[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 150 | ps[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 151 | pw[i] = min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 152 | 153 | loss_KL = F.kl_div(ps, pw, reduction = 'none') 154 | w = w.reshape(1,-1) 155 | loss_KL = loss_KL.reshape(-1,1) 156 | loss_KL = torch.mm(w, loss_KL) / len(m['out']) 157 | loss_RS = loss_RS.reshape(-1,1) 158 | loss_RS = torch.mm(w, loss_RS) / len(m['out']) 159 | total_loss = loss['out'] + loss_modality[key[0]] + loss_modality[key[1]] + loss_RS.squeeze() + loss_KL.squeeze() ## erase the dim of 1 160 | else: 161 | 162 | w1 = torch.tensor([0.0 for _ in range(len(m['out']))]) 163 | w1 = w1.to(device) 164 | w2 = torch.tensor([0.0 for _ in range(len(m['out']))]) 165 | w2 = w2.to(device) 166 | w3 = torch.tensor([0.0 for _ in range(len(m['out']))]) 167 | w3 = w3.to(device) 168 | ps1 = torch.tensor([0.0 for _ in range(len(m['out']))]) 169 | ps2 = torch.tensor([0.0 for _ in range(len(m['out']))]) 170 | ps3 = torch.tensor([0.0 for _ in range(len(m['out']))]) 171 | ps1 = ps1.to(device) 172 | ps2 = ps2.to(device) 173 | ps3 = ps3.to(device) 174 | pw1 = torch.tensor([0.0 for _ in range(len(m['out']))]) 175 | pw2 = torch.tensor([0.0 for _ in range(len(m['out']))]) 176 | pw3 = torch.tensor([0.0 for _ in range(len(m['out']))]) 177 | pw1 = pw1.to(device) 178 | pw2 = pw2.to(device) 179 | pw3 = pw3.to(device) 180 | for modality in modality_list: 181 | y_pred[modality] = prediction[modality] 182 | y_pred[modality] = y_pred[modality].argmax(dim=-1) 183 | 184 | for i in range(len(m['out'])): 185 | if y_pred[key[0]][i] == label[i] or y_pred[key[1]][i] == label[i]: 186 | w1[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 187 | if y_pred[key[0]][i] == label[i] or y_pred[key[2]][i] == label[i]: 188 | w2[i] = max(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]]) 189 | if y_pred[key[1]][i] == label[i] or y_pred[key[2]][i] == label[i]: 190 | w3[i] = max(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]]) - min(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]]) 191 | ps1[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 192 | pw1[i] = min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) 193 | ps2[i] = max(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]]) 194 | pw2[i] = min(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]]) 195 | ps3[i] = max(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]]) 196 | pw3[i] = min(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]]) 197 | loss_RS1 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[0]]-prediction[key[1]])**2,dim=1) 198 | loss_RS2 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[0]]-prediction[key[2]])**2,dim=1) 199 | loss_RS3 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[1]]-prediction[key[2]])**2,dim=1) 200 | 201 | loss_KL1 = F.kl_div(ps1, pw1, reduction = 'none') 202 | loss_KL2 = F.kl_div(ps2, pw2, reduction = 'none') 203 | loss_KL3 = F.kl_div(ps3, pw3, reduction = 'none') 204 | 205 | w1 = w1.reshape(1,-1) 206 | w2 = w2.reshape(1,-1) 207 | w3 = w3.reshape(1,-1) 208 | loss_KL1 = loss_KL1.reshape(-1,1) 209 | loss_KL1 = torch.mm(w1, loss_KL1) / len(m['out']) 210 | loss_KL2 = loss_KL2.reshape(-1,1) 211 | loss_KL2 = torch.mm(w2, loss_KL2) / len(m['out']) 212 | loss_KL3 = loss_KL3.reshape(-1,1) 213 | loss_KL3 = torch.mm(w3, loss_KL3) / len(m['out']) 214 | loss_KL = (loss_KL1 + loss_KL2 + loss_KL3) / 3 215 | 216 | loss_RS1 = loss_RS1.reshape(-1,1) 217 | loss_RS2 = loss_RS2.reshape(-1,1) 218 | loss_RS3 = loss_RS3.reshape(-1,1) 219 | loss_RS1 = torch.mm(w1, loss_RS1) / len(m['out']) 220 | loss_RS2 = torch.mm(w2, loss_RS2) / len(m['out']) 221 | loss_RS3 = torch.mm(w3, loss_RS3) / len(m['out']) 222 | loss_RS = (loss_RS1 + loss_RS2 + loss_RS3) / 3 223 | 224 | total_loss = loss['out'] + loss_modality[key[0]] + loss_modality[key[1]] + loss_modality[key[2]] + loss_KL.squeeze() + loss_RS.squeeze()## erase the dim of 1 225 | 226 | else: 227 | 228 | total_loss = loss['out'] 229 | total_loss.backward() 230 | 231 | return total_loss -------------------------------------------------------------------------------- /balancemm/trainer/MLA_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | from lightning_utilities import apply_to_collection 15 | import lightning as L 16 | import torch 17 | 18 | 19 | 20 | class GSPlugin(): 21 | def __init__(self, device, gs_flag = True ): 22 | 23 | super().__init__() 24 | 25 | # dtype = torch.cuda.FloatTensor # run on GPU 26 | if device != ' ': 27 | device = 'cuda:0' 28 | else: 29 | device = 'cpu' 30 | with torch.no_grad(): 31 | # self.Pl = torch.eye(1024).to(device) 32 | # depend on modality_size 33 | self.Pl = torch.eye(512).to(device) 34 | self.exp_count = 0 35 | 36 | # @torch.no_grad() 37 | def before_update(self, model, before_batch_input, batch_index, len_dataloader, train_exp_counter): 38 | lamda = batch_index / len_dataloader + 1 39 | alpha = 1.0 * 0.1 ** lamda 40 | 41 | # x_mean = torch.mean(strategy.mb_x, 0, True) 42 | if train_exp_counter != 0: 43 | for n, w in model.named_parameters(): 44 | if n == "weight": 45 | 46 | r = torch.mean(before_batch_input, 0, True) 47 | k = torch.mm(self.Pl, torch.t(r)) 48 | self.Pl = torch.sub(self.Pl, torch.mm(k, torch.t(k)) / (alpha + torch.mm(k, r))) 49 | 50 | pnorm2 = torch.norm(self.Pl.data, p='fro') 51 | 52 | self.Pl.data = self.Pl.data / pnorm2 53 | w.grad.add_(torch.mm(w.grad.data, torch.t(self.Pl.data))) 54 | 55 | 56 | 57 | 58 | class MLATrainer(BaseTrainer): 59 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {},args = {}): 60 | super(MLATrainer,self).__init__(fabric,**para_dict) 61 | self.modulation_starts = method_dict['modulation_starts'] 62 | self.modulation_ends = method_dict['modulation_ends'] 63 | self.device = args.model['device'] 64 | self.gs_plugin = GSPlugin(self.device) 65 | self.criterion = nn.CrossEntropyLoss() 66 | 67 | 68 | def train_loop( 69 | self, 70 | model: L.LightningModule, 71 | optimizer: torch.optim.Optimizer, 72 | train_loader: torch.utils.data.DataLoader, 73 | limit_batches: Union[int, float] = float("inf"), 74 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 75 | ): 76 | """The training loop running a single training epoch. 77 | 78 | Args: 79 | model: the LightningModule to train 80 | optimizer: the optimizer, optimizing the LightningModule. 81 | train_loader: The dataloader yielding the training batches. 82 | limit_batches: Limits the batches during this training epoch. 83 | If greater than the number of batches in the ``train_loader``, this has no effect. 84 | scheduler_cfg: The learning rate scheduler configuration. 85 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 86 | for supported values. 87 | 88 | """ 89 | modality_list = model.modalitys 90 | all_modalitys = list(model.modalitys) 91 | all_modalitys.append('output') 92 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 93 | iterable = self.progbar_wrapper( 94 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 95 | ) 96 | len_dataloader = len(train_loader) 97 | self.fabric.call("on_train_epoch_start") 98 | 99 | for batch_idx, batch in enumerate(iterable): 100 | # end epoch if stopping training completely or max batches for this epoch reached 101 | if self.should_stop or batch_idx >= limit_batches: 102 | break 103 | 104 | self.fabric.call("on_train_batch_start", batch, batch_idx) 105 | 106 | # check if optimizer should step in gradient accumulation 107 | should_optim_step = self.global_step % self.grad_accum_steps == 0 108 | if should_optim_step: 109 | self.training_step(model=model, batch=batch, batch_idx=batch_idx,len_dataloader=len_dataloader,optimizer=optimizer) 110 | 111 | else: 112 | # gradient accumulation -> no optimizer step 113 | self.training_step(model=model, batch=batch, batch_idx=batch_idx,len_dataloader=len_dataloader) 114 | # self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 115 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 116 | 117 | # this guard ensures, we only step the scheduler once per global step 118 | # if should_optim_step: 119 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 120 | 121 | # add output values to progress bar 122 | self._format_iterable(iterable, self._current_train_return, "train") 123 | # only increase global step if optimizer stepped 124 | self.global_step += int(should_optim_step) 125 | 126 | # self._current_metrics = self.precision_calculator.compute_metrics() 127 | self.fabric.call("on_train_epoch_end") 128 | 129 | def training_step(self, model, batch, batch_idx,len_dataloader,optimizer): 130 | 131 | # TODO: make it simpler and easier to extend 132 | modality_list = model.modalitys 133 | key = list(modality_list) 134 | # out_v,out_a,out = unimodal_result['visual'], unimodal_result['audio'], unimodal_result['output'] 135 | label = batch['label'] 136 | # label = label.to(model.device) 137 | loss = 0 138 | loss_modality = {} 139 | # feature =model(batch) 140 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 141 | # feature = {} 142 | # for modality in modality_list: 143 | # feature[modality] = model.encoder_result[modality].clone().contiguous() 144 | for modality in modality_list: 145 | feature = model.feature_extract(batch, modality = modality) 146 | out = model.fusion_module.fc_out(feature) 147 | loss = self.criterion(out,label) 148 | try: 149 | loss.backward() 150 | except RuntimeError as e: 151 | 152 | print("Computation graph:") 153 | for name, param in model.named_parameters(): 154 | if param.grad is not None: 155 | print(f"{name} grad shape: {param.grad.shape}") 156 | raise e 157 | encoder_out = feature.detach() 158 | self.gs_plugin.before_update(model.fusion_module.fc_out, encoder_out, 159 | batch_idx, len_dataloader, self.gs_plugin.exp_count) 160 | 161 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 162 | optimizer.step() 163 | self.fabric.call("on_before_zero_grad", optimizer) 164 | optimizer.zero_grad() 165 | loss_modality[modality] = loss.item() 166 | self.gs_plugin.exp_count += 1 167 | 168 | for n, p in model.named_parameters(): 169 | if p.grad != None: 170 | del p.grad 171 | 172 | loss = self.alpha*loss_modality[key[0]]+(1-self.alpha)*loss_modality[key[1]] 173 | 174 | 175 | else: 176 | loss = self.criterion(model.unimodal_result['output'], label) 177 | loss.backward() 178 | 179 | return loss 180 | -------------------------------------------------------------------------------- /balancemm/trainer/OGM_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from torch.optim.optimizer import Optimizer as Optimizer 3 | from .base_trainer import BaseTrainer 4 | import copy 5 | import torch 6 | import torch.nn as nn 7 | from balancemm.models.avclassify_model import BaseClassifierModel 8 | from collections.abc import Mapping 9 | from functools import partial 10 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 11 | import lightning as L 12 | import torch 13 | 14 | class OGMTrainer(BaseTrainer): 15 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 16 | super(OGMTrainer,self).__init__(fabric,**para_dict) 17 | self.alpha = method_dict['alpha'] 18 | self.method = method_dict['method'] 19 | self.modulation_starts = method_dict['modulation_starts'] 20 | self.modulation_ends = method_dict['modulation_ends'] 21 | # self.modality = method_dict['modality'] 22 | 23 | def train_loop( 24 | self, 25 | model: L.LightningModule, 26 | optimizer: torch.optim.Optimizer, 27 | train_loader: torch.utils.data.DataLoader, 28 | limit_batches: Union[int, float] = float("inf"), 29 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 30 | ): 31 | """The training loop running a single training epoch. 32 | 33 | Args: 34 | model: the LightningModule to train 35 | optimizer: the optimizer, optimizing the LightningModule. 36 | train_loader: The dataloader yielding the training batches. 37 | limit_batches: Limits the batches during this training epoch. 38 | If greater than the number of batches in the ``train_loader``, this has no effect. 39 | scheduler_cfg: The learning rate scheduler configuration. 40 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 41 | for supported values. 42 | 43 | """ 44 | self.fabric.call("on_train_epoch_start") 45 | all_modalitys = list(model.modalitys) 46 | all_modalitys.append('output') 47 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 48 | iterable = self.progbar_wrapper( 49 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 50 | ) 51 | 52 | for batch_idx, batch in enumerate(iterable): 53 | # end epoch if stopping training completely or max batches for this epoch reached 54 | if self.should_stop or batch_idx >= limit_batches: 55 | break 56 | 57 | self.fabric.call("on_train_batch_start", batch, batch_idx) 58 | 59 | # check if optimizer should step in gradient accumulation 60 | should_optim_step = self.global_step % self.grad_accum_steps == 0 61 | if should_optim_step: 62 | # currently only supports a single optimizer 63 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 64 | 65 | # optimizer step runs train step internally through closure 66 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 67 | self.fabric.call("on_before_zero_grad", optimizer) 68 | # torch.cuda.empty_cache() 69 | optimizer.zero_grad() 70 | 71 | else: 72 | # gradient accumulation -> no optimizer step 73 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 74 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 75 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 76 | 77 | # this guard ensures, we only step the scheduler once per global step 78 | # if should_optim_step: 79 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 80 | 81 | # add output values to progress bar 82 | 83 | self._format_iterable(iterable, self._current_train_return, "train") 84 | 85 | # only increase global step if optimizer stepped 86 | self.global_step += int(should_optim_step) 87 | self._current_metrics = self.precision_calculator.compute_metrics() 88 | 89 | def training_step(self, model : BaseClassifierModel, batch, batch_idx): 90 | 91 | # TODO: make it simpler and easier to extend 92 | softmax = nn.Softmax(dim=1) 93 | criterion = nn.CrossEntropyLoss() 94 | relu = nn.ReLU(inplace=True) 95 | tanh = nn.Tanh() 96 | label = batch['label'] 97 | label = label.to(model.device) 98 | model(batch) 99 | # model.Unimodality_Calculate() 100 | loss = criterion(model.unimodal_result['output'], label) 101 | loss.backward() 102 | modality_list = model.modalitys 103 | 104 | # Modulation starts here ! 105 | modality_nums = len(modality_list) 106 | scores = {} 107 | ratios = {} 108 | coeffs = {} 109 | minscore = float('inf') 110 | #Calculate the scores 111 | 112 | for modality in modality_list: 113 | if modality_nums == 2 or self.method == 'OGM_GE3': 114 | score_modality = sum([softmax(model.unimodal_result[modality])[i][label[i]] for i in range(model.unimodal_result['output'].size(0))]) 115 | elif modality_nums == 3: 116 | score_modality = sum([softmax(torch.cos(model.unimodal_result[modality]))[i][label[i]] if label[i] == torch.argmax(model.unimodal_result[modality][i]) else 0 for i in range(model.unimodal_result['output'].size(0))]) 117 | else: 118 | raise("Wrong number of modalitys for OGM, it should be 2 or 3, but given {:0}".format(modality_nums)) 119 | try: 120 | scores[modality] = score_modality.detach().clone() 121 | except: 122 | continue 123 | minscore = min(score_modality, minscore) 124 | ##Calculate the ratios 125 | if self.method == 'OGM_GE3': 126 | for modality in modality_list: 127 | count = 0 128 | ratios[modality] = 0 129 | for modality_another in modality_list: 130 | if modality_another == modality: 131 | continue 132 | ratios[modality] += scores[modality].detach().clone()/(scores[modality_another]+ 1e-5) 133 | count += 1 134 | ratios[modality]/= count 135 | elif self.method == "OGM_GE": 136 | for modality in modality_list: 137 | ratios[modality] = scores[modality].detach().clone() 138 | if modality_nums == 2: 139 | for modality_another in modality_list: 140 | if modality_another == modality: 141 | continue 142 | 143 | ratios[modality] /= (scores[modality_another]+ 1e-5) # prevent OOM 144 | if modality_nums == 3: 145 | ratios[modality] /= (minscore + 1e-5) 146 | #Calculate the coeffects 147 | for modality in modality_list: 148 | if ratios[modality] > 1 : 149 | coeffs[modality] = max(1 - tanh(self.alpha * relu(ratios[modality])),0) 150 | else: 151 | coeffs[modality] = 1 152 | 153 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: # bug fixed 154 | for name, parms in model.named_parameters(): 155 | layer = str(name) 156 | for modality in modality_list: 157 | 158 | if modality in layer and len(parms.grad.size()) != 1: ##Don't change the grad of bias for layer 159 | if self.method == 'OGM_GE' or self.method == 'OGM_GE3': # bug fixed 160 | parms.grad = parms.grad * coeffs[modality] + \ 161 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8) 162 | elif self.method == 'OGM': 163 | parms.grad *= coeffs[modality] 164 | else: 165 | pass 166 | 167 | 168 | return loss -------------------------------------------------------------------------------- /balancemm/trainer/OPM_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from torch.optim.optimizer import Optimizer as Optimizer 3 | from .base_trainer import BaseTrainer 4 | 5 | import torch 6 | import torch.nn as nn 7 | from balancemm.models.avclassify_model import BaseClassifierModel 8 | from collections.abc import Mapping 9 | from functools import partial 10 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 11 | import lightning as L 12 | import torch 13 | import numpy as np 14 | 15 | 16 | class Modality_drop(): 17 | def __init__(self, dim_list, p_exe=0.7, device='cuda'): 18 | self.dim_list = dim_list 19 | self.p_exe = p_exe 20 | self.device = device 21 | 22 | def execute_drop(self, feat_list, q, model): 23 | modality_list = list(model.modalitys) 24 | B = feat_list[modality_list[0]].shape[0] # batch size 25 | exe_drop = torch.tensor(np.random.rand(1)).to(device=self.device) >= 1-self.p_exe 26 | if not exe_drop: 27 | return feat_list, torch.ones([B], dtype=torch.float32, device=self.device) 28 | 29 | d_sum = sum(self.dim_list.values()) 30 | q_sum = sum(self.dim_list[m] * q[m] for m in modality_list) 31 | theta = q_sum/d_sum 32 | num_mod = len(modality_list) 33 | q_temp = torch.tensor([q[m] for m in modality_list], device=self.device) 34 | mask = torch.distributions.Bernoulli(1 - q_temp).sample([B, 1]).permute(2, 1, 0).contiguous().reshape(num_mod, B, -1).to(self.device) 35 | 36 | cleaned = {} 37 | for idx, modality in enumerate(modality_list): 38 | D = feat_list[modality].shape[1] 39 | current_mask = mask[idx].expand(-1,D) 40 | cleaned_fea = torch.mul(feat_list[modality], current_mask) 41 | cleaned_fea = cleaned_fea / (1 - theta + 1e-5) 42 | cleaned[modality] = cleaned_fea 43 | 44 | mask = mask.squeeze(-1).transpose(0,1) # [B,num_mod] 45 | 46 | update_flag = torch.sum(mask,dim=1) > 0 47 | for modality in modality_list: 48 | cleaned[modality] = cleaned[modality][update_flag] 49 | return cleaned,update_flag 50 | 51 | 52 | def clip(a, b, c): 53 | if b= limit_batches: 104 | break 105 | 106 | self.fabric.call("on_train_batch_start", batch, batch_idx) 107 | 108 | # check if optimizer should step in gradient accumulation 109 | should_optim_step = self.global_step % self.grad_accum_steps == 0 110 | if should_optim_step: 111 | # currently only supports a single optimizer 112 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 113 | 114 | # optimizer step runs train step internally through closure 115 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 116 | self.fabric.call("on_before_zero_grad", optimizer) 117 | # torch.cuda.empty_cache() 118 | optimizer.zero_grad() 119 | 120 | else: 121 | # gradient accumulation -> no optimizer step 122 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 123 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 124 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 125 | 126 | # this guard ensures, we only step the scheduler once per global step 127 | # if should_optim_step: 128 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 129 | 130 | # add output values to progress bar 131 | 132 | self._format_iterable(iterable, self._current_train_return, "train") 133 | 134 | # only increase global step if optimizer stepped 135 | self.global_step += int(should_optim_step) 136 | self._current_metrics = self.precision_calculator.compute_metrics() 137 | 138 | def training_step(self, model : BaseClassifierModel, batch, batch_idx): 139 | 140 | # TODO: make it simpler and easier to extend 141 | softmax = nn.Softmax(dim=1) 142 | criterion = nn.CrossEntropyLoss() 143 | relu = nn.ReLU(inplace=True) 144 | tanh = nn.Tanh() 145 | label = batch['label'] 146 | label = label.to(model.device) 147 | model(batch) 148 | model.Unimodality_Calculate() 149 | loss = {} 150 | modality_list = model.modalitys 151 | key = list(model.modalitys) 152 | 153 | # Modulation starts here ! 154 | modality_nums = len(modality_list) 155 | scores = {} 156 | ratios = {} 157 | coeffs = {} 158 | #Calculate the scores 159 | 160 | 161 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: # bug fixed 162 | for modality in modality_list: 163 | scores[modality] = sum([softmax(model.unimodal_result[modality])[i][label[i]] for i in range(model.unimodal_result['output'].shape[0])]) 164 | ##Calculate the ratios 165 | for modality in modality_list: 166 | ratios[modality] = scores[modality] 167 | if modality_nums == 2: 168 | for modality_another in modality_list: 169 | if modality_another == modality: 170 | continue 171 | 172 | ratios[modality] /= (scores[modality_another]+ 1e-5) # prevent OOM 173 | ratios[modality] = tanh(relu(ratios[modality]-1)) 174 | if modality_nums == 3: 175 | temp_score = 0.0 176 | for modality_another in modality_list: 177 | if modality_another == modality: 178 | continue 179 | temp_score += scores[modality_another] 180 | ratios[modality] /= (temp_score + 1e-5) 181 | ratios[modality] = tanh(relu(ratios[modality]-1)) 182 | #Calculate the coeffs 183 | for modality in modality_list: 184 | coeffs[modality] = self.q_base * (1 + self.alpha * ratios[modality]) if ratios[modality]>0 else 0 185 | coeffs[modality] = clip(coeffs[modality],0.0,1.0) 186 | model.encoder_result.pop('output') 187 | 188 | cleaned_fea,update_flag=self.modality_drop.execute_drop(model.encoder_result,coeffs,model) 189 | 190 | model.unimodal_result['output'] = model.fusion_module(cleaned_fea) 191 | select_mask=update_flag!=0 192 | label=label[select_mask] 193 | 194 | 195 | for modality in modality_list: 196 | 197 | model.unimodal_result[modality]=model.unimodal_result[modality][select_mask] 198 | 199 | 200 | for modality in model.unimodal_result.keys(): 201 | loss[modality] = criterion(model.unimodal_result[modality],label) 202 | 203 | loss['output'].backward() 204 | 205 | else: 206 | pass 207 | 208 | return loss['output'] 209 | -------------------------------------------------------------------------------- /balancemm/trainer/PMR_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | from lightning_utilities import apply_to_collection 15 | import lightning as L 16 | import torch 17 | 18 | def clip(a, b, c): 19 | if b= limit_batches: 73 | break 74 | 75 | self.fabric.call("on_train_batch_start", batch, batch_idx) 76 | 77 | # check if optimizer should step in gradient accumulation 78 | should_optim_step = self.global_step % self.grad_accum_steps == 0 79 | if should_optim_step: 80 | # currently only supports a single optimizer 81 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 82 | 83 | # optimizer step runs train step internally through closure 84 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx, proto = self.proto)) 85 | self.fabric.call("on_before_zero_grad", optimizer) 86 | 87 | optimizer.zero_grad() 88 | 89 | else: 90 | # gradient accumulation -> no optimizer step 91 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 92 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 93 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 94 | 95 | # this guard ensures, we only step the scheduler once per global step 96 | # if should_optim_step: 97 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 98 | 99 | # add output values to progress bar 100 | self._format_iterable(iterable, self._current_train_return, "train") 101 | # only increase global step if optimizer stepped 102 | self.global_step += int(should_optim_step) 103 | 104 | self._current_metrics = self.precision_calculator.compute_metrics() 105 | self.fabric.call("on_train_epoch_end") 106 | 107 | def training_step(self, model, batch, batch_idx, proto): 108 | 109 | # TODO: make it simpler and easier to extend 110 | criterion = nn.CrossEntropyLoss() 111 | softmax = nn.Softmax(dim=1) 112 | log_softmax = nn.LogSoftmax(dim=1) 113 | tanh = nn.Tanh() 114 | modality_list = model.modalitys 115 | key = list(modality_list) 116 | m = {} 117 | model(batch) 118 | for modality in modality_list: 119 | m[modality] = model.encoder_result[modality] 120 | # a, v = model(batch)['audio'], model(batch)['visual'] 121 | unimodal_result = model.Unimodality_Calculate() 122 | # out_v,out_a,out = unimodal_result['visual'], unimodal_result['audio'], unimodal_result['output'] 123 | label = batch['label'] 124 | label = label.to(model.device) 125 | loss_modality = {} 126 | for modality in modality_list: 127 | # print(unimodal_result[modality]) 128 | # print(label) 129 | loss_modality[modality] = criterion(unimodal_result[modality],label) 130 | 131 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 132 | sim = {} 133 | for modality in modality_list: 134 | sim[modality] = -EU_dist(m[modality],proto[modality]) 135 | # audio_sim = -EU_dist(a, audio_proto) # B x n_class 136 | # visual_sim = -EU_dist(v, visual_proto) # B x n_class 137 | # print('sim: ', audio_sim[0][0].data, visual_sim[0][0].data, a[0][0].data, v[0][0].data) 138 | 139 | score_p = {} 140 | # score = {} 141 | loss_proto = {} 142 | for modality in modality_list: 143 | score_p[modality] = sum([softmax(sim[modality])[i][label[i]] for i in range(sim[modality].size(0))]) 144 | 145 | # score_a_p = sum([softmax(audio_sim)[i][label[i]] for i in range(audio_sim.size(0))]) 146 | # score_v_p = sum([softmax(visual_sim)[i][label[i]] for i in range(visual_sim.size(0))]) 147 | if len(modality_list) == 2: 148 | ratio_a_p = score_p[key[0]] / score_p[key[1]] 149 | else: 150 | ratio = {} 151 | min_score = min(score_p.values()) 152 | for modality in modality_list: 153 | ratio[modality] = score_p[modality] / min_score 154 | 155 | # for modality in modality_list: 156 | # score[modality] = sum([softmax(unimodal_result[modality])[i][label[i]] for i in range(unimodal_result[modality].size(0))]) 157 | # # score_v = sum([softmax(out_v)[i][label[i]] for i in range(out_v.size(0))]) 158 | # # score_a = sum([softmax(out_a)[i][label[i]] for i in range(out_a.size(0))]) 159 | # ratio_a = score[key[0]] / score[key[1]] 160 | 161 | for modality in modality_list: 162 | loss_proto[modality] = criterion(sim[modality],label) 163 | # loss_proto_a = criterion(audio_sim, label) 164 | # loss_proto_v = criterion(visual_sim, label) 165 | if len(modality_list) == 2: 166 | if ratio_a_p > 1: 167 | beta = 0 # audio coef 168 | lam = 1 * self.alpha # visual coef 169 | elif ratio_a_p < 1: 170 | beta = 1 * self.alpha 171 | lam = 0 172 | else: 173 | beta = 0 174 | lam = 0 175 | loss = criterion(unimodal_result['output'], label) + beta * loss_proto[key[0]] + lam * loss_proto[key[1]] 176 | loss.backward() 177 | else: 178 | loss = criterion(unimodal_result['output'], label) 179 | loss.backward() 180 | k_t = {} 181 | for modality in modality_list: 182 | if ratio[modality] > 1: 183 | k_t[modality] = 1-tanh(self.eta * ratio[modality]) 184 | else: 185 | k_t[modality] = 1 186 | 187 | for name, parms in model.named_parameters(): 188 | layer = str(name) 189 | for modality in modality_list: 190 | if modality in layer and len(parms.grad.size()) != 1: ##Don't change the grad of bias for layer 191 | parms.grad = parms.grad * k_t[modality] - \ 192 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8) 193 | # loss_a = criterion(out_a, label) 194 | else: 195 | loss = criterion(unimodal_result['output'], label) 196 | loss.backward() 197 | # # avoid gradients in stored/accumulated values -> prevents potential OOM 198 | # self._current_train_return = apply_to_collection(model.encoder_result, dtype=torch.Tensor, function=lambda x: x.detach()) 199 | return loss 200 | def calculate_prototype(self, model, dataloader, proto0): 201 | # todo customed output of prototype 202 | n_classes = model.n_classes 203 | device = next(model.parameters()).device 204 | proto = {} 205 | modality_list = model.modalitys 206 | for modality in modality_list: 207 | proto[modality] = torch.zeros(n_classes, model.modality_size[modality]).to(device) 208 | count_class = [0 for _ in range(n_classes)] 209 | 210 | # calculate prototype 211 | model.eval() 212 | with torch.no_grad(): 213 | sample_count = 0 214 | all_num = len(dataloader) 215 | m = {} 216 | for batch_idx, batch in enumerate(dataloader): 217 | model(batch) 218 | for modality in modality_list: 219 | m[modality] = model.encoder_result[modality] 220 | label = batch['label'] 221 | 222 | 223 | for c, l in enumerate(label): 224 | l = l.long() 225 | count_class[l] += 1 226 | for modality in modality_list: 227 | 228 | proto[modality][l,:] += m[modality][c,:] 229 | 230 | 231 | sample_count += 1 232 | 233 | if sample_count >= all_num // 10: 234 | break 235 | for modality in modality_list: 236 | for c in range(proto[modality].shape[0]): 237 | proto[modality][c,:] /= count_class[c] 238 | # audio_prototypes[c, :] /= count_class[c] 239 | # visual_prototypes[c, :] /= count_class[c] 240 | 241 | if self.current_epoch <= 0: 242 | for modality in modality_list: 243 | proto[modality] = proto[modality] 244 | # audio_prototypes = audio_prototypes 245 | # visual_prototypes = visual_prototypes 246 | else: 247 | for modality in modality_list: 248 | proto[modality] = (1-self.momentum_coef) * proto[modality] + self.momentum_coef * proto0[modality] 249 | # audio_prototypes = (1 - self.momentum_coef) * audio_prototypes + self.momentum_coef * a_proto 250 | # visual_prototypes = (1 - self.momentum_coef) * visual_prototypes + self.momentum_coef * v_proto 251 | return proto -------------------------------------------------------------------------------- /balancemm/trainer/UMT_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | from ..models.avclassify_model import BaseClassifierModel 18 | from ..evaluation.complex import get_flops 19 | from ..models import create_model 20 | import copy 21 | from ..utils.train_utils import get_checkpoint_files, get_newest_path 22 | import os.path as osp 23 | import yaml 24 | from models.avclassify_model import MultiModalParallel 25 | class UMTTrainer(BaseTrainer): 26 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}, args = {}): 27 | super(UMTTrainer,self).__init__(fabric,**para_dict) 28 | self.scaling = method_dict['scaling'] 29 | self.modulation_starts = method_dict['modulation_starts'] 30 | self.modulation_ends = method_dict['modulation_ends'] 31 | 32 | self.loaded_model = [] 33 | loaded_model = {} 34 | temp_args = copy.deepcopy(args) 35 | # root_path = osp.dirname(osp.dirname(__file__)) 36 | # with open(osp.join(root_path ,"configs", "encoder_config.yaml"), 'r') as f: 37 | # encoder_settings = yaml.safe_load(f) 38 | if args.mode == "train_and_test": 39 | out_dir = '_'.join(temp_args.out_dir.split('/')[:-1]) 40 | for modality in args.model['encoders'].keys(): 41 | temp_args.model['encoders'] = {modality: args.model['encoders'][modality]} 42 | temp_args.model['modality_size'] = {modality: args.model['modality_size'][modality]} 43 | loaded_model[modality] = create_model(temp_args.model) 44 | out_dir = temp_args.out_dir.replace('UMTTrainer', 'unimodalTrainer_' + modality) 45 | out_dir = '/'.join(out_dir.split('/')[:-1]) 46 | path = get_newest_path(out_dir) 47 | # loaded_model[modality].load_state_dict(torch.load(get_checkpoint_files(path)[0])['model']) 48 | loaded_model[modality].load_state_dict(torch.load(get_checkpoint_files(path)[0])['model']) 49 | loaded_model[modality] = MultiModalParallel(loaded_model[modality],device_ids=[0,1]) 50 | loaded_model[modality] =loaded_model[modality].cuda() 51 | # loaded_model[modality] = torch.load(get_checkpoint_files(path)[0])['model'] 52 | # print(type(loaded_model)) 53 | self.loaded_model = loaded_model 54 | 55 | def train_loop( 56 | self, 57 | model: L.LightningModule, 58 | optimizer: torch.optim.Optimizer, 59 | train_loader: torch.utils.data.DataLoader, 60 | limit_batches: Union[int, float] = float("inf"), 61 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 62 | ): 63 | """The training loop running a single training epoch. 64 | 65 | Args: 66 | model: the LightningModule to train 67 | optimizer: the optimizer, optimizing the LightningModule. 68 | train_loader: The dataloader yielding the training batches. 69 | limit_batches: Limits the batches during this training epoch. 70 | If greater than the number of batches in the ``train_loader``, this has no effect. 71 | scheduler_cfg: The learning rate scheduler configuration. 72 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 73 | for supported values. 74 | 75 | """ 76 | self.fabric.call("on_train_epoch_start") 77 | all_modalitys = list(model.modalitys) 78 | all_modalitys.append('output') 79 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 80 | iterable = self.progbar_wrapper( 81 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 82 | ) 83 | # if self.current_epoch == 0: 84 | # for batch_idx, batch in enumerate(iterable): 85 | # batch_sample = batch 86 | # break 87 | # print(batch_sample.keys()) 88 | # model_flops, _ =get_flops(model = model, input_sample = batch_sample) 89 | # self.FlopsMonitor.update(model_flops / len(batch_sample['label']) * len(train_loader), 'forward') 90 | # self.FlopsMonitor.report(logger = self.logger) 91 | for batch_idx, batch in enumerate(iterable): 92 | # end epoch if stopping training completely or max batches for this epoch reached 93 | if self.should_stop or batch_idx >= limit_batches: 94 | break 95 | 96 | self.fabric.call("on_train_batch_start", batch, batch_idx) 97 | 98 | # check if optimizer should step in gradient accumulation 99 | should_optim_step = self.global_step % self.grad_accum_steps == 0 100 | if should_optim_step: 101 | # currently only supports a single optimizer 102 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 103 | 104 | # optimizer step runs train step internally through closure 105 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 106 | self.fabric.call("on_before_zero_grad", optimizer) 107 | 108 | optimizer.zero_grad() 109 | 110 | else: 111 | # gradient accumulation -> no optimizer step 112 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 113 | 114 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 115 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 116 | 117 | # this guard ensures, we only step the scheduler once per global step 118 | # if should_optim_step: 119 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 120 | 121 | # add output values to progress bar 122 | 123 | self._format_iterable(iterable, self._current_train_return, "train") 124 | 125 | # only increase global step if optimizer stepped 126 | self.global_step += int(should_optim_step) 127 | self._current_metrics = self.precision_calculator.compute_metrics() 128 | self.fabric.call("on_train_epoch_end") 129 | def training_step(self, model: BaseClassifierModel, batch, batch_idx): 130 | 131 | # TODO: make it simpler and easier to extend 132 | criterion = nn.CrossEntropyLoss() 133 | MSE = nn.MSELoss() 134 | label = batch['label'] 135 | label = label.to(model.device) 136 | _ = model(batch= batch) 137 | out = model.encoder_result['output'] 138 | loss = criterion(out, label) 139 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: 140 | for modality in self.loaded_model.keys(): 141 | with torch.no_grad(): 142 | self.loaded_model[modality](batch) 143 | out_unimodal = self.loaded_model[modality].encoder_result[modality] 144 | loss += self.scaling * MSE(out_unimodal, model.encoder_result[modality]) 145 | 146 | loss.backward() 147 | return loss -------------------------------------------------------------------------------- /balancemm/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from os import path as osp 4 | from types import SimpleNamespace 5 | import lightning as L 6 | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger 7 | 8 | __all__ = ['create_trainer'] 9 | 10 | # automatically scan and import trainer modules 11 | # scan all the files under the data folder with '_trainer' in file names 12 | trainer_folder = osp.dirname(osp.abspath(__file__)) 13 | trainer_filenames = [ 14 | osp.splitext(v)[0] for v in os.listdir(trainer_folder) 15 | if v.endswith('_trainer.py') 16 | ] 17 | 18 | # import all the trainer modules 19 | _trainer_modules = [ 20 | importlib.import_module(f'.{file_name}', package="balancemm.trainer") 21 | for file_name in trainer_filenames 22 | ] 23 | 24 | def create_trainer(fabric: L.Fabric ,trainer_opt:dict, para_opt, args, logger,tb_logger): 25 | # dynamic instantiation 26 | for module in _trainer_modules: 27 | trainer_cls = getattr(module, trainer_opt["trainer"], None) 28 | if trainer_cls is not None: 29 | break 30 | if trainer_cls is None: 31 | raise ValueError(f'trainer {trainer} is not found.') 32 | para_opt['base_para']['logger'] = logger 33 | if args.trainer['name'] != 'UMT' and args.trainer['name'] != 'LinearProbe' and args.trainer['name'] != "MLA" and args.trainer['name'] != "OPM" and args.trainer["name"] != "Sample": 34 | trainer = trainer_cls(fabric, para_opt, para_opt['base_para']) 35 | else: 36 | trainer = trainer_cls(fabric, para_opt, para_opt['base_para'], args) 37 | trainer.checkpoint_dir = args.checkpoint_dir 38 | 39 | print( 40 | f'Trainer {trainer.__class__.__name__} - {trainer_opt["trainer"]} ' 41 | 'is created.') 42 | para_opt['name'] = trainer_opt["trainer"] 43 | # logger.info("normal Settings: %s", para_opt) 44 | logger.info("trainer Settings: %s", para_opt) 45 | if isinstance(tb_logger, TensorBoardLogger): 46 | tb_logger.log_hyperparams(trainer_opt) 47 | tb_logger.experiment.add_text("Trainer Setup", f"Trainer: {trainer_opt['trainer']}") 48 | return trainer -------------------------------------------------------------------------------- /balancemm/trainer/baseline_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | from ..models.avclassify_model import BaseClassifierModel 18 | class baselineTrainer(BaseTrainer): 19 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 20 | super(baselineTrainer,self).__init__(fabric,**para_dict) 21 | 22 | self.modulation_starts = method_dict['modulation_starts'] 23 | self.modulation_ends = method_dict['modulation_ends'] 24 | self.modality = method_dict['modality'] 25 | 26 | def train_loop( 27 | self, 28 | model: L.LightningModule, 29 | optimizer: torch.optim.Optimizer, 30 | train_loader: torch.utils.data.DataLoader, 31 | limit_batches: Union[int, float] = float("inf"), 32 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 33 | ): 34 | """The training loop running a single training epoch. 35 | 36 | Args: 37 | model: the LightningModule to train 38 | optimizer: the optimizer, optimizing the LightningModule. 39 | train_loader: The dataloader yielding the training batches. 40 | limit_batches: Limits the batches during this training epoch. 41 | If greater than the number of batches in the ``train_loader``, this has no effect. 42 | scheduler_cfg: The learning rate scheduler configuration. 43 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers` 44 | for supported values. 45 | 46 | """ 47 | self.fabric.call("on_train_epoch_start") 48 | all_modalitys = list(model.modalitys) 49 | all_modalitys.append('output') 50 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 51 | iterable = self.progbar_wrapper( 52 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 53 | ) 54 | for batch_idx, batch in enumerate(iterable): 55 | # end epoch if stopping training completely or max batches for this epoch reached 56 | if self.should_stop or batch_idx >= limit_batches: 57 | break 58 | 59 | self.fabric.call("on_train_batch_start", batch, batch_idx) 60 | 61 | # check if optimizer should step in gradient accumulation 62 | should_optim_step = self.global_step % self.grad_accum_steps == 0 63 | if should_optim_step: 64 | # currently only supports a single optimizer 65 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 66 | 67 | # optimizer step runs train step internally through closure 68 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 69 | self.fabric.call("on_before_zero_grad", optimizer) 70 | 71 | optimizer.zero_grad() 72 | 73 | else: 74 | # gradient accumulation -> no optimizer step 75 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 76 | # print(len(batch['label']), len(model.module.prediction['output'])) 77 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 78 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 79 | 80 | self._format_iterable(iterable, self._current_train_return, "train") 81 | 82 | # only increase global step if optimizer stepped 83 | self.global_step += int(should_optim_step) 84 | self._current_metrics = self.precision_calculator.compute_metrics() 85 | self.fabric.call("on_train_epoch_end") 86 | def training_step(self, model: BaseClassifierModel, batch, batch_idx): 87 | 88 | # TODO: make it simpler and easier to extend 89 | criterion = nn.CrossEntropyLoss() 90 | label = batch['label'] 91 | label = label.to(model.device) 92 | 93 | model(batch) 94 | out = model.encoder_result['output'] 95 | loss = criterion(out, label) 96 | loss.backward() 97 | return loss -------------------------------------------------------------------------------- /balancemm/trainer/unimodal_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from lightning import LightningModule 3 | from torch.optim.optimizer import Optimizer as Optimizer 4 | from .base_trainer import BaseTrainer 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | import os 11 | from collections.abc import Mapping 12 | from functools import partial 13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast 14 | 15 | import lightning as L 16 | import torch 17 | from ..models.avclassify_model import BaseClassifierModel 18 | from ..evaluation.complex import get_flops 19 | class unimodalTrainer(BaseTrainer): 20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}): 21 | super(unimodalTrainer,self).__init__(fabric,**para_dict) 22 | self.alpha = method_dict['alpha'] 23 | self.method = method_dict['method'] 24 | self.modulation_starts = method_dict['modulation_starts'] 25 | self.modulation_ends = method_dict['modulation_ends'] 26 | self.modality = method_dict['modality'] 27 | 28 | def train_loop( 29 | self, 30 | model: L.LightningModule, 31 | optimizer: torch.optim.Optimizer, 32 | train_loader: torch.utils.data.DataLoader, 33 | limit_batches: Union[int, float] = float("inf"), 34 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None, 35 | ): 36 | 37 | self.fabric.call("on_train_epoch_start") 38 | all_modalitys = list(model.modalitys) 39 | all_modalitys.append('output') 40 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys) 41 | iterable = self.progbar_wrapper( 42 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}" 43 | ) 44 | # if self.current_epoch == 0: 45 | # for batch_idx, batch in enumerate(iterable): 46 | # batch_sample = batch 47 | # break 48 | # print(batch_sample.keys()) 49 | # model_flops, _ =get_flops(model = model, input_sample = batch_sample) 50 | # self.FlopsMonitor.update(model_flops / len(batch_sample['label']) * len(train_loader), 'forward') 51 | # self.FlopsMonitor.report(logger = self.logger) 52 | for batch_idx, batch in enumerate(iterable): 53 | # end epoch if stopping training completely or max batches for this epoch reached 54 | if self.should_stop or batch_idx >= limit_batches: 55 | break 56 | 57 | self.fabric.call("on_train_batch_start", batch, batch_idx) 58 | 59 | # check if optimizer should step in gradient accumulation 60 | should_optim_step = self.global_step % self.grad_accum_steps == 0 61 | if should_optim_step: 62 | # currently only supports a single optimizer 63 | self.fabric.call("on_before_optimizer_step", optimizer, 0) 64 | 65 | # optimizer step runs train step internally through closure 66 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx)) 67 | self.fabric.call("on_before_zero_grad", optimizer) 68 | 69 | optimizer.zero_grad() 70 | 71 | else: 72 | # gradient accumulation -> no optimizer step 73 | self.training_step(model=model, batch=batch, batch_idx=batch_idx) 74 | 75 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction) 76 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx) 77 | 78 | # this guard ensures, we only step the scheduler once per global step 79 | # if should_optim_step: 80 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step) 81 | 82 | # add output values to progress bar 83 | 84 | self._format_iterable(iterable, self._current_train_return, "train") 85 | 86 | # only increase global step if optimizer stepped 87 | self.global_step += int(should_optim_step) 88 | self._current_metrics = self.precision_calculator.compute_metrics() 89 | self.fabric.call("on_train_epoch_end") 90 | def training_step(self, model: BaseClassifierModel, batch, batch_idx): 91 | 92 | # TODO: make it simpler and easier to extend 93 | criterion = nn.CrossEntropyLoss() 94 | label = batch['label'] 95 | label = label.to(model.device) 96 | 97 | model(batch) 98 | out = model.encoder_result['output'] 99 | loss = criterion(out, label) 100 | loss.backward() 101 | return loss -------------------------------------------------------------------------------- /balancemm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torchvision import datasets, transforms 3 | from types import SimpleNamespace 4 | from ..datasets import create_dataset 5 | import lightning as L 6 | import numpy as np 7 | import random 8 | def worker_init_fn(worker_id): 9 | worker_seed = torch.initial_seed() % 2**32 10 | np.random.seed(worker_seed) 11 | random.seed(worker_seed) 12 | 13 | def create_train_val_dataloader(fabric: L.Fabric, config: dict): 14 | 15 | train_dataset = create_dataset(config.dataset, 'train') 16 | if config.dataset.get('validation', False): 17 | val_dataset = create_dataset(config.dataset, 'valid') 18 | test_dataset = create_dataset(config.dataset, 'test') 19 | else: 20 | val_dataset = create_dataset(config.dataset, 'test') 21 | test_dataset = val_dataset 22 | 23 | config_dataloader = SimpleNamespace(**config.dataloader) 24 | if config_dataloader.fast_run == True: 25 | train_dataset = torch.utils.data.Subset(train_dataset, list(range(config_dataloader.eff_batch_size*4))) 26 | val_dataset = torch.utils.data.Subset(val_dataset, list(range(config_dataloader.eff_batch_size*2))) 27 | test_dataset = torch.utils.data.Subset(test_dataset, list(range(config_dataloader.eff_batch_size*2))) 28 | # print len of datasets 29 | fabric.print(f"Train dataset: {train_dataset.__class__.__name__} - {config.Train['dataset']}, {len(train_dataset)} samples") 30 | fabric.print(f"Val dataset: {val_dataset.__class__.__name__} - {config.Val['dataset']}, {len(val_dataset)} samples") 31 | 32 | if (not hasattr(config_dataloader, 'batch_size') or config_dataloader.batch_size == -1): 33 | config_dataloader.batch_size = round(config_dataloader.eff_batch_size/fabric.world_size) # using the effective batch_size to calculate the batch_size per gpu 34 | 35 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 36 | batch_size=config_dataloader.batch_size, 37 | shuffle=config_dataloader.shuffle, 38 | drop_last = config_dataloader.drop_last, 39 | num_workers = config_dataloader.num_workers, 40 | multiprocessing_context='spawn', 41 | pin_memory = config_dataloader.pin_memory, 42 | prefetch_factor=6) 43 | if config.trainer['name'] == 'Sample': 44 | train_val_dataloader = torch.utils.data.DataLoader(train_dataset, 45 | batch_size=config_dataloader.batch_size, 46 | shuffle=False, 47 | drop_last = config_dataloader.drop_last, 48 | num_workers = config_dataloader.num_workers, 49 | multiprocessing_context='spawn', 50 | pin_memory = config_dataloader.pin_memory, 51 | prefetch_factor=6) 52 | 53 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config_dataloader.batch_size, drop_last = config_dataloader.drop_last, 54 | num_workers = config_dataloader.num_workers, 55 | multiprocessing_context='spawn', 56 | pin_memory = config_dataloader.pin_memory, 57 | prefetch_factor=6) 58 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config_dataloader.batch_size, drop_last = config_dataloader.drop_last, 59 | num_workers = config_dataloader.num_workers, 60 | multiprocessing_context='spawn', 61 | pin_memory = config_dataloader.pin_memory, 62 | prefetch_factor=6) 63 | 64 | if config.Train['dataset'] == 'FOOD101': 65 | g = torch.Generator() 66 | train_dataloader = torch.utils.data.DataLoader(train_dataset, 67 | batch_size=config_dataloader.batch_size, 68 | shuffle=config_dataloader.shuffle, 69 | drop_last = config_dataloader.drop_last, 70 | num_workers = config_dataloader.num_workers, 71 | multiprocessing_context='spawn', 72 | pin_memory = config_dataloader.pin_memory, 73 | worker_init_fn=worker_init_fn, 74 | generator=g) 75 | 76 | if config.trainer['name'] == 'Sample': 77 | train_val_dataloader = torch.utils.data.DataLoader(train_dataset, 78 | batch_size=config_dataloader.batch_size, 79 | shuffle=False, 80 | drop_last = config_dataloader.drop_last, 81 | num_workers = config_dataloader.num_workers, 82 | multiprocessing_context='spawn', 83 | pin_memory = config_dataloader.pin_memory, 84 | worker_init_fn=worker_init_fn, 85 | generator=g) 86 | 87 | val_dataloader = torch.utils.data.DataLoader(val_dataset, 88 | batch_size=config_dataloader.batch_size, 89 | shuffle=config_dataloader.shuffle, 90 | drop_last = config_dataloader.drop_last, 91 | num_workers = config_dataloader.num_workers, 92 | multiprocessing_context='spawn', 93 | pin_memory = config_dataloader.pin_memory, 94 | worker_init_fn=worker_init_fn, 95 | generator=g) 96 | test_dataloader = torch.utils.data.DataLoader(test_dataset, 97 | batch_size=config_dataloader.batch_size, 98 | shuffle=config_dataloader.shuffle, 99 | drop_last = config_dataloader.drop_last, 100 | num_workers = config_dataloader.num_workers, 101 | multiprocessing_context='spawn', 102 | pin_memory = config_dataloader.pin_memory, 103 | worker_init_fn=worker_init_fn, 104 | generator=g) 105 | # print batchsize and len 106 | fabric.print(f"Train dataloader: {len(train_dataloader)} batches, {len(train_dataloader.dataset)} samples") 107 | fabric.print(f"Val dataloader: {len(val_dataloader)} batches, {len(val_dataloader.dataset)} samples") 108 | if config.trainer['name'] == 'Sample': 109 | return train_dataloader, train_val_dataloader, val_dataloader, test_dataloader 110 | return train_dataloader, val_dataloader, test_dataloader -------------------------------------------------------------------------------- /balancemm/utils/encoder_module.py: -------------------------------------------------------------------------------- 1 | def find_module(encoder): 2 | if hasattr(optim, optimizer_type): 3 | return getattr(optim, optimizer_type) 4 | # elif globals().get(optimizer_type + 'Optimizer'): 5 | # return globals()[optimizer_type + 'Optimizer'] 6 | else: 7 | raise ValueError(f'Optimizer {optimizer_type} not found in torch.optim or current module.') -------------------------------------------------------------------------------- /balancemm/utils/logger.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import sys 3 | 4 | def setup_logger(log_file: str): 5 | logger.remove() 6 | logger.add(log_file, rotation="10 MB", level="INFO") 7 | logger.add(sys.stdout, level="INFO") 8 | return logger -------------------------------------------------------------------------------- /balancemm/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | def find_module(optimizer_type): 4 | if hasattr(optim, optimizer_type): 5 | return getattr(optim, optimizer_type) 6 | # elif globals().get(optimizer_type + 'Optimizer'): 7 | # return globals()[optimizer_type + 'Optimizer'] 8 | else: 9 | raise ValueError(f'Optimizer {optimizer_type} not found in torch.optim or current module.') 10 | 11 | def create_optimizer(model, args: dict, parameter: dict): 12 | if 'type' not in args: 13 | raise ValueError('Optimizer type is required.') 14 | optimizer_type = args['type'] 15 | optimizer_cls = find_module(optimizer_type) 16 | optimizer_args = {k: v for k, v in args.items() if k != 'type'} 17 | optimizer_args['lr'] = parameter['base_lr'] if parameter['lr'] == -1 else parameter['lr'] 18 | optimizer = optimizer_cls(model.parameters(), **optimizer_args) 19 | print (f'Optimizer {optimizer.__class__.__name__} - {optimizer_type} is created.') 20 | return optimizer -------------------------------------------------------------------------------- /balancemm/utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from dynaconf import Dynaconf 3 | 4 | __all__ = ['parse_cli_args_to_dict', 'load_config_dict_from_yaml', 'ensure_and_get_config_path', 'find_module'] 5 | 6 | def _merge_into_dict(target_dict: dict, key_path: str, value: any) -> None: 7 | """ 8 | Merge a hierarchically structured key and its value into the target dictionary. 9 | 10 | :param target_dict: The dictionary to be updated. 11 | :param key_path: A string representing the hierarchical keys, separated by dots. 12 | :param value: The value to be set at the innermost key. 13 | """ 14 | keys = key_path.split('.') 15 | for key in keys[:-1]: 16 | if key not in target_dict or not isinstance(target_dict[key], dict): 17 | target_dict[key] = {} 18 | target_dict = target_dict[key] 19 | target_dict[keys[-1]] = value 20 | 21 | def parse_cli_args_to_dict(cli_args: list[str], args_dict: dict = {}) -> dict: 22 | """ 23 | Parses a list of command-line arguments into a dictionary, supporting keys prefixed with '-' or '--'. 24 | Values can be specified either by '=' or as the next argument. Keys without explicit values are assigned 25 | an empty string. When the same key is specified multiple times, the last value takes precedence. 26 | Notice that format '-a.b' or '-a .b' is not supported, use "--a.b" or "-a '.b'" instead. 27 | 28 | Args: 29 | cli_args (list[str]): The list of command-line arguments. 30 | args_dict (dict, optional): An existing dictionary to update with the command-line arguments. 31 | Defaults to None, in which case a new dictionary is created. 32 | 33 | Returns: 34 | dict: A dictionary containing the parsed command-line arguments, merged with any existing 35 | entries in `args_dict`. 36 | """ 37 | i = 0 38 | while i < len(cli_args): 39 | if not cli_args[i].startswith('-'): 40 | i += 1 41 | continue 42 | if not cli_args[i].startswith('--'): 43 | # check for invalid argument format '-a.b' or '-a .b' 44 | if i + 1 < len(cli_args) and cli_args[i + 1].startswith('.'): 45 | raise ValueError(f"Invalid argument: '-{cli_args[i]}{cli_args[i + 1]}'. Use --a.b or -a '.b' instead.") 46 | arg = cli_args[i].lstrip('-') 47 | value = '' # Default value if no explicit value is provided 48 | # Check for '=' in the current argument 49 | if '=' in arg: 50 | arg, value = arg.split('=', 1) 51 | else: 52 | # Check if the next argument is a value (not another key) 53 | if i + 1 < len(cli_args) and not cli_args[i + 1].startswith('-'): 54 | value = cli_args[i + 1] 55 | i += 1 # Skip the next argument since it's a value 56 | # Use '.' to replace '-' in key paths for hierarchical structuring 57 | _merge_into_dict(args_dict, arg, value) 58 | i += 1 59 | 60 | return args_dict 61 | 62 | def load_config_dict(config_path: str) -> dict: 63 | """ 64 | Load configuration using dynaconf, suppor toml, yaml. 65 | 66 | :param config_path: The file path to the YAML configuration file. 67 | :return: A dictionary containing the loaded configuration, or None if an error occurs. 68 | """ 69 | try: 70 | with open(config_path, "r") as f: 71 | args = yaml.safe_load(f) 72 | except Exception as e: 73 | print(f"Error reading file {config_path}: \n{e}") 74 | print(f"Loaded config from {config_path}") 75 | return args 76 | 77 | def ensure_and_get_config_path(args: list[str], default_config_path: str) -> str: 78 | """ 79 | Ensure that the --config argument is present in args, supporting both '--config=' and '--config ' formats, 80 | as well as the '-c' abbreviation. If not present, extend args with the default config path. 81 | 82 | :param args: List of command-line arguments. 83 | :param default_config_path: The default path to the configuration file if --config or -c is not specified. 84 | :return: config_path (str): The path to the configuration file. 85 | """ 86 | config_key_long = '--config' 87 | config_key_short = '-c' 88 | config_path = None 89 | 90 | for i, arg in enumerate(args): 91 | if arg.startswith(config_key_long + '=') or arg.startswith(config_key_short + '='): 92 | config_path = arg.split('=', 1)[1] 93 | break 94 | elif arg == config_key_long or arg == config_key_short: 95 | if i + 1 < len(args) and not args[i + 1].startswith('-'): 96 | config_path = args[i + 1] 97 | break 98 | 99 | if config_path is None: 100 | config_path = default_config_path 101 | args.extend([f'{config_key_long}={config_path}']) 102 | 103 | return config_path 104 | 105 | def find_module(module_list: list, module_name: str, module_type: str = 'Module') -> any: 106 | module_name = module_name + module_type 107 | for module in module_list: 108 | module_cls = getattr(module, module_name, None) 109 | if module_cls is not None: 110 | return module_cls 111 | raise ValueError(f'{module_type} {module_name} is not found.') -------------------------------------------------------------------------------- /balancemm/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch.optim.lr_scheduler as lr_scheduler 2 | 3 | def find_scheduler(scheduler_type): 4 | if hasattr(lr_scheduler, scheduler_type): 5 | return getattr(lr_scheduler, scheduler_type) 6 | else: 7 | raise ValueError(f'Scheduler {scheduler_type} not found.') 8 | 9 | def create_scheduler(optimizer, args: dict): 10 | if 'type' not in args: 11 | raise ValueError('Scheduler type is required.') 12 | scheduler_type = args['type'] 13 | scheduler_cls = find_scheduler(scheduler_type) 14 | scheduler_args = {k: v for k, v in args.items() if k != 'type'} 15 | scheduler = scheduler_cls(optimizer, **scheduler_args) 16 | print (f'Scheduler {scheduler.__class__.__name__} - {scheduler_type} is created.') 17 | return scheduler 18 | -------------------------------------------------------------------------------- /balancemm/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | from typing import Optional 3 | import subprocess 4 | 5 | import torch.nn as nn 6 | 7 | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger 8 | from lightning.pytorch.loggers import WandbLogger 9 | 10 | import random 11 | import numpy as np 12 | import torch 13 | import os 14 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 15 | total = 0 16 | for p in module.parameters(): 17 | if requires_grad is None or p.requires_grad == requires_grad: 18 | total += p.numel() 19 | return total 20 | 21 | def choose_logger(logger_name: str, log_dir, project: Optional[str] = None, comment: Optional[str] = None, *args, **kwargs): 22 | if logger_name == "csv": 23 | return CSVLogger(root_dir = log_dir, name = 'csv', *args, **kwargs) 24 | elif logger_name == "tensorboard": 25 | logger = TensorBoardLogger(root_dir=log_dir, name='tensorboard',default_hp_metric=False, *args, **kwargs) 26 | tensorboard_log_dir = os.path.join(log_dir, 'tensorboard') 27 | # subprocess.Popen(['tensorboard', '--logdir', tensorboard_log_dir]) 28 | 29 | return logger 30 | elif logger_name == "wandb": 31 | return WandbLogger(project = project, save_dir = log_dir, notes = comment, *args, **kwargs) 32 | else: 33 | raise ValueError(f"`logger={logger_name}` is not a valid option.") 34 | 35 | def get_checkpoint_files(checkpoint_dir): 36 | checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))) 37 | print(f'the checkpoint is {checkpoint_files}') 38 | return checkpoint_files 39 | 40 | def get_newest_path(out_dir): 41 | folders = [f for f in os.listdir(out_dir) if os.path.isdir(os.path.join(out_dir, f)) and len(os.listdir(os.path.join(out_dir, f + '/checkpoints')))>0 ] 42 | # folder = max(folders, key=lambda f: os.path.getmtime(os.path.join(out_dir, f))) 43 | folder = max(folders) 44 | folder = os.path.join(out_dir, folder + '/checkpoints') 45 | if folder: 46 | return folder 47 | else: 48 | raise ValueError('there are no pretrained model') 49 | def set_seed(seed): 50 | """ 51 | Set random seed for training 52 | 53 | Args: 54 | seed (int): random value 55 | """ 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | torch.manual_seed(seed) 59 | torch.cuda.manual_seed_all(seed) 60 | 61 | torch.backends.cudnn.deterministic = True 62 | torch.backends.cudnn.benchmark = False 63 | 64 | os.environ['PYTHONHASHSEED'] = str(seed) 65 | 66 | print(f"Random seed set as {seed}") -------------------------------------------------------------------------------- /configs/dataset_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | dataset: 4 | CMUMOSEI: 5 | dataset_path: None 6 | data : mosei_senti 7 | if_align: false 8 | classes: 2 9 | validation: True 10 | audio: 11 | input_dim: 74 12 | visual: 13 | input_dim: 35 14 | text: 15 | input_dim: 300 16 | 17 | FOOD101: 18 | targ_dir: None 19 | dataset_path: None 20 | classes: 101 21 | validation: false 22 | text: 23 | input_dim: 40 24 | visual: 25 | input_dim: 224 26 | 27 | UCF101: 28 | classes: 101 29 | validation: false 30 | stat_path : None 31 | train_txt : None 32 | test_txt : None 33 | visual_path : None 34 | flow_path_v : None 35 | flow_path_u : None 36 | flow: 37 | input_dim: 224 38 | visual: 39 | input_dim: 224 40 | 41 | 42 | Balance: 43 | csv_path: None 44 | visual_path: None 45 | audio_path: None 46 | validation: false 47 | classes: 30 48 | audio: 49 | input_dim: None 50 | visual: 51 | input_dim: None 52 | 53 | CREMAD: 54 | classes: 6 55 | fps: 2 56 | visual_path : None 57 | audio_path : None 58 | stat_path : None 59 | train_txt : None 60 | test_txt : None 61 | validation: false 62 | audio: 63 | input_dim: None 64 | visual: 65 | input_dim: None 66 | 67 | KineticsSounds: 68 | csv_path_train: None 69 | visual_path_train: None 70 | audio_path_train: None 71 | csv_path_test: None 72 | visual_path_test: None 73 | audio_path_test: None 74 | validation: false 75 | classes: 31 76 | audio: 77 | input_dim: None 78 | visual: 79 | input_dim: None 80 | VGGSound: 81 | validation: false 82 | classes: 309 83 | use_video_frames: 3 84 | csv_root: None 85 | video_train_root: None 86 | video_test_root: None 87 | audio_train_root: None 88 | audio_test_root: None 89 | audio: 90 | input_dim: None 91 | visual: 92 | input_dim: None 93 | -------------------------------------------------------------------------------- /configs/encoder_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | ViT_B: 3 | if_pretrain: False 4 | pretrain_path: None 5 | audio: 6 | tpye: None 7 | visual: 8 | tpye: None 9 | text: 10 | tpye: None 11 | 12 | Transformer_LA: 13 | if_pretrain: False 14 | pretrain_path: None 15 | output_dim: 1000 16 | audio: 17 | layer: 6 18 | hidden_size: 512 19 | dropout_r: 0.1 20 | multi_head: 4 21 | ff_size: 1024 22 | seq_len: 60 23 | modality: "audio" 24 | visual: 25 | layer: 6 26 | hidden_size: 512 27 | dropout_r: 0.1 28 | multi_head: 4 29 | ff_size: 1024 30 | seq_len: 60 31 | modality: "visual" 32 | text: 33 | layer: 6 34 | hidden_size: 512 35 | dropout_r: 0.1 36 | multi_head: 4 37 | ff_size: 1024 38 | seq_len: 60 39 | modality: "text" 40 | Transformer: 41 | if_pretrain: True 42 | pretrain_path: None 43 | output_dim: 1024 44 | audio: 45 | n_features: 512 46 | dim: 1024 47 | n_head: 4 48 | n_layers: 2 49 | visual: 50 | n_features: 512 51 | dim: 1024 52 | n_head: 4 53 | n_layers: 2 54 | text: 55 | n_features: 512 56 | dim: 1024 57 | n_head: 4 58 | n_layers: 2 59 | Transformer_: 60 | if_pretrain: False 61 | pretrain_path: None 62 | output_dim: 1024 63 | audio: 64 | dim: 1024 65 | n_head: 4 66 | n_layers: 2 67 | visual: 68 | dim: 1024 69 | n_head: 4 70 | n_layers: 2 71 | text: 72 | dim: 1024 73 | n_head: 4 74 | n_layers: 2 75 | ResNet18: 76 | output_dim: 512 77 | if_pretrain: False 78 | pretrain_path: '/data/users/shaoxuan_xu/3_OGM/Pretrained_model/resnet18.pth' 79 | audio: 80 | modality: 'audio' 81 | visual: 82 | modality: 'visual' 83 | flow: 84 | modality: 'flow' 85 | front_view: 86 | modality: 'front_view' 87 | back_view: 88 | modality: 'back_view' -------------------------------------------------------------------------------- /configs/global_config.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | dataloader: 4 | eff_batch_size: 32 5 | num_workers: 32 # number of workers for dataloader 6 | fast_run: false # if true, only use part of the dataset 7 | shuffle: true # if true, shuffle the dataset 8 | drop_last: true # if true, drop the last batch if the size is not equal to batch_size 9 | pin_memory: true # if true, use pin_memory for dataloader 10 | 11 | fabric: 12 | accelerator: "gpu" # "cpu", "gpu", "cuda", "mps", "tpu" 13 | devices: [1] # number of devices or list of indexs 14 | precision: "32-true" # "32-true", "16-mixed", "bf16-mixed", etc. 15 | strategy: "dp" # "dp", "ddp", "ddp2", "ddp_spawn" 16 | 17 | log: 18 | logger_name: ['tensorboard'] # list of logger name 19 | wandb_name: '' 20 | log_interval: 100 # how many batches to wait before logging training status 21 | log_per_epoch: -1 # number of log for each epoch. If set will override log_interval 22 | comment: '' # comment for the logger 23 | 24 | 25 | train: 26 | parameter: 27 | total_epoch: 30 # number of epochs to train 28 | warmup: 10 # warmup beta and lr 29 | base_lr: 0.001 # base learning rate. Using linear law, lr = eff_batch_size/base_bs*base_lr 30 | lr: -1 # learning rate. If set, override base_lr, set to -1 to use base_lr 31 | base_batch_size: 2 # base batch_size 32 | lr_scaling_rule: "linear" # "sqrt", "linear" learning rate should scale with batch size 33 | checkpoint: 34 | resume: '' # path to the checkpoint 35 | checkpoint_frequency: 1 # number of epoch interval to save the checkpoint 36 | num_checkpoint_keep: 3 # set to -1 to save all checkpoints 37 | # optimizer: 38 | # type: "SGD" # type of optimizer 39 | # momentum: 0.9 40 | # weight_decay: 0.0001 41 | optimizer: 42 | type: "AdamW" # type of optimizer 43 | # momentum: 0.9 44 | weight_decay: 0.01 45 | scheduler: 46 | type: "StepLR" # type of scheduler 47 | step_size: 10 # step size for StepLR 48 | gamma: 0.1 # gamma for StepLR 49 | validation: 50 | frequency: 1 # number of epoch interval to do validation 51 | select_best: 'last' # 'last', 'best' select the best model based on validation loss 52 | loss: 53 | type: "CrossEntropyLoss" # type of loss 54 | 55 | -------------------------------------------------------------------------------- /configs/model_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | BaseClassifier: 3 | encoders: {audio: Transformer_, visual: Transformer_, text: Transformer_} 4 | fusion: "concat" 5 | modality_size: 1024 6 | BaseClassifier_AMCo: 7 | encoders: {audio: ResNet18, visual: ResNet18} 8 | fusion: "concat" 9 | modality_size: 512 10 | BaseClassifier_Greedy: 11 | encoders: {audio: ResNet18, visual: ResNet18} 12 | fusion: "concat" 13 | modality_size: 512 14 | BaseClassifier_MLA: 15 | encoders: {audio: ResNet18, visual: ResNet18} 16 | fusion: "shared" 17 | modality_size: 1024 18 | # If use Transformer as encoder, set the modality_size to 1024. If use ResNet18 as encoder, set the modality_size to 512. 19 | BaseClassifier_ReconBoost: 20 | encoders: {audio: ResNet18, visual: ResNet18} 21 | fusion: "shared" 22 | modality_size: 1024 23 | # expand other configuration of fusion methods 24 | fusion: 25 | concat: 26 | out_put_dim: 100 27 | shared: 28 | out_put_dim: 100 -------------------------------------------------------------------------------- /configs/trainer_config.yaml: -------------------------------------------------------------------------------- 1 | trainer_para: 2 | base: 3 | grad_accum_steps: 1 4 | max_epochs: 30 5 | baseline: 6 | modulation_starts: 0 7 | modulation_ends: 80 8 | modality: 2 9 | OGM: 10 | alpha: 0.3 11 | method: 'OGM_GE' 12 | modulation_starts: 10 # 13 | modulation_ends: 80 14 | AGM: 15 | alpha: 0.1 16 | method: "None" 17 | modulation_starts: 0 18 | modulation_ends: 80 19 | modality: 2 20 | AMCo: 21 | alpha: 0.1 22 | method: "None" 23 | modulation_starts: 0 24 | modulation_ends: 80 25 | sigma: 0.5 26 | U: 512 27 | eps: 0.3 28 | modality: 2 29 | CML: 30 | modulation_starts: 0 31 | modulation_ends: 80 32 | lam: 0.1 33 | modality: 2 34 | GBlending: 35 | method: "online" 36 | modulation_starts: 0 37 | modulation_ends: 80 38 | super_epoch: 10 39 | modality: 2 40 | PMR: 41 | alpha: 0.6 42 | modulation_starts: 0 43 | modulation_ends: 80 44 | modality: 2 45 | momentum_coef: 0.5 46 | eta: 0.01 47 | MBSD: 48 | modulation_starts: 0 49 | modulation_ends: 80 50 | modality: 2 51 | MMCosine: 52 | modulation_starts: 0 53 | modulation_ends: 80 54 | modality: 2 55 | scaling: 15 56 | Greedy: 57 | alpha: 0.001 58 | modulation_starts: 0 59 | modulation_ends: 80 60 | modality: 2 61 | window_size: 5 62 | UMT: 63 | alpha: 1 64 | modulation_starts: 0 65 | modulation_ends: 80 66 | scaling: 10 67 | MLA: 68 | modulation_starts: 0 69 | modulation_ends: 80 70 | ReconBoost: 71 | alpha: 0.5 72 | modulation_starts: 0 73 | modulation_ends: 80 74 | T_epochs: 1 75 | weight1: 5 76 | weight2: 1 77 | MMPareto: 78 | alpha: 1.5 79 | method: 'None' 80 | modulation_starts: 0 81 | modulation_ends: 80 82 | OPM: 83 | alpha: 0.5 84 | p_exe: 0.7 85 | q_base: 0.4 86 | method: 'None' 87 | modulation_starts: 0 88 | modulation_ends: 80 89 | Sample: 90 | alpha: 0.5 91 | # method: 'Sample-level' 92 | method: 'Modality-level' 93 | modulation_starts: 1 94 | modulation_ends: 80 95 | part_ratio: 0.2 96 | ReLearning: 97 | move_lambda: 3 98 | method: None 99 | modulation_starts: 0 100 | modulation_ends: 80 101 | reinit_epoch: 15 102 | reinit_num: 2 103 | LinearProbe: 104 | alpha: 0.3 105 | method: 'None' 106 | modulation_starts: 0 107 | modulation_ends: 80 108 | modality: 2 109 | trainer_probed: "OGMTrainer" 110 | LFM: 111 | alpha: 1 112 | method: 'learning-fitted' 113 | modulation_starts: 0 114 | modulation_ends: 80 115 | modality: 2 116 | trainer_probed: "OGMTrainer" 117 | lr_alpha: 0.0001 118 | -------------------------------------------------------------------------------- /configs/user_default.yaml: -------------------------------------------------------------------------------- 1 | mode: "train_and_test" 2 | check_point_path: '' 3 | 4 | seed: 42 5 | Main_config: 6 | model: "BaseClassifier" # 7 | model: "BaseClassifier" # 8 | tasks: "Classifier" # 9 | trainer: "LinearProbeTrainer" 10 | device: '1' 11 | dataset: CREMAD 12 | 13 | Train: 14 | dataset: CREMAD 15 | 16 | 17 | Test: 18 | dataset: CREMAD 19 | 20 | -------------------------------------------------------------------------------- /environment: -------------------------------------------------------------------------------- 1 | # conda create -n balancemm python=3.10 2 | # pip install torch==1.12.1+cu113 3 | 4 | # pip install lightning==2.0.0 5 | 6 | # pip install lightning-cloud==0.5.68 7 | # pip install lightning-utilities==0.11.2 8 | # -------------------------------------------------------------------------------- /images/Algorithms.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/Algorithms.jpeg -------------------------------------------------------------------------------- /images/Results.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/Results.jpeg -------------------------------------------------------------------------------- /images/frame6_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/frame6_00.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | anyio==4.3.0 5 | arrow==1.3.0 6 | async-timeout==4.0.3 7 | attrs==23.2.0 8 | audioread==3.0.1 9 | beautifulsoup4==4.12.3 10 | blessed==1.20.0 11 | boto3==1.34.100 12 | botocore==1.34.100 13 | cachetools==5.3.3 14 | cffi==1.16.0 15 | click==8.1.7 16 | colorama==0.4.6 17 | contourpy==1.3.1 18 | croniter==1.3.15 19 | cycler==0.12.1 20 | dateutils==0.6.12 21 | decorator==5.1.1 22 | deepdiff==7.0.1 23 | dynaconf==3.2.5 24 | editor==1.6.6 25 | exceptiongroup==1.2.1 26 | fastapi==0.88.0 27 | filelock==3.14.0 28 | fonttools==4.55.3 29 | frozenlist==1.4.1 30 | fsspec==2023.12.2 31 | grpcio==1.66.2 32 | h11==0.14.0 33 | h5py==3.11.0 34 | huggingface-hub==0.25.0 35 | inquirer==3.2.4 36 | itsdangerous==2.2.0 37 | jinja2==3.1.4 38 | jmespath==1.0.1 39 | joblib==1.4.2 40 | kiwisolver==1.4.7 41 | lazy-loader==0.4 42 | librosa==0.10.1 43 | llvmlite==0.42.0 44 | loguru==0.7.2 45 | markdown==3.7 46 | markdown-it-py==3.0.0 47 | matplotlib==3.10.0 48 | mdurl==0.1.2 49 | memory-profiler==0.61.0 50 | msgpack==1.0.8 51 | multidict==6.0.5 52 | networkx==3.3 53 | numba==0.59.1 54 | ordered-set==4.1.0 55 | packaging==24.0 56 | pandas==2.2.2 57 | platformdirs==4.2.1 58 | pooch==1.8.1 59 | pretty-errors==1.2.25 60 | protobuf==5.26.1 61 | psutil==5.9.8 62 | pycparser==2.22 63 | pydantic==1.10.15 64 | pygments==2.18.0 65 | pyjwt==2.8.0 66 | pyparsing==3.2.0 67 | python-dateutil==2.9.0.post0 68 | python-multipart==0.0.9 69 | pytz==2024.1 70 | pyyaml==6.0.1 71 | readchar==4.0.6 72 | regex==2024.9.11 73 | rich==13.7.1 74 | runs==1.2.2 75 | s3transfer==0.10.1 76 | safetensors==0.4.5 77 | scikit-learn==1.4.2 78 | scipy==1.13.0 79 | six==1.16.0 80 | sniffio==1.3.1 81 | soundfile==0.12.1 82 | soupsieve==2.5 83 | soxr==0.3.7 84 | starlette==0.22.0 85 | starsessions==1.3.0 86 | tensorboard==2.18.0 87 | tensorboard-data-server==0.7.2 88 | tensorboardx==2.6.2.2 89 | termcolor==2.4.0 90 | thop==0.1.1-2209072238 91 | threadpoolctl==3.5.0 92 | tokenizers==0.19.1 93 | tqdm==4.66.4 94 | traitlets==5.14.3 95 | transformers==4.40.2 96 | triton==2.3.0 97 | types-python-dateutil==2.9.0.20240316 98 | tzdata==2024.1 99 | uvicorn==0.29.0 100 | wcwidth==0.2.13 101 | websocket-client==1.8.0 102 | websockets==11.0.3 103 | werkzeug==3.0.4 104 | xmod==1.8.1 105 | yarl==1.9.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | from setuptools import setup, find_packages 3 | import pathlib 4 | 5 | here = pathlib.Path(__file__).parent.resolve() 6 | 7 | # Get the long description from the README file 8 | long_description = (here / 'README.md').read_text(encoding='utf-8') 9 | 10 | setup( 11 | name='balancemm', 12 | version='1.0', 13 | description='A Benchmark for balanced multimodal learning.', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url=' ', 17 | author=' ', 18 | author_email=' ', 19 | classifiers=[ 20 | "License :: OSI Approved :: MIT License", 21 | "Development Status :: 3 - Alpha", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.10", 24 | "Operating System :: POSIX :: Linux", 25 | "Operating System :: Microsoft :: Windows", 26 | ], 27 | keywords='multimodal, benchmark, balanced learning, classification, deep learning', 28 | packages=find_packages(), 29 | install_requires=[ 30 | 'pyyaml', 31 | 'docstring-parser', 32 | 'lightning', 33 | 'torch', 34 | 'torchvision', 35 | 'torchaudio', 36 | 'xformers', 37 | ], 38 | ) --------------------------------------------------------------------------------