├── .gitignore
├── README.md
├── balancemm
├── __main__.py
├── analysis
│ └── t_sne.py
├── datasets
│ ├── KS_dataset.py
│ ├── Mosei_dataset.py
│ ├── VGG_dataset.py
│ ├── __init__.py
│ ├── balance_dataset.py
│ ├── cremad_dataset.py
│ ├── food_dataset.py
│ └── ucf101_dataset.py
├── encoders
│ ├── VisionTransformer_encoder.py
│ ├── __init__.py
│ ├── pretrained_encoder.py
│ └── resnet18_encoder.py
├── evaluation
│ ├── __init__.py
│ ├── complex.py
│ ├── modalitys.py
│ └── precisions.py
├── models
│ ├── __init__.py
│ ├── avclassify_model.py
│ ├── encoders.py
│ ├── fusion_arch.py
│ └── resnet_arch.py
├── test.py
├── train.py
├── trainer
│ ├── AGM_trainer.py
│ ├── AMCo_trainer.py
│ ├── CML_trainer.py
│ ├── GBlending_trainer.py
│ ├── Greedy_trainer.py
│ ├── LFM_trainer.py
│ ├── LinearProbe_trainer.py
│ ├── MBSD_trainer.py
│ ├── MLA_trainer.py
│ ├── MMCosine_trainer.py
│ ├── MMPareto_trainer.py
│ ├── OGM_trainer.py
│ ├── OPM_trainer.py
│ ├── PMR_trainer.py
│ ├── ReLearning_trainer.py
│ ├── ReconBoost_trainer.py
│ ├── Sample_trainer.py
│ ├── UMT_trainer.py
│ ├── __init__.py
│ ├── base_trainer.py
│ ├── baseline_trainer.py
│ └── unimodal_trainer.py
└── utils
│ ├── data_utils.py
│ ├── encoder_module.py
│ ├── logger.py
│ ├── optimizer.py
│ ├── parser_utils.py
│ ├── scheduler.py
│ └── train_utils.py
├── configs
├── dataset_config.yaml
├── encoder_config.yaml
├── global_config.yaml
├── model_config.yaml
├── trainer_config.yaml
└── user_default.yaml
├── environment
├── images
├── Algorithms.jpeg
├── Results.jpeg
└── frame6_00.png
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore experiments
2 | experiments/
3 | # ignore python cache
4 | __pycache__/
5 | *.out
6 | # ignore temps
7 | temps/
8 | training_*
9 | result*
10 | # ignore setup
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BalanceBenchmark: A Survey for Multimodal Imbalance Learning
2 |
3 | ## Paper
4 | [**BalanceBenchmark: A Survey for Multimodal Imbalance Learning**](http://arxiv.org/abs/2502.10816)
5 | Shaoxuan Xu, Menglu Cui, Chengxiang Huang, Hongfa Wang and Di Hu
6 |
7 | If you find this repository useful, please cite our paper and corresponding toolkit:
8 | ```bibtex
9 | @misc{xu2025balancebenchmarksurveymultimodalimbalancelearning,
10 | title={BalanceBenchmark: A Survey for Multimodal Imbalance Learning},
11 | author={Shaoxuan Xu and Menglu Cui and Chengxiang Huang and Hongfa Wang and Di Hu},
12 | year={2025},
13 | eprint={2502.10816},
14 | archivePrefix={arXiv},
15 | primaryClass={cs.LG},
16 | url={https://arxiv.org/abs/2502.10816},
17 | }
18 | ```
19 |
20 | ## Overview
21 | 
22 |
23 | Multimodal learning has gained attention for its capacity to integrate information from different modalities. However, it is often hindered by the multimodal imbalance problem, where certain modalities disproportionately dominate while others remain underutilized. Although recent studies have proposed various methods to alleviate this problem, they lack comprehensive and fair comparisons.
24 | To facilitate this field, we introduce BalanceBenchmark, a systematic and unified benchmark for evaluating multimodal imbalance learning methods. BalanceBenchmark spans 17 algorithms and 7 datasets, providing a comprehensive framework for method evaluation and comparison.
25 |
26 | To accompany BalanceBenchmark, we release **BalanceMM**, a standardized toolkit that implements 17 state-of-the-art approaches spanning four research directions: data-level adjustments, feed-forward modifications, objective adaptations, and optimization-based methods. The toolkit provides a standardized pipeline that unifies innovations in fusion paradigms, optimization objectives, and training approaches.
27 | Our toolkit simplifies the research workflow through:
28 |
29 | + Standardized data loading for 7 multimodal datasets
30 | + Unified implementation of various imbalance learning methods
31 | + Automated experimental pipeline from training to evaluation
32 | + Comprehensive metrics for assessing performance, imbalance degree, and complexity
33 |
34 | BalanceMM is designed with modularity and extensibility in mind, enabling easy integration of new methods and datasets. It provides researchers with the necessary tools to reproduce experiments, conduct fair comparisons, and develop new approaches for addressing the multimodal imbalance problem.
35 | ## Datasets currently supported
36 | + Audio-Visual: KineticsSounds, CREMA-D, BalancedAV, VGGSound
37 | + RGB-Optical Flow: UCF-101
38 | + Image-Text: FOOD-101
39 | + Audio-Visual-Text: CMU-MOSEI
40 |
41 | To add a new dataset:
42 |
43 | 1. Go to `balancemm/datasets/`
44 | 2. Create a new Python file and a new dataset class
45 | 3. Implement the required data loading and preprocessing methods in the `corresponding_dataset.py` file
46 | 4. Add configuration file in `balancemm/configs/dataset_config.yaml`
47 |
48 | ## Algorithms currently supported
49 | + Data-level methods: Modality-valuation
50 | + Feed-forward methods: MLA, OPM, Greedy, AMCo
51 | + Objective methods: MMCosine, UMT, MBSD, CML, MMPareto, GBlending, LFM
52 | + Optimization methods: OGM, AGM, PMR, Relearning, ReconBoost
53 |
54 | See Section 3 in our paper for detailed descriptions of each method.
55 |
56 | 
57 |
58 | To add a new method:
59 |
60 | 1. Determine which category your method belongs to:
61 | + "Data" : methods that adjust data processing
62 | + "Feed-forward" : methods that modify network architecture
63 | + "Objective" : methods that adapt learning objectives
64 | + "Optimization" : methods that adjust optimization process
65 | 2. Go to `balancemm/trainer/`
66 | 3. Create a new Python file implementing your method
67 | 4. Implement the `corresponding_trainer.py` file based on `base_trainer.py`, you should rewrite `trainer.training_step` usually.
68 | 5. Other implementation by your method's category:
69 | + If your method belongs to "Data", go to `balancemm/datasets/__init.py` and modify properly.
70 | + If your method belongs to "Feed-forward", go to `balancemm/models/avclassify_model.py`, create a new model class and rewrite specific functions.
71 | + If your method belongs to "Objective", you mostly don't have to do other modification except traienr.
72 | + If your method belongs to "Optimization", you may need to modify trainer or any parts mentioned above.
73 | + You can also modify any combination of the parts metioned above according to your method.
74 | 6. Add configuration file in `balancemm/configs/trainer_config.yaml`
75 | ## Installation
76 | ```
77 | git clone https://github.com/GeWu-Lab/BalanceBenchmark.git
78 | cd BalanceBenchmark
79 | conda create -n balancemm python=3.10
80 | conda activate balancemm
81 | pip install torch==1.12.1+cu113
82 | pip install -r requirements.txt
83 | pip install lightning==2.0.0
84 | pip install lightning-cloud==0.5.68
85 | pip install lightning-utilities==0.11.2
86 | ```
87 | ## Experiment
88 | To run experiments, you'll need to download the datasets from their open-sourced links. After downloading, place the datasets in your preferred directory and update the dataset path in your configuration file.
89 |
90 | You can run any experiment using a single command line:
91 | ```
92 | python -m balancemm \
93 | --trainer [trainer_name] \
94 | --dataset [dataset_name] \
95 | --model [model_name] \
96 | --hyper-params [param_file.yaml] \
97 | --device [0/cpu]
98 | ```
99 | For example, to run OGM on CREMA-D dataset:
100 | ```
101 | python -m balancemm \
102 | --trainer OGM \
103 | --dataset CREMAD \
104 | --model BaseClassifier \
105 | --alpha 0.5 \
106 | --device 0
107 | ```
108 | ## Results
109 | We have conducted comprehensive experiments using the proposed BanlenceBenchmark on 7 datasets. The results indicate that almost all related methods outperform the Baseline in terms of accuracy and F1 score, demonstrating that the multimodal imbalance problem is prevalent across various scenarios.
110 |
111 | 
112 |
113 |
--------------------------------------------------------------------------------
/balancemm/analysis/t_sne.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.manifold import TSNE
3 | import matplotlib.pyplot as plt
4 | from matplotlib.colors import ListedColormap
5 |
6 | def visualize_tsne_2d(datas, perplexity=30, n_components=2, random_state=42, save_dir = 't-sne', modality = '', n_class = 10):
7 | """
8 | use t-SNE to do dimensionality reduction visualization
9 |
10 | """
11 | data = []
12 | labels = []
13 | for i in range(len(datas)):
14 | data.append(datas[i][0].cpu().detach().numpy())
15 | labels.append(datas[i][1].cpu().numpy())
16 | data = np.array(data)
17 | labels = np.array(labels)
18 | colors = plt.cm.rainbow(np.linspace(0, 1, n_class))
19 | custom_cmap = ListedColormap(colors)
20 |
21 | tsne = TSNE(
22 | n_components=n_components,
23 | perplexity=perplexity,
24 | random_state=random_state
25 | )
26 | tsne_results = tsne.fit_transform(data)
27 |
28 | plt.figure(figsize=(10, 8))
29 |
30 | if labels is not None:
31 | scatter = plt.scatter(
32 | tsne_results[:, 0],
33 | tsne_results[:, 1],
34 | c=labels,
35 | cmap=custom_cmap
36 | )
37 | plt.colorbar(scatter)
38 | else:
39 | plt.scatter(
40 | tsne_results[:, 0],
41 | tsne_results[:, 1],
42 | alpha=0.5
43 | )
44 |
45 | plt.title(f't-SNE Visualization {modality} {n_components}d')
46 | plt.xlabel('t-SNE 1')
47 | plt.ylabel('t-SNE 2')
48 | plt.savefig(save_dir)
49 | return plt.gcf()
50 |
51 | def visualize_tsne_3d(datas, perplexity=30, n_components=3, random_state=42, save_dir='t-sne', modality='',n_class = 10):
52 | """
53 | use t-SNE to do 3D visualization
54 |
55 | """
56 | # prepare data
57 | data = []
58 | labels = []
59 | for i in range(len(datas)):
60 | data.append(datas[i][0].cpu().detach().numpy())
61 | labels.append(datas[i][1].cpu().numpy())
62 | data = np.array(data)
63 | labels = np.array(labels)
64 | colors = plt.cm.rainbow(np.linspace(0, 1, n_class))
65 | custom_cmap = ListedColormap(colors)
66 | # t-SNE dimension reduction
67 | tsne = TSNE(
68 | n_components=n_components,
69 | perplexity=perplexity,
70 | random_state=random_state
71 | )
72 | tsne_results = tsne.fit_transform(data)
73 |
74 | fig = plt.figure(figsize=(12, 10))
75 | ax = fig.add_subplot(111, projection='3d')
76 |
77 | scatter = ax.scatter(
78 | tsne_results[:, 0],
79 | tsne_results[:, 1],
80 | tsne_results[:, 2],
81 | c=labels,
82 | cmap=custom_cmap,
83 | alpha=0.6
84 | )
85 |
86 | plt.colorbar(scatter)
87 | ax.set_title(f't-SNE Visualization {modality} {n_components}d')
88 | ax.set_xlabel('t-SNE 1')
89 | ax.set_ylabel('t-SNE 2')
90 | ax.set_zlabel('t-SNE 3')
91 |
92 | ax.grid(True)
93 |
94 | plt.savefig(save_dir)
95 |
96 | for angle in range(0, 360, 45):
97 | ax.view_init(30, angle)
98 | plt.savefig(f'{save_dir}_angle_{angle}_{modality}.jpg')
99 |
100 | return fig
101 |
--------------------------------------------------------------------------------
/balancemm/datasets/KS_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import math
3 | import os
4 | import random
5 | import copy
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional
9 | import torchaudio
10 | from PIL import Image
11 | from scipy import signal
12 | from torch.utils.data import Dataset
13 | from torchvision import transforms
14 |
15 | class KineticsSoundsDataset(Dataset):
16 | def __init__(self, args:dict, transforms=None):
17 | self.data = []
18 | self.label = []
19 | self.mode = args['mode']
20 | if self.mode == "train":
21 | self.csv_path = args['csv_path_train']
22 | self.audio_path = args['audio_path_train']
23 | self.visual_path = args['visual_path_train']
24 | else:
25 | self.csv_path = args['csv_path_test']
26 | self.audio_path = args['audio_path_test']
27 | self.visual_path = args['visual_path_test']
28 |
29 |
30 | with open(self.csv_path) as f:
31 | for line in f:
32 | item = line.split("\n")[0].split(" ")
33 | name = item[0]
34 |
35 | if os.path.exists(self.audio_path + '/' + name + '.npy'):
36 | path = self.visual_path + '/' + name
37 | files_list=[lists for lists in os.listdir(path)]
38 | if(len(files_list)>3):
39 | self.data.append(name)
40 | self.label.append(int(item[-1]))
41 |
42 | print('data load finish')
43 | self.transforms = transforms
44 |
45 | self._init_atransform()
46 |
47 | print('# of files = %d ' % len(self.data))
48 |
49 | def _init_atransform(self):
50 | self.aid_transform = transforms.Compose([transforms.ToTensor()])
51 |
52 | def __len__(self):
53 | return len(self.data)
54 |
55 | def __getitem__(self, idx):
56 | av_file = self.data[idx]
57 |
58 | spectrogram = np.load(self.audio_path + '/' + av_file + '.npy')
59 | spectrogram = np.expand_dims(spectrogram, axis=0)
60 |
61 | # Visual
62 | path = self.visual_path + '/' + av_file
63 | files_list=[lists for lists in os.listdir(path)]
64 | file_num = len([fn for fn in files_list if fn.endswith("jpg")])
65 | if self.mode == 'train':
66 | transf = transforms.Compose([
67 | transforms.RandomResizedCrop(224),
68 | transforms.RandomHorizontalFlip(),
69 | transforms.ToTensor(),
70 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
71 | ])
72 | else:
73 | transf = transforms.Compose([
74 | transforms.Resize(size=(224, 224)),
75 | transforms.ToTensor(),
76 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
77 | ])
78 |
79 | pick_num = 3
80 | seg = int(file_num / pick_num)
81 | path1 = []
82 | image = []
83 | image_arr = []
84 | t = [0] * pick_num
85 |
86 | for i in range(pick_num):
87 | if self.mode == 'train':
88 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1
89 | if t[i] >= 10:
90 | t[i] = 9
91 | else:
92 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1
93 |
94 | path1.append('frame_0000' + str(t[i]) + '.jpg')
95 | image.append(Image.open(path + "/" + path1[i]).convert('RGB'))
96 |
97 | image_arr.append(transf(image[i]))
98 | image_arr[i] = image_arr[i].unsqueeze(1).float()
99 |
100 | if i == 0:
101 | image_n = copy.copy(image_arr[i])
102 | else:
103 | image_n = torch.cat((image_n, image_arr[i]), 1)
104 |
105 |
106 | label = self.label[idx]
107 |
108 | return {'visual':image_n, 'audio':spectrogram, 'label': label,'idx': idx}
--------------------------------------------------------------------------------
/balancemm/datasets/Mosei_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data.dataset import Dataset
3 | import pickle
4 | import os
5 | import torch
6 |
7 |
8 | def acc3(i):
9 | if i<-0.5:
10 | return 0
11 | if i>0.5:
12 | return 1
13 | return 2
14 |
15 | def acc7(i):
16 | if i < -2:
17 | res = 0
18 | if -2 <= i and i < -1:
19 | res = 1
20 | if -1 <= i and i < 0:
21 | res = 2
22 | if 0 <= i and i <= 0:
23 | res = 3
24 | if 0 < i and i <= 1:
25 | res = 4
26 | if 1 < i and i <= 2:
27 | res = 5
28 | if i > 2:
29 | res = 6
30 | return res
31 |
32 | def acc2(i):
33 | if i<0:
34 | return 0
35 | else :
36 | return 1
37 |
38 | class CMUMOSEIDataset(Dataset):
39 | def __init__(self, args: dict, transforms = None):
40 | super(CMUMOSEIDataset, self).__init__()
41 | dataset_path = args['dataset_path']
42 | data= args['data']
43 | split_type= args['mode']
44 | if_align = args['if_align']
45 | dataset_path = os.path.join(dataset_path, data+'_data.pkl' if if_align else data+'_data_noalign.pkl' )
46 | dataset = pickle.load(open(dataset_path, 'rb'))
47 |
48 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach()
49 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach()
50 | self.audio = dataset[split_type]['audio'].astype(np.float32)
51 | self.audio[self.audio == -np.inf] = 0
52 | self.audio = torch.tensor(self.audio).cpu().detach()
53 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach()
54 |
55 | self.meta = dataset[split_type]['id'] if 'id' in dataset[split_type].keys() else None
56 |
57 | self.data = data
58 |
59 | self.n_modalities = 3 # vision/ text/ audio
60 | def get_n_modalities(self):
61 | return self.n_modalities
62 | def get_seq_len(self):
63 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1]
64 | def get_dim(self):
65 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
66 | def get_lbl_info(self):
67 | # return number_of_labels, label_dim
68 | return self.labels.shape[1], self.labels.shape[2]
69 | def __len__(self):
70 | return len(self.labels)
71 |
72 | def __getitem__(self, index):
73 | X = [index, self.text[index], self.audio[index], self.vision[index]]
74 | Y = self.labels[index]
75 |
76 | Y = acc2(Y[0,0])
77 | META = (0,0,0) if self.meta is None else (self.meta[index][0], self.meta[index][1], self.meta[index][2])
78 | if self.data == 'mosi':
79 | META = (self.meta[index][0].decode('UTF-8'), self.meta[index][1].decode('UTF-8'), self.meta[index][2].decode('UTF-8'))
80 | if self.data == 'iemocap':
81 | Y = torch.argmax(Y, dim=-1)
82 |
83 | return {'text' : X[1], 'visual': X[3], 'audio' : X[2], 'label':Y, 'idx': index}
84 |
--------------------------------------------------------------------------------
/balancemm/datasets/VGG_dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import csv
3 | import os
4 | import pickle
5 | import librosa
6 | import numpy as np
7 | from scipy import signal
8 | import torch
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 | from torchvision import transforms
12 | import pdb
13 | import random
14 |
15 | class VGGSoundDataset(Dataset):
16 |
17 | def __init__(self, args, mode='train'):
18 |
19 | self.args = args
20 | self.mode = args['mode']
21 | train_video_data = []
22 | train_audio_data = []
23 | test_video_data = []
24 | test_audio_data = []
25 | train_label = []
26 | test_label = []
27 | train_class = []
28 | test_class = []
29 | csv_root = args['csv_root']
30 | video_train_root = args['video_train_root']
31 | video_test_root = args['video_test_root']
32 | audio_train_root = args['audio_train_root']
33 | audio_test_root = args['audio_test_root']
34 |
35 | print(video_train_root)
36 | print(video_test_root)
37 | print(audio_train_root)
38 | print(audio_test_root)
39 | train_valid = 0
40 | test_valid = 0
41 | with open(csv_root) as f:
42 | csv_reader = csv.reader(f)
43 | for item in csv_reader:
44 | if item[3] == 'train':
45 |
46 | video_dir = os.path.join(video_train_root, item[0]+'_'+item[1])
47 | audio_dir = os.path.join(audio_train_root, item[0]+'_'+item[1] + '.npy')
48 |
49 | if os.path.exists(video_dir) and os.path.exists(audio_dir) and len(os.listdir(video_dir))>3 :
50 | train_video_data.append(video_dir)
51 | train_audio_data.append(audio_dir)
52 | if item[2] not in train_class: train_class.append(item[2])
53 | train_label.append(item[2])
54 | train_valid += 1
55 |
56 | if item[3] == 'test':
57 |
58 | video_dir = os.path.join(video_test_root, item[0]+'_'+item[1])
59 | audio_dir = os.path.join(audio_test_root, item[0]+'_'+item[1] + '.npy')
60 |
61 | if os.path.exists(video_dir) and os.path.exists(audio_dir) and len(os.listdir(video_dir))>3:
62 | test_video_data.append(video_dir)
63 | test_audio_data.append(audio_dir)
64 | if item[2] not in test_class: test_class.append(item[2])
65 | test_label.append(item[2])
66 | test_valid += 1
67 |
68 | print("Get Valid Train Sample: " + str(train_valid))
69 | print("Get Valid Test Sample: " + str(test_valid))
70 |
71 | assert len(train_class) == len(test_class)
72 |
73 | if len(train_class) == 0:
74 | raise ValueError("If you see this, it means you have problem in reading dataset")
75 |
76 | self.classes = train_class
77 |
78 | class_dict = dict(zip(self.classes, range(len(self.classes))))
79 |
80 | if self.mode == 'train':
81 | self.video = train_video_data
82 | self.audio = train_audio_data
83 | self.label = [class_dict[train_label[idx]] for idx in range(len(train_label))]
84 | else:
85 | self.video = test_video_data
86 | self.audio = test_audio_data
87 | self.label = [class_dict[test_label[idx]] for idx in range(len(test_label))]
88 |
89 |
90 | def __len__(self):
91 | return len(self.video)
92 |
93 | def __getitem__(self, idx):
94 |
95 | spectrogram = np.load(self.audio[idx])
96 |
97 | # Def Image Transform
98 | if self.mode == 'train':
99 | transform = transforms.Compose([
100 | transforms.RandomResizedCrop(224),
101 | transforms.RandomHorizontalFlip(),
102 | transforms.ToTensor(),
103 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
104 | ])
105 | else:
106 | transform = transforms.Compose([
107 | transforms.Resize(size=(224, 224)),
108 | transforms.ToTensor(),
109 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
110 | ])
111 |
112 | pick_num = self.args['use_video_frames']
113 | image_samples = os.listdir(self.video[idx])
114 | image_samples = sorted(image_samples)
115 | file_num = len(image_samples)
116 | select_index = np.random.choice(len(image_samples), size=pick_num, replace=False)
117 | select_index.sort()
118 | images = torch.zeros((pick_num, 3, 224, 224))
119 | t = [0] * pick_num
120 | seg = (file_num//pick_num)
121 | for i in range(pick_num):
122 | if self.mode == 'train':
123 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1
124 | if t[i] >= 10:
125 | t[i] = 9
126 | else:
127 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1
128 | for i, idx_frame in enumerate(select_index):
129 | img_path = os.path.join(self.video[idx], image_samples[t[i]-1])
130 | # img_path = os.path.join(self.video[idx], image_samples[min(i * seg + max(int(seg/2), 1), len(image_samples)-1)])
131 | img = Image.open(img_path).convert('RGB')
132 | img = transform(img)
133 | images[i] = img
134 |
135 | spectrogram = torch.tensor(spectrogram).unsqueeze(0).float()
136 | images = images.permute(1,0,2,3)
137 |
138 | # label
139 | label = self.label[idx]
140 |
141 | return {
142 | 'audio': spectrogram,
143 | 'visual': images,
144 | 'label': label,
145 | 'idx': idx
146 | }
147 |
148 |
--------------------------------------------------------------------------------
/balancemm/datasets/balance_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import h5py
3 | import os
4 | import pickle
5 | from PIL import Image
6 | import io
7 | import torch
8 | import csv
9 | from torchvision import transforms
10 | from torch.utils.data import Dataset
11 | import numpy as np
12 | import random
13 | import copy
14 |
15 | class BalanceDataset(Dataset):
16 |
17 | def __init__(self, args: dict, transforms=None):
18 | self.data = []
19 | self.label = []
20 | self.mode = None
21 | csv_path = args['csv_path']
22 | self.visual_path = args['visual_path']
23 | self.audio_path = args['audio_path']
24 |
25 | self.mode = args['mode']
26 |
27 |
28 | with open(csv_path) as f:
29 | annotation_data = json.load(f)
30 | all_data = annotation_data['database']
31 | # choose = ['playing piano', 'playing cello', 'lawn mowing', 'singing', 'cleaning floor', 'bowling', 'swimming', 'whistling', 'motorcycling', 'playing flute', 'writing on blackboard', 'beat boxing']
32 | class_labels = annotation_data['labels']
33 |
34 | self.class_to_idx = {label : i for i,label in enumerate(class_labels)}
35 | print(len(class_labels))
36 |
37 | for key in all_data.keys():
38 | if all_data[key]['subset'] == (self.mode + 'ing'):
39 | if os.path.exists(self.visual_path + key + '.hdf5') and os.path.exists(self.audio_path + key + '.pkl'):
40 | self.data.append(key)
41 | self.label.append(self.class_to_idx[all_data[key]['label']])
42 |
43 |
44 | print('data load finish')
45 |
46 | self.transforms = transforms
47 |
48 | self._init_atransform()
49 |
50 | print('# of files = %d ' % len(self.data))
51 |
52 | def _init_atransform(self):
53 | self.aid_transform = transforms.Compose([transforms.ToTensor()])
54 |
55 | def __len__(self):
56 | return len(self.data)
57 |
58 | def __getitem__(self, idx):
59 | av_file = self.data[idx]
60 |
61 | with open(self.audio_path + av_file + '.pkl',"rb") as f:
62 | spectrogram = pickle.load(f)
63 | spectrogram = np.expand_dims(spectrogram, axis=0)
64 |
65 | # Visual
66 | path = self.visual_path + av_file + '.hdf5'
67 | with h5py.File(path, 'r') as f:
68 | video_data = f['video']
69 | file_num = len(video_data)
70 |
71 | if self.mode == 'training':
72 |
73 | transf = transforms.Compose([
74 | transforms.RandomResizedCrop(224),
75 | transforms.RandomHorizontalFlip(),
76 | transforms.ToTensor(),
77 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
78 | ])
79 | else:
80 | transf = transforms.Compose([
81 | transforms.Resize(size=(224, 224)),
82 | transforms.ToTensor(),
83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
84 | ])
85 |
86 | pick_num = 3
87 | seg = int(file_num / pick_num)
88 | image = []
89 | image_arr = []
90 | t = [0] * pick_num
91 |
92 | for i in range(pick_num):
93 | if self.mode == 'train':
94 | t[i] = random.randint(i * seg + 1, i * seg + seg) if file_num > 6 else 1
95 | if t[i] >= 9:
96 | t[i] = 8
97 | else:
98 | t[i] = i*seg + max(int(seg/2), 1) if file_num > 6 else 1
99 |
100 | image.append(Image.open(io.BytesIO(video_data[t[i]])).convert('RGB'))
101 |
102 | image_arr.append(transf(image[i]))
103 | image_arr[i] = image_arr[i].unsqueeze(1).float()
104 | if i == 0:
105 | image_n = copy.copy(image_arr[i])
106 | else:
107 | image_n = torch.cat((image_n, image_arr[i]), 1)
108 |
109 |
110 | label = self.label[idx]
111 |
112 | return {'visual':image_n, 'audio':spectrogram, 'label': label,'idx': idx}
113 | # return image_n,spectrogram,label,idx
--------------------------------------------------------------------------------
/balancemm/datasets/cremad_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import librosa
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | import os.path as osp
10 |
11 | class CREMADDataset(Dataset):
12 | def __init__(self, args: dict):
13 |
14 | self.image = []
15 | self.audio = []
16 | self.label = []
17 | class_dict = {'NEU':0, 'HAP':1, 'SAD':2, 'FEA':3, 'DIS':4, 'ANG':5}
18 |
19 | self.mode = args['mode']
20 | self.fps = args['fps']
21 | self.visual_path = args['visual_path']
22 | self.audio_path = args['audio_path']
23 | self.train_txt = args['train_txt']
24 | self.test_txt = args['test_txt']
25 | self.aid_transform = transforms.Compose([transforms.ToTensor()])
26 | if self.mode == 'train':
27 | self.csv_file = self.train_txt
28 | else:
29 | self.csv_file = self.test_txt
30 | self.data = []
31 | self.label = []
32 | with open(self.csv_file) as f:
33 | csv_reader = csv.reader(f)
34 | for item in csv_reader:
35 | if item[1] in class_dict and os.path.exists(osp.join(self.audio_path, item[0] + '.pt')) and os.path.exists(osp.join(self.visual_path, item[0])) and len(os.listdir(osp.join(self.visual_path, item[0]))) >= self.fps:
36 | self.data.append(item[0])
37 | self.label.append(class_dict[item[1]])
38 |
39 | def __len__(self):
40 | return len(self.data)
41 |
42 | def __getitem__(self, idx):
43 | datum = self.data[idx]
44 |
45 | # Audio
46 | fbank = torch.load(self.audio_path + datum + '.pt').unsqueeze(0)
47 |
48 | # Visual
49 | if self.mode == 'train':
50 | transf = transforms.Compose([
51 | transforms.RandomResizedCrop(224),
52 | transforms.RandomHorizontalFlip(),
53 | transforms.ToTensor(),
54 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
55 | ])
56 | else:
57 | transf = transforms.Compose([
58 | transforms.Resize(size=(224, 224)),
59 | transforms.ToTensor(),
60 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
61 | ])
62 |
63 | folder_path = self.visual_path + datum
64 | file_num = len(os.listdir(folder_path))
65 | file_list = os.listdir(folder_path)
66 | pick_num = self.fps
67 | seg = int(file_num/pick_num)
68 | image_arr = []
69 |
70 | for i in range(pick_num):
71 | if self.mode == 'train':
72 | index = i*seg + np.random.randint(seg)
73 | else:
74 | index = i*seg + seg//2
75 | path = os.path.join(folder_path, file_list[index])
76 | image_arr.append(transf(Image.open(path).convert('RGB')).unsqueeze(0))
77 |
78 | images = torch.cat(image_arr)
79 |
80 | label = self.label[idx]
81 | images = images.permute(1,0,2,3)
82 | return {'audio': fbank, 'visual': images, 'label': label, 'idx': idx}
83 |
84 | if __name__ == '__main__':
85 | print('start')
86 | a = CREMADDataset({'mode':'train','fps':2})
87 | a.__getitem__(0)
--------------------------------------------------------------------------------
/balancemm/datasets/food_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import torch.utils.data as data
5 | from pathlib import Path
6 | from random import randrange
7 | import numpy as np
8 | from torch.utils.data import Dataset
9 | from transformers import BertTokenizer
10 | import torch
11 |
12 | import os
13 | import pandas as pd
14 | from PIL import Image
15 | import re
16 | import random
17 | from torchvision import transforms
18 |
19 | def find_classes(directory) :
20 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
21 | if not classes:
22 | raise FileNotFoundError(f"Couldn't find any classes in {directory}.")
23 |
24 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
25 | idx_to_class = {i: cls_name for i, cls_name in enumerate(classes)}
26 | return classes, class_to_idx,idx_to_class
27 |
28 | # 1. Subclass torch.utils.data.Dataset
29 | class FOOD101Dataset(Dataset):
30 |
31 | # 2. Initialize with a targ_dir and transform (optional) parameter
32 | def __init__(self, args:dict,transform=None):
33 | # 3. Create class attributes
34 | # targ_dir = args['dataset_path']
35 | targ_dir = args['targ_dir']
36 | phase = args['mode']
37 | mode = 'all'
38 | #resize = 384
39 | resize = 224
40 | train_transforms = transforms.Compose([transforms.RandomRotation(30),
41 | transforms.Resize((resize,resize)),
42 | transforms.RandomHorizontalFlip(),
43 | transforms.ToTensor(),
44 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
45 |
46 | test_transforms = transforms.Compose([transforms.Resize((resize,resize)),
47 | #transforms.CenterCrop(resize),
48 | transforms.ToTensor(),
49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
50 | self.dataset_root = targ_dir
51 | self.csv_file_path = '%s/texts/%s_titles.csv' % (self.dataset_root, phase)
52 | self.img_dir='%s/images/%s' % (self.dataset_root, phase)
53 |
54 | self.data = pd.read_csv(self.csv_file_path)
55 | self.tokenizer = BertTokenizer.from_pretrained('')
56 | # Setup transforms
57 | if phase == 'train':
58 | self.transform = train_transforms
59 | else:
60 | self.transform = test_transforms
61 | if transform != None: self.transform = transform
62 | # Create classes and class_to_idx attributes
63 | self.classes, self.class_to_idx,self.idx_to_class = find_classes(self.img_dir)
64 | self.mode=mode
65 |
66 | # 4. Make function to load images
67 | def load_image(self, index,img_path): #-> Image.Image:
68 | "Opens an image via a path and returns it."
69 | return Image.open(img_path).convert('RGB')
70 |
71 | def clean_text(self,raw_text):
72 | t = re.sub(r'^RT[\s]+', '', raw_text)# remove old style retweet text "RT"
73 | t = re.sub(r'https?:\/\/.*[\r\n]*', '', t)# remove hyperlinks
74 | t = re.sub(r'#', '', t) # remove hashtags
75 | return t
76 |
77 | def tokenize(self, sentence):
78 | ids = self.tokenizer(sentence ,
79 | padding='max_length', max_length=40, truncation=True).items()
80 | return {k: torch.tensor(v) for k, v in ids}
81 |
82 | # 5. Overwrite the __len__() method (optional but recommended for subclasses of torch.utils.data.Dataset)
83 | def __len__(self):# -> int:
84 | "Returns the total number of samples."
85 | return len(self.data)
86 |
87 | # 6. Overwrite the __getitem__() method (required for subclasses of torch.utils.data.Dataset)
88 | def __getitem__(self, index): #returns Tuple[torch.Tensor, int]:
89 | "Returns one sample of data, data and label (X, y)."
90 | sample=self.data.iloc[index]
91 | txt=sample['text']
92 | txt = self.clean_text(txt)
93 | text_tokens= self.tokenize(txt)
94 | class_name = sample['label']
95 | class_idx = self.class_to_idx[class_name]
96 | if self.mode =="all":
97 | img_path = os.path.join(self.img_dir,sample['label'] ,sample["Image_path"] )
98 | img = self.load_image(index,img_path)
99 | # Transform if necessary
100 | if self.transform:
101 | x = {'visual':self.transform(img).unsqueeze(1),'text':text_tokens['input_ids'].unsqueeze(0), 'label':class_idx,'idx': index} ##text:40
102 | return x
103 | else:
104 | return img, text_tokens, txt, class_idx ,index
105 | elif self.mode =="Text_only":
106 | return text_tokens, txt, class_idx ,index
--------------------------------------------------------------------------------
/balancemm/datasets/ucf101_dataset.py:
--------------------------------------------------------------------------------
1 | import csv
2 | from genericpath import isdir
3 | import os
4 | import random
5 | import numpy as np
6 | import torch
7 | import torchvision.datasets as datasets
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 |
12 |
13 | import csv
14 | from genericpath import isdir
15 | import os
16 | import random
17 | import numpy as np
18 | import torch
19 | import torchvision.datasets as datasets
20 | import torchvision.transforms as transforms
21 | from PIL import Image
22 | from torch.utils.data import Dataset
23 |
24 | class UCF101Dataset(Dataset):
25 |
26 | def __init__(self, args:dict, v_norm = True, a_norm = False, name = "UCF101"):
27 | self.data = []
28 | classes = []
29 | data2class = {}
30 | self.mode= args['mode']
31 | self.v_norm = v_norm
32 | self.a_norm = a_norm
33 | self.stat_path = args['stat_path']
34 | self.train_txt = args['train_txt']
35 | self.test_txt = args['test_txt']
36 | self.visual_path = args['visual_path']
37 | self.flow_path_v = args['flow_path_v']
38 | self.flow_path_u = args['flow_path_u']
39 |
40 | if self.mode == 'train':
41 | csv_file = self.train_txt
42 | else:
43 | csv_file = self.test_txt
44 |
45 | with open(self.stat_path) as f:
46 | for line in f:
47 | item = line.split("\n")[0].split(" ")[1]
48 | classes.append(item)
49 | with open(csv_file) as f:
50 | for line in f:
51 | class_name = line.split('/')[0]
52 | name = line.split('/')[1].split('.')[0]
53 | if os.path.isdir(self.visual_path + name) and os.path.isdir(self.flow_path_u + name) and os.path.isdir(self.flow_path_v + name):
54 | self.data.append(name)
55 | data2class[name] = class_name
56 | self.classes = sorted(classes)
57 | self.data2class = data2class
58 | self.class_num = len(self.classes)
59 | print(self.class_num)
60 | print('# of files = %d ' % len(self.data))
61 |
62 | def _init_atransform(self):
63 | self.aid_transform = transforms.Compose([transforms.ToTensor()])
64 |
65 | def __len__(self):
66 | return len(self.data)
67 |
68 | def __getitem__(self, idx):
69 | datum = self.data[idx]
70 | # crop = transforms.RandomResizedCrop(112, (1/4, 1.0), (3/4, 4/3))
71 | if self.mode == 'train':
72 | rgb_transf = [
73 | transforms.Resize(size=(224, 224)),
74 | transforms.RandomHorizontalFlip(),
75 | transforms.ToTensor(),
76 | ]
77 | diff_transf = [transforms.ToTensor()]
78 |
79 | flow_transf = [
80 | transforms.Resize(size=(224, 224)),
81 | transforms.RandomHorizontalFlip(),
82 | transforms.ToTensor(),
83 | ]
84 | else:
85 | rgb_transf = [
86 | transforms.Resize(size=(224, 224)),
87 | transforms.ToTensor()
88 | ]
89 | diff_transf = [transforms.ToTensor()]
90 | flow_transf = [
91 | transforms.Resize(size=(224, 224)),
92 | transforms.ToTensor(),
93 | ]
94 |
95 | if self.v_norm:
96 | rgb_transf.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
97 | diff_transf.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
98 | if self.a_norm :
99 | flow_transf.append(transforms.Normalize([0.1307], [0.3081]))
100 | rgb_transf = transforms.Compose(rgb_transf)
101 | diff_transf = transforms.Compose(diff_transf)
102 | flow_transf = transforms.Compose(flow_transf)
103 | folder_path = self.visual_path + datum
104 |
105 | ####### RGB
106 | file_num = 6
107 |
108 | pick_num = 3
109 | seg = int(file_num/pick_num)
110 | image_arr = []
111 |
112 | for i in range(pick_num):
113 | if self.mode == 'train':
114 | chosen_index = random.randint(i*seg + 1, i*seg + seg)
115 | else:
116 | chosen_index = i*seg + max(int(seg/2), 1)
117 | path = folder_path + '/frame_0000' + str(chosen_index) + '.jpg'
118 | tranf_image = rgb_transf(Image.open(path).convert('RGB'))
119 | image_arr.append(tranf_image.unsqueeze(0))
120 |
121 | images = torch.cat(image_arr)
122 |
123 | num_u = len(os.listdir(self.flow_path_u + datum))
124 | pick_num = 3
125 | flow_arr = []
126 | seg = int(num_u/pick_num)
127 |
128 | for i in range(pick_num):
129 | if self.mode == 'train':
130 | chosen_index = random.randint(i*seg + 1, i*seg + seg)
131 | else:
132 | chosen_index = i*seg + max(int(seg/2), 1)
133 |
134 | flow_u = self.flow_path_u + datum + '/frame00' + str(chosen_index).zfill(4) + '.jpg'
135 | flow_v = self.flow_path_v + datum + '/frame00' + str(chosen_index).zfill(4) + '.jpg'
136 | u = flow_transf(Image.open(flow_u))
137 | v = flow_transf(Image.open(flow_v))
138 | flow = torch.cat((u,v),0)
139 | flow_arr.append(flow.unsqueeze(0))
140 |
141 | flow_n = torch.cat(flow_arr)
142 | images = images.permute(1,0,2,3)
143 | sample = {
144 | 'flow':flow_n,
145 | 'visual':images,
146 | 'label': self.classes.index(self.data2class[datum]),
147 | 'raw':datum,
148 | 'idx':idx
149 | }
150 |
151 |
152 | return sample
153 |
154 |
--------------------------------------------------------------------------------
/balancemm/encoders/VisionTransformer_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | ## waiting for update
4 | class Transformer_Encoder(nn.Module):
5 | """Extends nn.Transformer."""
6 |
7 | def __init__(self, input_dim, dim, n_head, n_layers):
8 | """Initialize Transformer object.
9 |
10 | Args:
11 | n_features (int): Number of features in the input.
12 | dim (int): Dimension which to embed upon / Hidden dimension size.
13 | """
14 | super().__init__()
15 | self.embed_dim = dim
16 | self.conv = nn.Conv1d(input_dim, self.embed_dim,
17 | kernel_size=1, padding=0, bias=False)
18 | layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=n_head)
19 | self.transformer = nn.TransformerEncoder(layer, num_layers=n_layers)
20 |
21 | def forward(self, x):
22 | """Apply Transformer to Input.
23 |
24 | Args:
25 | x (torch.Tensor): Layer Input
26 |
27 | Returns:
28 | torch.Tensor: Layer Output
29 | """
30 | if type(x) is list:
31 | x = x[0]
32 | x = self.conv(x.permute([0, 2, 1]))
33 | x = x.permute([2, 0, 1])
34 | x = self.transformer(x)[-1]
35 | return x
--------------------------------------------------------------------------------
/balancemm/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from os import path as osp
4 | from ..utils.parser_utils import find_module
5 | import torch.nn as nn
6 | import torch
7 | __all__ = ['find_encoder', 'create_encoder']
8 | from .pretrained_encoder import text_encoder
9 | from torchvision.models import vit_b_16, vit_h_14
10 |
11 |
12 | # import encoder modules from those with '_encoder' in file names
13 | _encoder_folder = osp.dirname(osp.abspath(__file__))
14 | _encoder_filenames = [
15 | osp.splitext(v)[0] for v in os.listdir(_encoder_folder)
16 | if v.endswith('_encoder.py')
17 | ]
18 | _encoder_modules = [
19 | importlib.import_module(f'.{file_name}', package="balancemm.encoders")
20 | for file_name in _encoder_filenames
21 | ]
22 |
23 | # find encoder from encoder_opt
24 | def find_encoder(encoder_name: str) -> object:
25 | encoder_cls = find_module(_encoder_modules, encoder_name, 'Encoder')
26 | return encoder_cls
27 |
28 | def create_encoders(encoder_opt: dict[str, dict])->dict[str, nn.Module]:
29 | modalitys = encoder_opt.keys()
30 | encoders = {}
31 | for modality in modalitys:
32 | pre_train = encoder_opt[modality]['if_pretrain']
33 | path = encoder_opt[modality]['pretrain_path']
34 | name = encoder_opt[modality]['name']
35 | if name == "ViT_B":
36 | if pre_train:
37 | encoders[modality] = vit_b_16(path)
38 | else:
39 | encoders[modality] = vit_b_16()
40 | continue
41 | del encoder_opt[modality]['pretrain_path']
42 | del encoder_opt[modality]['if_pretrain']
43 | encoder = find_encoder(encoder_opt[modality]['name'])
44 | del encoder_opt[modality]['name']
45 | encoders[modality] = encoder(**encoder_opt[modality])
46 | encoder_opt[modality]['name'] = name
47 | if pre_train:
48 | if modality == 'text':
49 | encoders[modality] = text_encoder()
50 | else:
51 | state = torch.load(path)
52 | if modality == 'flow':
53 | del state['conv1.weight']
54 | encoders[modality].load_state_dict(state, strict=False)
55 | print('pretrain load finish')
56 | encoder_opt[modality]['if_pretrain'] = pre_train
57 | encoder_opt[modality]['pretrain_path'] = path
58 | print (f'Encoder {name} - {modality} is created.')
59 | return encoders
--------------------------------------------------------------------------------
/balancemm/encoders/pretrained_encoder.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import random
3 | from functools import partialmethod
4 |
5 | import torch
6 | import numpy as np
7 | from sklearn.metrics import precision_recall_fscore_support
8 | import os
9 |
10 | from transformers import BertModel, BertConfig
11 | import torch.nn.functional as F
12 | from torch import nn
13 |
14 | class text_encoder(nn.Module):
15 | def __init__(self, output_dim=1024):
16 | super().__init__()
17 | config = BertConfig()
18 | self.textEncoder= BertModel(config).from_pretrained('')
19 | self.linear = nn.Linear(config.hidden_size, output_dim)
20 |
21 | def forward(self, x):
22 | text = x.squeeze(1)
23 | hidden_states = self.textEncoder(text)
24 | e_i = self.linear(hidden_states[1])
25 | e_i = F.dropout(e_i)
26 | return e_i
27 |
28 | class image_encoder(nn.Module):
29 | def __init__(self,model_arch,num_classes=10,weights="IMAGENET1K_V1",device="cuda"):
30 | super().__init__()
31 | self.model_arch=model_arch
32 | self.device=device
33 | self.num_classes=num_classes
34 | print(model_arch)
35 | self.model=torch.hub.load("pytorch/vision", self.model_arch, weights=weights).to(self.device)
36 | # print(self.model.parameters)
37 | self.model.heads = nn.Sequential()
38 |
39 | def forward(self,x):
40 | y=self.model(x)
41 | return y
42 |
--------------------------------------------------------------------------------
/balancemm/encoders/resnet18_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
4 | """3x3 convolution with padding"""
5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
6 | padding=dilation, groups=groups, bias=False, dilation=dilation)
7 |
8 |
9 | def conv1x1(in_planes, out_planes, stride=1):
10 | """1x1 convolution"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
18 | base_width=64, dilation=1, norm_layer=None):
19 | super(BasicBlock, self).__init__()
20 | if norm_layer is None:
21 | norm_layer = nn.BatchNorm2d
22 | if groups != 1 or base_width != 64:
23 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
24 | if dilation > 1:
25 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = norm_layer(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = norm_layer(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | identity = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | identity = self.downsample(x)
47 |
48 | out += identity
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class ResNet(nn.Module):
55 | def __init__(self, block, layers, modality, zero_init_residual=False,
56 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
57 | norm_layer=None, input_dim = None):
58 | super(ResNet, self).__init__()
59 | self.modality = modality
60 | if norm_layer is None:
61 | norm_layer = nn.BatchNorm2d
62 | self._norm_layer = norm_layer
63 |
64 | self.inplanes = 64
65 | self.dilation = 1
66 | if replace_stride_with_dilation is None:
67 | # each element in the tuple indicates if we should replace
68 | # the 2x2 stride with a dilated convolution instead
69 | replace_stride_with_dilation = [False, False, False]
70 | if len(replace_stride_with_dilation) != 3:
71 | raise ValueError("replace_stride_with_dilation should be None "
72 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
73 | self.groups = groups
74 | self.base_width = width_per_group
75 | if modality == 'audio':
76 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
77 | bias=False)
78 | elif modality == 'visual' or modality == 'front_view' or modality == 'back_view':
79 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
80 | bias=False)
81 | elif modality == 'flow':
82 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3,
83 | bias=False)
84 | else:
85 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality))
86 | self.bn1 = norm_layer(self.inplanes)
87 | self.relu = nn.ReLU(inplace=True)
88 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
89 | self.layer1 = self._make_layer(block, 64, layers[0])
90 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
91 | dilate=replace_stride_with_dilation[0])
92 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
93 | dilate=replace_stride_with_dilation[1])
94 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
95 | dilate=replace_stride_with_dilation[2])
96 |
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
100 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
101 | nn.init.normal_(m.weight, mean=1, std=0.02)
102 | nn.init.constant_(m.bias, 0)
103 |
104 | if zero_init_residual:
105 | for m in self.modules():
106 | if isinstance(m, Bottleneck):
107 | nn.init.constant_(m.bn3.weight, 0)
108 | elif isinstance(m, BasicBlock):
109 | nn.init.constant_(m.bn2.weight, 0)
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
112 | norm_layer = self._norm_layer
113 | downsample = None
114 | previous_dilation = self.dilation
115 | if dilate:
116 | self.dilation *= stride
117 | stride = 1
118 | if stride != 1 or self.inplanes != planes * block.expansion:
119 | downsample = nn.Sequential(
120 | conv1x1(self.inplanes, planes * block.expansion, stride),
121 | norm_layer(planes * block.expansion),
122 | )
123 |
124 | layers = []
125 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
126 | self.base_width, previous_dilation, norm_layer))
127 | self.inplanes = planes * block.expansion
128 | for _ in range(1, blocks):
129 | layers.append(block(self.inplanes, planes, groups=self.groups,
130 | base_width=self.base_width, dilation=self.dilation,
131 | norm_layer=norm_layer))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 |
137 | if self.modality == 'visual' or self.modality == 'flow':
138 | (B, T, C, H, W) = x.size()
139 | x = x.view(B * T, C, H, W)
140 |
141 | x = self.conv1(x)
142 | x = self.bn1(x)
143 | x = self.relu(x)
144 | x = self.maxpool(x)
145 |
146 | x = self.layer1(x)
147 | x = self.layer2(x)
148 | x = self.layer3(x)
149 | x = self.layer4(x)
150 | out = x
151 |
152 | return out
153 |
154 | class ResNet18Encoder(ResNet):
155 | def __init__(self, **kwargs):
156 | super(ResNet18Encoder, self).__init__(BasicBlock, [2, 2, 2, 2], **kwargs)
157 |
158 | class Bottleneck(nn.Module):
159 | expansion = 4
160 |
161 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
162 | base_width=64, dilation=1, norm_layer=None):
163 | super(Bottleneck, self).__init__()
164 | if norm_layer is None:
165 | norm_layer = nn.BatchNorm2d
166 | width = int(planes * (base_width / 64.)) * groups
167 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
168 | self.conv1 = conv1x1(inplanes, width)
169 | self.bn1 = norm_layer(width)
170 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
171 | self.bn2 = norm_layer(width)
172 | self.conv3 = conv1x1(width, planes * self.expansion)
173 | self.bn3 = norm_layer(planes * self.expansion)
174 | self.relu = nn.ReLU(inplace=True)
175 | self.downsample = downsample
176 | self.stride = stride
177 |
178 | def forward(self, x):
179 | identity = x
180 |
181 | out = self.conv1(x)
182 | out = self.bn1(out)
183 | out = self.relu(out)
184 |
185 | out = self.conv2(out)
186 | out = self.bn2(out)
187 | out = self.relu(out)
188 |
189 | out = self.conv3(out)
190 | out = self.bn3(out)
191 |
192 | if self.downsample is not None:
193 | identity = self.downsample(x)
194 |
195 | out += identity
196 | out = self.relu(out)
197 |
198 | return out
--------------------------------------------------------------------------------
/balancemm/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .precisions import BatchMetricsCalculator
2 | def Evaluation(trainer, model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, logger):
3 | if temp_model is None:
4 | trainer(model, train_dataloader, val_dataloader, optimizer, scheduler, logger)
5 | class ComprehensiveModelEvaluator:
6 | def __init__(self, args):
7 | self.Metrics = BatchMetricsCalculator(args['Metrics'])
8 | self.Complex = {}
9 | self.Modalitys = {}
10 |
11 |
--------------------------------------------------------------------------------
/balancemm/evaluation/complex.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from thop.vision.basic_hooks import *
3 | from thop.rnn_hooks import *
4 | from thop.utils import prRed
5 | import torch.nn as nn
6 | ## thop
7 | register_hooks = {
8 | nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication.
9 | nn.Conv1d: count_convNd,
10 | nn.Conv2d: count_convNd,
11 | nn.Conv3d: count_convNd,
12 | nn.ConvTranspose1d: count_convNd,
13 | nn.ConvTranspose2d: count_convNd,
14 | nn.ConvTranspose3d: count_convNd,
15 | nn.BatchNorm1d: count_normalization,
16 | nn.BatchNorm2d: count_normalization,
17 | nn.BatchNorm3d: count_normalization,
18 | nn.LayerNorm: count_normalization,
19 | nn.InstanceNorm1d: count_normalization,
20 | nn.InstanceNorm2d: count_normalization,
21 | nn.InstanceNorm3d: count_normalization,
22 | nn.PReLU: count_prelu,
23 | nn.Softmax: count_softmax,
24 | nn.ReLU: zero_ops,
25 | nn.ReLU6: zero_ops,
26 | nn.LeakyReLU: count_relu,
27 | nn.MaxPool1d: zero_ops,
28 | nn.MaxPool2d: zero_ops,
29 | nn.MaxPool3d: zero_ops,
30 | nn.AdaptiveMaxPool1d: zero_ops,
31 | nn.AdaptiveMaxPool2d: zero_ops,
32 | nn.AdaptiveMaxPool3d: zero_ops,
33 | nn.AvgPool1d: count_avgpool,
34 | nn.AvgPool2d: count_avgpool,
35 | nn.AvgPool3d: count_avgpool,
36 | nn.AdaptiveAvgPool1d: count_adap_avgpool,
37 | nn.AdaptiveAvgPool2d: count_adap_avgpool,
38 | nn.AdaptiveAvgPool3d: count_adap_avgpool,
39 | nn.Linear: count_linear,
40 | nn.Dropout: zero_ops,
41 | nn.Upsample: count_upsample,
42 | nn.UpsamplingBilinear2d: count_upsample,
43 | nn.UpsamplingNearest2d: count_upsample,
44 | nn.RNNCell: count_rnn_cell,
45 | nn.GRUCell: count_gru_cell,
46 | nn.LSTMCell: count_lstm_cell,
47 | nn.RNN: count_rnn,
48 | nn.GRU: count_gru,
49 | nn.LSTM: count_lstm,
50 | nn.Sequential: zero_ops,
51 | nn.PixelShuffle: zero_ops,
52 | }
53 |
54 | def profile(
55 | model: nn.Module,
56 | inputs,
57 | method_name = 'forward',
58 | custom_ops=None,
59 | verbose=True,
60 | ret_layer_info=False,
61 | report_missing=False,
62 | ):
63 | ##change from thop
64 | handler_collection = {}
65 | types_collection = set()
66 | if custom_ops is None:
67 | custom_ops = {}
68 | if report_missing:
69 | # overwrite `verbose` option when enable report_missing
70 | verbose = True
71 |
72 | def add_hooks(m: nn.Module):
73 | m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
74 | m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
75 |
76 | # for p in m.parameters():
77 | # m.total_params += torch.DoubleTensor([p.numel()])
78 |
79 | m_type = type(m)
80 |
81 | fn = None
82 | if m_type in custom_ops:
83 | # if defined both op maps, use custom_ops to overwrite.
84 | fn = custom_ops[m_type]
85 | if m_type not in types_collection and verbose:
86 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
87 | elif m_type in register_hooks:
88 | fn = register_hooks[m_type]
89 | if m_type not in types_collection and verbose:
90 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
91 | else:
92 | if m_type not in types_collection and report_missing:
93 | prRed(
94 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
95 | % m_type
96 | )
97 |
98 | if fn is not None:
99 | handler_collection[m] = (
100 | m.register_forward_hook(fn),
101 | m.register_forward_hook(count_parameters),
102 | )
103 | types_collection.add(m_type)
104 |
105 | prev_training_status = model.training
106 |
107 | model.eval()
108 | model.apply(add_hooks)
109 |
110 | with torch.no_grad():
111 | method = getattr(model, method_name)
112 | method(*inputs)
113 |
114 | def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
115 | total_ops, total_params = module.total_ops.item(), 0
116 | ret_dict = {}
117 | for n, m in module.named_children():
118 | # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0:
119 | # m_ops, m_params = dfs_count(m, prefix=prefix + "\t")
120 | # else:
121 | # m_ops, m_params = m.total_ops, m.total_params
122 | next_dict = {}
123 | if m in handler_collection and not isinstance(
124 | m, (nn.Sequential, nn.ModuleList)
125 | ):
126 | m_ops, m_params = m.total_ops.item(), m.total_params.item()
127 | else:
128 | m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
129 | ret_dict[n] = (m_ops, m_params, next_dict)
130 | total_ops += m_ops
131 | total_params += m_params
132 | # print(prefix, module._get_name(), (total_ops, total_params))
133 | return total_ops, total_params, ret_dict
134 |
135 | total_ops, total_params, ret_dict = dfs_count(model)
136 |
137 | # reset model to original status
138 | model.train(prev_training_status)
139 | for m, (op_handler, params_handler) in handler_collection.items():
140 | op_handler.remove()
141 | params_handler.remove()
142 | m._buffers.pop("total_ops")
143 | m._buffers.pop("total_params")
144 |
145 | if ret_layer_info:
146 | return total_ops, total_params, ret_dict
147 | return total_ops, total_params
148 |
149 | class FLOPsMonitor:
150 | def __init__(self):
151 | self.total_flops = 0
152 | self.forward_flops = 0
153 | self.backward_flops = 0
154 |
155 | def update(self, flops, operation='total'):
156 | if operation == 'forward':
157 | self.forward_flops += flops
158 | self.total_flops += flops
159 |
160 | def report(self, logger):
161 | print(f"Total FLOPs: {self.total_flops}")
162 | print(f"Forward FLOPs: {self.forward_flops}")
163 | logger.info(f"Total FLOPs: {self.total_flops}")
164 | logger.info(f"Forward FLOPs: {self.forward_flops}")
165 |
166 | def get_flops(model, input_sample):
167 | with torch.no_grad():
168 | flops, params = profile(model, inputs= input_sample)
169 | return flops, params
--------------------------------------------------------------------------------
/balancemm/evaluation/modalitys.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations
2 | from collections import defaultdict
3 | from torch.utils.data.dataset import Dataset
4 | import logging
5 | from copy import deepcopy
6 | from tqdm import tqdm
7 | from math import factorial
8 | import torch
9 | def generate_all_combinations(input_list: list[str], include_empty: bool = True):
10 | all_combinations = []
11 |
12 | # generate all combinations of length from 0 to len(input_list)
13 | start_range = 0 if include_empty else 1
14 | for r in range(start_range, len(input_list) + 1):
15 | all_combinations.extend(combinations(input_list, r))
16 | # converts a combination to a list
17 | return [list(combo) for combo in all_combinations]
18 |
19 | def Calculate_Shapley(trainer, model, CalcuLoader: Dataset, logger: logging.Logger, conduct: bool = True) -> dict[str: float]:
20 |
21 | if conduct:
22 | modalitys = model.modalitys
23 | n = len(modalitys)
24 | Shapley = defaultdict(float) ##default is 0
25 | res_cahce = defaultdict(lambda:float('inf')) ## store the middle results
26 | for modality in modalitys:
27 | temp_modalitys = list(modalitys)
28 | temp_modalitys.remove(modality)
29 | combinations = generate_all_combinations(temp_modalitys, include_empty = True)
30 | for combo in combinations:
31 | S_size = len(combo)
32 | indentifer = tuple(sorted(combo))
33 | if res_cahce[indentifer] == float('inf'):
34 | with torch.no_grad():
35 | _, v_combo = trainer.val_loop(model = model, val_loader= CalcuLoader, limit_modalitys= combo.copy())
36 | res_cahce[indentifer] = v_combo
37 | else:
38 | v_combo = res_cahce[indentifer]
39 | if modality not in combo:
40 | add_combo = combo.copy()
41 | add_combo.append(modality)
42 | add_combo = sorted(add_combo)
43 | indentifer = tuple(add_combo)
44 | if res_cahce[indentifer] == float('inf'):
45 | with torch.no_grad():
46 | _, v_add = trainer.val_loop(model = model, val_loader= CalcuLoader, limit_modalitys= add_combo)
47 | res_cahce[indentifer] = v_add
48 | else:
49 | v_add = res_cahce[indentifer]
50 | else:
51 | v_add = v_combo
52 | Shapley[modality] += (factorial(S_size) * factorial(n - S_size - 1)) / factorial(n)*(v_add['acc']['output'] - v_combo['acc']['output'])
53 | logger.info(Shapley)
54 | return Shapley
55 | else:
56 | return
57 |
58 | def Calculate_Shapley_Sample(trainer, model, CalcuLoader: Dataset, logger: logging.Logger,conduct: bool = True,is_print: bool = False) -> dict[str: float]:
59 | if not conduct:
60 | return None
61 | modalitys = model.modalitys
62 | n = len(modalitys)
63 | Shapley = {modality: {} for modality in modalitys}
64 | res_cache = defaultdict(lambda: None)
65 |
66 | for batch_idx, batch in tqdm(enumerate(CalcuLoader)):
67 | label = batch['label'].to(model.device)
68 | batch_size = len(label)
69 |
70 | all_combinations = generate_all_combinations(modalitys, include_empty=True)
71 | for combo in all_combinations:
72 | identifier = tuple(sorted(combo))
73 | if res_cache[identifier] is None:
74 | if not combo:
75 | res_cache[identifier] = torch.zeros(batch_size, dtype=torch.bool)
76 | else:
77 | with torch.no_grad():
78 | model.validation_step(batch, batch_idx,limit_modality=combo)
79 | res_cache[identifier] = (model.pridiction['output'] == label)
80 |
81 | for i in range(batch_size):
82 | sample_idx = int(batch['idx'][i])
83 |
84 | for modality in modalitys:
85 | shapley_value = 0.0
86 | temp_modalitys = [m for m in modalitys if m != modality]
87 | combinations = generate_all_combinations(temp_modalitys, include_empty=True)
88 |
89 | for combo in combinations:
90 | S_size = len(combo)
91 | v_combo = res_cache[tuple(sorted(combo))][i]
92 |
93 | add_combo = sorted(combo + [modality])
94 | v_add = res_cache[tuple(add_combo)][i]
95 |
96 | weight = (factorial(S_size) * factorial(n - S_size - 1)) / factorial(n)
97 | marginal_contribution = float(v_add) - float(v_combo)
98 | shapley_value += weight * marginal_contribution
99 |
100 | Shapley[modality][sample_idx] = shapley_value
101 |
102 |
103 | return Shapley
--------------------------------------------------------------------------------
/balancemm/evaluation/precisions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import confusion_matrix
3 | from collections import defaultdict
4 |
5 | class BatchMetricsCalculator:
6 | def __init__(self, num_classes: int, modalitys: list):
7 | self.num_classes = num_classes
8 | self.confusion_matrix = {}
9 | self.modalitys = modalitys
10 | for modality in modalitys:
11 | self.confusion_matrix[modality] = np.zeros((num_classes, num_classes))
12 | self.total_samples = 0
13 |
14 | def update(self, y_true, y_pred):
15 | for modality in self.confusion_matrix.keys():
16 | batch_cm = confusion_matrix(y_true, y_pred[modality].cpu(), labels=range(self.num_classes))
17 | self.confusion_matrix[modality] += batch_cm
18 | self.total_samples += len(y_true)
19 |
20 | def compute_metrics(self):
21 | # calculate accuracy
22 | Metrics_res = defaultdict(dict)
23 | for modality in self.confusion_matrix.keys():
24 | accuracy = np.sum(np.diag(self.confusion_matrix[modality])) / self.total_samples
25 |
26 | # calculate f1 score of each class
27 | fps = self.confusion_matrix[modality].sum(axis=0) - np.diag(self.confusion_matrix[modality])
28 | fns = self.confusion_matrix[modality].sum(axis=1) - np.diag(self.confusion_matrix[modality])
29 | tps = np.diag(self.confusion_matrix[modality])
30 | precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps, dtype=float), where=(tps + fps) != 0)
31 | recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps, dtype=float), where=(tps + fns) != 0)
32 | f1_scores = np.divide(2 * (precisions * recalls), precisions + recalls,
33 | out=np.zeros_like(precisions, dtype=float),
34 | where=(precisions + recalls) != 0)
35 | Metrics_res['f1'][modality] = np.mean(f1_scores)
36 | Metrics_res['acc'][modality] = accuracy
37 | return Metrics_res
38 | def ClearAll(self):
39 | for modality in self.modalitys:
40 | self.confusion_matrix[modality] = np.zeros((self.num_classes, self.num_classes))
41 | self.total_samples = 0
42 |
--------------------------------------------------------------------------------
/balancemm/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from os import path as osp
4 | from types import SimpleNamespace
5 | from ..utils.parser_utils import find_module
6 | __all__ = ['find_model', 'create_model']
7 |
8 |
9 | # import model modules from those with '_model' in file names
10 | _model_folder = osp.dirname(osp.abspath(__file__))
11 | _model_filenames = [
12 | osp.splitext(v)[0] for v in os.listdir(_model_folder)
13 | if v.endswith('_model.py')
14 | ]
15 | _model_modules = [
16 | importlib.import_module(f'.{file_name}', package="balancemm.models")
17 | for file_name in _model_filenames
18 | ]
19 |
20 | # find model from model_opt
21 | def find_model(model_name: str) -> object:
22 | model_cls = find_module(_model_modules, model_name, 'Model')
23 | return model_cls
24 |
25 | # create model from model_opt
26 | def create_model(model_opt: dict):
27 | if 'type' not in model_opt:
28 | raise ValueError('Model type is required.')
29 | model_cls = find_model(model_opt['type'])
30 | model = model_cls(model_opt)
31 |
32 | print(
33 | f'Model {model.__class__.__name__} - {model_opt["type"]} is created.')
34 | return model
--------------------------------------------------------------------------------
/balancemm/models/encoders.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import random
3 | from functools import partialmethod
4 |
5 | import torch
6 | import numpy as np
7 | from sklearn.metrics import precision_recall_fscore_support
8 | import os
9 |
10 |
11 |
12 | from transformers import BertModel, BertConfig
13 | import torch.nn.functional as F
14 | from torch import nn
15 |
16 | class text_encoder(nn.Module):
17 | def __init__(self, dim_text_repr=768):
18 | super().__init__()
19 | config = BertConfig()
20 | self.textEncoder= BertModel(config).from_pretrained('')
21 |
22 | def forward(self, x):
23 | text = x
24 | hidden_states = self.textEncoder(**text) # B, T, dim_text_repr
25 | e_i = F.dropout(hidden_states[1])
26 | return e_i
27 |
28 | class image_encoder(nn.Module):
29 | def __init__(self,model_arch,num_classes=10,weights="IMAGENET1K_V1",device="cuda"):
30 | super().__init__()
31 | self.model_arch=model_arch
32 | self.device=device
33 | self.num_classes=num_classes
34 | print(model_arch)
35 | self.model=torch.hub.load("pytorch/vision", self.model_arch, weights=weights).to(self.device)
36 | # print(self.model.parameters)
37 | self.model.heads = nn.Sequential()
38 |
39 | def forward(self,x):
40 | y=self.model(x)
41 | return y
42 |
--------------------------------------------------------------------------------
/balancemm/models/fusion_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SumFusion(nn.Module):
6 | def __init__(self, input_dim=512, output_dim=100):
7 | super(SumFusion, self).__init__()
8 | self.fc_x = nn.Linear(input_dim, output_dim)
9 | self.fc_y = nn.Linear(input_dim, output_dim)
10 |
11 | def forward(self, x, y):
12 | output = self.fc_x(x) + self.fc_y(y)
13 | return x, y, output
14 |
15 | class SharedHead(nn.Module):
16 | def __init__(self, input_dim=512, output_dim=100):
17 | super(SharedHead, self).__init__()
18 | self.fc_out = nn.Linear(input_dim, output_dim)
19 | def forward(self, x):
20 | output = self.fc_out(x)
21 | return output
22 |
23 | class ConcatFusion(nn.Module):
24 | def __init__(self, input_dim=1024, output_dim=100):
25 | super(ConcatFusion, self).__init__()
26 | self.fc_out = nn.Linear(input_dim, output_dim)
27 | def forward(self, x, y):
28 | output = torch.cat((x, y), dim=1)
29 | output = self.fc_out(output)
30 | return x, y, output
31 |
32 |
33 | class ConcatFusion_N(nn.Module):
34 | def __init__(self, input_dim=3072, output_dim=100):
35 | super(ConcatFusion_N, self).__init__()
36 | self.fc_out = nn.Linear(input_dim, output_dim)
37 | def forward(self, encoder_res):
38 | output = torch.cat(list(encoder_res.values()),dim = 1)
39 | output = self.fc_out(output)
40 | return output
41 |
42 | class ConcatFusion_Mask(nn.Module):
43 | def __init__(self, input_dim=3072, output_dim=100):
44 | super(ConcatFusion_N, self).__init__()
45 | self.fc_out = nn.Linear(input_dim, output_dim)
46 | def forward(self, encoder_res):
47 | output = torch.cat(list(encoder_res.values()),dim = 1)
48 | output = self.fc_out(output)
49 | return output
50 |
51 | class ConcatFusion_3(nn.Module):
52 | def __init__(self, input_dim=3072, output_dim=100):
53 | super(ConcatFusion_3, self).__init__()
54 | self.fc_out = nn.Linear(input_dim, output_dim)
55 | self.input_dim = input_dim
56 |
57 | def forward(self, x, y, z):
58 | output = torch.cat((x, y), dim=1)
59 | output = torch.cat((output, z),dim = 1)
60 | output = self.fc_out(output)
61 | # x = (torch.mm(x, torch.transpose(self.fc_out.weight[:, self.input_dim // 3: 2 * self.input_dim // 3], 0, 1))
62 | # + self.fc_out.bias / 2)
63 |
64 | # y = (torch.mm(y, torch.transpose(self.fc_out.weight[:, 2* self.input_dim // 3: self.input_dim ], 0, 1))
65 | # + self.fc_out.bias / 2)
66 |
67 | return x, y, z, output
68 |
69 | class FiLM(nn.Module):
70 | """
71 | FiLM: Visual Reasoning with a General Conditioning Layer,
72 | https://arxiv.org/pdf/1709.07871.pdf.
73 | """
74 |
75 | def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
76 | super(FiLM, self).__init__()
77 |
78 | self.dim = input_dim
79 | self.fc = nn.Linear(input_dim, 2 * dim)
80 | self.fc_out = nn.Linear(dim, output_dim)
81 |
82 | self.x_film = x_film
83 |
84 | def forward(self, x, y):
85 |
86 | if self.x_film:
87 | film = x
88 | to_be_film = y
89 | else:
90 | film = y
91 | to_be_film = x
92 |
93 | gamma, beta = torch.split(self.fc(film), self.dim, 1)
94 |
95 | output = gamma * to_be_film + beta
96 | output = self.fc_out(output)
97 |
98 | return x, y, output
99 |
100 |
101 | class GatedFusion(nn.Module):
102 | """
103 | Efficient Large-Scale Multi-Modal Classification,
104 | https://arxiv.org/pdf/1802.02892.pdf.
105 | """
106 |
107 | def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
108 | super(GatedFusion, self).__init__()
109 |
110 | self.fc_x = nn.Linear(input_dim, dim)
111 | self.fc_y = nn.Linear(input_dim, dim)
112 | self.fc_out = nn.Linear(dim, output_dim)
113 |
114 | self.x_gate = x_gate # whether to choose the x to obtain the gate
115 |
116 | self.sigmoid = nn.Sigmoid()
117 |
118 | def forward(self, x, y):
119 | out_x = self.fc_x(x)
120 | out_y = self.fc_y(y)
121 |
122 | if self.x_gate:
123 | gate = self.sigmoid(out_x)
124 | output = self.fc_out(torch.mul(gate, out_y))
125 | else:
126 | gate = self.sigmoid(out_y)
127 | output = self.fc_out(torch.mul(out_x, gate))
128 |
129 | return out_x, out_y, output
130 |
131 |
--------------------------------------------------------------------------------
/balancemm/models/resnet_arch.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
4 | """3x3 convolution with padding"""
5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
6 | padding=dilation, groups=groups, bias=False, dilation=dilation)
7 |
8 |
9 | def conv1x1(in_planes, out_planes, stride=1):
10 | """1x1 convolution"""
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
18 | base_width=64, dilation=1, norm_layer=None):
19 | super(BasicBlock, self).__init__()
20 | if norm_layer is None:
21 | norm_layer = nn.BatchNorm2d
22 | if groups != 1 or base_width != 64:
23 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
24 | if dilation > 1:
25 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = norm_layer(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = norm_layer(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | identity = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | identity = self.downsample(x)
47 |
48 | out += identity
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class ResNet(nn.Module):
55 |
56 | def __init__(self, block, layers, modality, zero_init_residual=False,
57 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
58 | norm_layer=None):
59 | super(ResNet, self).__init__()
60 | self.modality = modality
61 | if norm_layer is None:
62 | norm_layer = nn.BatchNorm2d
63 | self._norm_layer = norm_layer
64 |
65 | self.inplanes = 64
66 | self.dilation = 1
67 | if replace_stride_with_dilation is None:
68 | # each element in the tuple indicates if we should replace
69 | # the 2x2 stride with a dilated convolution instead
70 | replace_stride_with_dilation = [False, False, False]
71 | if len(replace_stride_with_dilation) != 3:
72 | raise ValueError("replace_stride_with_dilation should be None "
73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
74 | self.groups = groups
75 | self.base_width = width_per_group
76 | if modality == 'audio':
77 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
78 | bias=False)
79 | elif modality == 'visual':
80 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
81 | bias=False)
82 | elif modality == 'flow':
83 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3,
84 | bias=False)
85 | else:
86 | raise NotImplementedError('Incorrect modality, should be audio or visual but got {}'.format(modality))
87 | self.bn1 = norm_layer(self.inplanes)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
90 | self.layer1 = self._make_layer(block, 64, layers[0])
91 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
92 | dilate=replace_stride_with_dilation[0])
93 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
94 | dilate=replace_stride_with_dilation[1])
95 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
96 | dilate=replace_stride_with_dilation[2])
97 |
98 | for m in self.modules():
99 | if isinstance(m, nn.Conv2d):
100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
101 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
102 | nn.init.normal_(m.weight, mean=1, std=0.02)
103 | nn.init.constant_(m.bias, 0)
104 |
105 | # Zero-initialize the last BN in each residual branch,
106 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
107 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
108 | if zero_init_residual:
109 | for m in self.modules():
110 | if isinstance(m, Bottleneck):
111 | nn.init.constant_(m.bn3.weight, 0)
112 | elif isinstance(m, BasicBlock):
113 | nn.init.constant_(m.bn2.weight, 0)
114 |
115 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
116 | norm_layer = self._norm_layer
117 | downsample = None
118 | previous_dilation = self.dilation
119 | if dilate:
120 | self.dilation *= stride
121 | stride = 1
122 | if stride != 1 or self.inplanes != planes * block.expansion:
123 | downsample = nn.Sequential(
124 | conv1x1(self.inplanes, planes * block.expansion, stride),
125 | norm_layer(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
130 | self.base_width, previous_dilation, norm_layer))
131 | self.inplanes = planes * block.expansion
132 | for _ in range(1, blocks):
133 | layers.append(block(self.inplanes, planes, groups=self.groups,
134 | base_width=self.base_width, dilation=self.dilation,
135 | norm_layer=norm_layer))
136 |
137 | return nn.Sequential(*layers)
138 |
139 | def forward(self, x):
140 |
141 | if self.modality == 'visual' or self.modality == 'flow':
142 | (B, T, C, H, W) = x.size()
143 | x = x.view(B * T, C, H, W)
144 |
145 | x = self.conv1(x)
146 | x = self.bn1(x)
147 | x = self.relu(x)
148 | x = self.maxpool(x)
149 |
150 | x = self.layer1(x)
151 | x = self.layer2(x)
152 | x = self.layer3(x)
153 | x = self.layer4(x)
154 | out = x
155 |
156 | return out
157 |
158 | class ResNet18(ResNet):
159 | def __init__(self, **kwargs):
160 | super(ResNet18, self).__init__(BasicBlock, [2, 2, 2, 2], **kwargs)
161 |
162 |
163 | class Bottleneck(nn.Module):
164 | expansion = 4
165 |
166 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
167 | base_width=64, dilation=1, norm_layer=None):
168 | super(Bottleneck, self).__init__()
169 | if norm_layer is None:
170 | norm_layer = nn.BatchNorm2d
171 | width = int(planes * (base_width / 64.)) * groups
172 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
173 | self.conv1 = conv1x1(inplanes, width)
174 | self.bn1 = norm_layer(width)
175 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
176 | self.bn2 = norm_layer(width)
177 | self.conv3 = conv1x1(width, planes * self.expansion)
178 | self.bn3 = norm_layer(planes * self.expansion)
179 | self.relu = nn.ReLU(inplace=True)
180 | self.downsample = downsample
181 | self.stride = stride
182 |
183 | def forward(self, x):
184 | identity = x
185 |
186 | out = self.conv1(x)
187 | out = self.bn1(out)
188 | out = self.relu(out)
189 |
190 | out = self.conv2(out)
191 | out = self.bn2(out)
192 | out = self.relu(out)
193 |
194 | out = self.conv3(out)
195 | out = self.bn3(out)
196 |
197 | if self.downsample is not None:
198 | identity = self.downsample(x)
199 |
200 | out += identity
201 | out = self.relu(out)
202 |
203 | return out
--------------------------------------------------------------------------------
/balancemm/train.py:
--------------------------------------------------------------------------------
1 | from .utils.logger import setup_logger
2 | from .models import create_model
3 | from .trainer import create_trainer
4 | from .utils.train_utils import choose_logger
5 | import subprocess
6 | from .utils.data_utils import create_train_val_dataloader
7 | from types import SimpleNamespace
8 | from .utils.optimizer import create_optimizer
9 | from .utils.scheduler import create_scheduler
10 | from lightning.fabric import Fabric
11 | import lightning as L
12 | from os import path as osp
13 | import os
14 | import torch
15 | import logging
16 | from datetime import datetime
17 | from .evaluation.modalitys import Calculate_Shapley
18 | from .trainer.LinearProbe_trainer import NewLinearHead
19 | from .models.avclassify_model import MultiModalParallel
20 |
21 | import copy
22 | def train_and_test(args: dict):
23 | dict_args = args
24 | args = SimpleNamespace(**args)
25 |
26 | if args.trainer['name'] == 'unimodal':
27 | args.out_dir = args.out_dir.replace('unimodalTrainer','unimodalTrainer_' + list(args.model['encoders'].keys())[0])
28 |
29 | log_dir = osp.join(args.out_dir, "logs")
30 | print("logg:{}".format(log_dir))
31 | loggers_online = [choose_logger(logger_name, log_dir = log_dir, project = args.name, comment = args.log['comment']) for logger_name in args.log['logger_name']]
32 | logger = logging.getLogger(__name__)
33 | args.checkpoint_dir = osp.join(args.out_dir, 'checkpoints')
34 | os.makedirs(args.out_dir, exist_ok=True)
35 | os.makedirs(args.checkpoint_dir, exist_ok=True)
36 | file_handler = logging.FileHandler(args.out_dir + '/training.log')
37 | file_handler.setLevel(logging.INFO)
38 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
39 | file_handler.setFormatter(formatter)
40 | logger.addHandler(file_handler)
41 | loggers = [logger]
42 | logger.setLevel(logging.DEBUG)
43 | logger.info(dict_args)
44 | for tb_logger in loggers_online:
45 | if isinstance(tb_logger, L.fabric.loggers.CSVLogger):
46 | os.makedirs(os.path.join(log_dir, 'csv'), exist_ok=True)
47 | elif isinstance(tb_logger, L.fabric.loggers.TensorBoardLogger):
48 | os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True)
49 | elif isinstance(tb_logger, L.pytorch.loggers.wandb.WandbLogger):
50 | os.makedirs(os.path.join(log_dir, 'wandb'), exist_ok=True)
51 | if tb_logger:
52 | tb_logger.log_hyperparams(dict_args)
53 | fabric = Fabric(**(args.fabric),
54 | loggers = loggers,
55 | )
56 | if isinstance(fabric.accelerator, L.fabric.accelerators.CUDAAccelerator):
57 | fabric.print('set float32 matmul precision to high')
58 | torch.set_float32_matmul_precision('high')
59 | device = args.Main_config['device']
60 | # if device == '':
61 | # device = torch.device('cpu')
62 | # else:
63 | # device = torch.device('cuda:' + args.Main_config['device'])
64 | if device == '':
65 | model = create_model(args.model)
66 | else:
67 | model = create_model(args.model)
68 | print(list(range(torch.cuda.device_count())))
69 | model = MultiModalParallel(model, device_ids = list(range(torch.cuda.device_count())))
70 | model = model.cuda()
71 | # args.model['device'] = device
72 | if args.trainer['name'] != 'Sample':
73 | train_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args)
74 | else:
75 | train_dataloader, train_val_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args)
76 | args.trainer['checkpoint_dir'] = args.checkpoint_dir ##
77 | optimizer = create_optimizer(model, args.train['optimizer'], args.train['parameter'])
78 | scheduler = create_scheduler(optimizer, args.train['scheduler'])
79 | trainer = create_trainer(fabric, args.Main_config, args.trainer, args, logger,tb_logger)
80 |
81 | start_time = datetime.now()
82 | if args.trainer['name'] == 'GBlendingTrainer':
83 | temp_model = create_model(args.model)
84 | temp_model = MultiModalParallel(temp_model,device_ids = list(range(torch.cuda.device_count())))
85 | temp_model.cuda()
86 | temp_optimizer = create_optimizer(temp_model, args.train['optimizer'], args.train['parameter'])
87 | temp_optimizer_origin = copy.deepcopy(temp_optimizer.state_dict())
88 | trainer.fit(model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, temp_optimizer,temp_optimizer_origin,logger,tb_logger)
89 |
90 | elif args.trainer['name'] == 'SampleTrainer':
91 | trainer.fit(model, train_dataloader, train_val_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger)
92 | else :
93 | trainer.fit(model, train_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger)
94 | end_time = datetime.now()
95 | total_time = end_time - start_time
96 | total_time = total_time.total_seconds() / 3600
97 |
98 | logger.info("Training time :{:.2f}".format(total_time))
99 | #val best
100 | print(f'The best val acc is : {trainer.best_acc}')
101 | logger.info(f'The best val acc is : {trainer.best_acc}')
102 | logger.info('Use the best model to Test')
103 | #load best
104 | model.eval()
105 | best_state = torch.load(args.checkpoint_dir+ '/epoch_normal.ckpt')
106 | model.load_state_dict(best_state['model'])
107 | #test sharply
108 | logger.info('Calculate the shapley value of best model')
109 | Calculate_Shapley(trainer = trainer, model = model, CalcuLoader = test_dataloader, logger= logger)
110 | #best test
111 | _, Metrics_res = trainer.val_loop(model, test_dataloader)
112 | test_acc = Metrics_res['acc']['output']
113 | info = ''
114 | output_info = ''
115 | for metircs in sorted(Metrics_res.keys()):
116 | if metircs == 'acc':
117 | valid_acc = Metrics_res[metircs]
118 | for modality in sorted(valid_acc.keys()):
119 | tag = "valid_acc"
120 | if modality == 'output':
121 | output_info += f"test_acc: {valid_acc[modality]}"
122 |
123 | else:
124 | info += f", acc_{modality}: {valid_acc[modality]}"
125 |
126 |
127 | if metircs == 'f1':
128 | valid_f1 = Metrics_res[metircs]
129 | for modality in sorted(valid_f1.keys()):
130 | tag = "valid_f1"
131 | if modality == 'output':
132 | output_info += f", test_f1: {valid_f1[modality]}"
133 |
134 | else:
135 | info += f", f1_{modality}: {valid_f1[modality]}"
136 |
137 | info = output_info+ ', ' + info
138 |
139 | logger.info(info)
140 | for handler in logger.handlers:
141 | handler.flush()
142 | logger.info(f'The best test acc is : {test_acc}')
143 | print(f'The best test acc is : {test_acc}')
144 | # =======
145 | # _, Metrics_res = trainer.val_loop(model, test_dataloader)
146 | # logger.info('Calculate the shapley value of best model')
147 | # Calculate_Shapley(trainer = trainer, model = model, CalcuLoader = test_dataloader, logger= logger)
148 | # logger.info(f'The best val acc is : {trainer.best_acc}')
149 | # print(f'The best val acc is : {trainer.best_acc}')
150 | # # for metircs in sorted(Metrics_res.keys()):
151 | # # if metircs == 'acc':
152 | # # valid_acc = Metrics_res[metircs]
153 | # # for modality in sorted(valid_acc.keys()):
154 | # # if modality == 'output':
155 | # # print(f'The test acc is : {Metrics_res[metrics][modality]}')
156 | # # print(f'The imbalance is :{imbalance}')
157 | # >>>>>>> main
158 |
159 |
160 |
161 | def linear_probe_eval(args: dict):
162 | dict_args = args
163 | args = SimpleNamespace(**args)
164 | args.out_dir = osp.join(args.out_dir,args.trainer['trainer_probed'])
165 | log_dir = osp.join(args.out_dir, "logs")
166 | print("logg:{}".format(log_dir))
167 | loggers_online = [choose_logger(logger_name, log_dir = log_dir, project = args.name, comment = args.log['comment']) for logger_name in args.log['logger_name']]
168 | logger = logging.getLogger(__name__)
169 | args.checkpoint_dir = osp.join(args.out_dir, 'checkpoints')
170 | os.makedirs(args.out_dir, exist_ok=True)
171 | os.makedirs(args.checkpoint_dir, exist_ok=True)
172 | file_handler = logging.FileHandler(args.out_dir + '/training.log')
173 | file_handler.setLevel(logging.INFO)
174 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
175 | file_handler.setFormatter(formatter)
176 | logger.addHandler(file_handler)
177 | loggers = [logger]
178 | logger.setLevel(logging.DEBUG)
179 | logger.info(dict_args)
180 | for tb_logger in loggers_online:
181 | if isinstance(tb_logger, L.fabric.loggers.CSVLogger):
182 | os.makedirs(os.path.join(log_dir, 'csv'), exist_ok=True)
183 | elif isinstance(tb_logger, L.fabric.loggers.TensorBoardLogger):
184 | os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True)
185 | elif isinstance(tb_logger, L.pytorch.loggers.wandb.WandbLogger):
186 | os.makedirs(os.path.join(log_dir, 'wandb'), exist_ok=True)
187 | if tb_logger:
188 | tb_logger.log_hyperparams(dict_args)
189 | fabric = Fabric(**(args.fabric),
190 | loggers = loggers,
191 | )
192 | if isinstance(fabric.accelerator, L.fabric.accelerators.CUDAAccelerator):
193 | fabric.print('set float32 matmul precision to high')
194 | torch.set_float32_matmul_precision('high')
195 | device = args.Main_config['device']
196 | if device == '':
197 | device = torch.device('cpu')
198 | else:
199 | device = torch.device('cuda:' + args.Main_config['device'])
200 | args.model['device'] = device
201 | train_dataloader, val_dataloader, test_dataloader = create_train_val_dataloader(fabric, args)
202 | args.trainer['checkpoint_dir'] = args.checkpoint_dir ##
203 | model = create_model(args.model)
204 | model.to(device)
205 | model.device = device
206 | input_dim = sum(model.modality_size.values())
207 | # Create a linear-classifier-head
208 | new_head = NewLinearHead(input_dim, model.n_classes).to(model.device)
209 | optimizer = create_optimizer(new_head, args.train['optimizer'], args.train['parameter'])
210 | scheduler = create_scheduler(optimizer, args.train['scheduler'])
211 | trainer = create_trainer(fabric, args.Main_config, args.trainer, args, logger,tb_logger)
212 |
213 | start_time = datetime.now()
214 | if args.trainer['name'] == 'GBlendingTrainer':
215 | temp_model = create_model(args.model)
216 | temp_model.to(device)
217 | temp_model.device = device
218 | temp_optimizer = create_optimizer(temp_model, args.train['optimizer'], args.train['parameter'])
219 | trainer.fit(model, temp_model,train_dataloader, val_dataloader, optimizer, scheduler, temp_optimizer,logger,tb_logger)
220 | else :
221 | trainer.fit(model,new_head, train_dataloader, val_dataloader, optimizer, scheduler, logger,tb_logger)
222 | end_time = datetime.now()
223 | total_time = end_time - start_time
224 | total_time = total_time.total_seconds() / 3600
225 | logger.info("Training time :{:.2f}".format(total_time))
226 | logger.info('Use the best model to Test')
227 | model.eval()
228 | new_head.eval()
229 | best_state = torch.load(args.checkpoint_dir+ '/epoch_normal.ckpt')
230 | model.load_state_dict(best_state['model'])
231 | new_head.load_state_dict(best_state['new_head'])
232 | trainer.val_loop(model, new_head,test_dataloader)
233 | logger.info(f'The best val acc is : {trainer.best_acc}')
234 | print(f'The best val acc is : {trainer.best_acc}')
235 |
--------------------------------------------------------------------------------
/balancemm/trainer/AMCo_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 | import numpy as np
18 | from ..models.avclassify_model import BaseClassifierModel
19 | class AMCoTrainer(BaseTrainer):
20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
21 | super(AMCoTrainer,self).__init__(fabric,**para_dict)
22 | self.alpha = method_dict['alpha']
23 | self.method = method_dict['method']
24 | self.modulation_starts = method_dict['modulation_starts']
25 | self.modulation_ends = method_dict['modulation_ends']
26 |
27 | self.sigma = method_dict['sigma']
28 | self.U = method_dict['U']
29 | self.eps = method_dict['eps']
30 | self.modality = method_dict['modality']
31 |
32 | def train_loop(
33 | self,
34 | model: BaseClassifierModel,
35 | optimizer: torch.optim.Optimizer,
36 | train_loader: torch.utils.data.DataLoader,
37 | limit_batches: Union[int, float] = float("inf"),
38 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
39 | ):
40 | """The training loop running a single training epoch.
41 |
42 | Args:
43 | model: the LightningModule to train
44 | optimizer: the optimizer, optimizing the LightningModule.
45 | train_loader: The dataloader yielding the training batches.
46 | limit_batches: Limits the batches during this training epoch.
47 | If greater than the number of batches in the ``train_loader``, this has no effect.
48 | scheduler_cfg: The learning rate scheduler configuration.
49 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
50 | for supported values.
51 |
52 | """
53 | self.fabric.call("on_train_epoch_start")
54 | all_modalitys = list(model.modalitys)
55 | all_modalitys.append('output')
56 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
57 | iterable = self.progbar_wrapper(
58 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
59 | )
60 |
61 | dependent_modality = {}
62 | for modality in model.modalitys:
63 | dependent_modality[modality] = False
64 | l_t = 0
65 | for batch_idx, batch in enumerate(iterable):
66 | # end epoch if stopping training completely or max batches for this epoch reached
67 | if self.should_stop or batch_idx >= limit_batches:
68 | break
69 |
70 | self.fabric.call("on_train_batch_start", batch, batch_idx)
71 |
72 | # prepare the mask
73 | pt = np.sin(np.pi/2*(min(self.eps,l_t)/self.eps))
74 | N = int(pt * model.n_classes)
75 | mask_t = np.ones(model.n_classes-N)
76 | mask_t = np.pad(mask_t,(0,N))
77 | np.random.shuffle(mask_t)
78 | mask_t = torch.from_numpy(mask_t)
79 | mask_t = mask_t.to(model.device)
80 | mask_t = mask_t.float()
81 | mask_t = mask_t.unsqueeze(0)
82 | mask_t = mask_t.expand(2, -1)
83 | l_t += self.current_epoch/10
84 |
85 | # check if optimizer should step in gradient accumulation
86 | should_optim_step = self.global_step % self.grad_accum_steps == 0
87 | if should_optim_step:
88 | # currently only supports a single optimizer
89 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
90 |
91 | # optimizer step runs train step internally through closure
92 | loss, dependent_modality = self.training_step( model=model, batch=batch,
93 | batch_idx=batch_idx, mask= mask_t,
94 | dependent_modality= dependent_modality,
95 | pt = pt)
96 | optimizer.step()
97 | self.fabric.call("on_before_zero_grad", optimizer)
98 |
99 | optimizer.zero_grad()
100 |
101 | else:
102 | # gradient accumulation -> no optimizer step
103 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
104 |
105 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
106 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
107 |
108 | # this guard ensures, we only step the scheduler once per global step
109 | # if should_optim_step:
110 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
111 |
112 | # add output values to progress bar
113 |
114 | self._format_iterable(iterable, self._current_train_return, "train")
115 |
116 | # only increase global step if optimizer stepped
117 | self.global_step += int(should_optim_step)
118 |
119 | self._current_metrics = self.precision_calculator.compute_metrics()
120 | self.fabric.call("on_train_epoch_end")
121 |
122 | def training_step(self, model: BaseClassifierModel, batch, batch_idx, dependent_modality, mask ,pt):
123 |
124 | # TODO: make it simpler and easier to extend
125 | criterion = nn.CrossEntropyLoss()
126 | softmax = nn.Softmax(dim=1)
127 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
128 | model(batch,dependent_modality = dependent_modality, mask = mask,\
129 | pt = pt)
130 | else:
131 | model(batch)
132 | # model.Unimodality_Calculate(mask, dependent_modality)
133 |
134 | label = batch['label']
135 | label = label.to(model.device)
136 | # print(a.shape, v.shape, model.head.weight.shape)
137 |
138 | ## our modality-wise normalization on weight and feature
139 | out = model.unimodal_result['output']
140 | loss = criterion(out, label)
141 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
142 | loss.backward(retain_graph = True)
143 | for modality in model.modalitys:
144 | loss_uni = criterion(model.unimodal_result[modality],label)
145 | loss_uni.backward()
146 | out_combine = torch.cat([value for key,value in model.unimodal_result.items() if key != 'output'],1)
147 | sft_out = softmax(out_combine)
148 | now_dim = 0
149 | for modality in model.modalitys:
150 | if now_dim < sft_out.shape[1] - model.n_classes:
151 | sft_uni = torch.sum(sft_out[:, now_dim: now_dim + model.n_classes])/(len(label))
152 | else:
153 | sft_uni = torch.sum(sft_out[:, now_dim: ])/(len(label))
154 | dependent_modality[modality] = bool(sft_uni > self.sigma)
155 | now_dim += model.n_classes
156 | else:
157 | loss.backward()
158 |
159 | return loss, dependent_modality
--------------------------------------------------------------------------------
/balancemm/trainer/CML_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 | from lightning_utilities import apply_to_collection
15 | import lightning as L
16 | import torch
17 | import random
18 |
19 | def conf_loss(conf, pred, conf_x, pred_x, label):
20 | sign = (~((pred == label) & (pred_x != label))).long() # trick 1
21 | #print(sign)
22 | return (max(0, torch.sub(conf_x, conf).sum())), sign.sum()
23 |
24 | class CMLTrainer(BaseTrainer):
25 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
26 | super(CMLTrainer,self).__init__(fabric,**para_dict)
27 |
28 | self.modulation_starts = method_dict['modulation_starts']
29 | self.modulation_ends = method_dict['modulation_ends']
30 |
31 | self.lam = method_dict['lam']
32 |
33 | def train_loop(
34 | self,
35 | model: L.LightningModule,
36 | optimizer: torch.optim.Optimizer,
37 | train_loader: torch.utils.data.DataLoader,
38 | limit_batches: Union[int, float] = float("inf"),
39 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
40 | ):
41 | """The training loop running a single training epoch.
42 |
43 | Args:
44 | model: the LightningModule to train
45 | optimizer: the optimizer, optimizing the LightningModule.
46 | train_loader: The dataloader yielding the training batches.
47 | limit_batches: Limits the batches during this training epoch.
48 | If greater than the number of batches in the ``train_loader``, this has no effect.
49 | scheduler_cfg: The learning rate scheduler configuration.
50 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
51 | for supported values.
52 |
53 | """
54 | self.fabric.call("on_train_epoch_start")
55 | all_modalitys = list(model.modalitys)
56 | all_modalitys.append('output')
57 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
58 | iterable = self.progbar_wrapper(
59 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
60 | )
61 |
62 | random_dict = list(model.modalitys)
63 | # if self.modality == 3:
64 | # random_dict = ["audio", "visual", "text"]
65 | # else:
66 | # random_dict = ['audio', "visual" ]
67 | random.shuffle(random_dict)
68 | for batch_idx, batch in enumerate(iterable):
69 | # end epoch if stopping training completely or max batches for this epoch reached
70 | if self.should_stop or batch_idx >= limit_batches:
71 | break
72 |
73 | self.fabric.call("on_train_batch_start", batch, batch_idx)
74 |
75 | # check if optimizer should step in gradient accumulation
76 | should_optim_step = self.global_step % self.grad_accum_steps == 0
77 | if should_optim_step:
78 | # currently only supports a single optimizer
79 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
80 |
81 | # optimizer step runs train step internally through closure
82 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx, random_dict_ = random_dict))
83 | self.fabric.call("on_before_zero_grad", optimizer)
84 |
85 | optimizer.zero_grad()
86 |
87 | else:
88 | # gradient accumulation -> no optimizer step
89 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
90 |
91 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
92 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
93 |
94 | # this guard ensures, we only step the scheduler once per global step
95 | # if should_optim_step:
96 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
97 |
98 | # add output values to progress bar
99 | self._format_iterable(iterable, self._current_train_return, "train")
100 |
101 | # only increase global step if optimizer stepped
102 | self.global_step += int(should_optim_step)
103 |
104 | self._current_metrics = self.precision_calculator.compute_metrics()
105 | self.fabric.call("on_train_epoch_end")
106 |
107 | def training_step(self, model, batch, batch_idx, random_dict_ ):
108 |
109 | # TODO: make it simpler and easier to extend
110 | criterion = nn.CrossEntropyLoss()
111 | softmax = nn.Softmax(dim=1)
112 | label = batch['label']
113 | label = label.to(model.device)
114 |
115 | _loss_c = 0
116 | modality_num = len(model.modalitys)
117 | modality_list = model.modalitys
118 | key = list(modality_list)
119 | m = {}
120 | if modality_num == 3:
121 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: ######
122 | pad_audio = False
123 | pad_visual = False
124 | pad_text = False
125 | loss_mm = 0
126 | model(batch)
127 | for modality in modality_list:
128 | m[modality] = model.encoder_result[modality]
129 | m['out'] = model.encoder_result['output']
130 | # a, v, t, out = model(batch)
131 | unimodal_result = model.Unimodality_Calculate()
132 | out_s = unimodal_result['output']
133 | # out_a, out_v, out_t = model.AVTCalculate(a, v, t, out)
134 | # out_s = out
135 | random_dict = random_dict_.copy()
136 | for i in range(modality_num - 1):
137 | removed_mm = random_dict.pop()
138 |
139 | out_p = out_s - unimodal_result[removed_mm] +model.fusion_module.fc_out.bias/3
140 |
141 | prediction_s = softmax(out_s)
142 | conf_s, pred_s = torch.max(prediction_s, dim=1)
143 |
144 | prediction_p = softmax(out_p)
145 | conf_p, pred_p = torch.max(prediction_p, dim=1)
146 |
147 | if i ==0 : loss = criterion(out_s, label)
148 |
149 | loss_p = criterion(out_p, label)
150 | loss_pc ,_ = conf_loss(conf_s, pred_s, conf_p, pred_p, label)
151 | loss = loss + loss_p
152 | _loss_c = _loss_c + loss_pc
153 |
154 | out_s = out_p
155 |
156 | loss = (loss) / 3 +self.lam * _loss_c
157 | else:
158 | model(batch)
159 | for modality in modality_list:
160 | m[modality] = model.encoder_result[modality]
161 | m['out'] = model.encoder_result['output']
162 | # a, v, t, out = model(batch)
163 | unimodal_result = model.Unimodality_Calculate()
164 | out_s = unimodal_result['output']
165 |
166 |
167 | loss = criterion(m['out'], label)
168 |
169 | else:
170 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: ######
171 | pad_audio = False
172 | pad_visual = False
173 | pad_text = False
174 | loss_mm = 0
175 | model(batch)
176 | for modality in modality_list:
177 | m[modality] = model.encoder_result[modality]
178 | m['out'] = model.encoder_result['output']
179 | unimodal_result = model.Unimodality_Calculate()
180 | out_s = unimodal_result['output']
181 | random_dict = random_dict_.copy()
182 | for i in range(modality_num - 1):
183 | removed_mm = random_dict.pop()
184 |
185 | out_p = out_s - unimodal_result[removed_mm] +model.fusion_module.fc_out.bias/2
186 |
187 | prediction_s = softmax(out_s)
188 | conf_s, pred_s = torch.max(prediction_s, dim=1)
189 |
190 | prediction_p = softmax(out_p)
191 | conf_p, pred_p = torch.max(prediction_p, dim=1)
192 |
193 | if i ==0 : loss = criterion(out_s, label)
194 |
195 | loss_p = criterion(out_p, label)
196 | loss_pc ,_ = conf_loss(conf_s, pred_s, conf_p, pred_p, label)
197 | loss += loss_p
198 | _loss_c += loss_pc
199 |
200 | out_s = out_p
201 | loss = (loss) / 2 +self.lam * _loss_c
202 | else:
203 | model(batch)
204 | for modality in modality_list:
205 | m[modality] = model.encoder_result[modality]
206 | m['out'] = model.encoder_result['output']
207 | # out_a, out_v = model.AVCalculate(a, v, out)
208 |
209 | loss = criterion(m['out'], label)
210 | loss.backward()
211 |
212 | # # avoid gradients in stored/accumulated values -> prevents potential OOM
213 | # self._current_train_return = apply_to_collection(outputs, dtype=torch.Tensor, function=lambda x: x.detach())
214 |
215 |
216 | return loss
--------------------------------------------------------------------------------
/balancemm/trainer/Greedy_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 | import numpy as np
18 | from ..models.avclassify_model import BaseClassifierGreedyModel
19 | class GreedyTrainer(BaseTrainer):
20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
21 | super(GreedyTrainer,self).__init__(fabric,**para_dict)
22 | self.alpha = method_dict['alpha']
23 | self.modulation_starts = method_dict['modulation_starts']
24 | self.modulation_ends = method_dict['modulation_ends']
25 |
26 | self.modality = method_dict['modality']
27 | self.window_size = method_dict['window_size']
28 | self.M_bypass_modal_0 = 0
29 | self.M_bypass_modal_1 = 0
30 | self.M_main_modal_0 = 0
31 | self.M_main_modal_1 = 0
32 | self.curation_mode = False
33 | self.caring_modality = 0
34 | self.curation_step = self.window_size
35 | self.speed = 0
36 |
37 | def train_loop(
38 | self,
39 | model: BaseClassifierGreedyModel,
40 | optimizer: torch.optim.Optimizer,
41 | train_loader: torch.utils.data.DataLoader,
42 | limit_batches: Union[int, float] = float("inf"),
43 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
44 | ):
45 | """The training loop running a single training epoch.
46 |
47 | Args:
48 | model: the LightningModule to train
49 | optimizer: the optimizer, optimizing the LightningModule.
50 | train_loader: The dataloader yielding the training batches.
51 | limit_batches: Limits the batches during this training epoch.
52 | If greater than the number of batches in the ``train_loader``, this has no effect.
53 | scheduler_cfg: The learning rate scheduler configuration.
54 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
55 | for supported values.
56 |
57 | """
58 | self.fabric.call("on_train_epoch_start")
59 | all_modalitys = list(model.modalitys)
60 | all_modalitys.append('output')
61 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
62 | iterable = self.progbar_wrapper(
63 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
64 | )
65 |
66 | for batch_idx, batch in enumerate(iterable):
67 | # end epoch if stopping training completely or max batches for this epoch reached
68 | if self.should_stop or batch_idx >= limit_batches:
69 | break
70 | self.fabric.call("on_train_batch_start", batch, batch_idx)
71 |
72 |
73 | # check if optimizer should step in gradient accumulation
74 | should_optim_step = self.global_step % self.grad_accum_steps == 0
75 | if should_optim_step:
76 | # currently only supports a single optimizer
77 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
78 |
79 | # optimizer step runs train step internally through closure
80 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
81 | if not self.curation_mode:
82 | self.speed = self.compute_learning_speed(model)
83 | if abs(self.speed) > self.alpha:
84 | biased_direction=np.sign(self.speed)
85 | self.curation_mode = True
86 | self.curation_step = 0
87 |
88 | if biased_direction==-1: #BDR0BDR1
91 | self.caring_modality = 0
92 | else:
93 | self.curation_mode = False
94 | self.caring_modality = 0
95 | else:
96 | self.curation_step +=1
97 | if self.curation_step==self.window_size:
98 | self.curation_mode=False
99 | optimizer.step()
100 | self.fabric.call("on_before_zero_grad", optimizer)
101 |
102 | optimizer.zero_grad()
103 |
104 | else:
105 | # gradient accumulation -> no optimizer step
106 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
107 |
108 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
109 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
110 |
111 | # this guard ensures, we only step the scheduler once per global step
112 | # if should_optim_step:
113 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
114 |
115 | # add output values to progress bar
116 |
117 | self._format_iterable(iterable, self._current_train_return, "train")
118 |
119 | # only increase global step if optimizer stepped
120 | self.global_step += int(should_optim_step)
121 |
122 | self._current_metrics = self.precision_calculator.compute_metrics()
123 | self.fabric.call("on_train_epoch_end")
124 |
125 | def training_step(self, model: BaseClassifierGreedyModel, batch, batch_idx):
126 |
127 | # TODO: make it simpler and easier to extend
128 | modality_list = model.modalitys
129 | criterion = nn.CrossEntropyLoss()
130 | softmax = nn.Softmax(dim=1)
131 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
132 | model(batch,self.curation_mode,self.caring_modality)
133 | else:
134 | model(batch,self.curation_mode,self.caring_modality)
135 | model.Unimodality_Calculate()
136 |
137 | label = batch['label']
138 | label = label.to(model.device)
139 | # print(a.shape, v.shape, model.head.weight.shape)
140 |
141 | ## our modality-wise normalization on weight and feature
142 | out = model.encoder_result['output']
143 | loss = criterion(out, label)
144 | loss.backward()
145 | return loss
146 |
147 | def compute_learning_speed(self,model:BaseClassifierGreedyModel):
148 | modality_list = model.modalitys
149 | wn_main, wn_bypass = [0]*len(modality_list), [0]*len(modality_list)
150 | gn_main, gn_bypass = [0]*len(modality_list), [0]*len(modality_list)
151 | for name, parameter in model.named_parameters():
152 | wn = (parameter ** 2).sum().item()
153 | gn = (parameter.grad.data ** 2).sum().item()#(grad ** 2).sum().item()
154 | if 'mmtm_layers' in name:
155 | shared=True
156 | for ind, modal in enumerate(modality_list):
157 | if modal in name:
158 | wn_bypass[ind]+=wn
159 | gn_bypass[ind]+=gn
160 | shared = False
161 | if shared:
162 | for ind, modal in enumerate(modality_list):
163 | wn_bypass[ind]+=wn
164 | gn_bypass[ind]+=gn
165 |
166 | else:
167 | for ind, modal in enumerate(modality_list):
168 | if modal in name:
169 | wn_main[ind]+=wn
170 | gn_main[ind]+=gn
171 |
172 | self.M_bypass_modal_0 += gn_bypass[0]/wn_bypass[0]
173 | self.M_bypass_modal_1 += gn_bypass[1]/wn_bypass[1]
174 | self.M_main_modal_0 += gn_main[0]/wn_main[0]
175 | self.M_main_modal_1 += gn_main[1]/wn_main[1]
176 |
177 | BDR_0 = np.log10(self.M_bypass_modal_0/self.M_main_modal_0)
178 | BDR_1 = np.log10(self.M_bypass_modal_1/self.M_main_modal_1)
179 |
180 | return BDR_0 - BDR_1
--------------------------------------------------------------------------------
/balancemm/trainer/MBSD_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 |
18 |
19 | class MBSDTrainer(BaseTrainer):
20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
21 | super(MBSDTrainer,self).__init__(fabric,**para_dict)
22 | self.modulation_starts = method_dict['modulation_starts']
23 | self.modulation_ends = method_dict['modulation_ends']
24 |
25 | def train_loop(
26 | self,
27 | model: L.LightningModule,
28 | optimizer: torch.optim.Optimizer,
29 | train_loader: torch.utils.data.DataLoader,
30 | limit_batches: Union[int, float] = float("inf"),
31 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
32 | ):
33 | """The training loop running a single training epoch.
34 |
35 | Args:
36 | model: the LightningModule to train
37 | optimizer: the optimizer, optimizing the LightningModule.
38 | train_loader: The dataloader yielding the training batches.
39 | limit_batches: Limits the batches during this training epoch.
40 | If greater than the number of batches in the ``train_loader``, this has no effect.
41 | scheduler_cfg: The learning rate scheduler configuration.
42 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
43 | for supported values.
44 |
45 | """
46 | self.fabric.call("on_train_epoch_start")
47 | all_modalitys = list(model.modalitys)
48 | all_modalitys.append('output')
49 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
50 | iterable = self.progbar_wrapper(
51 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
52 | )
53 |
54 |
55 | for batch_idx, batch in enumerate(iterable):
56 | # end epoch if stopping training completely or max batches for this epoch reached
57 | if self.should_stop or batch_idx >= limit_batches:
58 | break
59 |
60 | self.fabric.call("on_train_batch_start", batch, batch_idx)
61 |
62 | # check if optimizer should step in gradient accumulation
63 | should_optim_step = self.global_step % self.grad_accum_steps == 0
64 | if should_optim_step:
65 | # currently only supports a single optimizer
66 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
67 |
68 | # optimizer step runs train step internally through closure
69 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
70 | self.fabric.call("on_before_zero_grad", optimizer)
71 |
72 | optimizer.zero_grad()
73 |
74 | else:
75 | # gradient accumulation -> no optimizer step
76 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
77 |
78 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
79 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
80 |
81 | # this guard ensures, we only step the scheduler once per global step
82 | # if should_optim_step:
83 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
84 |
85 | # add output values to progress bar
86 |
87 | self._format_iterable(iterable, self._current_train_return, "train")
88 |
89 | # only increase global step if optimizer stepped
90 | self.global_step += int(should_optim_step)
91 |
92 | self._current_metrics = self.precision_calculator.compute_metrics()
93 | self.fabric.call("on_train_epoch_end")
94 |
95 | def training_step(self, model, batch, batch_idx , dependent_modality : str = 'none'):
96 |
97 | # TODO: make it simpler and easier to extend
98 | criterion = nn.CrossEntropyLoss()
99 | softmax = nn.Softmax(dim=1)
100 | modality_list = model.modalitys
101 | key = list(modality_list)
102 | m = {}
103 | loss = {}
104 | loss_modality = {}
105 | prediction = {}
106 | y_pred = {}
107 | model(batch)
108 | for modality in modality_list:
109 | m[modality] = model.encoder_result[modality]
110 | m['out'] = model.encoder_result['output']
111 | model.Unimodality_Calculate()
112 | # out_a, out_v = model.AVCalculate(a, v, out)
113 | label = batch['label']
114 | device = model.device
115 | # print(a.shape, v.shape, model.head.weight.shape)
116 |
117 | ## our modality-wise normalization on weight and feature
118 |
119 | loss['out'] = criterion(m['out'], label)
120 | for modality in modality_list:
121 | loss_modality[modality] = criterion(model.unimodal_result[modality], label)
122 | # loss_v = criterion(unimodal_result[], label)
123 | # loss_a = criterion(out_a, label)
124 |
125 | for modality in modality_list:
126 | prediction[modality] = softmax(model.unimodal_result[modality])
127 | # prediction_a = softmax(out_a)
128 | # prediction_v = softmax(out_v)
129 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
130 | if len(modality_list) == 2:
131 |
132 | loss_RS = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((model.unimodal_result[key[0]] - model.unimodal_result[key[1]])**2, dim = 1)
133 |
134 | w = torch.tensor([0.0 for _ in range(len(m['out']))])
135 | w = w.to(device)
136 | for modality in modality_list:
137 | y_pred[modality] = prediction[modality]
138 | y_pred[modality] = y_pred[modality].argmax(dim=-1)
139 | # y_pred_a = prediction_a
140 | # y_pred_a = y_pred_a.argmax(dim = -1)
141 | # y_pred_v = prediction_v
142 | # y_pred_v = y_pred_v.argmax(dim = -1)
143 | ps = torch.tensor([0.0 for _ in range(len(m['out']))])
144 | ps = ps.to(device)
145 | pw = torch.tensor([0.0 for _ in range(len(m['out']))])
146 | pw = pw.to(device)
147 | for i in range(len(m['out'])):
148 | if y_pred[key[0]][i] == label[i] or y_pred[key[1]][i] == label[i]:
149 | w[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
150 | ps[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
151 | pw[i] = min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
152 |
153 | loss_KL = F.kl_div(ps, pw, reduction = 'none')
154 | w = w.reshape(1,-1)
155 | loss_KL = loss_KL.reshape(-1,1)
156 | loss_KL = torch.mm(w, loss_KL) / len(m['out'])
157 | loss_RS = loss_RS.reshape(-1,1)
158 | loss_RS = torch.mm(w, loss_RS) / len(m['out'])
159 | total_loss = loss['out'] + loss_modality[key[0]] + loss_modality[key[1]] + loss_RS.squeeze() + loss_KL.squeeze() ## erase the dim of 1
160 | else:
161 |
162 | w1 = torch.tensor([0.0 for _ in range(len(m['out']))])
163 | w1 = w1.to(device)
164 | w2 = torch.tensor([0.0 for _ in range(len(m['out']))])
165 | w2 = w2.to(device)
166 | w3 = torch.tensor([0.0 for _ in range(len(m['out']))])
167 | w3 = w3.to(device)
168 | ps1 = torch.tensor([0.0 for _ in range(len(m['out']))])
169 | ps2 = torch.tensor([0.0 for _ in range(len(m['out']))])
170 | ps3 = torch.tensor([0.0 for _ in range(len(m['out']))])
171 | ps1 = ps1.to(device)
172 | ps2 = ps2.to(device)
173 | ps3 = ps3.to(device)
174 | pw1 = torch.tensor([0.0 for _ in range(len(m['out']))])
175 | pw2 = torch.tensor([0.0 for _ in range(len(m['out']))])
176 | pw3 = torch.tensor([0.0 for _ in range(len(m['out']))])
177 | pw1 = pw1.to(device)
178 | pw2 = pw2.to(device)
179 | pw3 = pw3.to(device)
180 | for modality in modality_list:
181 | y_pred[modality] = prediction[modality]
182 | y_pred[modality] = y_pred[modality].argmax(dim=-1)
183 |
184 | for i in range(len(m['out'])):
185 | if y_pred[key[0]][i] == label[i] or y_pred[key[1]][i] == label[i]:
186 | w1[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
187 | if y_pred[key[0]][i] == label[i] or y_pred[key[2]][i] == label[i]:
188 | w2[i] = max(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]]) - min(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]])
189 | if y_pred[key[1]][i] == label[i] or y_pred[key[2]][i] == label[i]:
190 | w3[i] = max(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]]) - min(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]])
191 | ps1[i] = max(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
192 | pw1[i] = min(prediction[key[0]][i][label[i]], prediction[key[1]][i][label[i]])
193 | ps2[i] = max(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]])
194 | pw2[i] = min(prediction[key[0]][i][label[i]], prediction[key[2]][i][label[i]])
195 | ps3[i] = max(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]])
196 | pw3[i] = min(prediction[key[1]][i][label[i]], prediction[key[2]][i][label[i]])
197 | loss_RS1 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[0]]-prediction[key[1]])**2,dim=1)
198 | loss_RS2 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[0]]-prediction[key[2]])**2,dim=1)
199 | loss_RS3 = 1/model.unimodal_result[key[0]].shape[1] * torch.sum((prediction[key[1]]-prediction[key[2]])**2,dim=1)
200 |
201 | loss_KL1 = F.kl_div(ps1, pw1, reduction = 'none')
202 | loss_KL2 = F.kl_div(ps2, pw2, reduction = 'none')
203 | loss_KL3 = F.kl_div(ps3, pw3, reduction = 'none')
204 |
205 | w1 = w1.reshape(1,-1)
206 | w2 = w2.reshape(1,-1)
207 | w3 = w3.reshape(1,-1)
208 | loss_KL1 = loss_KL1.reshape(-1,1)
209 | loss_KL1 = torch.mm(w1, loss_KL1) / len(m['out'])
210 | loss_KL2 = loss_KL2.reshape(-1,1)
211 | loss_KL2 = torch.mm(w2, loss_KL2) / len(m['out'])
212 | loss_KL3 = loss_KL3.reshape(-1,1)
213 | loss_KL3 = torch.mm(w3, loss_KL3) / len(m['out'])
214 | loss_KL = (loss_KL1 + loss_KL2 + loss_KL3) / 3
215 |
216 | loss_RS1 = loss_RS1.reshape(-1,1)
217 | loss_RS2 = loss_RS2.reshape(-1,1)
218 | loss_RS3 = loss_RS3.reshape(-1,1)
219 | loss_RS1 = torch.mm(w1, loss_RS1) / len(m['out'])
220 | loss_RS2 = torch.mm(w2, loss_RS2) / len(m['out'])
221 | loss_RS3 = torch.mm(w3, loss_RS3) / len(m['out'])
222 | loss_RS = (loss_RS1 + loss_RS2 + loss_RS3) / 3
223 |
224 | total_loss = loss['out'] + loss_modality[key[0]] + loss_modality[key[1]] + loss_modality[key[2]] + loss_KL.squeeze() + loss_RS.squeeze()## erase the dim of 1
225 |
226 | else:
227 |
228 | total_loss = loss['out']
229 | total_loss.backward()
230 |
231 | return total_loss
--------------------------------------------------------------------------------
/balancemm/trainer/MLA_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 | from lightning_utilities import apply_to_collection
15 | import lightning as L
16 | import torch
17 |
18 |
19 |
20 | class GSPlugin():
21 | def __init__(self, device, gs_flag = True ):
22 |
23 | super().__init__()
24 |
25 | # dtype = torch.cuda.FloatTensor # run on GPU
26 | if device != ' ':
27 | device = 'cuda:0'
28 | else:
29 | device = 'cpu'
30 | with torch.no_grad():
31 | # self.Pl = torch.eye(1024).to(device)
32 | # depend on modality_size
33 | self.Pl = torch.eye(512).to(device)
34 | self.exp_count = 0
35 |
36 | # @torch.no_grad()
37 | def before_update(self, model, before_batch_input, batch_index, len_dataloader, train_exp_counter):
38 | lamda = batch_index / len_dataloader + 1
39 | alpha = 1.0 * 0.1 ** lamda
40 |
41 | # x_mean = torch.mean(strategy.mb_x, 0, True)
42 | if train_exp_counter != 0:
43 | for n, w in model.named_parameters():
44 | if n == "weight":
45 |
46 | r = torch.mean(before_batch_input, 0, True)
47 | k = torch.mm(self.Pl, torch.t(r))
48 | self.Pl = torch.sub(self.Pl, torch.mm(k, torch.t(k)) / (alpha + torch.mm(k, r)))
49 |
50 | pnorm2 = torch.norm(self.Pl.data, p='fro')
51 |
52 | self.Pl.data = self.Pl.data / pnorm2
53 | w.grad.add_(torch.mm(w.grad.data, torch.t(self.Pl.data)))
54 |
55 |
56 |
57 |
58 | class MLATrainer(BaseTrainer):
59 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {},args = {}):
60 | super(MLATrainer,self).__init__(fabric,**para_dict)
61 | self.modulation_starts = method_dict['modulation_starts']
62 | self.modulation_ends = method_dict['modulation_ends']
63 | self.device = args.model['device']
64 | self.gs_plugin = GSPlugin(self.device)
65 | self.criterion = nn.CrossEntropyLoss()
66 |
67 |
68 | def train_loop(
69 | self,
70 | model: L.LightningModule,
71 | optimizer: torch.optim.Optimizer,
72 | train_loader: torch.utils.data.DataLoader,
73 | limit_batches: Union[int, float] = float("inf"),
74 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
75 | ):
76 | """The training loop running a single training epoch.
77 |
78 | Args:
79 | model: the LightningModule to train
80 | optimizer: the optimizer, optimizing the LightningModule.
81 | train_loader: The dataloader yielding the training batches.
82 | limit_batches: Limits the batches during this training epoch.
83 | If greater than the number of batches in the ``train_loader``, this has no effect.
84 | scheduler_cfg: The learning rate scheduler configuration.
85 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
86 | for supported values.
87 |
88 | """
89 | modality_list = model.modalitys
90 | all_modalitys = list(model.modalitys)
91 | all_modalitys.append('output')
92 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
93 | iterable = self.progbar_wrapper(
94 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
95 | )
96 | len_dataloader = len(train_loader)
97 | self.fabric.call("on_train_epoch_start")
98 |
99 | for batch_idx, batch in enumerate(iterable):
100 | # end epoch if stopping training completely or max batches for this epoch reached
101 | if self.should_stop or batch_idx >= limit_batches:
102 | break
103 |
104 | self.fabric.call("on_train_batch_start", batch, batch_idx)
105 |
106 | # check if optimizer should step in gradient accumulation
107 | should_optim_step = self.global_step % self.grad_accum_steps == 0
108 | if should_optim_step:
109 | self.training_step(model=model, batch=batch, batch_idx=batch_idx,len_dataloader=len_dataloader,optimizer=optimizer)
110 |
111 | else:
112 | # gradient accumulation -> no optimizer step
113 | self.training_step(model=model, batch=batch, batch_idx=batch_idx,len_dataloader=len_dataloader)
114 | # self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
115 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
116 |
117 | # this guard ensures, we only step the scheduler once per global step
118 | # if should_optim_step:
119 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
120 |
121 | # add output values to progress bar
122 | self._format_iterable(iterable, self._current_train_return, "train")
123 | # only increase global step if optimizer stepped
124 | self.global_step += int(should_optim_step)
125 |
126 | # self._current_metrics = self.precision_calculator.compute_metrics()
127 | self.fabric.call("on_train_epoch_end")
128 |
129 | def training_step(self, model, batch, batch_idx,len_dataloader,optimizer):
130 |
131 | # TODO: make it simpler and easier to extend
132 | modality_list = model.modalitys
133 | key = list(modality_list)
134 | # out_v,out_a,out = unimodal_result['visual'], unimodal_result['audio'], unimodal_result['output']
135 | label = batch['label']
136 | # label = label.to(model.device)
137 | loss = 0
138 | loss_modality = {}
139 | # feature =model(batch)
140 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
141 | # feature = {}
142 | # for modality in modality_list:
143 | # feature[modality] = model.encoder_result[modality].clone().contiguous()
144 | for modality in modality_list:
145 | feature = model.feature_extract(batch, modality = modality)
146 | out = model.fusion_module.fc_out(feature)
147 | loss = self.criterion(out,label)
148 | try:
149 | loss.backward()
150 | except RuntimeError as e:
151 |
152 | print("Computation graph:")
153 | for name, param in model.named_parameters():
154 | if param.grad is not None:
155 | print(f"{name} grad shape: {param.grad.shape}")
156 | raise e
157 | encoder_out = feature.detach()
158 | self.gs_plugin.before_update(model.fusion_module.fc_out, encoder_out,
159 | batch_idx, len_dataloader, self.gs_plugin.exp_count)
160 |
161 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
162 | optimizer.step()
163 | self.fabric.call("on_before_zero_grad", optimizer)
164 | optimizer.zero_grad()
165 | loss_modality[modality] = loss.item()
166 | self.gs_plugin.exp_count += 1
167 |
168 | for n, p in model.named_parameters():
169 | if p.grad != None:
170 | del p.grad
171 |
172 | loss = self.alpha*loss_modality[key[0]]+(1-self.alpha)*loss_modality[key[1]]
173 |
174 |
175 | else:
176 | loss = self.criterion(model.unimodal_result['output'], label)
177 | loss.backward()
178 |
179 | return loss
180 |
--------------------------------------------------------------------------------
/balancemm/trainer/OGM_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from torch.optim.optimizer import Optimizer as Optimizer
3 | from .base_trainer import BaseTrainer
4 | import copy
5 | import torch
6 | import torch.nn as nn
7 | from balancemm.models.avclassify_model import BaseClassifierModel
8 | from collections.abc import Mapping
9 | from functools import partial
10 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
11 | import lightning as L
12 | import torch
13 |
14 | class OGMTrainer(BaseTrainer):
15 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
16 | super(OGMTrainer,self).__init__(fabric,**para_dict)
17 | self.alpha = method_dict['alpha']
18 | self.method = method_dict['method']
19 | self.modulation_starts = method_dict['modulation_starts']
20 | self.modulation_ends = method_dict['modulation_ends']
21 | # self.modality = method_dict['modality']
22 |
23 | def train_loop(
24 | self,
25 | model: L.LightningModule,
26 | optimizer: torch.optim.Optimizer,
27 | train_loader: torch.utils.data.DataLoader,
28 | limit_batches: Union[int, float] = float("inf"),
29 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
30 | ):
31 | """The training loop running a single training epoch.
32 |
33 | Args:
34 | model: the LightningModule to train
35 | optimizer: the optimizer, optimizing the LightningModule.
36 | train_loader: The dataloader yielding the training batches.
37 | limit_batches: Limits the batches during this training epoch.
38 | If greater than the number of batches in the ``train_loader``, this has no effect.
39 | scheduler_cfg: The learning rate scheduler configuration.
40 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
41 | for supported values.
42 |
43 | """
44 | self.fabric.call("on_train_epoch_start")
45 | all_modalitys = list(model.modalitys)
46 | all_modalitys.append('output')
47 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
48 | iterable = self.progbar_wrapper(
49 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
50 | )
51 |
52 | for batch_idx, batch in enumerate(iterable):
53 | # end epoch if stopping training completely or max batches for this epoch reached
54 | if self.should_stop or batch_idx >= limit_batches:
55 | break
56 |
57 | self.fabric.call("on_train_batch_start", batch, batch_idx)
58 |
59 | # check if optimizer should step in gradient accumulation
60 | should_optim_step = self.global_step % self.grad_accum_steps == 0
61 | if should_optim_step:
62 | # currently only supports a single optimizer
63 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
64 |
65 | # optimizer step runs train step internally through closure
66 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
67 | self.fabric.call("on_before_zero_grad", optimizer)
68 | # torch.cuda.empty_cache()
69 | optimizer.zero_grad()
70 |
71 | else:
72 | # gradient accumulation -> no optimizer step
73 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
74 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
75 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
76 |
77 | # this guard ensures, we only step the scheduler once per global step
78 | # if should_optim_step:
79 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
80 |
81 | # add output values to progress bar
82 |
83 | self._format_iterable(iterable, self._current_train_return, "train")
84 |
85 | # only increase global step if optimizer stepped
86 | self.global_step += int(should_optim_step)
87 | self._current_metrics = self.precision_calculator.compute_metrics()
88 |
89 | def training_step(self, model : BaseClassifierModel, batch, batch_idx):
90 |
91 | # TODO: make it simpler and easier to extend
92 | softmax = nn.Softmax(dim=1)
93 | criterion = nn.CrossEntropyLoss()
94 | relu = nn.ReLU(inplace=True)
95 | tanh = nn.Tanh()
96 | label = batch['label']
97 | label = label.to(model.device)
98 | model(batch)
99 | # model.Unimodality_Calculate()
100 | loss = criterion(model.unimodal_result['output'], label)
101 | loss.backward()
102 | modality_list = model.modalitys
103 |
104 | # Modulation starts here !
105 | modality_nums = len(modality_list)
106 | scores = {}
107 | ratios = {}
108 | coeffs = {}
109 | minscore = float('inf')
110 | #Calculate the scores
111 |
112 | for modality in modality_list:
113 | if modality_nums == 2 or self.method == 'OGM_GE3':
114 | score_modality = sum([softmax(model.unimodal_result[modality])[i][label[i]] for i in range(model.unimodal_result['output'].size(0))])
115 | elif modality_nums == 3:
116 | score_modality = sum([softmax(torch.cos(model.unimodal_result[modality]))[i][label[i]] if label[i] == torch.argmax(model.unimodal_result[modality][i]) else 0 for i in range(model.unimodal_result['output'].size(0))])
117 | else:
118 | raise("Wrong number of modalitys for OGM, it should be 2 or 3, but given {:0}".format(modality_nums))
119 | try:
120 | scores[modality] = score_modality.detach().clone()
121 | except:
122 | continue
123 | minscore = min(score_modality, minscore)
124 | ##Calculate the ratios
125 | if self.method == 'OGM_GE3':
126 | for modality in modality_list:
127 | count = 0
128 | ratios[modality] = 0
129 | for modality_another in modality_list:
130 | if modality_another == modality:
131 | continue
132 | ratios[modality] += scores[modality].detach().clone()/(scores[modality_another]+ 1e-5)
133 | count += 1
134 | ratios[modality]/= count
135 | elif self.method == "OGM_GE":
136 | for modality in modality_list:
137 | ratios[modality] = scores[modality].detach().clone()
138 | if modality_nums == 2:
139 | for modality_another in modality_list:
140 | if modality_another == modality:
141 | continue
142 |
143 | ratios[modality] /= (scores[modality_another]+ 1e-5) # prevent OOM
144 | if modality_nums == 3:
145 | ratios[modality] /= (minscore + 1e-5)
146 | #Calculate the coeffects
147 | for modality in modality_list:
148 | if ratios[modality] > 1 :
149 | coeffs[modality] = max(1 - tanh(self.alpha * relu(ratios[modality])),0)
150 | else:
151 | coeffs[modality] = 1
152 |
153 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: # bug fixed
154 | for name, parms in model.named_parameters():
155 | layer = str(name)
156 | for modality in modality_list:
157 |
158 | if modality in layer and len(parms.grad.size()) != 1: ##Don't change the grad of bias for layer
159 | if self.method == 'OGM_GE' or self.method == 'OGM_GE3': # bug fixed
160 | parms.grad = parms.grad * coeffs[modality] + \
161 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
162 | elif self.method == 'OGM':
163 | parms.grad *= coeffs[modality]
164 | else:
165 | pass
166 |
167 |
168 | return loss
--------------------------------------------------------------------------------
/balancemm/trainer/OPM_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from torch.optim.optimizer import Optimizer as Optimizer
3 | from .base_trainer import BaseTrainer
4 |
5 | import torch
6 | import torch.nn as nn
7 | from balancemm.models.avclassify_model import BaseClassifierModel
8 | from collections.abc import Mapping
9 | from functools import partial
10 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
11 | import lightning as L
12 | import torch
13 | import numpy as np
14 |
15 |
16 | class Modality_drop():
17 | def __init__(self, dim_list, p_exe=0.7, device='cuda'):
18 | self.dim_list = dim_list
19 | self.p_exe = p_exe
20 | self.device = device
21 |
22 | def execute_drop(self, feat_list, q, model):
23 | modality_list = list(model.modalitys)
24 | B = feat_list[modality_list[0]].shape[0] # batch size
25 | exe_drop = torch.tensor(np.random.rand(1)).to(device=self.device) >= 1-self.p_exe
26 | if not exe_drop:
27 | return feat_list, torch.ones([B], dtype=torch.float32, device=self.device)
28 |
29 | d_sum = sum(self.dim_list.values())
30 | q_sum = sum(self.dim_list[m] * q[m] for m in modality_list)
31 | theta = q_sum/d_sum
32 | num_mod = len(modality_list)
33 | q_temp = torch.tensor([q[m] for m in modality_list], device=self.device)
34 | mask = torch.distributions.Bernoulli(1 - q_temp).sample([B, 1]).permute(2, 1, 0).contiguous().reshape(num_mod, B, -1).to(self.device)
35 |
36 | cleaned = {}
37 | for idx, modality in enumerate(modality_list):
38 | D = feat_list[modality].shape[1]
39 | current_mask = mask[idx].expand(-1,D)
40 | cleaned_fea = torch.mul(feat_list[modality], current_mask)
41 | cleaned_fea = cleaned_fea / (1 - theta + 1e-5)
42 | cleaned[modality] = cleaned_fea
43 |
44 | mask = mask.squeeze(-1).transpose(0,1) # [B,num_mod]
45 |
46 | update_flag = torch.sum(mask,dim=1) > 0
47 | for modality in modality_list:
48 | cleaned[modality] = cleaned[modality][update_flag]
49 | return cleaned,update_flag
50 |
51 |
52 | def clip(a, b, c):
53 | if b= limit_batches:
104 | break
105 |
106 | self.fabric.call("on_train_batch_start", batch, batch_idx)
107 |
108 | # check if optimizer should step in gradient accumulation
109 | should_optim_step = self.global_step % self.grad_accum_steps == 0
110 | if should_optim_step:
111 | # currently only supports a single optimizer
112 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
113 |
114 | # optimizer step runs train step internally through closure
115 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
116 | self.fabric.call("on_before_zero_grad", optimizer)
117 | # torch.cuda.empty_cache()
118 | optimizer.zero_grad()
119 |
120 | else:
121 | # gradient accumulation -> no optimizer step
122 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
123 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
124 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
125 |
126 | # this guard ensures, we only step the scheduler once per global step
127 | # if should_optim_step:
128 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
129 |
130 | # add output values to progress bar
131 |
132 | self._format_iterable(iterable, self._current_train_return, "train")
133 |
134 | # only increase global step if optimizer stepped
135 | self.global_step += int(should_optim_step)
136 | self._current_metrics = self.precision_calculator.compute_metrics()
137 |
138 | def training_step(self, model : BaseClassifierModel, batch, batch_idx):
139 |
140 | # TODO: make it simpler and easier to extend
141 | softmax = nn.Softmax(dim=1)
142 | criterion = nn.CrossEntropyLoss()
143 | relu = nn.ReLU(inplace=True)
144 | tanh = nn.Tanh()
145 | label = batch['label']
146 | label = label.to(model.device)
147 | model(batch)
148 | model.Unimodality_Calculate()
149 | loss = {}
150 | modality_list = model.modalitys
151 | key = list(model.modalitys)
152 |
153 | # Modulation starts here !
154 | modality_nums = len(modality_list)
155 | scores = {}
156 | ratios = {}
157 | coeffs = {}
158 | #Calculate the scores
159 |
160 |
161 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends: # bug fixed
162 | for modality in modality_list:
163 | scores[modality] = sum([softmax(model.unimodal_result[modality])[i][label[i]] for i in range(model.unimodal_result['output'].shape[0])])
164 | ##Calculate the ratios
165 | for modality in modality_list:
166 | ratios[modality] = scores[modality]
167 | if modality_nums == 2:
168 | for modality_another in modality_list:
169 | if modality_another == modality:
170 | continue
171 |
172 | ratios[modality] /= (scores[modality_another]+ 1e-5) # prevent OOM
173 | ratios[modality] = tanh(relu(ratios[modality]-1))
174 | if modality_nums == 3:
175 | temp_score = 0.0
176 | for modality_another in modality_list:
177 | if modality_another == modality:
178 | continue
179 | temp_score += scores[modality_another]
180 | ratios[modality] /= (temp_score + 1e-5)
181 | ratios[modality] = tanh(relu(ratios[modality]-1))
182 | #Calculate the coeffs
183 | for modality in modality_list:
184 | coeffs[modality] = self.q_base * (1 + self.alpha * ratios[modality]) if ratios[modality]>0 else 0
185 | coeffs[modality] = clip(coeffs[modality],0.0,1.0)
186 | model.encoder_result.pop('output')
187 |
188 | cleaned_fea,update_flag=self.modality_drop.execute_drop(model.encoder_result,coeffs,model)
189 |
190 | model.unimodal_result['output'] = model.fusion_module(cleaned_fea)
191 | select_mask=update_flag!=0
192 | label=label[select_mask]
193 |
194 |
195 | for modality in modality_list:
196 |
197 | model.unimodal_result[modality]=model.unimodal_result[modality][select_mask]
198 |
199 |
200 | for modality in model.unimodal_result.keys():
201 | loss[modality] = criterion(model.unimodal_result[modality],label)
202 |
203 | loss['output'].backward()
204 |
205 | else:
206 | pass
207 |
208 | return loss['output']
209 |
--------------------------------------------------------------------------------
/balancemm/trainer/PMR_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 | from lightning_utilities import apply_to_collection
15 | import lightning as L
16 | import torch
17 |
18 | def clip(a, b, c):
19 | if b= limit_batches:
73 | break
74 |
75 | self.fabric.call("on_train_batch_start", batch, batch_idx)
76 |
77 | # check if optimizer should step in gradient accumulation
78 | should_optim_step = self.global_step % self.grad_accum_steps == 0
79 | if should_optim_step:
80 | # currently only supports a single optimizer
81 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
82 |
83 | # optimizer step runs train step internally through closure
84 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx, proto = self.proto))
85 | self.fabric.call("on_before_zero_grad", optimizer)
86 |
87 | optimizer.zero_grad()
88 |
89 | else:
90 | # gradient accumulation -> no optimizer step
91 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
92 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
93 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
94 |
95 | # this guard ensures, we only step the scheduler once per global step
96 | # if should_optim_step:
97 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
98 |
99 | # add output values to progress bar
100 | self._format_iterable(iterable, self._current_train_return, "train")
101 | # only increase global step if optimizer stepped
102 | self.global_step += int(should_optim_step)
103 |
104 | self._current_metrics = self.precision_calculator.compute_metrics()
105 | self.fabric.call("on_train_epoch_end")
106 |
107 | def training_step(self, model, batch, batch_idx, proto):
108 |
109 | # TODO: make it simpler and easier to extend
110 | criterion = nn.CrossEntropyLoss()
111 | softmax = nn.Softmax(dim=1)
112 | log_softmax = nn.LogSoftmax(dim=1)
113 | tanh = nn.Tanh()
114 | modality_list = model.modalitys
115 | key = list(modality_list)
116 | m = {}
117 | model(batch)
118 | for modality in modality_list:
119 | m[modality] = model.encoder_result[modality]
120 | # a, v = model(batch)['audio'], model(batch)['visual']
121 | unimodal_result = model.Unimodality_Calculate()
122 | # out_v,out_a,out = unimodal_result['visual'], unimodal_result['audio'], unimodal_result['output']
123 | label = batch['label']
124 | label = label.to(model.device)
125 | loss_modality = {}
126 | for modality in modality_list:
127 | # print(unimodal_result[modality])
128 | # print(label)
129 | loss_modality[modality] = criterion(unimodal_result[modality],label)
130 |
131 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
132 | sim = {}
133 | for modality in modality_list:
134 | sim[modality] = -EU_dist(m[modality],proto[modality])
135 | # audio_sim = -EU_dist(a, audio_proto) # B x n_class
136 | # visual_sim = -EU_dist(v, visual_proto) # B x n_class
137 | # print('sim: ', audio_sim[0][0].data, visual_sim[0][0].data, a[0][0].data, v[0][0].data)
138 |
139 | score_p = {}
140 | # score = {}
141 | loss_proto = {}
142 | for modality in modality_list:
143 | score_p[modality] = sum([softmax(sim[modality])[i][label[i]] for i in range(sim[modality].size(0))])
144 |
145 | # score_a_p = sum([softmax(audio_sim)[i][label[i]] for i in range(audio_sim.size(0))])
146 | # score_v_p = sum([softmax(visual_sim)[i][label[i]] for i in range(visual_sim.size(0))])
147 | if len(modality_list) == 2:
148 | ratio_a_p = score_p[key[0]] / score_p[key[1]]
149 | else:
150 | ratio = {}
151 | min_score = min(score_p.values())
152 | for modality in modality_list:
153 | ratio[modality] = score_p[modality] / min_score
154 |
155 | # for modality in modality_list:
156 | # score[modality] = sum([softmax(unimodal_result[modality])[i][label[i]] for i in range(unimodal_result[modality].size(0))])
157 | # # score_v = sum([softmax(out_v)[i][label[i]] for i in range(out_v.size(0))])
158 | # # score_a = sum([softmax(out_a)[i][label[i]] for i in range(out_a.size(0))])
159 | # ratio_a = score[key[0]] / score[key[1]]
160 |
161 | for modality in modality_list:
162 | loss_proto[modality] = criterion(sim[modality],label)
163 | # loss_proto_a = criterion(audio_sim, label)
164 | # loss_proto_v = criterion(visual_sim, label)
165 | if len(modality_list) == 2:
166 | if ratio_a_p > 1:
167 | beta = 0 # audio coef
168 | lam = 1 * self.alpha # visual coef
169 | elif ratio_a_p < 1:
170 | beta = 1 * self.alpha
171 | lam = 0
172 | else:
173 | beta = 0
174 | lam = 0
175 | loss = criterion(unimodal_result['output'], label) + beta * loss_proto[key[0]] + lam * loss_proto[key[1]]
176 | loss.backward()
177 | else:
178 | loss = criterion(unimodal_result['output'], label)
179 | loss.backward()
180 | k_t = {}
181 | for modality in modality_list:
182 | if ratio[modality] > 1:
183 | k_t[modality] = 1-tanh(self.eta * ratio[modality])
184 | else:
185 | k_t[modality] = 1
186 |
187 | for name, parms in model.named_parameters():
188 | layer = str(name)
189 | for modality in modality_list:
190 | if modality in layer and len(parms.grad.size()) != 1: ##Don't change the grad of bias for layer
191 | parms.grad = parms.grad * k_t[modality] - \
192 | torch.zeros_like(parms.grad).normal_(0, parms.grad.std().item() + 1e-8)
193 | # loss_a = criterion(out_a, label)
194 | else:
195 | loss = criterion(unimodal_result['output'], label)
196 | loss.backward()
197 | # # avoid gradients in stored/accumulated values -> prevents potential OOM
198 | # self._current_train_return = apply_to_collection(model.encoder_result, dtype=torch.Tensor, function=lambda x: x.detach())
199 | return loss
200 | def calculate_prototype(self, model, dataloader, proto0):
201 | # todo customed output of prototype
202 | n_classes = model.n_classes
203 | device = next(model.parameters()).device
204 | proto = {}
205 | modality_list = model.modalitys
206 | for modality in modality_list:
207 | proto[modality] = torch.zeros(n_classes, model.modality_size[modality]).to(device)
208 | count_class = [0 for _ in range(n_classes)]
209 |
210 | # calculate prototype
211 | model.eval()
212 | with torch.no_grad():
213 | sample_count = 0
214 | all_num = len(dataloader)
215 | m = {}
216 | for batch_idx, batch in enumerate(dataloader):
217 | model(batch)
218 | for modality in modality_list:
219 | m[modality] = model.encoder_result[modality]
220 | label = batch['label']
221 |
222 |
223 | for c, l in enumerate(label):
224 | l = l.long()
225 | count_class[l] += 1
226 | for modality in modality_list:
227 |
228 | proto[modality][l,:] += m[modality][c,:]
229 |
230 |
231 | sample_count += 1
232 |
233 | if sample_count >= all_num // 10:
234 | break
235 | for modality in modality_list:
236 | for c in range(proto[modality].shape[0]):
237 | proto[modality][c,:] /= count_class[c]
238 | # audio_prototypes[c, :] /= count_class[c]
239 | # visual_prototypes[c, :] /= count_class[c]
240 |
241 | if self.current_epoch <= 0:
242 | for modality in modality_list:
243 | proto[modality] = proto[modality]
244 | # audio_prototypes = audio_prototypes
245 | # visual_prototypes = visual_prototypes
246 | else:
247 | for modality in modality_list:
248 | proto[modality] = (1-self.momentum_coef) * proto[modality] + self.momentum_coef * proto0[modality]
249 | # audio_prototypes = (1 - self.momentum_coef) * audio_prototypes + self.momentum_coef * a_proto
250 | # visual_prototypes = (1 - self.momentum_coef) * visual_prototypes + self.momentum_coef * v_proto
251 | return proto
--------------------------------------------------------------------------------
/balancemm/trainer/UMT_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 | from ..models.avclassify_model import BaseClassifierModel
18 | from ..evaluation.complex import get_flops
19 | from ..models import create_model
20 | import copy
21 | from ..utils.train_utils import get_checkpoint_files, get_newest_path
22 | import os.path as osp
23 | import yaml
24 | from models.avclassify_model import MultiModalParallel
25 | class UMTTrainer(BaseTrainer):
26 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}, args = {}):
27 | super(UMTTrainer,self).__init__(fabric,**para_dict)
28 | self.scaling = method_dict['scaling']
29 | self.modulation_starts = method_dict['modulation_starts']
30 | self.modulation_ends = method_dict['modulation_ends']
31 |
32 | self.loaded_model = []
33 | loaded_model = {}
34 | temp_args = copy.deepcopy(args)
35 | # root_path = osp.dirname(osp.dirname(__file__))
36 | # with open(osp.join(root_path ,"configs", "encoder_config.yaml"), 'r') as f:
37 | # encoder_settings = yaml.safe_load(f)
38 | if args.mode == "train_and_test":
39 | out_dir = '_'.join(temp_args.out_dir.split('/')[:-1])
40 | for modality in args.model['encoders'].keys():
41 | temp_args.model['encoders'] = {modality: args.model['encoders'][modality]}
42 | temp_args.model['modality_size'] = {modality: args.model['modality_size'][modality]}
43 | loaded_model[modality] = create_model(temp_args.model)
44 | out_dir = temp_args.out_dir.replace('UMTTrainer', 'unimodalTrainer_' + modality)
45 | out_dir = '/'.join(out_dir.split('/')[:-1])
46 | path = get_newest_path(out_dir)
47 | # loaded_model[modality].load_state_dict(torch.load(get_checkpoint_files(path)[0])['model'])
48 | loaded_model[modality].load_state_dict(torch.load(get_checkpoint_files(path)[0])['model'])
49 | loaded_model[modality] = MultiModalParallel(loaded_model[modality],device_ids=[0,1])
50 | loaded_model[modality] =loaded_model[modality].cuda()
51 | # loaded_model[modality] = torch.load(get_checkpoint_files(path)[0])['model']
52 | # print(type(loaded_model))
53 | self.loaded_model = loaded_model
54 |
55 | def train_loop(
56 | self,
57 | model: L.LightningModule,
58 | optimizer: torch.optim.Optimizer,
59 | train_loader: torch.utils.data.DataLoader,
60 | limit_batches: Union[int, float] = float("inf"),
61 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
62 | ):
63 | """The training loop running a single training epoch.
64 |
65 | Args:
66 | model: the LightningModule to train
67 | optimizer: the optimizer, optimizing the LightningModule.
68 | train_loader: The dataloader yielding the training batches.
69 | limit_batches: Limits the batches during this training epoch.
70 | If greater than the number of batches in the ``train_loader``, this has no effect.
71 | scheduler_cfg: The learning rate scheduler configuration.
72 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
73 | for supported values.
74 |
75 | """
76 | self.fabric.call("on_train_epoch_start")
77 | all_modalitys = list(model.modalitys)
78 | all_modalitys.append('output')
79 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
80 | iterable = self.progbar_wrapper(
81 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
82 | )
83 | # if self.current_epoch == 0:
84 | # for batch_idx, batch in enumerate(iterable):
85 | # batch_sample = batch
86 | # break
87 | # print(batch_sample.keys())
88 | # model_flops, _ =get_flops(model = model, input_sample = batch_sample)
89 | # self.FlopsMonitor.update(model_flops / len(batch_sample['label']) * len(train_loader), 'forward')
90 | # self.FlopsMonitor.report(logger = self.logger)
91 | for batch_idx, batch in enumerate(iterable):
92 | # end epoch if stopping training completely or max batches for this epoch reached
93 | if self.should_stop or batch_idx >= limit_batches:
94 | break
95 |
96 | self.fabric.call("on_train_batch_start", batch, batch_idx)
97 |
98 | # check if optimizer should step in gradient accumulation
99 | should_optim_step = self.global_step % self.grad_accum_steps == 0
100 | if should_optim_step:
101 | # currently only supports a single optimizer
102 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
103 |
104 | # optimizer step runs train step internally through closure
105 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
106 | self.fabric.call("on_before_zero_grad", optimizer)
107 |
108 | optimizer.zero_grad()
109 |
110 | else:
111 | # gradient accumulation -> no optimizer step
112 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
113 |
114 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
115 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
116 |
117 | # this guard ensures, we only step the scheduler once per global step
118 | # if should_optim_step:
119 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
120 |
121 | # add output values to progress bar
122 |
123 | self._format_iterable(iterable, self._current_train_return, "train")
124 |
125 | # only increase global step if optimizer stepped
126 | self.global_step += int(should_optim_step)
127 | self._current_metrics = self.precision_calculator.compute_metrics()
128 | self.fabric.call("on_train_epoch_end")
129 | def training_step(self, model: BaseClassifierModel, batch, batch_idx):
130 |
131 | # TODO: make it simpler and easier to extend
132 | criterion = nn.CrossEntropyLoss()
133 | MSE = nn.MSELoss()
134 | label = batch['label']
135 | label = label.to(model.device)
136 | _ = model(batch= batch)
137 | out = model.encoder_result['output']
138 | loss = criterion(out, label)
139 | if self.modulation_starts <= self.current_epoch <= self.modulation_ends:
140 | for modality in self.loaded_model.keys():
141 | with torch.no_grad():
142 | self.loaded_model[modality](batch)
143 | out_unimodal = self.loaded_model[modality].encoder_result[modality]
144 | loss += self.scaling * MSE(out_unimodal, model.encoder_result[modality])
145 |
146 | loss.backward()
147 | return loss
--------------------------------------------------------------------------------
/balancemm/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | from os import path as osp
4 | from types import SimpleNamespace
5 | import lightning as L
6 | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
7 |
8 | __all__ = ['create_trainer']
9 |
10 | # automatically scan and import trainer modules
11 | # scan all the files under the data folder with '_trainer' in file names
12 | trainer_folder = osp.dirname(osp.abspath(__file__))
13 | trainer_filenames = [
14 | osp.splitext(v)[0] for v in os.listdir(trainer_folder)
15 | if v.endswith('_trainer.py')
16 | ]
17 |
18 | # import all the trainer modules
19 | _trainer_modules = [
20 | importlib.import_module(f'.{file_name}', package="balancemm.trainer")
21 | for file_name in trainer_filenames
22 | ]
23 |
24 | def create_trainer(fabric: L.Fabric ,trainer_opt:dict, para_opt, args, logger,tb_logger):
25 | # dynamic instantiation
26 | for module in _trainer_modules:
27 | trainer_cls = getattr(module, trainer_opt["trainer"], None)
28 | if trainer_cls is not None:
29 | break
30 | if trainer_cls is None:
31 | raise ValueError(f'trainer {trainer} is not found.')
32 | para_opt['base_para']['logger'] = logger
33 | if args.trainer['name'] != 'UMT' and args.trainer['name'] != 'LinearProbe' and args.trainer['name'] != "MLA" and args.trainer['name'] != "OPM" and args.trainer["name"] != "Sample":
34 | trainer = trainer_cls(fabric, para_opt, para_opt['base_para'])
35 | else:
36 | trainer = trainer_cls(fabric, para_opt, para_opt['base_para'], args)
37 | trainer.checkpoint_dir = args.checkpoint_dir
38 |
39 | print(
40 | f'Trainer {trainer.__class__.__name__} - {trainer_opt["trainer"]} '
41 | 'is created.')
42 | para_opt['name'] = trainer_opt["trainer"]
43 | # logger.info("normal Settings: %s", para_opt)
44 | logger.info("trainer Settings: %s", para_opt)
45 | if isinstance(tb_logger, TensorBoardLogger):
46 | tb_logger.log_hyperparams(trainer_opt)
47 | tb_logger.experiment.add_text("Trainer Setup", f"Trainer: {trainer_opt['trainer']}")
48 | return trainer
--------------------------------------------------------------------------------
/balancemm/trainer/baseline_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 | from ..models.avclassify_model import BaseClassifierModel
18 | class baselineTrainer(BaseTrainer):
19 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
20 | super(baselineTrainer,self).__init__(fabric,**para_dict)
21 |
22 | self.modulation_starts = method_dict['modulation_starts']
23 | self.modulation_ends = method_dict['modulation_ends']
24 | self.modality = method_dict['modality']
25 |
26 | def train_loop(
27 | self,
28 | model: L.LightningModule,
29 | optimizer: torch.optim.Optimizer,
30 | train_loader: torch.utils.data.DataLoader,
31 | limit_batches: Union[int, float] = float("inf"),
32 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
33 | ):
34 | """The training loop running a single training epoch.
35 |
36 | Args:
37 | model: the LightningModule to train
38 | optimizer: the optimizer, optimizing the LightningModule.
39 | train_loader: The dataloader yielding the training batches.
40 | limit_batches: Limits the batches during this training epoch.
41 | If greater than the number of batches in the ``train_loader``, this has no effect.
42 | scheduler_cfg: The learning rate scheduler configuration.
43 | Have a look at :meth:`~lightning.pytorch.core.LightningModule.configure_optimizers`
44 | for supported values.
45 |
46 | """
47 | self.fabric.call("on_train_epoch_start")
48 | all_modalitys = list(model.modalitys)
49 | all_modalitys.append('output')
50 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
51 | iterable = self.progbar_wrapper(
52 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
53 | )
54 | for batch_idx, batch in enumerate(iterable):
55 | # end epoch if stopping training completely or max batches for this epoch reached
56 | if self.should_stop or batch_idx >= limit_batches:
57 | break
58 |
59 | self.fabric.call("on_train_batch_start", batch, batch_idx)
60 |
61 | # check if optimizer should step in gradient accumulation
62 | should_optim_step = self.global_step % self.grad_accum_steps == 0
63 | if should_optim_step:
64 | # currently only supports a single optimizer
65 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
66 |
67 | # optimizer step runs train step internally through closure
68 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
69 | self.fabric.call("on_before_zero_grad", optimizer)
70 |
71 | optimizer.zero_grad()
72 |
73 | else:
74 | # gradient accumulation -> no optimizer step
75 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
76 | # print(len(batch['label']), len(model.module.prediction['output']))
77 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
78 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
79 |
80 | self._format_iterable(iterable, self._current_train_return, "train")
81 |
82 | # only increase global step if optimizer stepped
83 | self.global_step += int(should_optim_step)
84 | self._current_metrics = self.precision_calculator.compute_metrics()
85 | self.fabric.call("on_train_epoch_end")
86 | def training_step(self, model: BaseClassifierModel, batch, batch_idx):
87 |
88 | # TODO: make it simpler and easier to extend
89 | criterion = nn.CrossEntropyLoss()
90 | label = batch['label']
91 | label = label.to(model.device)
92 |
93 | model(batch)
94 | out = model.encoder_result['output']
95 | loss = criterion(out, label)
96 | loss.backward()
97 | return loss
--------------------------------------------------------------------------------
/balancemm/trainer/unimodal_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from lightning import LightningModule
3 | from torch.optim.optimizer import Optimizer as Optimizer
4 | from .base_trainer import BaseTrainer
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 | from collections.abc import Mapping
12 | from functools import partial
13 | from typing import Any, Iterable, List, Literal, Optional, Tuple, Union, cast
14 |
15 | import lightning as L
16 | import torch
17 | from ..models.avclassify_model import BaseClassifierModel
18 | from ..evaluation.complex import get_flops
19 | class unimodalTrainer(BaseTrainer):
20 | def __init__(self,fabric, method_dict: dict = {}, para_dict : dict = {}):
21 | super(unimodalTrainer,self).__init__(fabric,**para_dict)
22 | self.alpha = method_dict['alpha']
23 | self.method = method_dict['method']
24 | self.modulation_starts = method_dict['modulation_starts']
25 | self.modulation_ends = method_dict['modulation_ends']
26 | self.modality = method_dict['modality']
27 |
28 | def train_loop(
29 | self,
30 | model: L.LightningModule,
31 | optimizer: torch.optim.Optimizer,
32 | train_loader: torch.utils.data.DataLoader,
33 | limit_batches: Union[int, float] = float("inf"),
34 | scheduler_cfg: Optional[Mapping[str, Union[L.fabric.utilities.types.LRScheduler, bool, str, int]]] = None,
35 | ):
36 |
37 | self.fabric.call("on_train_epoch_start")
38 | all_modalitys = list(model.modalitys)
39 | all_modalitys.append('output')
40 | self.precision_calculator = self.PrecisionCalculatorType(model.n_classes, all_modalitys)
41 | iterable = self.progbar_wrapper(
42 | train_loader, total=min(len(train_loader), limit_batches), desc=f"Epoch {self.current_epoch}"
43 | )
44 | # if self.current_epoch == 0:
45 | # for batch_idx, batch in enumerate(iterable):
46 | # batch_sample = batch
47 | # break
48 | # print(batch_sample.keys())
49 | # model_flops, _ =get_flops(model = model, input_sample = batch_sample)
50 | # self.FlopsMonitor.update(model_flops / len(batch_sample['label']) * len(train_loader), 'forward')
51 | # self.FlopsMonitor.report(logger = self.logger)
52 | for batch_idx, batch in enumerate(iterable):
53 | # end epoch if stopping training completely or max batches for this epoch reached
54 | if self.should_stop or batch_idx >= limit_batches:
55 | break
56 |
57 | self.fabric.call("on_train_batch_start", batch, batch_idx)
58 |
59 | # check if optimizer should step in gradient accumulation
60 | should_optim_step = self.global_step % self.grad_accum_steps == 0
61 | if should_optim_step:
62 | # currently only supports a single optimizer
63 | self.fabric.call("on_before_optimizer_step", optimizer, 0)
64 |
65 | # optimizer step runs train step internally through closure
66 | optimizer.step(partial(self.training_step, model=model, batch=batch, batch_idx=batch_idx))
67 | self.fabric.call("on_before_zero_grad", optimizer)
68 |
69 | optimizer.zero_grad()
70 |
71 | else:
72 | # gradient accumulation -> no optimizer step
73 | self.training_step(model=model, batch=batch, batch_idx=batch_idx)
74 |
75 | self.precision_calculator.update(y_true = batch['label'].cpu(), y_pred = model.prediction)
76 | self.fabric.call("on_train_batch_end", self._current_train_return, batch, batch_idx)
77 |
78 | # this guard ensures, we only step the scheduler once per global step
79 | # if should_optim_step:
80 | # self.step_scheduler(model, scheduler_cfg, level="step", current_value=self.global_step)
81 |
82 | # add output values to progress bar
83 |
84 | self._format_iterable(iterable, self._current_train_return, "train")
85 |
86 | # only increase global step if optimizer stepped
87 | self.global_step += int(should_optim_step)
88 | self._current_metrics = self.precision_calculator.compute_metrics()
89 | self.fabric.call("on_train_epoch_end")
90 | def training_step(self, model: BaseClassifierModel, batch, batch_idx):
91 |
92 | # TODO: make it simpler and easier to extend
93 | criterion = nn.CrossEntropyLoss()
94 | label = batch['label']
95 | label = label.to(model.device)
96 |
97 | model(batch)
98 | out = model.encoder_result['output']
99 | loss = criterion(out, label)
100 | loss.backward()
101 | return loss
--------------------------------------------------------------------------------
/balancemm/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from torchvision import datasets, transforms
3 | from types import SimpleNamespace
4 | from ..datasets import create_dataset
5 | import lightning as L
6 | import numpy as np
7 | import random
8 | def worker_init_fn(worker_id):
9 | worker_seed = torch.initial_seed() % 2**32
10 | np.random.seed(worker_seed)
11 | random.seed(worker_seed)
12 |
13 | def create_train_val_dataloader(fabric: L.Fabric, config: dict):
14 |
15 | train_dataset = create_dataset(config.dataset, 'train')
16 | if config.dataset.get('validation', False):
17 | val_dataset = create_dataset(config.dataset, 'valid')
18 | test_dataset = create_dataset(config.dataset, 'test')
19 | else:
20 | val_dataset = create_dataset(config.dataset, 'test')
21 | test_dataset = val_dataset
22 |
23 | config_dataloader = SimpleNamespace(**config.dataloader)
24 | if config_dataloader.fast_run == True:
25 | train_dataset = torch.utils.data.Subset(train_dataset, list(range(config_dataloader.eff_batch_size*4)))
26 | val_dataset = torch.utils.data.Subset(val_dataset, list(range(config_dataloader.eff_batch_size*2)))
27 | test_dataset = torch.utils.data.Subset(test_dataset, list(range(config_dataloader.eff_batch_size*2)))
28 | # print len of datasets
29 | fabric.print(f"Train dataset: {train_dataset.__class__.__name__} - {config.Train['dataset']}, {len(train_dataset)} samples")
30 | fabric.print(f"Val dataset: {val_dataset.__class__.__name__} - {config.Val['dataset']}, {len(val_dataset)} samples")
31 |
32 | if (not hasattr(config_dataloader, 'batch_size') or config_dataloader.batch_size == -1):
33 | config_dataloader.batch_size = round(config_dataloader.eff_batch_size/fabric.world_size) # using the effective batch_size to calculate the batch_size per gpu
34 |
35 | train_dataloader = torch.utils.data.DataLoader(train_dataset,
36 | batch_size=config_dataloader.batch_size,
37 | shuffle=config_dataloader.shuffle,
38 | drop_last = config_dataloader.drop_last,
39 | num_workers = config_dataloader.num_workers,
40 | multiprocessing_context='spawn',
41 | pin_memory = config_dataloader.pin_memory,
42 | prefetch_factor=6)
43 | if config.trainer['name'] == 'Sample':
44 | train_val_dataloader = torch.utils.data.DataLoader(train_dataset,
45 | batch_size=config_dataloader.batch_size,
46 | shuffle=False,
47 | drop_last = config_dataloader.drop_last,
48 | num_workers = config_dataloader.num_workers,
49 | multiprocessing_context='spawn',
50 | pin_memory = config_dataloader.pin_memory,
51 | prefetch_factor=6)
52 |
53 | val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config_dataloader.batch_size, drop_last = config_dataloader.drop_last,
54 | num_workers = config_dataloader.num_workers,
55 | multiprocessing_context='spawn',
56 | pin_memory = config_dataloader.pin_memory,
57 | prefetch_factor=6)
58 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=config_dataloader.batch_size, drop_last = config_dataloader.drop_last,
59 | num_workers = config_dataloader.num_workers,
60 | multiprocessing_context='spawn',
61 | pin_memory = config_dataloader.pin_memory,
62 | prefetch_factor=6)
63 |
64 | if config.Train['dataset'] == 'FOOD101':
65 | g = torch.Generator()
66 | train_dataloader = torch.utils.data.DataLoader(train_dataset,
67 | batch_size=config_dataloader.batch_size,
68 | shuffle=config_dataloader.shuffle,
69 | drop_last = config_dataloader.drop_last,
70 | num_workers = config_dataloader.num_workers,
71 | multiprocessing_context='spawn',
72 | pin_memory = config_dataloader.pin_memory,
73 | worker_init_fn=worker_init_fn,
74 | generator=g)
75 |
76 | if config.trainer['name'] == 'Sample':
77 | train_val_dataloader = torch.utils.data.DataLoader(train_dataset,
78 | batch_size=config_dataloader.batch_size,
79 | shuffle=False,
80 | drop_last = config_dataloader.drop_last,
81 | num_workers = config_dataloader.num_workers,
82 | multiprocessing_context='spawn',
83 | pin_memory = config_dataloader.pin_memory,
84 | worker_init_fn=worker_init_fn,
85 | generator=g)
86 |
87 | val_dataloader = torch.utils.data.DataLoader(val_dataset,
88 | batch_size=config_dataloader.batch_size,
89 | shuffle=config_dataloader.shuffle,
90 | drop_last = config_dataloader.drop_last,
91 | num_workers = config_dataloader.num_workers,
92 | multiprocessing_context='spawn',
93 | pin_memory = config_dataloader.pin_memory,
94 | worker_init_fn=worker_init_fn,
95 | generator=g)
96 | test_dataloader = torch.utils.data.DataLoader(test_dataset,
97 | batch_size=config_dataloader.batch_size,
98 | shuffle=config_dataloader.shuffle,
99 | drop_last = config_dataloader.drop_last,
100 | num_workers = config_dataloader.num_workers,
101 | multiprocessing_context='spawn',
102 | pin_memory = config_dataloader.pin_memory,
103 | worker_init_fn=worker_init_fn,
104 | generator=g)
105 | # print batchsize and len
106 | fabric.print(f"Train dataloader: {len(train_dataloader)} batches, {len(train_dataloader.dataset)} samples")
107 | fabric.print(f"Val dataloader: {len(val_dataloader)} batches, {len(val_dataloader.dataset)} samples")
108 | if config.trainer['name'] == 'Sample':
109 | return train_dataloader, train_val_dataloader, val_dataloader, test_dataloader
110 | return train_dataloader, val_dataloader, test_dataloader
--------------------------------------------------------------------------------
/balancemm/utils/encoder_module.py:
--------------------------------------------------------------------------------
1 | def find_module(encoder):
2 | if hasattr(optim, optimizer_type):
3 | return getattr(optim, optimizer_type)
4 | # elif globals().get(optimizer_type + 'Optimizer'):
5 | # return globals()[optimizer_type + 'Optimizer']
6 | else:
7 | raise ValueError(f'Optimizer {optimizer_type} not found in torch.optim or current module.')
--------------------------------------------------------------------------------
/balancemm/utils/logger.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import sys
3 |
4 | def setup_logger(log_file: str):
5 | logger.remove()
6 | logger.add(log_file, rotation="10 MB", level="INFO")
7 | logger.add(sys.stdout, level="INFO")
8 | return logger
--------------------------------------------------------------------------------
/balancemm/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 | def find_module(optimizer_type):
4 | if hasattr(optim, optimizer_type):
5 | return getattr(optim, optimizer_type)
6 | # elif globals().get(optimizer_type + 'Optimizer'):
7 | # return globals()[optimizer_type + 'Optimizer']
8 | else:
9 | raise ValueError(f'Optimizer {optimizer_type} not found in torch.optim or current module.')
10 |
11 | def create_optimizer(model, args: dict, parameter: dict):
12 | if 'type' not in args:
13 | raise ValueError('Optimizer type is required.')
14 | optimizer_type = args['type']
15 | optimizer_cls = find_module(optimizer_type)
16 | optimizer_args = {k: v for k, v in args.items() if k != 'type'}
17 | optimizer_args['lr'] = parameter['base_lr'] if parameter['lr'] == -1 else parameter['lr']
18 | optimizer = optimizer_cls(model.parameters(), **optimizer_args)
19 | print (f'Optimizer {optimizer.__class__.__name__} - {optimizer_type} is created.')
20 | return optimizer
--------------------------------------------------------------------------------
/balancemm/utils/parser_utils.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from dynaconf import Dynaconf
3 |
4 | __all__ = ['parse_cli_args_to_dict', 'load_config_dict_from_yaml', 'ensure_and_get_config_path', 'find_module']
5 |
6 | def _merge_into_dict(target_dict: dict, key_path: str, value: any) -> None:
7 | """
8 | Merge a hierarchically structured key and its value into the target dictionary.
9 |
10 | :param target_dict: The dictionary to be updated.
11 | :param key_path: A string representing the hierarchical keys, separated by dots.
12 | :param value: The value to be set at the innermost key.
13 | """
14 | keys = key_path.split('.')
15 | for key in keys[:-1]:
16 | if key not in target_dict or not isinstance(target_dict[key], dict):
17 | target_dict[key] = {}
18 | target_dict = target_dict[key]
19 | target_dict[keys[-1]] = value
20 |
21 | def parse_cli_args_to_dict(cli_args: list[str], args_dict: dict = {}) -> dict:
22 | """
23 | Parses a list of command-line arguments into a dictionary, supporting keys prefixed with '-' or '--'.
24 | Values can be specified either by '=' or as the next argument. Keys without explicit values are assigned
25 | an empty string. When the same key is specified multiple times, the last value takes precedence.
26 | Notice that format '-a.b' or '-a .b' is not supported, use "--a.b" or "-a '.b'" instead.
27 |
28 | Args:
29 | cli_args (list[str]): The list of command-line arguments.
30 | args_dict (dict, optional): An existing dictionary to update with the command-line arguments.
31 | Defaults to None, in which case a new dictionary is created.
32 |
33 | Returns:
34 | dict: A dictionary containing the parsed command-line arguments, merged with any existing
35 | entries in `args_dict`.
36 | """
37 | i = 0
38 | while i < len(cli_args):
39 | if not cli_args[i].startswith('-'):
40 | i += 1
41 | continue
42 | if not cli_args[i].startswith('--'):
43 | # check for invalid argument format '-a.b' or '-a .b'
44 | if i + 1 < len(cli_args) and cli_args[i + 1].startswith('.'):
45 | raise ValueError(f"Invalid argument: '-{cli_args[i]}{cli_args[i + 1]}'. Use --a.b or -a '.b' instead.")
46 | arg = cli_args[i].lstrip('-')
47 | value = '' # Default value if no explicit value is provided
48 | # Check for '=' in the current argument
49 | if '=' in arg:
50 | arg, value = arg.split('=', 1)
51 | else:
52 | # Check if the next argument is a value (not another key)
53 | if i + 1 < len(cli_args) and not cli_args[i + 1].startswith('-'):
54 | value = cli_args[i + 1]
55 | i += 1 # Skip the next argument since it's a value
56 | # Use '.' to replace '-' in key paths for hierarchical structuring
57 | _merge_into_dict(args_dict, arg, value)
58 | i += 1
59 |
60 | return args_dict
61 |
62 | def load_config_dict(config_path: str) -> dict:
63 | """
64 | Load configuration using dynaconf, suppor toml, yaml.
65 |
66 | :param config_path: The file path to the YAML configuration file.
67 | :return: A dictionary containing the loaded configuration, or None if an error occurs.
68 | """
69 | try:
70 | with open(config_path, "r") as f:
71 | args = yaml.safe_load(f)
72 | except Exception as e:
73 | print(f"Error reading file {config_path}: \n{e}")
74 | print(f"Loaded config from {config_path}")
75 | return args
76 |
77 | def ensure_and_get_config_path(args: list[str], default_config_path: str) -> str:
78 | """
79 | Ensure that the --config argument is present in args, supporting both '--config=' and '--config ' formats,
80 | as well as the '-c' abbreviation. If not present, extend args with the default config path.
81 |
82 | :param args: List of command-line arguments.
83 | :param default_config_path: The default path to the configuration file if --config or -c is not specified.
84 | :return: config_path (str): The path to the configuration file.
85 | """
86 | config_key_long = '--config'
87 | config_key_short = '-c'
88 | config_path = None
89 |
90 | for i, arg in enumerate(args):
91 | if arg.startswith(config_key_long + '=') or arg.startswith(config_key_short + '='):
92 | config_path = arg.split('=', 1)[1]
93 | break
94 | elif arg == config_key_long or arg == config_key_short:
95 | if i + 1 < len(args) and not args[i + 1].startswith('-'):
96 | config_path = args[i + 1]
97 | break
98 |
99 | if config_path is None:
100 | config_path = default_config_path
101 | args.extend([f'{config_key_long}={config_path}'])
102 |
103 | return config_path
104 |
105 | def find_module(module_list: list, module_name: str, module_type: str = 'Module') -> any:
106 | module_name = module_name + module_type
107 | for module in module_list:
108 | module_cls = getattr(module, module_name, None)
109 | if module_cls is not None:
110 | return module_cls
111 | raise ValueError(f'{module_type} {module_name} is not found.')
--------------------------------------------------------------------------------
/balancemm/utils/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch.optim.lr_scheduler as lr_scheduler
2 |
3 | def find_scheduler(scheduler_type):
4 | if hasattr(lr_scheduler, scheduler_type):
5 | return getattr(lr_scheduler, scheduler_type)
6 | else:
7 | raise ValueError(f'Scheduler {scheduler_type} not found.')
8 |
9 | def create_scheduler(optimizer, args: dict):
10 | if 'type' not in args:
11 | raise ValueError('Scheduler type is required.')
12 | scheduler_type = args['type']
13 | scheduler_cls = find_scheduler(scheduler_type)
14 | scheduler_args = {k: v for k, v in args.items() if k != 'type'}
15 | scheduler = scheduler_cls(optimizer, **scheduler_args)
16 | print (f'Scheduler {scheduler.__class__.__name__} - {scheduler_type} is created.')
17 | return scheduler
18 |
--------------------------------------------------------------------------------
/balancemm/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 | from typing import Optional
3 | import subprocess
4 |
5 | import torch.nn as nn
6 |
7 | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
8 | from lightning.pytorch.loggers import WandbLogger
9 |
10 | import random
11 | import numpy as np
12 | import torch
13 | import os
14 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
15 | total = 0
16 | for p in module.parameters():
17 | if requires_grad is None or p.requires_grad == requires_grad:
18 | total += p.numel()
19 | return total
20 |
21 | def choose_logger(logger_name: str, log_dir, project: Optional[str] = None, comment: Optional[str] = None, *args, **kwargs):
22 | if logger_name == "csv":
23 | return CSVLogger(root_dir = log_dir, name = 'csv', *args, **kwargs)
24 | elif logger_name == "tensorboard":
25 | logger = TensorBoardLogger(root_dir=log_dir, name='tensorboard',default_hp_metric=False, *args, **kwargs)
26 | tensorboard_log_dir = os.path.join(log_dir, 'tensorboard')
27 | # subprocess.Popen(['tensorboard', '--logdir', tensorboard_log_dir])
28 |
29 | return logger
30 | elif logger_name == "wandb":
31 | return WandbLogger(project = project, save_dir = log_dir, notes = comment, *args, **kwargs)
32 | else:
33 | raise ValueError(f"`logger={logger_name}` is not a valid option.")
34 |
35 | def get_checkpoint_files(checkpoint_dir):
36 | checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*.ckpt")))
37 | print(f'the checkpoint is {checkpoint_files}')
38 | return checkpoint_files
39 |
40 | def get_newest_path(out_dir):
41 | folders = [f for f in os.listdir(out_dir) if os.path.isdir(os.path.join(out_dir, f)) and len(os.listdir(os.path.join(out_dir, f + '/checkpoints')))>0 ]
42 | # folder = max(folders, key=lambda f: os.path.getmtime(os.path.join(out_dir, f)))
43 | folder = max(folders)
44 | folder = os.path.join(out_dir, folder + '/checkpoints')
45 | if folder:
46 | return folder
47 | else:
48 | raise ValueError('there are no pretrained model')
49 | def set_seed(seed):
50 | """
51 | Set random seed for training
52 |
53 | Args:
54 | seed (int): random value
55 | """
56 | random.seed(seed)
57 | np.random.seed(seed)
58 | torch.manual_seed(seed)
59 | torch.cuda.manual_seed_all(seed)
60 |
61 | torch.backends.cudnn.deterministic = True
62 | torch.backends.cudnn.benchmark = False
63 |
64 | os.environ['PYTHONHASHSEED'] = str(seed)
65 |
66 | print(f"Random seed set as {seed}")
--------------------------------------------------------------------------------
/configs/dataset_config.yaml:
--------------------------------------------------------------------------------
1 |
2 |
3 | dataset:
4 | CMUMOSEI:
5 | dataset_path: None
6 | data : mosei_senti
7 | if_align: false
8 | classes: 2
9 | validation: True
10 | audio:
11 | input_dim: 74
12 | visual:
13 | input_dim: 35
14 | text:
15 | input_dim: 300
16 |
17 | FOOD101:
18 | targ_dir: None
19 | dataset_path: None
20 | classes: 101
21 | validation: false
22 | text:
23 | input_dim: 40
24 | visual:
25 | input_dim: 224
26 |
27 | UCF101:
28 | classes: 101
29 | validation: false
30 | stat_path : None
31 | train_txt : None
32 | test_txt : None
33 | visual_path : None
34 | flow_path_v : None
35 | flow_path_u : None
36 | flow:
37 | input_dim: 224
38 | visual:
39 | input_dim: 224
40 |
41 |
42 | Balance:
43 | csv_path: None
44 | visual_path: None
45 | audio_path: None
46 | validation: false
47 | classes: 30
48 | audio:
49 | input_dim: None
50 | visual:
51 | input_dim: None
52 |
53 | CREMAD:
54 | classes: 6
55 | fps: 2
56 | visual_path : None
57 | audio_path : None
58 | stat_path : None
59 | train_txt : None
60 | test_txt : None
61 | validation: false
62 | audio:
63 | input_dim: None
64 | visual:
65 | input_dim: None
66 |
67 | KineticsSounds:
68 | csv_path_train: None
69 | visual_path_train: None
70 | audio_path_train: None
71 | csv_path_test: None
72 | visual_path_test: None
73 | audio_path_test: None
74 | validation: false
75 | classes: 31
76 | audio:
77 | input_dim: None
78 | visual:
79 | input_dim: None
80 | VGGSound:
81 | validation: false
82 | classes: 309
83 | use_video_frames: 3
84 | csv_root: None
85 | video_train_root: None
86 | video_test_root: None
87 | audio_train_root: None
88 | audio_test_root: None
89 | audio:
90 | input_dim: None
91 | visual:
92 | input_dim: None
93 |
--------------------------------------------------------------------------------
/configs/encoder_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | ViT_B:
3 | if_pretrain: False
4 | pretrain_path: None
5 | audio:
6 | tpye: None
7 | visual:
8 | tpye: None
9 | text:
10 | tpye: None
11 |
12 | Transformer_LA:
13 | if_pretrain: False
14 | pretrain_path: None
15 | output_dim: 1000
16 | audio:
17 | layer: 6
18 | hidden_size: 512
19 | dropout_r: 0.1
20 | multi_head: 4
21 | ff_size: 1024
22 | seq_len: 60
23 | modality: "audio"
24 | visual:
25 | layer: 6
26 | hidden_size: 512
27 | dropout_r: 0.1
28 | multi_head: 4
29 | ff_size: 1024
30 | seq_len: 60
31 | modality: "visual"
32 | text:
33 | layer: 6
34 | hidden_size: 512
35 | dropout_r: 0.1
36 | multi_head: 4
37 | ff_size: 1024
38 | seq_len: 60
39 | modality: "text"
40 | Transformer:
41 | if_pretrain: True
42 | pretrain_path: None
43 | output_dim: 1024
44 | audio:
45 | n_features: 512
46 | dim: 1024
47 | n_head: 4
48 | n_layers: 2
49 | visual:
50 | n_features: 512
51 | dim: 1024
52 | n_head: 4
53 | n_layers: 2
54 | text:
55 | n_features: 512
56 | dim: 1024
57 | n_head: 4
58 | n_layers: 2
59 | Transformer_:
60 | if_pretrain: False
61 | pretrain_path: None
62 | output_dim: 1024
63 | audio:
64 | dim: 1024
65 | n_head: 4
66 | n_layers: 2
67 | visual:
68 | dim: 1024
69 | n_head: 4
70 | n_layers: 2
71 | text:
72 | dim: 1024
73 | n_head: 4
74 | n_layers: 2
75 | ResNet18:
76 | output_dim: 512
77 | if_pretrain: False
78 | pretrain_path: '/data/users/shaoxuan_xu/3_OGM/Pretrained_model/resnet18.pth'
79 | audio:
80 | modality: 'audio'
81 | visual:
82 | modality: 'visual'
83 | flow:
84 | modality: 'flow'
85 | front_view:
86 | modality: 'front_view'
87 | back_view:
88 | modality: 'back_view'
--------------------------------------------------------------------------------
/configs/global_config.yaml:
--------------------------------------------------------------------------------
1 | seed: 42
2 |
3 | dataloader:
4 | eff_batch_size: 32
5 | num_workers: 32 # number of workers for dataloader
6 | fast_run: false # if true, only use part of the dataset
7 | shuffle: true # if true, shuffle the dataset
8 | drop_last: true # if true, drop the last batch if the size is not equal to batch_size
9 | pin_memory: true # if true, use pin_memory for dataloader
10 |
11 | fabric:
12 | accelerator: "gpu" # "cpu", "gpu", "cuda", "mps", "tpu"
13 | devices: [1] # number of devices or list of indexs
14 | precision: "32-true" # "32-true", "16-mixed", "bf16-mixed", etc.
15 | strategy: "dp" # "dp", "ddp", "ddp2", "ddp_spawn"
16 |
17 | log:
18 | logger_name: ['tensorboard'] # list of logger name
19 | wandb_name: ''
20 | log_interval: 100 # how many batches to wait before logging training status
21 | log_per_epoch: -1 # number of log for each epoch. If set will override log_interval
22 | comment: '' # comment for the logger
23 |
24 |
25 | train:
26 | parameter:
27 | total_epoch: 30 # number of epochs to train
28 | warmup: 10 # warmup beta and lr
29 | base_lr: 0.001 # base learning rate. Using linear law, lr = eff_batch_size/base_bs*base_lr
30 | lr: -1 # learning rate. If set, override base_lr, set to -1 to use base_lr
31 | base_batch_size: 2 # base batch_size
32 | lr_scaling_rule: "linear" # "sqrt", "linear" learning rate should scale with batch size
33 | checkpoint:
34 | resume: '' # path to the checkpoint
35 | checkpoint_frequency: 1 # number of epoch interval to save the checkpoint
36 | num_checkpoint_keep: 3 # set to -1 to save all checkpoints
37 | # optimizer:
38 | # type: "SGD" # type of optimizer
39 | # momentum: 0.9
40 | # weight_decay: 0.0001
41 | optimizer:
42 | type: "AdamW" # type of optimizer
43 | # momentum: 0.9
44 | weight_decay: 0.01
45 | scheduler:
46 | type: "StepLR" # type of scheduler
47 | step_size: 10 # step size for StepLR
48 | gamma: 0.1 # gamma for StepLR
49 | validation:
50 | frequency: 1 # number of epoch interval to do validation
51 | select_best: 'last' # 'last', 'best' select the best model based on validation loss
52 | loss:
53 | type: "CrossEntropyLoss" # type of loss
54 |
55 |
--------------------------------------------------------------------------------
/configs/model_config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | BaseClassifier:
3 | encoders: {audio: Transformer_, visual: Transformer_, text: Transformer_}
4 | fusion: "concat"
5 | modality_size: 1024
6 | BaseClassifier_AMCo:
7 | encoders: {audio: ResNet18, visual: ResNet18}
8 | fusion: "concat"
9 | modality_size: 512
10 | BaseClassifier_Greedy:
11 | encoders: {audio: ResNet18, visual: ResNet18}
12 | fusion: "concat"
13 | modality_size: 512
14 | BaseClassifier_MLA:
15 | encoders: {audio: ResNet18, visual: ResNet18}
16 | fusion: "shared"
17 | modality_size: 1024
18 | # If use Transformer as encoder, set the modality_size to 1024. If use ResNet18 as encoder, set the modality_size to 512.
19 | BaseClassifier_ReconBoost:
20 | encoders: {audio: ResNet18, visual: ResNet18}
21 | fusion: "shared"
22 | modality_size: 1024
23 | # expand other configuration of fusion methods
24 | fusion:
25 | concat:
26 | out_put_dim: 100
27 | shared:
28 | out_put_dim: 100
--------------------------------------------------------------------------------
/configs/trainer_config.yaml:
--------------------------------------------------------------------------------
1 | trainer_para:
2 | base:
3 | grad_accum_steps: 1
4 | max_epochs: 30
5 | baseline:
6 | modulation_starts: 0
7 | modulation_ends: 80
8 | modality: 2
9 | OGM:
10 | alpha: 0.3
11 | method: 'OGM_GE'
12 | modulation_starts: 10 #
13 | modulation_ends: 80
14 | AGM:
15 | alpha: 0.1
16 | method: "None"
17 | modulation_starts: 0
18 | modulation_ends: 80
19 | modality: 2
20 | AMCo:
21 | alpha: 0.1
22 | method: "None"
23 | modulation_starts: 0
24 | modulation_ends: 80
25 | sigma: 0.5
26 | U: 512
27 | eps: 0.3
28 | modality: 2
29 | CML:
30 | modulation_starts: 0
31 | modulation_ends: 80
32 | lam: 0.1
33 | modality: 2
34 | GBlending:
35 | method: "online"
36 | modulation_starts: 0
37 | modulation_ends: 80
38 | super_epoch: 10
39 | modality: 2
40 | PMR:
41 | alpha: 0.6
42 | modulation_starts: 0
43 | modulation_ends: 80
44 | modality: 2
45 | momentum_coef: 0.5
46 | eta: 0.01
47 | MBSD:
48 | modulation_starts: 0
49 | modulation_ends: 80
50 | modality: 2
51 | MMCosine:
52 | modulation_starts: 0
53 | modulation_ends: 80
54 | modality: 2
55 | scaling: 15
56 | Greedy:
57 | alpha: 0.001
58 | modulation_starts: 0
59 | modulation_ends: 80
60 | modality: 2
61 | window_size: 5
62 | UMT:
63 | alpha: 1
64 | modulation_starts: 0
65 | modulation_ends: 80
66 | scaling: 10
67 | MLA:
68 | modulation_starts: 0
69 | modulation_ends: 80
70 | ReconBoost:
71 | alpha: 0.5
72 | modulation_starts: 0
73 | modulation_ends: 80
74 | T_epochs: 1
75 | weight1: 5
76 | weight2: 1
77 | MMPareto:
78 | alpha: 1.5
79 | method: 'None'
80 | modulation_starts: 0
81 | modulation_ends: 80
82 | OPM:
83 | alpha: 0.5
84 | p_exe: 0.7
85 | q_base: 0.4
86 | method: 'None'
87 | modulation_starts: 0
88 | modulation_ends: 80
89 | Sample:
90 | alpha: 0.5
91 | # method: 'Sample-level'
92 | method: 'Modality-level'
93 | modulation_starts: 1
94 | modulation_ends: 80
95 | part_ratio: 0.2
96 | ReLearning:
97 | move_lambda: 3
98 | method: None
99 | modulation_starts: 0
100 | modulation_ends: 80
101 | reinit_epoch: 15
102 | reinit_num: 2
103 | LinearProbe:
104 | alpha: 0.3
105 | method: 'None'
106 | modulation_starts: 0
107 | modulation_ends: 80
108 | modality: 2
109 | trainer_probed: "OGMTrainer"
110 | LFM:
111 | alpha: 1
112 | method: 'learning-fitted'
113 | modulation_starts: 0
114 | modulation_ends: 80
115 | modality: 2
116 | trainer_probed: "OGMTrainer"
117 | lr_alpha: 0.0001
118 |
--------------------------------------------------------------------------------
/configs/user_default.yaml:
--------------------------------------------------------------------------------
1 | mode: "train_and_test"
2 | check_point_path: ''
3 |
4 | seed: 42
5 | Main_config:
6 | model: "BaseClassifier" #
7 | model: "BaseClassifier" #
8 | tasks: "Classifier" #
9 | trainer: "LinearProbeTrainer"
10 | device: '1'
11 | dataset: CREMAD
12 |
13 | Train:
14 | dataset: CREMAD
15 |
16 |
17 | Test:
18 | dataset: CREMAD
19 |
20 |
--------------------------------------------------------------------------------
/environment:
--------------------------------------------------------------------------------
1 | # conda create -n balancemm python=3.10
2 | # pip install torch==1.12.1+cu113
3 |
4 | # pip install lightning==2.0.0
5 |
6 | # pip install lightning-cloud==0.5.68
7 | # pip install lightning-utilities==0.11.2
8 | #
--------------------------------------------------------------------------------
/images/Algorithms.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/Algorithms.jpeg
--------------------------------------------------------------------------------
/images/Results.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/Results.jpeg
--------------------------------------------------------------------------------
/images/frame6_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GeWu-Lab/BalanceBenchmark/4c1482fca5aaa8d4278af8e26020daa039e3ccf3/images/frame6_00.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | anyio==4.3.0
5 | arrow==1.3.0
6 | async-timeout==4.0.3
7 | attrs==23.2.0
8 | audioread==3.0.1
9 | beautifulsoup4==4.12.3
10 | blessed==1.20.0
11 | boto3==1.34.100
12 | botocore==1.34.100
13 | cachetools==5.3.3
14 | cffi==1.16.0
15 | click==8.1.7
16 | colorama==0.4.6
17 | contourpy==1.3.1
18 | croniter==1.3.15
19 | cycler==0.12.1
20 | dateutils==0.6.12
21 | decorator==5.1.1
22 | deepdiff==7.0.1
23 | dynaconf==3.2.5
24 | editor==1.6.6
25 | exceptiongroup==1.2.1
26 | fastapi==0.88.0
27 | filelock==3.14.0
28 | fonttools==4.55.3
29 | frozenlist==1.4.1
30 | fsspec==2023.12.2
31 | grpcio==1.66.2
32 | h11==0.14.0
33 | h5py==3.11.0
34 | huggingface-hub==0.25.0
35 | inquirer==3.2.4
36 | itsdangerous==2.2.0
37 | jinja2==3.1.4
38 | jmespath==1.0.1
39 | joblib==1.4.2
40 | kiwisolver==1.4.7
41 | lazy-loader==0.4
42 | librosa==0.10.1
43 | llvmlite==0.42.0
44 | loguru==0.7.2
45 | markdown==3.7
46 | markdown-it-py==3.0.0
47 | matplotlib==3.10.0
48 | mdurl==0.1.2
49 | memory-profiler==0.61.0
50 | msgpack==1.0.8
51 | multidict==6.0.5
52 | networkx==3.3
53 | numba==0.59.1
54 | ordered-set==4.1.0
55 | packaging==24.0
56 | pandas==2.2.2
57 | platformdirs==4.2.1
58 | pooch==1.8.1
59 | pretty-errors==1.2.25
60 | protobuf==5.26.1
61 | psutil==5.9.8
62 | pycparser==2.22
63 | pydantic==1.10.15
64 | pygments==2.18.0
65 | pyjwt==2.8.0
66 | pyparsing==3.2.0
67 | python-dateutil==2.9.0.post0
68 | python-multipart==0.0.9
69 | pytz==2024.1
70 | pyyaml==6.0.1
71 | readchar==4.0.6
72 | regex==2024.9.11
73 | rich==13.7.1
74 | runs==1.2.2
75 | s3transfer==0.10.1
76 | safetensors==0.4.5
77 | scikit-learn==1.4.2
78 | scipy==1.13.0
79 | six==1.16.0
80 | sniffio==1.3.1
81 | soundfile==0.12.1
82 | soupsieve==2.5
83 | soxr==0.3.7
84 | starlette==0.22.0
85 | starsessions==1.3.0
86 | tensorboard==2.18.0
87 | tensorboard-data-server==0.7.2
88 | tensorboardx==2.6.2.2
89 | termcolor==2.4.0
90 | thop==0.1.1-2209072238
91 | threadpoolctl==3.5.0
92 | tokenizers==0.19.1
93 | tqdm==4.66.4
94 | traitlets==5.14.3
95 | transformers==4.40.2
96 | triton==2.3.0
97 | types-python-dateutil==2.9.0.20240316
98 | tzdata==2024.1
99 | uvicorn==0.29.0
100 | wcwidth==0.2.13
101 | websocket-client==1.8.0
102 | websockets==11.0.3
103 | werkzeug==3.0.4
104 | xmod==1.8.1
105 | yarl==1.9.4
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Always prefer setuptools over distutils
2 | from setuptools import setup, find_packages
3 | import pathlib
4 |
5 | here = pathlib.Path(__file__).parent.resolve()
6 |
7 | # Get the long description from the README file
8 | long_description = (here / 'README.md').read_text(encoding='utf-8')
9 |
10 | setup(
11 | name='balancemm',
12 | version='1.0',
13 | description='A Benchmark for balanced multimodal learning.',
14 | long_description=long_description,
15 | long_description_content_type='text/markdown',
16 | url=' ',
17 | author=' ',
18 | author_email=' ',
19 | classifiers=[
20 | "License :: OSI Approved :: MIT License",
21 | "Development Status :: 3 - Alpha",
22 | "Programming Language :: Python :: 3",
23 | "Programming Language :: Python :: 3.10",
24 | "Operating System :: POSIX :: Linux",
25 | "Operating System :: Microsoft :: Windows",
26 | ],
27 | keywords='multimodal, benchmark, balanced learning, classification, deep learning',
28 | packages=find_packages(),
29 | install_requires=[
30 | 'pyyaml',
31 | 'docstring-parser',
32 | 'lightning',
33 | 'torch',
34 | 'torchvision',
35 | 'torchaudio',
36 | 'xformers',
37 | ],
38 | )
--------------------------------------------------------------------------------