├── overall.png ├── LICENSE ├── datasets ├── CMUDataset.py ├── dataloader.py ├── BratsDataset.py ├── FOODDataset.py └── IEMODataset.py ├── models ├── foodmodel.py ├── msamodel.py └── segmodel.py ├── README.md ├── run ├── brats_run.py ├── iemo_run.py ├── mosi_run.py └── food_run.py ├── modules ├── position_embedding.py ├── transformer.py └── multihead_attention.py └── src ├── eval_metrics.py ├── segtrain.py ├── msatrain.py └── foodtrain.py /overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zrguo/CGGM/HEAD/overall.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 zrguo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/CMUDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | import pickle 5 | import os 6 | import torch 7 | 8 | 9 | class CMUData(Dataset): 10 | def __init__(self, dataset_path, split_type='train'): 11 | super(CMUData, self).__init__() 12 | dataset = pickle.load(open(dataset_path, 'rb')) 13 | 14 | self.vision = torch.tensor(dataset[split_type]['vision'].astype(np.float32)).cpu().detach() 15 | self.text = torch.tensor(dataset[split_type]['text'].astype(np.float32)).cpu().detach() 16 | self.audio = dataset[split_type]['audio'].astype(np.float32) 17 | self.audio[self.audio == -np.inf] = 0 18 | self.audio = torch.tensor(self.audio).cpu().detach() 19 | self.labels = torch.tensor(dataset[split_type]['labels'].astype(np.float32)).cpu().detach() 20 | self.n_modalities = 3 # vision/ text/ audio 21 | 22 | def get_n_modalities(self): 23 | return self.n_modalities 24 | 25 | def get_seq_len(self): 26 | return self.text.shape[1], self.audio.shape[1], self.vision.shape[1] 27 | 28 | def get_dim(self): 29 | return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] 30 | 31 | def get_lbl_info(self): 32 | return self.labels.shape[1], self.labels.shape[2] 33 | 34 | def __len__(self): 35 | return len(self.labels) 36 | 37 | def __getitem__(self, index): 38 | sample = { 39 | 'text': self.text[index], 40 | 'audio': self.audio[index], 41 | 'vision': self.vision[index], 42 | 'labels': self.labels[index] 43 | } 44 | return sample -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from datasets.CMUDataset import CMUData 4 | from datasets.IEMODataset import IEMOData 5 | from datasets.FOODDataset import Food101 6 | from datasets.BratsDataset import BraTSData 7 | 8 | 9 | class opt: 10 | cvNo = 1 11 | A_type = "comparE" 12 | V_type = "denseface" 13 | L_type = "bert_large" 14 | norm_method = 'trn' 15 | in_mem = False 16 | 17 | 18 | def getdataloader(dataset, batch_size, data_path): 19 | if dataset == 'mosi': 20 | data = { 21 | 'train': CMUData(data_path, 'train'), 22 | 'valid': CMUData(data_path, 'valid'), 23 | 'test': CMUData(data_path, 'test'), 24 | } 25 | orig_dim = data['test'].get_dim() 26 | dataLoader = { 27 | ds: DataLoader(data[ds], 28 | batch_size=batch_size, 29 | num_workers=8) 30 | for ds in data.keys() 31 | } 32 | elif dataset == 'iemo': 33 | data = { 34 | 'train': IEMOData(opt, data_path, set_name='trn'), 35 | 'valid': IEMOData(opt, data_path, set_name='val'), 36 | 'test': IEMOData(opt, data_path, set_name='tst'), 37 | } 38 | orig_dim = data['test'].get_dim() 39 | dataLoader = { 40 | ds: DataLoader(data[ds], 41 | batch_size=batch_size, 42 | drop_last=False, 43 | collate_fn=data['test'].collate_fn) 44 | for ds in data.keys() 45 | } 46 | elif dataset == 'food': 47 | data = { 48 | 'train': Food101(mode='train', dataset_root_dir=data_path), 49 | 'valid': Food101(mode='test', dataset_root_dir=data_path), 50 | 'test': Food101(mode='test', dataset_root_dir=data_path), 51 | } 52 | orig_dim = None 53 | dataLoader = { 54 | ds: DataLoader(data[ds], 55 | batch_size=batch_size, 56 | num_workers=8) 57 | for ds in data.keys() 58 | } 59 | elif dataset == 'brats': 60 | data = { 61 | 'train': BraTSData(root=data_path, mode='train'), 62 | 'valid': BraTSData(root=data_path, mode='valid'), 63 | 'test': BraTSData(root=data_path, mode='test'), 64 | } 65 | orig_dim = None 66 | dataLoader = { 67 | ds: DataLoader(data[ds], 68 | batch_size=batch_size, 69 | num_workers=8) 70 | for ds in data.keys() 71 | } 72 | 73 | return dataLoader, orig_dim -------------------------------------------------------------------------------- /models/foodmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.transformer import TransformerEncoder 4 | import torch.nn.functional as F 5 | 6 | 7 | class FoodModel(nn.Module): 8 | def __init__(self, output_dim=101, num_heads=5, layers=4, 9 | relu_dropout=0.1, embed_dropout=0.3, res_dropout=0.1, out_dropout=0.1, 10 | attn_dropout=0.25): 11 | super(FoodModel, self).__init__() 12 | self.num_mod = 2 13 | self.proj_dim = 40 14 | self.num_heads = num_heads 15 | self.layers = layers 16 | self.attn_dropout = attn_dropout 17 | self.relu_dropout = relu_dropout 18 | self.res_dropout = res_dropout 19 | self.out_dropout = out_dropout 20 | self.embed_dropout = embed_dropout 21 | self.projv = nn.Conv1d(768, self.proj_dim, kernel_size=1, padding=0) 22 | self.projt = nn.Conv1d(768, self.proj_dim, kernel_size=1, padding=0) 23 | self.vision_encoder = TransformerEncoder( 24 | embed_dim=self.proj_dim, num_heads=self.num_heads, 25 | layers=self.layers, attn_dropout=self.attn_dropout, res_dropout=self.res_dropout, 26 | relu_dropout=self.relu_dropout, embed_dropout=self.embed_dropout 27 | ) 28 | self.text_encoder = TransformerEncoder( 29 | embed_dim=self.proj_dim, num_heads=self.num_heads, 30 | layers=self.layers, attn_dropout=self.attn_dropout, res_dropout=self.res_dropout, 31 | relu_dropout=self.relu_dropout, embed_dropout=self.embed_dropout 32 | ) 33 | 34 | self.fusion = TransformerEncoder( 35 | embed_dim=self.proj_dim, num_heads=self.num_heads, 36 | layers=self.layers-2, attn_dropout=self.attn_dropout, res_dropout=self.res_dropout, 37 | relu_dropout=self.relu_dropout, embed_dropout=self.embed_dropout 38 | ) 39 | 40 | # Output layers 41 | self.proj1 = nn.Linear(self.proj_dim, self.proj_dim) 42 | self.proj2 = nn.Linear(self.proj_dim, self.proj_dim) 43 | self.out_layer = nn.Linear(self.proj_dim, output_dim) 44 | 45 | def forward(self, v, t): 46 | v = v.transpose(1, 2) 47 | v = self.projv(v) 48 | v = v.permute(2, 0, 1) 49 | t = t.transpose(1, 2) 50 | t = self.projv(t) 51 | t = t.permute(2, 0, 1) 52 | v = self.vision_encoder(v) 53 | t = self.text_encoder(t) 54 | hs = [v.clone().detach(), t.clone().detach()] 55 | f = torch.cat([v, t], dim=0) 56 | last_hs = self.fusion(f)[0] 57 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training)) 58 | last_hs_proj += last_hs 59 | output = self.out_layer(last_hs_proj) 60 | return output, hs 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classifier-guided Gradient Modulation for Enhanced Multimodal Learning 2 | 3 | [NeurIPS 2024] Official PyTorch implementation of the paper "Classifier-guided Gradient Modulation for Enhanced Multimodal Learning" 4 | 5 | ## Introduction 6 | 7 | Multimodal learning has developed very fast in recent years. However, during the multimodal training process, the model tends to rely on only one modality based on which it could learn faster, thus leading to inadequate use of other modalities. Existing methods to balance the training process always have some limitations on the loss functions, optimizers and the number of modalities and only consider modulating the magnitude of the gradients while ignoring the directions of the gradients. To solve these problems, in this paper, we present a novel method to balance multimodal learning with **C**lassifier-**G**uided **G**radient **M**odulation (CGGM), considering both the magnitude and directions of the gradients. We conduct extensive experiments on four multimodal datasets, covering classification, regression and segmentation tasks. The results show that CGGM outperforms all the baselines and other state-of-the-art methods consistently, demonstrating its effectiveness and versatility. 8 | 9 | ![image-20241010151651174](overall.png) 10 | 11 | 12 | 13 | ## Getting Started 14 | 15 | ### Environment 16 | 17 | - Python >= 3.8, PyTorch >= 1.8.0 18 | 19 | ``` 20 | git clone https://github.com/zrguo/CGGM.git 21 | ``` 22 | 23 | ### Pre-trained Model 24 | 25 | For feature extraction of Food 101 dataset, we use pre-trained BERT ([google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)) and ViT model ([google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)). The pre-trained models are used only in Food 101 dataset. 26 | 27 | ### Running the Code 28 | 29 | Run the python file under the directory `run` according to the dataset. 30 | 31 | For example, if you want to evaluate the performance on CMU-MOSI dataset, you can run the `mosi_run.py` file: 32 | 33 | ```bash 34 | python mosi_run.py --data_path "mosipath" --modulation "cggm" --batch_size 64 --rou 1.3 --lambda 0.2 35 | ``` 36 | 37 | If you want to evaluate the performance on Food 101 dataset, you can run the `food_fun.py` file: 38 | 39 | ```bash 40 | python food_run.py --data_path "foodpath" --vit "pre-trained vit path" --bert "pre-trained bert path" --modulation "cggm" 41 | ``` 42 | 43 | 44 | 45 | ## Citation 46 | 47 | If you find the repository useful, please cite the following paper: 48 | 49 | ```bibtex 50 | @inproceedings{guo2024classifier, 51 | title={Classifier-guided Gradient Modulation for Enhanced Multimodal Learning}, 52 | author={Guo, Zirun and Jin, Tao and Chen, Jingyuan and Zhao, Zhou}, 53 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 54 | year={2024} 55 | } 56 | ``` 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /run/brats_run.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import argparse 4 | from src import segtrain 5 | from datasets.dataloader import getdataloader 6 | import numpy as np 7 | 8 | 9 | def segrun(): 10 | 11 | parser = argparse.ArgumentParser(description='Multimodal Segmentation') 12 | 13 | # Tasks 14 | parser.add_argument('--dataset', type=str, default='brats') 15 | parser.add_argument('--modulation', type=str, default='cggm', 16 | help='strategy to use (none/cggm)') 17 | parser.add_argument('--data_path', type=str, default='') 18 | 19 | # Dropouts 20 | parser.add_argument('--weight_decay', type=float, default=3e-4) 21 | parser.add_argument('--warmup_epochs', type=int, default=10) 22 | parser.add_argument('--start_warmup_value', type=float, default=4e-4) 23 | parser.add_argument('--base_lr', type=float, default=0.01) 24 | parser.add_argument('--final_lr', type=float, default=0.001) 25 | parser.add_argument('--batch_size', type=int, default=64, metavar='N') 26 | parser.add_argument('--clip', type=float, default=0.8) 27 | parser.add_argument('--momentum', type=float, default=0.9) 28 | parser.add_argument('--cls_lr', type=float, default=6e-3) 29 | parser.add_argument('--num_epochs', type=int, default=75) 30 | parser.add_argument('--when', type=int, default=10) 31 | parser.add_argument('--rou', type=float, default=1.0) 32 | parser.add_argument('--lamda', type=float, default=0.1) 33 | 34 | 35 | # Logistics 36 | parser.add_argument('--log_interval', type=int, default=30, 37 | help='frequency of result logging (default: 30)') 38 | parser.add_argument('--seed', type=int, default=666, 39 | help='random seed') 40 | parser.add_argument('--no_cuda', action='store_true', 41 | help='do not use cuda') 42 | args = parser.parse_args() 43 | 44 | 45 | dataset = str.lower(args.dataset.strip()) 46 | 47 | def setup_seed(seed): 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | 54 | 55 | torch.set_default_tensor_type('torch.FloatTensor') 56 | if torch.cuda.is_available(): 57 | if args.no_cuda: 58 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 59 | else: 60 | use_cuda = True 61 | else: 62 | print('cuda not available!') 63 | 64 | setup_seed(args.seed) 65 | 66 | 67 | 68 | dataloder, orig_dim = getdataloader(args.dataset, args.batch_size, args.data_path) 69 | train_loader = dataloder['train'] 70 | valid_loader = dataloder['valid'] 71 | test_loader = dataloder['test'] 72 | hyp_params = args 73 | hyp_params.orig_dim = orig_dim 74 | hyp_params.use_cuda = use_cuda 75 | hyp_params.dataset = dataset 76 | hyp_params.when = args.when 77 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_loader), len(valid_loader), len(test_loader) 78 | hyp_params.num_mod = 4 79 | 80 | test_loss = segtrain.initiate(hyp_params, train_loader, valid_loader, test_loader) 81 | 82 | 83 | if __name__ == '__main__': 84 | segrun() 85 | -------------------------------------------------------------------------------- /modules/position_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | # Code adapted from the fairseq repo. 7 | 8 | def make_positions(tensor, padding_idx, left_pad): 9 | """Replace non-padding symbols with their position numbers. 10 | Position numbers begin at padding_idx+1. 11 | Padding symbols are ignored, but it is necessary to specify whether padding 12 | is added on the left side (left_pad=True) or right side (left_pad=False). 13 | """ 14 | max_pos = padding_idx + 1 + tensor.size(1) 15 | device = tensor.get_device() 16 | buf_name = f'range_buf_{device}' 17 | if not hasattr(make_positions, buf_name): 18 | setattr(make_positions, buf_name, tensor.new()) 19 | setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor)) 20 | if getattr(make_positions, buf_name).numel() < max_pos: 21 | torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name)) 22 | mask = tensor.ne(padding_idx) 23 | positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor) 24 | if left_pad: 25 | positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) 26 | new_tensor = tensor.clone() 27 | return new_tensor.masked_scatter_(mask, positions[mask]).long() 28 | 29 | 30 | class SinusoidalPositionalEmbedding(nn.Module): 31 | """This module produces sinusoidal positional embeddings of any length. 32 | Padding symbols are ignored, but it is necessary to specify whether padding 33 | is added on the left side (left_pad=True) or right side (left_pad=False). 34 | """ 35 | 36 | def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128): 37 | super().__init__() 38 | self.embedding_dim = embedding_dim 39 | self.padding_idx = padding_idx 40 | self.left_pad = left_pad 41 | self.weights = dict() # device --> actual weight; due to nn.DataParallel :-( 42 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 43 | 44 | @staticmethod 45 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 46 | """Build sinusoidal embeddings. 47 | This matches the implementation in tensor2tensor, but differs slightly 48 | from the description in Section 3.5 of "Attention Is All You Need". 49 | """ 50 | half_dim = embedding_dim // 2 51 | emb = math.log(10000) / (half_dim - 1) 52 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 53 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 55 | if embedding_dim % 2 == 1: 56 | # zero pad 57 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 58 | if padding_idx is not None: 59 | emb[padding_idx, :] = 0 60 | return emb 61 | 62 | def forward(self, input): 63 | """Input is expected to be of size [bsz x seqlen].""" 64 | bsz, seq_len = input.size() 65 | max_pos = self.padding_idx + 1 + seq_len 66 | device = input.get_device() 67 | if device not in self.weights or max_pos > self.weights[device].size(0): 68 | # recompute/expand embeddings if needed 69 | self.weights[device] = SinusoidalPositionalEmbedding.get_embedding( 70 | max_pos, 71 | self.embedding_dim, 72 | self.padding_idx, 73 | ) 74 | self.weights[device] = self.weights[device].type_as(self._float_tensor) 75 | positions = make_positions(input, self.padding_idx, self.left_pad) 76 | return self.weights[device].index_select(0, positions.contiguous().view(-1)).view(bsz, seq_len, -1).detach() 77 | 78 | def max_positions(self): 79 | """Maximum number of supported positions.""" 80 | return int(1e5) # an arbitrary large number -------------------------------------------------------------------------------- /run/iemo_run.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import argparse 4 | from src import msatrain 5 | from datasets.dataloader import getdataloader 6 | import numpy as np 7 | 8 | 9 | def iemorun(): 10 | parser = argparse.ArgumentParser(description='IEMOCAP') 11 | 12 | # Tasks 13 | parser.add_argument('--modulation', type=str, default='cggm', 14 | help='strategy to use (none/cggm)') 15 | parser.add_argument('--data_path', type=str, default='') 16 | 17 | # Dropouts 18 | parser.add_argument('--attn_dropout', type=float, default=0.15, 19 | help='attention dropout') 20 | parser.add_argument('--relu_dropout', type=float, default=0.15, 21 | help='relu dropout') 22 | parser.add_argument('--embed_dropout', type=float, default=0.2, 23 | help='embedding dropout') 24 | parser.add_argument('--res_dropout', type=float, default=0.15, 25 | help='residual block dropout') 26 | parser.add_argument('--out_dropout', type=float, default=0.2, 27 | help='output layer dropout') 28 | 29 | # Architecture 30 | parser.add_argument('--nlevels', type=int, default=4, 31 | help='number of layers in the network') 32 | parser.add_argument('--cls_layers', type=int, default=2, 33 | help='number of layers in the network') 34 | parser.add_argument('--num_heads', type=int, default=5, 35 | help='number of heads for the transformer network') 36 | parser.add_argument('--proj_dim', type=int, default=40, 37 | help='number of heads for the transformer network') 38 | parser.add_argument('--attn_mask', action='store_false', 39 | help='use attention mask for Transformer') 40 | 41 | 42 | # Tuning 43 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 44 | help='batch size') 45 | parser.add_argument('--clip', type=float, default=0.8, 46 | help='gradient clip value') 47 | parser.add_argument('--lr', type=float, default=1e-3, 48 | help='initial learning rate') 49 | parser.add_argument('--cls_lr', type=float, default=5e-4, 50 | help='initial learning rate') 51 | parser.add_argument('--optim', type=str, default='Adam') 52 | parser.add_argument('--num_epochs', type=int, default=30) 53 | parser.add_argument('--when', type=int, default=10, 54 | help='when to decay learning rate') 55 | parser.add_argument('--rou', type=float, default=1.3) 56 | parser.add_argument('--lamda', type=float, default=0.2) 57 | 58 | 59 | # Logistics 60 | parser.add_argument('--log_interval', type=int, default=30, 61 | help='frequency of result logging') 62 | parser.add_argument('--seed', type=int, default=666, 63 | help='random seed') 64 | parser.add_argument('--no_cuda', action='store_true', 65 | help='do not use cuda') 66 | args = parser.parse_args() 67 | 68 | dataset = 'iemo' 69 | 70 | 71 | def setup_seed(seed): 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | np.random.seed(seed) 75 | random.seed(seed) 76 | torch.backends.cudnn.deterministic = True 77 | 78 | 79 | torch.set_default_tensor_type('torch.FloatTensor') 80 | if torch.cuda.is_available(): 81 | if args.no_cuda: 82 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 83 | else: 84 | use_cuda = True 85 | else: 86 | print('cuda not available!') 87 | 88 | setup_seed(args.seed) 89 | 90 | dataloder, orig_dim = getdataloader(dataset, args.batch_size, args.data_path) 91 | train_loader = dataloder['train'] 92 | valid_loader = dataloder['valid'] 93 | test_loader = dataloder['test'] 94 | hyp_params = args 95 | hyp_params.orig_dim = orig_dim 96 | hyp_params.layers = args.nlevels 97 | hyp_params.use_cuda = use_cuda 98 | hyp_params.dataset = dataset 99 | hyp_params.when = args.when 100 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_loader), len(valid_loader), len(test_loader) 101 | hyp_params.output_dim = 4 102 | hyp_params.criterion = 'CrossEntropyLoss' 103 | hyp_params.num_mod = 3 104 | test_loss = msatrain.initiate(hyp_params, train_loader, valid_loader, test_loader) 105 | 106 | 107 | if __name__ == '__main__': 108 | iemorun() 109 | -------------------------------------------------------------------------------- /run/mosi_run.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import argparse 4 | from src import msatrain 5 | from datasets.dataloader import getdataloader 6 | import numpy as np 7 | 8 | 9 | def mosirun(): 10 | 11 | parser = argparse.ArgumentParser(description='CMU-MOSI') 12 | 13 | # Tasks 14 | parser.add_argument('--dataset', type=str, default='mosi') 15 | parser.add_argument('--modulation', type=str, default='cggm', 16 | help='strategy to use (none/cggm)') 17 | parser.add_argument('--data_path', type=str, default='') 18 | 19 | # Dropouts 20 | parser.add_argument('--attn_dropout', type=float, default=0.15, 21 | help='attention dropout') 22 | parser.add_argument('--relu_dropout', type=float, default=0.15, 23 | help='relu dropout') 24 | parser.add_argument('--embed_dropout', type=float, default=0.2, 25 | help='embedding dropout') 26 | parser.add_argument('--res_dropout', type=float, default=0.1, 27 | help='residual block dropout') 28 | parser.add_argument('--out_dropout', type=float, default=0.1, 29 | help='output layer dropout') 30 | 31 | # Architecture 32 | parser.add_argument('--nlevels', type=int, default=5, 33 | help='number of layers in the network') 34 | parser.add_argument('--cls_layers', type=int, default=2, 35 | help='number of layers in the network') 36 | parser.add_argument('--num_heads', type=int, default=5, 37 | help='number of heads for the transformer network') 38 | parser.add_argument('--proj_dim', type=int, default=40, 39 | help='number of heads for the transformer network') 40 | parser.add_argument('--attn_mask', action='store_false', 41 | help='use attention mask for Transformer') 42 | 43 | 44 | # Tuning 45 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 46 | help='batch size') 47 | parser.add_argument('--clip', type=float, default=0.8, 48 | help='gradient clip value') 49 | parser.add_argument('--lr', type=float, default=1e-3, 50 | help='initial learning rate') 51 | parser.add_argument('--cls_lr', type=float, default=5e-4, 52 | help='classifier learning rate') 53 | parser.add_argument('--optim', type=str, default='Adam') 54 | parser.add_argument('--num_epochs', type=int, default=30, 55 | help='number of epochs') 56 | parser.add_argument('--when', type=int, default=10, 57 | help='when to decay learning rate') 58 | parser.add_argument('--rou', type=float, default=1.3) 59 | parser.add_argument('--lamda', type=float, default=0.2) 60 | 61 | 62 | # Logistics 63 | parser.add_argument('--log_interval', type=int, default=30, 64 | help='frequency of result logging') 65 | parser.add_argument('--seed', type=int, default=666, 66 | help='random seed') 67 | parser.add_argument('--no_cuda', action='store_true', 68 | help='do not use cuda') 69 | 70 | args = parser.parse_args() 71 | 72 | 73 | dataset = str.lower(args.dataset.strip()) 74 | 75 | def setup_seed(seed): 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | np.random.seed(seed) 79 | random.seed(seed) 80 | torch.backends.cudnn.deterministic = True 81 | 82 | 83 | torch.set_default_tensor_type('torch.FloatTensor') 84 | if torch.cuda.is_available(): 85 | if args.no_cuda: 86 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 87 | else: 88 | use_cuda = True 89 | else: 90 | print('cuda not available!') 91 | 92 | 93 | dataloder, orig_dim = getdataloader(args.dataset, args.batch_size, args.data_path) 94 | train_loader = dataloder['train'] 95 | valid_loader = dataloder['valid'] 96 | test_loader = dataloder['test'] 97 | hyp_params = args 98 | hyp_params.orig_dim = orig_dim 99 | hyp_params.layers = args.nlevels 100 | hyp_params.use_cuda = use_cuda 101 | hyp_params.dataset = dataset 102 | hyp_params.when = args.when 103 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_loader), len(valid_loader), len(test_loader) 104 | hyp_params.output_dim = 1 105 | hyp_params.criterion = 'L1Loss' 106 | hyp_params.num_mod = 3 107 | 108 | setup_seed(args.seed) 109 | test_loss = msatrain.initiate(hyp_params, train_loader, valid_loader, test_loader) 110 | 111 | 112 | if __name__ == '__main__': 113 | mosirun() 114 | -------------------------------------------------------------------------------- /run/food_run.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import argparse 4 | from src import foodtrain 5 | from datasets.dataloader import getdataloader 6 | import numpy as np 7 | 8 | 9 | def foodrun(): 10 | 11 | parser = argparse.ArgumentParser(description='Food') 12 | 13 | # Tasks 14 | parser.add_argument('--dataset', type=str, default='food') 15 | parser.add_argument('--modulation', type=str, default='cggm', 16 | help='strategy to use (none/cggm)') 17 | parser.add_argument('--vit', type=str, default='', 18 | help='pre-trained vit path for visual feature extraction') 19 | parser.add_argument('--bert', type=str, default='', 20 | help='pre-trained bert path for textual feature extraction') 21 | parser.add_argument('--data_path', type=str, default='') 22 | 23 | # Dropouts 24 | parser.add_argument('--attn_dropout', type=float, default=0.2, 25 | help='attention dropout') 26 | parser.add_argument('--relu_dropout', type=float, default=0.15, 27 | help='relu dropout') 28 | parser.add_argument('--embed_dropout', type=float, default=0.2, 29 | help='embedding dropout') 30 | parser.add_argument('--res_dropout', type=float, default=0.15, 31 | help='residual block dropout') 32 | parser.add_argument('--out_dropout', type=float, default=0.1, 33 | help='output layer dropout') 34 | 35 | # Architecture 36 | parser.add_argument('--nlevels', type=int, default=4, 37 | help='number of layers in the network (default: 5)') 38 | parser.add_argument('--cls_layers', type=int, default=2, 39 | help='number of layers in the network (default: 2)') 40 | parser.add_argument('--num_heads', type=int, default=5, 41 | help='number of heads for the transformer network (default: 5)') 42 | parser.add_argument('--proj_dim', type=int, default=40, 43 | help='number of heads for the transformer network (default: 5)') 44 | parser.add_argument('--attn_mask', action='store_false', 45 | help='use attention mask for Transformer (default: true)') 46 | 47 | 48 | # Tuning 49 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 50 | help='batch size') 51 | parser.add_argument('--clip', type=float, default=0.8, 52 | help='gradient clip value') 53 | parser.add_argument('--lr', type=float, default=1e-3, 54 | help='initial learning rate') 55 | parser.add_argument('--cls_lr', type=float, default=5e-4, 56 | help='classifier learning rate') 57 | parser.add_argument('--optim', type=str, default='AdamW', 58 | help='optimizer to use') 59 | parser.add_argument('--num_epochs', type=int, default=60, 60 | help='number of epochs') 61 | parser.add_argument('--when', type=int, default=10, 62 | help='when to decay learning rate') 63 | parser.add_argument('--rou', type=float, default=1.3) 64 | parser.add_argument('--lamda', type=float, default=0.05) 65 | 66 | 67 | # Logistics 68 | parser.add_argument('--log_interval', type=int, default=30, 69 | help='frequency of result logging') 70 | parser.add_argument('--seed', type=int, default=666, 71 | help='random seed') 72 | parser.add_argument('--no_cuda', action='store_true', 73 | help='do not use cuda') 74 | args = parser.parse_args() 75 | 76 | dataset = str.lower(args.dataset.strip()) 77 | 78 | 79 | def setup_seed(seed): 80 | torch.manual_seed(seed) 81 | torch.cuda.manual_seed_all(seed) 82 | np.random.seed(seed) 83 | random.seed(seed) 84 | torch.backends.cudnn.deterministic = True 85 | 86 | 87 | torch.set_default_tensor_type('torch.FloatTensor') 88 | if torch.cuda.is_available(): 89 | if args.no_cuda: 90 | print("WARNING: You have a CUDA device, so you should probably not run with --no_cuda") 91 | else: 92 | use_cuda = True 93 | else: 94 | print('cuda not available!') 95 | 96 | setup_seed(args.seed) 97 | 98 | dataloder, orig_dim = getdataloader(args.dataset, args.batch_size, args.data_path) 99 | train_loader = dataloder['train'] 100 | valid_loader = dataloder['valid'] 101 | test_loader = dataloder['test'] 102 | hyp_params = args 103 | hyp_params.orig_dim = orig_dim 104 | hyp_params.layers = args.nlevels 105 | hyp_params.use_cuda = use_cuda 106 | hyp_params.dataset = dataset 107 | hyp_params.when = args.when 108 | hyp_params.n_train, hyp_params.n_valid, hyp_params.n_test = len(train_loader), len(valid_loader), len(test_loader) 109 | hyp_params.output_dim = 101 110 | hyp_params.criterion = 'CrossEntropyLoss' 111 | hyp_params.num_mod = 2 112 | test_loss = foodtrain.initiate(hyp_params, train_loader, valid_loader, test_loader) 113 | 114 | 115 | if __name__ == '__main__': 116 | foodrun() 117 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.position_embedding import SinusoidalPositionalEmbedding 5 | from modules.multihead_attention import MultiheadAttention 6 | import math 7 | 8 | 9 | class TransformerEncoder(nn.Module): 10 | 11 | def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0, 12 | embed_dropout=0.0, attn_mask=False): 13 | super().__init__() 14 | self.dropout = embed_dropout # Embedding dropout 15 | self.attn_dropout = attn_dropout 16 | self.embed_dim = embed_dim 17 | self.embed_scale = math.sqrt(embed_dim) 18 | self.embed_positions = SinusoidalPositionalEmbedding(embed_dim) 19 | 20 | self.attn_mask = attn_mask 21 | 22 | self.layers = nn.ModuleList([]) 23 | for layer in range(layers): 24 | new_layer = TransformerEncoderLayer(embed_dim, 25 | num_heads=num_heads, 26 | attn_dropout=attn_dropout, 27 | relu_dropout=relu_dropout, 28 | res_dropout=res_dropout, 29 | attn_mask=attn_mask) 30 | self.layers.append(new_layer) 31 | 32 | self.register_buffer('version', torch.Tensor([2])) 33 | self.normalize = True 34 | if self.normalize: 35 | self.layer_norm = LayerNorm(embed_dim) 36 | 37 | def forward(self, x_in): 38 | # embed tokens and positions 39 | x = self.embed_scale * x_in 40 | if self.embed_positions is not None: 41 | x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding 42 | x = F.dropout(x, p=self.dropout, training=self.training) 43 | 44 | # encoder layers 45 | intermediates = [x] 46 | for layer in self.layers: 47 | x = layer(x) 48 | intermediates.append(x) 49 | 50 | if self.normalize: 51 | x = self.layer_norm(x) 52 | 53 | return x 54 | 55 | def max_positions(self): 56 | """Maximum input length supported by the encoder.""" 57 | if self.embed_positions is None: 58 | return self.max_source_positions 59 | return min(self.max_source_positions, self.embed_positions.max_positions()) 60 | 61 | 62 | class TransformerEncoderLayer(nn.Module): 63 | def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1, 64 | attn_mask=False): 65 | super().__init__() 66 | self.embed_dim = embed_dim 67 | self.num_heads = num_heads 68 | 69 | self.self_attn = MultiheadAttention( 70 | embed_dim=self.embed_dim, 71 | num_heads=self.num_heads, 72 | attn_dropout=attn_dropout 73 | ) 74 | self.attn_mask = attn_mask 75 | 76 | self.relu_dropout = relu_dropout 77 | self.res_dropout = res_dropout 78 | self.normalize_before = True 79 | 80 | self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper 81 | self.fc2 = Linear(4*self.embed_dim, self.embed_dim) 82 | self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) 83 | 84 | def forward(self, x): 85 | residual = x 86 | x = self.maybe_layer_norm(0, x, before=True) 87 | x, _ = self.self_attn(query=x, key=x, value=x) 88 | x = F.dropout(x, p=self.res_dropout, training=self.training) 89 | x = residual + x 90 | x = self.maybe_layer_norm(0, x, after=True) 91 | 92 | residual = x 93 | x = self.maybe_layer_norm(1, x, before=True) 94 | x = F.relu(self.fc1(x)) 95 | x = F.dropout(x, p=self.relu_dropout, training=self.training) 96 | x = self.fc2(x) 97 | x = F.dropout(x, p=self.res_dropout, training=self.training) 98 | x = residual + x 99 | x = self.maybe_layer_norm(1, x, after=True) 100 | return x 101 | 102 | def maybe_layer_norm(self, i, x, before=False, after=False): 103 | assert before ^ after 104 | if after ^ self.normalize_before: 105 | return self.layer_norms[i](x) 106 | else: 107 | return x 108 | 109 | def fill_with_neg_inf(t): 110 | """FP16-compatible function that fills a tensor with -inf.""" 111 | return t.float().fill_(float('-inf')).type_as(t) 112 | 113 | 114 | def buffered_future_mask(tensor, tensor2=None): 115 | dim1 = dim2 = tensor.size(0) 116 | if tensor2 is not None: 117 | dim2 = tensor2.size(0) 118 | future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1)) 119 | if tensor.is_cuda: 120 | future_mask = future_mask.cuda() 121 | return future_mask[:dim1, :dim2] 122 | 123 | 124 | def Linear(in_features, out_features, bias=True): 125 | m = nn.Linear(in_features, out_features, bias) 126 | nn.init.xavier_uniform_(m.weight) 127 | if bias: 128 | nn.init.constant_(m.bias, 0.) 129 | return m 130 | 131 | 132 | def LayerNorm(embedding_dim): 133 | m = nn.LayerNorm(embedding_dim) 134 | return m 135 | 136 | -------------------------------------------------------------------------------- /models/msamodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from src.eval_metrics import * 5 | 6 | from modules.transformer import TransformerEncoder 7 | 8 | 9 | class MSAModel(nn.Module): 10 | def __init__(self, output_dim, orig_dim, proj_dim=30, num_heads=5, layers=5, 11 | relu_dropout=0.1, embed_dropout=0.3, res_dropout=0.1, out_dropout=0.1, 12 | attn_dropout=0.25 13 | ): 14 | super(MSAModel, self).__init__() 15 | 16 | self.proj_dim = proj_dim 17 | self.orig_dim = orig_dim 18 | self.num_mod = len(orig_dim) 19 | self.num_heads = num_heads 20 | self.layers = layers 21 | self.attn_dropout = attn_dropout 22 | self.relu_dropout = relu_dropout 23 | self.res_dropout = res_dropout 24 | self.out_dropout = out_dropout 25 | self.embed_dropout = embed_dropout 26 | 27 | # Projection Layers 28 | self.proj = nn.ModuleList([ 29 | nn.Conv1d(self.orig_dim[i], self.proj_dim, kernel_size=1, padding=0) 30 | for i in range(self.num_mod) 31 | ]) 32 | 33 | # Encoders 34 | self.encoders = nn.ModuleList([ 35 | TransformerEncoder(embed_dim=proj_dim, num_heads=self.num_heads, 36 | layers=self.layers, attn_dropout=self.attn_dropout, res_dropout=self.res_dropout, 37 | relu_dropout=self.relu_dropout, embed_dropout=self.embed_dropout) 38 | for _ in range(self.num_mod) 39 | ]) 40 | 41 | # Fusion 42 | self.fusion = TransformerEncoder( 43 | embed_dim=proj_dim, num_heads=self.num_heads, 44 | layers=self.layers-2, attn_dropout=self.attn_dropout, res_dropout=self.res_dropout, 45 | relu_dropout=self.relu_dropout, embed_dropout=self.embed_dropout 46 | ) 47 | 48 | # Output layers 49 | self.proj1 = nn.Linear(self.proj_dim, self.proj_dim) 50 | self.proj2 = nn.Linear(self.proj_dim, self.proj_dim) 51 | self.out_layer = nn.Linear(self.proj_dim, output_dim) 52 | 53 | def forward(self, x): 54 | """ 55 | dimension [batch_size, seq_len, n_features] 56 | """ 57 | hs = list() 58 | hs_detach = list() 59 | 60 | for i in range(self.num_mod): 61 | x[i] = x[i].transpose(1, 2) 62 | x[i] = self.proj[i](x[i]) 63 | x[i] = x[i].permute(2, 0, 1) 64 | h_tmp = self.encoders[i](x[i]) 65 | hs.append(h_tmp) 66 | hs_detach.append(h_tmp.clone().detach()) 67 | 68 | last_hs = self.fusion(torch.cat(hs))[0] 69 | 70 | # A residual block 71 | last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training)) 72 | last_hs_proj += last_hs 73 | output = self.out_layer(last_hs_proj) 74 | return output, hs_detach 75 | 76 | 77 | class Classifier(nn.Module): 78 | def __init__(self, in_dim, out_dim, num_heads=5, layers=2, 79 | relu_dropout=0.1, embed_dropout=0.3, 80 | attn_dropout=0.25, res_dropout=0.1): 81 | super(Classifier, self).__init__() 82 | self.bone = TransformerEncoder(embed_dim=in_dim, num_heads=num_heads, 83 | layers=layers, attn_dropout=attn_dropout, res_dropout=res_dropout, 84 | relu_dropout=relu_dropout, embed_dropout=embed_dropout) 85 | 86 | self.proj1 = nn.Linear(in_dim, in_dim) 87 | self.out_layer = nn.Linear(in_dim, out_dim) 88 | 89 | def forward(self, x): 90 | x = self.bone(x) 91 | x = self.proj1(x[0]) 92 | x = F.relu(self.proj1(x)) 93 | x = self.out_layer(x) 94 | return x 95 | 96 | 97 | class ClassifierGuided(nn.Module): 98 | def __init__(self, output_dim, num_mod, proj_dim=30, num_heads=5, layers=5, 99 | relu_dropout=0.1, embed_dropout=0.3, res_dropout=0.1, attn_dropout=0.25): 100 | super(ClassifierGuided, self).__init__() 101 | # Classifiers 102 | self.num_mod = num_mod 103 | self.classifers = nn.ModuleList([ 104 | Classifier(in_dim=proj_dim, out_dim=output_dim, layers=layers, 105 | num_heads=num_heads, attn_dropout=attn_dropout, res_dropout=res_dropout, 106 | relu_dropout=relu_dropout, embed_dropout=embed_dropout) 107 | for _ in range(self.num_mod) 108 | ]) 109 | 110 | def cal_coeff(self, dataset, y, cls_res): 111 | acc_list = list() 112 | 113 | if dataset in ['mosi', 'mosei']: 114 | for r in cls_res: 115 | acc = train_eval_senti(r, y) 116 | acc_list.append(acc) 117 | elif dataset == 'iemo': 118 | for r in cls_res: 119 | acc = train_eval_iemo(r, y) 120 | acc_list.append(acc) 121 | elif dataset == 'food': 122 | for r in cls_res: 123 | acc = train_eval_food(r, y) 124 | acc_list.append(acc) 125 | 126 | return acc_list 127 | 128 | def forward(self, x): 129 | self.cls_res = list() 130 | for i in range(len(x)): 131 | self.cls_res.append(self.classifers[i](x[i])) 132 | return self.cls_res 133 | -------------------------------------------------------------------------------- /datasets/BratsDataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import h5py 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | 9 | class RandomCrop(object): 10 | """ 11 | Crop randomly the image in a sample 12 | Args: 13 | output_size (int): Desired output size 14 | """ 15 | 16 | def __init__(self, output_size=(160, 160)): 17 | self.output_size = output_size 18 | 19 | def __call__(self, sample): 20 | image, label = sample['image'], sample['label'] 21 | 22 | (c, w, h) = image.shape 23 | w1 = np.random.randint(0, w - self.output_size[0]) 24 | h1 = np.random.randint(0, h - self.output_size[1]) 25 | 26 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1]] 27 | image = image[:, w1:w1 + self.output_size[0], h1:h1 + self.output_size[1]] 28 | return {'image': image, 'label': label} 29 | 30 | 31 | class CenterCrop(object): 32 | def __init__(self, output_size=(160, 160)): 33 | self.output_size = output_size 34 | 35 | def __call__(self, sample): 36 | image, label = sample['image'], sample['label'] 37 | 38 | (c, w, h) = image.shape 39 | 40 | w1 = int(round((w - self.output_size[0]) / 2.)) 41 | h1 = int(round((h - self.output_size[1]) / 2.)) 42 | 43 | label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1]] 44 | image = image[:, w1:w1 + self.output_size[0], h1:h1 + self.output_size[1]] 45 | 46 | return {'image': image, 'label': label} 47 | 48 | 49 | class RandomRotFlip(object): 50 | """ 51 | Crop randomly flip the dataset in a sample 52 | Args: 53 | output_size (int): Desired output size 54 | """ 55 | 56 | def __call__(self, sample): 57 | image, label = sample['image'], sample['label'] 58 | k = np.random.randint(0, 3) 59 | image = np.stack([np.rot90(x, k) for x in image], axis=0) 60 | label = np.rot90(label, k) 61 | axis = np.random.randint(1, 3) 62 | image = np.flip(image, axis=axis).copy() 63 | label = np.flip(label, axis=axis - 1).copy() 64 | 65 | return {'image': image, 'label': label} 66 | 67 | 68 | def augment_gaussian_noise(data_sample, noise_variance=(0, 0.1)): 69 | if noise_variance[0] == noise_variance[1]: 70 | variance = noise_variance[0] 71 | else: 72 | variance = random.uniform(noise_variance[0], noise_variance[1]) 73 | data_sample = data_sample + np.random.normal(0.0, variance, size=data_sample.shape) 74 | return data_sample 75 | 76 | 77 | class GaussianNoise(object): 78 | def __init__(self, noise_variance=(0, 0.1), p=0.5): 79 | self.prob = p 80 | self.noise_variance = noise_variance 81 | 82 | def __call__(self, sample): 83 | image = sample 84 | if np.random.uniform() < self.prob: 85 | image = augment_gaussian_noise(image, self.noise_variance) 86 | return image 87 | 88 | 89 | def cutout(image, label, patch_size=(30, 30)): 90 | mod, image_height, image_width = image.shape 91 | patch_height, patch_width = patch_size 92 | 93 | x = np.random.randint(0, image_height - patch_height) 94 | y = np.random.randint(0, image_width - patch_width) 95 | 96 | for i in range(image.shape[0]): 97 | modality = image[i, :, :] 98 | modality[x:x + patch_height, y:y + patch_width] = 0 99 | image[i, :, :] = modality 100 | 101 | label[x:x + patch_height, y:y + patch_width] = 0 102 | 103 | return image, label 104 | 105 | 106 | class ToTensor(object): 107 | """Convert ndarrays in sample to Tensors.""" 108 | 109 | def __call__(self, sample): 110 | image = sample['image'] 111 | label = sample['label'] 112 | 113 | image = torch.from_numpy(image).float() 114 | label = torch.from_numpy(label).long() 115 | 116 | return image, label 117 | 118 | 119 | class BraTSData(Dataset): 120 | def __init__(self, root, mode, size=(160, 160)): 121 | self.root = root 122 | self.mode = mode 123 | self.size = size 124 | self.root = os.path.join(self.root, self.mode) 125 | data = os.listdir(self.root) 126 | self.data = [os.path.join(self.root, d) for d in data] 127 | 128 | def __getitem__(self, item): 129 | id = self.data[item] 130 | h5f = h5py.File(id, 'r') 131 | image = h5f['image'][:] 132 | label = h5f['label'][:] 133 | label[label == 4] = 3 134 | idx = np.random.randint(0, 128) 135 | while np.max(label[:, :, idx]) == 0: 136 | idx = np.random.randint(0, 128) 137 | image = image[:, :, :, idx] 138 | label = label[:, :, idx] 139 | 140 | if self.mode == 'train': 141 | sample = {'image': image, 'label': label} 142 | sample = CenterCrop(self.size)(sample) 143 | sample = RandomRotFlip()(sample) 144 | image, label = sample['image'], sample['label'] 145 | image = GaussianNoise(p=0.1)(image) 146 | sample = {'image': image, 'label': label} 147 | image, label = ToTensor()(sample) 148 | else: 149 | sample = {'image': image, 'label': label} 150 | sample = CenterCrop(self.size)(sample) 151 | image, label = ToTensor()(sample) 152 | 153 | flair, t1ce, t1, t2 = image[0].unsqueeze(dim=0), image[1].unsqueeze(dim=0), image[2].unsqueeze(dim=0), image[3].unsqueeze(dim=0) 154 | 155 | return flair, t1ce, t1, t2, label 156 | 157 | def __len__(self): 158 | return len(self.data) 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /modules/multihead_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | import torch.nn.functional as F 5 | import sys 6 | 7 | # Code adapted from the fairseq repo. 8 | 9 | class MultiheadAttention(nn.Module): 10 | """Multi-headed attention. 11 | See "Attention Is All You Need" for more details. 12 | """ 13 | 14 | def __init__(self, embed_dim, num_heads, attn_dropout=0., 15 | bias=True, add_bias_kv=False, add_zero_attn=False): 16 | super().__init__() 17 | self.embed_dim = embed_dim 18 | self.num_heads = num_heads 19 | self.attn_dropout = attn_dropout 20 | self.head_dim = embed_dim // num_heads 21 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 22 | self.scaling = self.head_dim ** -0.5 23 | 24 | self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 25 | self.register_parameter('in_proj_bias', None) 26 | if bias: 27 | self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) 28 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 29 | 30 | if add_bias_kv: 31 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 32 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 33 | else: 34 | self.bias_k = self.bias_v = None 35 | 36 | self.add_zero_attn = add_zero_attn 37 | 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | nn.init.xavier_uniform_(self.in_proj_weight) 42 | nn.init.xavier_uniform_(self.out_proj.weight) 43 | if self.in_proj_bias is not None: 44 | nn.init.constant_(self.in_proj_bias, 0.) 45 | nn.init.constant_(self.out_proj.bias, 0.) 46 | if self.bias_k is not None: 47 | nn.init.xavier_normal_(self.bias_k) 48 | if self.bias_v is not None: 49 | nn.init.xavier_normal_(self.bias_v) 50 | 51 | def forward(self, query, key, value, attn_mask=None): 52 | """Input shape: Time x Batch x Channel 53 | Self-attention can be implemented by passing in the same arguments for 54 | query, key and value. Timesteps can be masked by supplying a T x T mask in the 55 | `attn_mask` argument. Padding elements can be excluded from 56 | the key by passing a binary ByteTensor (`key_padding_mask`) with shape: 57 | batch x src_len, where padding elements are indicated by 1s. 58 | """ 59 | qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() 60 | kv_same = key.data_ptr() == value.data_ptr() 61 | 62 | tgt_len, bsz, embed_dim = query.size() 63 | assert embed_dim == self.embed_dim 64 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 65 | assert key.size() == value.size() 66 | 67 | aved_state = None 68 | 69 | if qkv_same: 70 | # self-attention 71 | q, k, v = self.in_proj_qkv(query) 72 | elif kv_same: 73 | # encoder-decoder attention 74 | q = self.in_proj_q(query) 75 | 76 | if key is None: 77 | assert value is None 78 | k = v = None 79 | else: 80 | k, v = self.in_proj_kv(key) 81 | else: 82 | q = self.in_proj_q(query) 83 | k = self.in_proj_k(key) 84 | v = self.in_proj_v(value) 85 | q = q * self.scaling 86 | 87 | if self.bias_k is not None: 88 | assert self.bias_v is not None 89 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 90 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 91 | if attn_mask is not None: 92 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 93 | 94 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 95 | if k is not None: 96 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 97 | if v is not None: 98 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 99 | 100 | src_len = k.size(1) 101 | 102 | if self.add_zero_attn: 103 | src_len += 1 104 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 105 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 106 | if attn_mask is not None: 107 | attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) 108 | 109 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 110 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 111 | 112 | if attn_mask is not None: 113 | try: 114 | attn_weights += attn_mask.unsqueeze(0) 115 | except: 116 | print(attn_weights.shape) 117 | print(attn_mask.unsqueeze(0).shape) 118 | assert False 119 | 120 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 121 | # attn_weights = F.relu(attn_weights) 122 | # attn_weights = attn_weights / torch.max(attn_weights) 123 | attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training) 124 | 125 | attn = torch.bmm(attn_weights, v) 126 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 127 | 128 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 129 | attn = self.out_proj(attn) 130 | 131 | # average attention weights over heads 132 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 133 | attn_weights = attn_weights.sum(dim=1) / self.num_heads 134 | return attn, attn_weights 135 | 136 | def in_proj_qkv(self, query): 137 | return self._in_proj(query).chunk(3, dim=-1) 138 | 139 | def in_proj_kv(self, key): 140 | return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) 141 | 142 | def in_proj_q(self, query, **kwargs): 143 | return self._in_proj(query, end=self.embed_dim, **kwargs) 144 | 145 | def in_proj_k(self, key): 146 | return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) 147 | 148 | def in_proj_v(self, value): 149 | return self._in_proj(value, start=2 * self.embed_dim) 150 | 151 | def _in_proj(self, input, start=0, end=None, **kwargs): 152 | weight = kwargs.get('weight', self.in_proj_weight) 153 | bias = kwargs.get('bias', self.in_proj_bias) 154 | weight = weight[start:end, :] 155 | if bias is not None: 156 | bias = bias[start:end] 157 | return F.linear(input, weight, bias) 158 | -------------------------------------------------------------------------------- /datasets/FOODDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | from torchvision import transforms 5 | import random 6 | from torch.utils.data import Dataset 7 | 8 | from PIL import Image 9 | from transformers import BertTokenizer 10 | 11 | from os.path import join 12 | import json 13 | 14 | CLASS_NAME = ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 15 | 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 16 | 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 17 | 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 18 | 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 19 | 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 20 | 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 21 | 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 22 | 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 23 | 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 24 | 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 25 | 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 26 | 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 27 | 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 28 | 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'] 29 | TEXT_MAX_LENGTH = 50 30 | MIN_FREQ = 3 31 | NUMBER_OF_SAMPLES_PER_CLASS = None 32 | 33 | 34 | class Datapool(Dataset): 35 | def __init__(self, all_ids, mode): 36 | self.all_ids = all_ids 37 | self.mode = mode 38 | assert len(list(set(self.all_ids))) == len(self.all_ids), "dataset has duplicated ids" 39 | self.unlabeled_ids = self.all_ids.copy() 40 | self.labeled_ids = [] 41 | self.sample_ids = self.all_ids.copy() # default for val and test datapool 42 | 43 | def initialize(self, query_budget: int): 44 | # query_budget is the number of labels been queried each round 45 | # random initialization for first batch of labels 46 | self.labeled_ids = self.unlabeled_ids[:query_budget] 47 | self.unlabeled_ids = [id for id in self.all_ids if id not in self.labeled_ids] 48 | 49 | def query_for_label(self, queried_ids: list): 50 | # queried_ids are generated from query strategy 51 | self.labeled_ids += queried_ids 52 | self.unlabeled_ids = [id for id in self.all_ids if id not in self.labeled_ids] 53 | assert len(self.labeled_ids) + len(self.unlabeled_ids) == len(self.all_ids) 54 | 55 | def query(self): 56 | # prepare unlabeled data index for label querying 57 | self.mode = "query" 58 | print("dataset for querying") 59 | self.sample_ids = self.unlabeled_ids 60 | 61 | def train(self): 62 | # prepare labeled queried data index for model training 63 | self.mode = "train" 64 | print("dataset for training") 65 | self.sample_ids = self.labeled_ids 66 | 67 | def __len__(self): 68 | return len(self.sample_ids) 69 | 70 | def __getitem__(self, idx): 71 | pass 72 | 73 | 74 | class Food101(Datapool): 75 | def __init__(self, 76 | mode="train", 77 | dataset_root_dir=r"", 78 | ): 79 | self.dataset_root_dir = dataset_root_dir 80 | self.mode = mode 81 | assert self.mode in ["train", "dev", "test"] 82 | with open(join(dataset_root_dir, f"{mode}.json")) as file: 83 | data_list = json.load(file) 84 | self.data = {x["id"]: x for x in data_list} 85 | 86 | self.all_ids = list(self.data.keys()) 87 | random.Random(0).shuffle(self.all_ids) 88 | 89 | super(Food101, self).__init__(self.all_ids, self.mode) 90 | color_distort_strength = 0.5 91 | color_jitter = transforms.ColorJitter( 92 | brightness=0.8 * color_distort_strength, 93 | contrast=0.8 * color_distort_strength, 94 | saturation=0.8 * color_distort_strength, 95 | hue=0.2 * color_distort_strength 96 | ) 97 | gaussian_kernel_size = 21 98 | self.train_transform = transforms.Compose([ 99 | transforms.Resize([224, 224]), 100 | transforms.RandomHorizontalFlip(p=0.5), 101 | transforms.RandomApply([color_jitter], p=0.8), 102 | transforms.RandomGrayscale(p=0.2), 103 | transforms.GaussianBlur(kernel_size=gaussian_kernel_size), 104 | transforms.ToTensor(), 105 | ]) 106 | 107 | self.val_transform = transforms.Compose([ 108 | transforms.Resize([224, 224]), 109 | transforms.ToTensor(), 110 | ]) 111 | 112 | self.sentence_max_len = TEXT_MAX_LENGTH 113 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 114 | 115 | def load_bert_tokens(self, sample_id): 116 | text_tokens = ' '.join(self.data[sample_id]["text_tokens"]) 117 | text_input = self.tokenizer( 118 | text_tokens, 119 | add_special_tokens=True, 120 | padding="max_length", 121 | max_length=self.sentence_max_len, 122 | truncation=True, 123 | return_tensors='pt' 124 | ) 125 | for k, v in text_input.items(): 126 | text_input[k] = v.squeeze(0) 127 | return text_input 128 | 129 | def load_image(self, sample_id): 130 | image_path = join(self.dataset_root_dir, self.data[sample_id]["img_path"]) 131 | with open(image_path, "rb") as f: 132 | image = Image.open(f) 133 | image = image.convert("RGB") 134 | if self.mode == "train": 135 | image = self.train_transform(image) 136 | else: 137 | image = self.val_transform(image) 138 | 139 | return image 140 | 141 | def __len__(self): 142 | return len(self.sample_ids) 143 | 144 | def __getitem__(self, idx): 145 | sample_id = self.sample_ids[idx] 146 | text_input = self.load_bert_tokens(sample_id) 147 | class_name = self.data[sample_id]["label"] 148 | image = self.load_image(sample_id) 149 | label = torch.tensor(CLASS_NAME.index(class_name), dtype=torch.long) 150 | return text_input, image, label 151 | 152 | -------------------------------------------------------------------------------- /src/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score, f1_score 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | 9 | def multiclass_acc(preds, truths): 10 | return np.sum(np.round(preds) == np.round(truths)) / float(len(truths)) 11 | 12 | 13 | def weighted_accuracy(test_preds_emo, test_truth_emo): 14 | true_label = (test_truth_emo > 0) 15 | predicted_label = (test_preds_emo > 0) 16 | tp = float(np.sum((true_label == 1) & (predicted_label == 1))) 17 | tn = float(np.sum((true_label == 0) & (predicted_label == 0))) 18 | p = float(np.sum(true_label == 1)) 19 | n = float(np.sum(true_label == 0)) 20 | 21 | return (tp * (n / p) + tn) / (2 * n) 22 | 23 | 24 | def eval_senti(results, truths, exclude_zero=False): 25 | test_preds = results.view(-1).cpu().detach().numpy() 26 | test_truth = truths.view(-1).cpu().detach().numpy() 27 | 28 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)]) 29 | 30 | test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.) 31 | test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.) 32 | test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.) 33 | test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.) 34 | 35 | mae = np.mean(np.absolute(test_preds - test_truth)) # Average L1 distance between preds and truths 36 | corr = np.corrcoef(test_preds, test_truth)[0][1] 37 | mult_a7 = multiclass_acc(test_preds_a7, test_truth_a7) 38 | mult_a5 = multiclass_acc(test_preds_a5, test_truth_a5) 39 | f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted') 40 | binary_truth = (test_truth[non_zeros] > 0) 41 | binary_preds = (test_preds[non_zeros] > 0) 42 | acc = accuracy_score(binary_truth, binary_preds) 43 | return acc 44 | 45 | 46 | def train_eval_senti(results, truths, exclude_zero=False): 47 | test_preds = results.view(-1).cpu().detach().numpy() 48 | test_truth = truths.view(-1).cpu().detach().numpy() 49 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0 or (not exclude_zero)]) 50 | binary_truth = (test_truth[non_zeros] > 0) 51 | binary_preds = (test_preds[non_zeros] > 0) 52 | acc = accuracy_score(binary_truth, binary_preds) 53 | return acc 54 | 55 | 56 | def train_eval_iemo(results, truths): 57 | test_preds = results.view(-1, 4).cpu().detach().numpy() 58 | test_truth = truths.view(-1).cpu().detach().numpy() 59 | 60 | test_preds_i = np.argmax(test_preds, axis=1) 61 | test_truth_i = test_truth 62 | acc = accuracy_score(test_truth_i, test_preds_i) 63 | return acc 64 | 65 | 66 | def train_eval_food(results, truths): 67 | test_preds = results.view(-1, 101).cpu().detach().numpy() 68 | test_truth = truths.view(-1).cpu().detach().numpy() 69 | 70 | test_preds_i = np.argmax(test_preds, axis=1) 71 | test_truth_i = test_truth 72 | acc = accuracy_score(test_truth_i, test_preds_i) 73 | return acc 74 | 75 | def eval_food(results, truths): 76 | test_preds = results.view(-1, 101).cpu().detach().numpy() 77 | test_truth = truths.view(-1).cpu().detach().numpy() 78 | 79 | test_preds_i = np.argmax(test_preds, axis=1) 80 | test_truth_i = test_truth 81 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted') 82 | acc = accuracy_score(test_truth_i, test_preds_i) 83 | return acc 84 | 85 | 86 | def eval_iemocap(results, truths): 87 | test_preds = results.view(-1, 4).cpu().detach().numpy() 88 | test_truth = truths.view(-1).cpu().detach().numpy() 89 | 90 | test_preds_i = np.argmax(test_preds, axis=1) 91 | test_truth_i = test_truth 92 | f1 = f1_score(test_truth_i, test_preds_i, average='weighted') 93 | acc = accuracy_score(test_truth_i, test_preds_i) 94 | return acc 95 | 96 | 97 | class SegLoss(nn.Module): 98 | def __init__(self, n_classes=4, weight=None, alpha=0.5): 99 | # dice_loss_plus_cetr_weighted 100 | super(SegLoss, self).__init__() 101 | self.n_classes = n_classes 102 | self.weight = weight.cuda() 103 | # self.weight = weight 104 | self.alpha = alpha 105 | 106 | def forward(self, input, target): 107 | smooth = 0.01 108 | input1 = F.softmax(input, dim=1) 109 | target1 = F.one_hot(target, self.n_classes) 110 | input1 = rearrange(input1, 'b n h w -> b n (h w)') 111 | target1 = rearrange(target1, 'b h w n -> b n (h w)') 112 | input1 = input1[:, 1:, :] 113 | target1 = target1[:, 1:, :].float() 114 | inter = torch.sum(input1 * target1) 115 | union = torch.sum(input1) + torch.sum(target1) + smooth 116 | dice = 2.0 * inter / union 117 | 118 | loss = F.cross_entropy(input, target, weight=self.weight) 119 | 120 | total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha 121 | 122 | return total_loss 123 | 124 | 125 | def Dice(output, target, eps=1e-3): 126 | inter = torch.sum(output * target,dim=(1,2)) + eps 127 | union = torch.sum(output,dim=(1,2)) + torch.sum(target,dim=(1,2)) + eps * 2 128 | x = 2 * inter / union 129 | dice = torch.mean(x) 130 | return dice 131 | 132 | 133 | def train_eval_seg(output, target): 134 | output = torch.argmax(output, dim=1) 135 | dice1 = Dice((output == 3).float(), (target == 3).float()) 136 | dice2 = Dice(((output == 1) | (output == 3)).float(), ((target == 1) | (target == 3)).float()) 137 | dice3 = Dice((output != 0).float(), (target != 0).float()) 138 | return (dice1 + dice2 + dice3) / 3 139 | 140 | 141 | def cal_dice(output, target): 142 | output = torch.argmax(output, dim=1) 143 | dice1 = Dice((output == 3).float(), (target == 3).float()) 144 | dice2 = Dice(((output == 1) | (output == 3)).float(), ((target == 1) | (target == 3)).float()) 145 | dice3 = Dice((output != 0).float(), (target != 0).float()) 146 | print(f'ET: {dice1.item()}, TC: {dice2.item()}, WT: {dice3.item()}.') 147 | 148 | return dice1, dice2, dice3 149 | 150 | 151 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.): 152 | warmup_schedule = np.array([]) 153 | warmup_iters = warmup_epochs * niter_per_ep 154 | if warmup_epochs > 0: 155 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 156 | 157 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 158 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) 159 | 160 | schedule = np.concatenate((warmup_schedule, schedule)) 161 | assert len(schedule) == epochs * niter_per_ep 162 | return schedule 163 | 164 | 165 | def cal_cos(cls_grad, fusion_grad): 166 | fgn = fusion_grad.clone().view(-1) 167 | loss = list() 168 | for i in range(len(cls_grad)): 169 | tmp = cls_grad[i].clone().view(-1) 170 | l = F.cosine_similarity(tmp, fgn, dim=0) 171 | loss.append(l) 172 | 173 | return loss 174 | 175 | -------------------------------------------------------------------------------- /src/segtrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import time 5 | from models.segmodel import DeepLabMultiInput, SegClassifier 6 | from src.eval_metrics import * 7 | 8 | 9 | def initiate(hyp_params, train_loader, valid_loader, test_loader): 10 | model = DeepLabMultiInput(output_stride=16, num_classes=4) 11 | 12 | if hyp_params.modulation != 'none': 13 | classifier = SegClassifier(num_classes=4) 14 | else: 15 | classifier = None 16 | cls_optimizer = None 17 | 18 | if hyp_params.use_cuda: 19 | model = model.cuda() 20 | if hyp_params.modulation != 'none': 21 | classifier = classifier.cuda() 22 | 23 | optimizer = optim.SGD(model.parameters(), lr=0, weight_decay=hyp_params.weight_decay, momentum=hyp_params.momentum) 24 | if hyp_params.modulation != 'none': 25 | cls_optimizer = optim.SGD(classifier.parameters(), lr=hyp_params.cls_lr, weight_decay=hyp_params.weight_decay, momentum=hyp_params.momentum) 26 | criterion = SegLoss(n_classes=4, weight=torch.tensor([0.2, 0.3, 0.25, 0.25])).cuda() 27 | scheduler = cosine_scheduler(base_value=hyp_params.base_lr, final_value=hyp_params.final_lr, epochs=hyp_params.num_epochs, 28 | niter_per_ep=len(train_loader), warmup_epochs=hyp_params.warmup_epochs, start_warmup_value=hyp_params.start_warmup_value) 29 | settings = {'model': model, 'optimizer': optimizer, 'criterion': criterion, 'scheduler': scheduler, 30 | 'classifier': classifier, 'cls_optimizer': cls_optimizer} 31 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader) 32 | 33 | 34 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader): 35 | model = settings['model'] 36 | optimizer = settings['optimizer'] 37 | criterion = settings['criterion'] 38 | scheduler = settings['scheduler'] 39 | classifier = settings['classifier'] 40 | cls_optimizer = settings['cls_optimizer'] 41 | acc1 = [0] * hyp_params.num_mod 42 | l_gm = None 43 | 44 | def train(model, classifier, optimizer, cls_optimizer, criterion, epoch): 45 | nonlocal acc1, l_gm 46 | epoch_loss = 0 47 | model.train() 48 | num_batches = hyp_params.n_train // hyp_params.batch_size 49 | proc_loss, proc_size = 0, 0 50 | start_time = time.time() 51 | for i_batch, batch in enumerate(train_loader): 52 | it = len(train_loader) * (epoch-1) + i_batch 53 | param_group = optimizer.param_groups[0] 54 | param_group['lr'] = scheduler[it] 55 | 56 | flair, t1ce, t1, t2, batch_Y = batch 57 | model.zero_grad() 58 | 59 | if hyp_params.use_cuda: 60 | with torch.cuda.device(0): 61 | flair, t1ce, t1, t2, batch_Y = flair.cuda(), t1ce.cuda(), t1.cuda(), t2.cuda(), batch_Y.cuda() 62 | 63 | batch_size = flair.size(0) 64 | net = nn.DataParallel(model) if batch_size > 10 else model 65 | preds, hf, lf = net(flair, t1ce, t1, t2) 66 | raw_loss = criterion(preds, batch_Y) 67 | if hyp_params.modulation == 'cggm' and l_gm is not None: 68 | raw_loss += hyp_params.lamda * l_gm 69 | 70 | raw_loss.backward() 71 | 72 | if hyp_params.modulation == 'cggm': 73 | cls_optimizer.zero_grad() 74 | net2 = nn.DataParallel(classifier) if batch_size > 10 else classifier 75 | cls_res = net2(hf, lf) 76 | cls_loss = criterion(cls_res[0], batch_Y) 77 | 78 | for name, para in net.named_parameters(): 79 | if 'decoder.last_conv.7.weight' in name: 80 | fusion_grad = para 81 | 82 | for i in range(1, hyp_params.num_mod): 83 | cls_loss += criterion(cls_res[i], batch_Y) 84 | cls_loss.backward() 85 | 86 | cls_grad = [] 87 | for name, para in net2.named_parameters(): 88 | if 'last_conv.7.weight' in name: 89 | cls_grad.append(para) 90 | 91 | llist = cal_cos(cls_grad, fusion_grad) 92 | 93 | acc2 = classifier.cal_coeff(cls_res, batch_Y) 94 | diff = [acc2[i] - acc1[i] for i in range(hyp_params.num_mod)] 95 | 96 | diff_sum = sum(diff) + 1e-8 97 | coeff = list() 98 | 99 | for d in diff: 100 | coeff.append((diff_sum - d) / diff_sum) 101 | acc1 = acc2 102 | 103 | l_gm = np.sum(np.abs(coeff)) - (coeff[0] * llist[0] + coeff[1] * llist[1] + coeff[2] * llist[2] + coeff[3] * llist[3]) 104 | l_gm /= hyp_params.num_mod 105 | 106 | for name, params in net.named_parameters(): 107 | if 'vision_encoder' in name: 108 | params.grad *= (coeff[0] * hyp_params.rou) 109 | if 'text_encoder' in name: 110 | params.grad *= (coeff[1] * hyp_params.rou) 111 | 112 | cls_optimizer.step() 113 | 114 | optimizer.step() 115 | 116 | proc_loss += raw_loss.item() * batch_size 117 | proc_size += batch_size 118 | epoch_loss += raw_loss.item() * batch_size 119 | 120 | return epoch_loss / hyp_params.n_train 121 | 122 | def evaluate(model, criterion, test=False): 123 | model.eval() 124 | loader = test_loader if test else valid_loader 125 | total_loss = 0.0 126 | 127 | results = [] 128 | truths = [] 129 | 130 | with torch.no_grad(): 131 | for i_batch, batch in enumerate(loader): 132 | flair, t1ce, t1, t2, batch_Y = batch 133 | 134 | if hyp_params.use_cuda: 135 | with torch.cuda.device(0): 136 | flair, t1ce, t1, t2, batch_Y = flair.cuda(), t1ce.cuda(), t1.cuda(), t2.cuda(), batch_Y.cuda() 137 | 138 | net = model 139 | preds, _, _ = net(flair, t1ce, t1, t2) 140 | 141 | total_loss += criterion(preds, batch_Y).item() 142 | 143 | # Collect the results into dictionary 144 | results.append(preds) 145 | truths.append(batch_Y) 146 | 147 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 148 | 149 | results = torch.cat(results) 150 | truths = torch.cat(truths) 151 | return avg_loss, results, truths 152 | 153 | best_dice = 0 154 | for epoch in range(1, hyp_params.num_epochs + 1): 155 | start = time.time() 156 | train(model, classifier, optimizer, cls_optimizer, criterion, epoch) 157 | val_loss, r, t = evaluate(model, criterion, test=False) 158 | d1, d2, d3 = cal_dice(r, t) 159 | dice = (d1 + d2 + d3) / 3 160 | 161 | end = time.time() 162 | duration = end - start 163 | 164 | print("-" * 50) 165 | print( 166 | 'Epoch {:2d} | Time {:5.4f} sec | Valid Loss {:5.4f}'.format(epoch, duration, val_loss)) 167 | print("-" * 50) 168 | 169 | if dice < best_dice: 170 | best_dice = dice 171 | 172 | print('Dice: ', best_dice) -------------------------------------------------------------------------------- /src/msatrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import time 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from models.msamodel import MSAModel, ClassifierGuided 8 | from src.eval_metrics import eval_iemocap, eval_senti, cal_cos 9 | 10 | 11 | def initiate(hyp_params, train_loader, valid_loader, test_loader): 12 | model = MSAModel(hyp_params.output_dim, hyp_params.orig_dim, hyp_params.proj_dim, 13 | hyp_params.num_heads, hyp_params.layers, hyp_params.relu_dropout, 14 | hyp_params.embed_dropout, hyp_params.res_dropout, hyp_params.out_dropout, 15 | hyp_params.attn_dropout) 16 | 17 | if hyp_params.modulation != 'none': 18 | classifier = ClassifierGuided(hyp_params.output_dim, hyp_params.num_mod, hyp_params.proj_dim, hyp_params.num_heads, 19 | hyp_params.cls_layers, hyp_params.relu_dropout, hyp_params.embed_dropout, 20 | hyp_params.res_dropout, hyp_params.attn_dropout) 21 | cls_optimizer = getattr(optim, hyp_params.optim)(classifier.parameters(), lr=hyp_params.cls_lr) 22 | else: 23 | classifier, cls_optimizer = None, None 24 | 25 | if hyp_params.use_cuda: 26 | model = model.cuda() 27 | if hyp_params.modulation != 'none': 28 | classifier = classifier.cuda() 29 | 30 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr) 31 | 32 | criterion = getattr(nn, hyp_params.criterion)() 33 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True) 34 | settings = {'model': model, 'optimizer': optimizer, 'criterion': criterion, 'scheduler': scheduler, 35 | 'classifier': classifier, 'cls_optimizer': cls_optimizer} 36 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader) 37 | 38 | 39 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader): 40 | model = settings['model'] 41 | optimizer = settings['optimizer'] 42 | criterion = settings['criterion'] 43 | scheduler = settings['scheduler'] 44 | classifier = settings['classifier'] 45 | cls_optimizer = settings['cls_optimizer'] 46 | acc1 = [0] * hyp_params.num_mod 47 | l_gm = None 48 | 49 | def train(model, classifier, optimizer, cls_optimizer, criterion): 50 | nonlocal acc1, l_gm 51 | epoch_loss = 0 52 | model.train() 53 | num_batches = hyp_params.n_train // hyp_params.batch_size 54 | proc_loss, proc_size = 0, 0 55 | start_time = time.time() 56 | for i_batch, batch in enumerate(train_loader): 57 | text, audio, vision, batch_Y = batch['text'], batch['audio'], batch['vision'], batch['labels'] 58 | eval_attr = batch_Y.squeeze(-1) # if num of labels is 1 59 | model.zero_grad() 60 | 61 | if hyp_params.use_cuda: 62 | with torch.cuda.device(0): 63 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 64 | if hyp_params.dataset == 'iemo': 65 | eval_attr = eval_attr.long() 66 | 67 | batch_size = text.size(0) 68 | net = nn.DataParallel(model) if batch_size > 10 else model 69 | preds, hs = net([text, audio, vision]) 70 | 71 | if hyp_params.dataset == 'iemo': 72 | preds = preds.view(-1, 4) 73 | eval_attr = eval_attr.view(-1) 74 | 75 | raw_loss = criterion(preds, eval_attr) 76 | if hyp_params.modulation == 'cggm' and l_gm is not None: 77 | raw_loss += hyp_params.lamda * l_gm 78 | raw_loss.backward() 79 | 80 | if hyp_params.modulation == 'cggm': 81 | cls_optimizer.zero_grad() 82 | net2 = nn.DataParallel(classifier) if batch_size > 10 else classifier 83 | cls_res = net2(hs) 84 | 85 | for name, para in net.named_parameters(): 86 | if 'out_layer.weight' in name: 87 | fusion_grad = para 88 | 89 | if hyp_params.dataset == 'iemo': 90 | for i in range(len(cls_res)): 91 | cls_res[i] = cls_res[i].view(-1, 4) 92 | 93 | cls_loss = criterion(cls_res[0], eval_attr) 94 | for i in range(1, hyp_params.num_mod): 95 | cls_loss += criterion(cls_res[i], eval_attr) 96 | 97 | cls_loss.backward() 98 | 99 | cls_grad = [] 100 | for name, para in net2.named_parameters(): 101 | if 'out_layer.weight' in name: 102 | cls_grad.append(para) 103 | 104 | llist = cal_cos(cls_grad, fusion_grad) 105 | 106 | acc2 = classifier.cal_coeff(hyp_params.dataset, eval_attr, cls_res) 107 | diff = [acc2[i] - acc1[i] for i in range(hyp_params.num_mod)] 108 | 109 | diff_sum = sum(diff) + 1e-8 110 | coeff = list() 111 | 112 | for d in diff: 113 | coeff.append((diff_sum - d) / diff_sum) 114 | 115 | acc1 = acc2 116 | l_gm = np.sum(np.abs(coeff)) - (coeff[0] * llist[0] + coeff[1] * llist[1] + coeff[2] * llist[2]) 117 | l_gm /= hyp_params.num_mod 118 | 119 | for i in range(hyp_params.num_mod): 120 | for name, params in net.named_parameters(): 121 | if f'encoders.{i}' in name: 122 | params.grad *= (coeff[i] * hyp_params.rou) 123 | 124 | cls_optimizer.step() 125 | 126 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip) 127 | optimizer.step() 128 | 129 | proc_loss += raw_loss.item() * batch_size 130 | proc_size += batch_size 131 | epoch_loss += raw_loss.item() * batch_size 132 | 133 | return epoch_loss / hyp_params.n_train 134 | 135 | def evaluate(model, criterion, test=False): 136 | model.eval() 137 | loader = test_loader if test else valid_loader 138 | total_loss = 0.0 139 | 140 | results = [] 141 | truths = [] 142 | 143 | with torch.no_grad(): 144 | for i_batch, batch in enumerate(loader): 145 | text, audio, vision, batch_Y = batch['text'], batch['audio'], batch['vision'], batch['labels'] 146 | eval_attr = batch_Y.squeeze(dim=-1) # if num of labels is 1 147 | 148 | if hyp_params.use_cuda: 149 | with torch.cuda.device(0): 150 | text, audio, vision, eval_attr = text.cuda(), audio.cuda(), vision.cuda(), eval_attr.cuda() 151 | if hyp_params.dataset == 'iemo': 152 | eval_attr = eval_attr.long() 153 | 154 | net = model 155 | preds, _ = net([text, audio, vision]) 156 | if hyp_params.dataset == 'iemo': 157 | preds = preds.view(-1, 4) 158 | eval_attr = eval_attr.view(-1) 159 | 160 | total_loss += criterion(preds, eval_attr).item() 161 | results.append(preds) 162 | truths.append(eval_attr) 163 | 164 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 165 | 166 | results = torch.cat(results) 167 | truths = torch.cat(truths) 168 | return avg_loss, results, truths 169 | 170 | best_acc = 0 171 | for epoch in range(1, hyp_params.num_epochs + 1): 172 | start = time.time() 173 | train(model, classifier, optimizer, cls_optimizer, criterion) 174 | val_loss, val_res, val_truth = evaluate(model, criterion, test=False) 175 | 176 | if hyp_params.dataset == 'iemo': 177 | acc = eval_iemocap(val_res, val_truth) 178 | else: 179 | acc = eval_senti(val_res, val_truth) 180 | 181 | end = time.time() 182 | duration = end - start 183 | scheduler.step(val_loss) # Decay learning rate by validation loss 184 | 185 | print("-" * 50) 186 | print( 187 | 'Epoch {:2d} | Time {:5.4f} sec | Valid Loss {:5.4f}'.format(epoch, duration, val_loss)) 188 | print("-" * 50) 189 | 190 | if best_acc < acc: 191 | best_acc = acc 192 | 193 | print("Accuracy: ", best_acc) -------------------------------------------------------------------------------- /src/foodtrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.optim as optim 5 | import time 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from models.msamodel import ClassifierGuided 8 | from models.foodmodel import FoodModel 9 | from src.eval_metrics import eval_food, cal_cos 10 | from transformers import ViTModel, BertModel 11 | 12 | 13 | def initiate(hyp_params, train_loader, valid_loader, test_loader): 14 | model = FoodModel(101, 5, hyp_params.layers, hyp_params.relu_dropout, 15 | hyp_params.embed_dropout, hyp_params.res_dropout, hyp_params.out_dropout, hyp_params.attn_dropout) 16 | 17 | if hyp_params.modulation != 'none': 18 | classifier = ClassifierGuided(101, 2, hyp_params.proj_dim, 5, 19 | hyp_params.cls_layers, hyp_params.relu_dropout, hyp_params.embed_dropout, 20 | hyp_params.res_dropout, hyp_params.attn_dropout) 21 | cls_optimizer = getattr(optim, hyp_params.optim)(classifier.parameters(), lr=hyp_params.cls_lr, weight_decay=3e-3) 22 | else: 23 | classifier = None 24 | cls_optimizer = None 25 | 26 | vit = ViTModel.from_pretrained(hyp_params.vit) 27 | bert = BertModel.from_pretrained(hyp_params.bert) 28 | 29 | if hyp_params.use_cuda: 30 | model = model.cuda() 31 | if hyp_params.modulation != 'none': 32 | classifier = classifier.cuda() 33 | bert = bert.cuda() 34 | vit = vit.cuda() 35 | 36 | optimizer = getattr(optim, hyp_params.optim)(model.parameters(), lr=hyp_params.lr, weight_decay=6e-3) 37 | 38 | criterion = getattr(nn, hyp_params.criterion)() 39 | scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=hyp_params.when, factor=0.1, verbose=True) 40 | settings = {'model': model, 'optimizer': optimizer, 'criterion': criterion, 'scheduler': scheduler, 41 | 'classifier': classifier, 'cls_optimizer': cls_optimizer, 'vit': vit, 'bert': bert} 42 | return train_model(settings, hyp_params, train_loader, valid_loader, test_loader) 43 | 44 | 45 | def train_model(settings, hyp_params, train_loader, valid_loader, test_loader): 46 | model = settings['model'] 47 | optimizer = settings['optimizer'] 48 | criterion = settings['criterion'] 49 | scheduler = settings['scheduler'] 50 | classifier = settings['classifier'] 51 | cls_optimizer = settings['cls_optimizer'] 52 | bert = settings['bert'] 53 | vit = settings['vit'] 54 | acc1 = [0] * hyp_params.num_mod 55 | l_gm = None 56 | 57 | def train(model, classifier, optimizer, cls_optimizer, criterion, bert, vit): 58 | nonlocal acc1, l_gm 59 | epoch_loss = 0 60 | model.train() 61 | num_batches = hyp_params.n_train // hyp_params.batch_size 62 | proc_loss, proc_size = 0, 0 63 | start_time = time.time() 64 | for i_batch, batch in enumerate(train_loader): 65 | text, image, batch_Y = batch 66 | eval_attr = batch_Y.squeeze(-1) # if num of labels is 1 67 | model.zero_grad() 68 | if hyp_params.use_cuda: 69 | with torch.cuda.device(0): 70 | ti, ta, tt = text['input_ids'].cuda(), text['attention_mask'].cuda(), text['token_type_ids'].cuda() 71 | image, eval_attr = image.cuda(), eval_attr.cuda() 72 | eval_attr = eval_attr.long() 73 | 74 | with torch.no_grad(): 75 | v = vit(image)['last_hidden_state'] 76 | t = bert(ti, ta, tt)['last_hidden_state'] 77 | 78 | batch_size = image.size(0) 79 | net = nn.DataParallel(model) if batch_size > 10 else model 80 | 81 | preds, hs = net(v, t) 82 | preds = preds.view(-1, 101) 83 | eval_attr = eval_attr.view(-1) 84 | 85 | raw_loss = criterion(preds, eval_attr) 86 | if hyp_params.modulation == 'cggm' and l_gm is not None: 87 | raw_loss += hyp_params.lamda * l_gm 88 | raw_loss.backward() 89 | 90 | if hyp_params.modulation == 'cggm': 91 | cls_optimizer.zero_grad() 92 | net2 = nn.DataParallel(classifier) if batch_size > 10 else classifier 93 | cls_res = net2(hs) 94 | for i in range(len(cls_res)): 95 | cls_res[i] = cls_res[i].view(-1, 101) 96 | 97 | for name, para in net.named_parameters(): 98 | if 'out_layer.weight' in name: 99 | fusion_grad = para 100 | cls_loss = criterion(cls_res[0], eval_attr) 101 | for i in range(1, hyp_params.num_mod): 102 | cls_loss += criterion(cls_res[i], eval_attr) 103 | 104 | cls_loss.backward() 105 | cls_grad = [] 106 | for name, para in net2.named_parameters(): 107 | if 'out_layer.weight' in name: 108 | cls_grad.append(para) 109 | 110 | llist = cal_cos(cls_grad, fusion_grad) 111 | 112 | acc2 = classifier.cal_coeff(hyp_params.dataset, eval_attr, cls_res) 113 | diff = [acc2[i] - acc1[i] for i in range(hyp_params.num_mod)] 114 | 115 | diff_sum = sum(diff) + 1e-8 116 | coeff = list() 117 | 118 | for d in diff: 119 | coeff.append((diff_sum - d) / diff_sum) 120 | acc1 = acc2 121 | 122 | l_gm = np.sum(np.abs(coeff)) - (coeff[0] * llist[0] + coeff[1] * llist[1]) 123 | l_gm /= hyp_params.num_mod 124 | 125 | for name, params in net.named_parameters(): 126 | if 'vision_encoder' in name: 127 | params.grad *= (coeff[0] * hyp_params.rou) 128 | if 'text_encoder' in name: 129 | params.grad *= (coeff[1] * hyp_params.rou) 130 | 131 | cls_optimizer.step() 132 | 133 | torch.nn.utils.clip_grad_norm_(model.parameters(), hyp_params.clip) 134 | 135 | optimizer.step() 136 | 137 | proc_loss += raw_loss.item() * batch_size 138 | proc_size += batch_size 139 | epoch_loss += raw_loss.item() * batch_size 140 | 141 | return epoch_loss / hyp_params.n_train 142 | 143 | def evaluate(model, criterion, bert, vit, test=False): 144 | model.eval() 145 | loader = test_loader if test else valid_loader 146 | total_loss = 0.0 147 | 148 | results = [] 149 | truths = [] 150 | 151 | with torch.no_grad(): 152 | for i_batch, batch in enumerate(loader): 153 | text, image, batch_Y = batch 154 | eval_attr = batch_Y.squeeze(dim=-1) # if num of labels is 1 155 | 156 | if hyp_params.use_cuda: 157 | with torch.cuda.device(0): 158 | ti, ta, tt = text['input_ids'].cuda(), text['attention_mask'].cuda(), text['token_type_ids'].cuda() 159 | image, eval_attr = image.cuda(), eval_attr.cuda() 160 | eval_attr = eval_attr.long() 161 | 162 | t = bert(ti, ta, tt)['last_hidden_state'] 163 | v = vit(image)['last_hidden_state'] 164 | 165 | net = model 166 | preds, _ = net(v, t) 167 | preds = preds.view(-1, 101) 168 | eval_attr = eval_attr.view(-1) 169 | total_loss += criterion(preds, eval_attr).item() 170 | 171 | results.append(preds) 172 | truths.append(eval_attr) 173 | 174 | avg_loss = total_loss / (hyp_params.n_test if test else hyp_params.n_valid) 175 | 176 | results = torch.cat(results) 177 | truths = torch.cat(truths) 178 | return avg_loss, results, truths 179 | 180 | best_acc = 0 181 | 182 | for epoch in range(1, hyp_params.num_epochs + 1): 183 | start = time.time() 184 | train(model, classifier, optimizer, cls_optimizer, criterion, bert, vit) 185 | val_loss, r, t = evaluate(model, criterion, bert, vit, test=False) 186 | acc = eval_food(r, t) 187 | 188 | end = time.time() 189 | duration = end - start 190 | scheduler.step(val_loss) # Decay learning rate by validation loss 191 | 192 | print("-" * 50) 193 | print( 194 | 'Epoch {:2d} | Time {:5.4f} sec | Valid Loss {:5.4f}'.format(epoch, duration, val_loss)) 195 | print("-" * 50) 196 | 197 | if acc > best_acc: 198 | best_acc = acc 199 | 200 | print("Accuracy: ", best_acc) -------------------------------------------------------------------------------- /datasets/IEMODataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import numpy as np 5 | import torch 6 | from torch.nn.utils.rnn import pad_sequence 7 | 8 | import random 9 | import torch.utils.data as data 10 | from PIL import Image 11 | import torchvision.transforms as transforms 12 | from abc import ABC, abstractmethod 13 | 14 | 15 | class BaseDataset(data.Dataset, ABC): 16 | """This class is an abstract base class (ABC) for datasets. 17 | 18 | To create a subclass, you need to implement the following four functions: 19 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 20 | -- <__len__>: return the size of dataset. 21 | -- <__getitem__>: get a data point. 22 | -- : (optionally) add dataset-specific options and set default options. 23 | """ 24 | 25 | def __init__(self, opt): 26 | """Initialize the class; save the options in the class 27 | 28 | Parameters: 29 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 30 | """ 31 | self.opt = opt 32 | self.manual_collate_fn = False 33 | 34 | @staticmethod 35 | def modify_commandline_options(parser, is_train): 36 | """Add new dataset-specific options, and rewrite default values for existing options. 37 | 38 | Parameters: 39 | parser -- original option parser 40 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 41 | 42 | Returns: 43 | the modified parser. 44 | """ 45 | return parser 46 | 47 | @abstractmethod 48 | def __len__(self): 49 | """Return the total number of images in the dataset.""" 50 | return 0 51 | 52 | @abstractmethod 53 | def __getitem__(self, index): 54 | """Return a data point and its metadata information. 55 | 56 | Parameters: 57 | index - - a random integer for data indexing 58 | 59 | Returns: 60 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 61 | """ 62 | pass 63 | 64 | 65 | def get_params(opt, size): 66 | w, h = size 67 | new_h = h 68 | new_w = w 69 | if opt.preprocess == 'resize_and_crop': 70 | new_h = new_w = opt.load_size 71 | elif opt.preprocess == 'scale_width_and_crop': 72 | new_w = opt.load_size 73 | new_h = opt.load_size * h // w 74 | 75 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 76 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 77 | 78 | flip = random.random() > 0.5 79 | 80 | return {'crop_pos': (x, y), 'flip': flip} 81 | 82 | 83 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 84 | transform_list = [] 85 | if grayscale: 86 | transform_list.append(transforms.Grayscale(1)) 87 | if 'resize' in opt.preprocess: 88 | osize = [opt.load_size, opt.load_size] 89 | transform_list.append(transforms.Resize(osize, method)) 90 | elif 'scale_width' in opt.preprocess: 91 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 92 | 93 | if 'crop' in opt.preprocess: 94 | if params is None: 95 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 96 | else: 97 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 98 | 99 | if opt.preprocess == 'none': 100 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 101 | 102 | if not opt.no_flip: 103 | if params is None: 104 | transform_list.append(transforms.RandomHorizontalFlip()) 105 | elif params['flip']: 106 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 107 | 108 | if convert: 109 | transform_list += [transforms.ToTensor()] 110 | if grayscale: 111 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 112 | else: 113 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 114 | return transforms.Compose(transform_list) 115 | 116 | 117 | def __make_power_2(img, base, method=Image.BICUBIC): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __scale_width(img, target_width, method=Image.BICUBIC): 129 | ow, oh = img.size 130 | if (ow == target_width): 131 | return img 132 | w = target_width 133 | h = int(target_width * oh / ow) 134 | return img.resize((w, h), method) 135 | 136 | 137 | def __crop(img, pos, size): 138 | ow, oh = img.size 139 | x1, y1 = pos 140 | tw = th = size 141 | if (ow > tw or oh > th): 142 | return img.crop((x1, y1, x1 + tw, y1 + th)) 143 | return img 144 | 145 | 146 | def __flip(img, flip): 147 | if flip: 148 | return img.transpose(Image.FLIP_LEFT_RIGHT) 149 | return img 150 | 151 | 152 | def __print_size_warning(ow, oh, w, h): 153 | """Print warning information about image size(only print once)""" 154 | if not hasattr(__print_size_warning, 'has_printed'): 155 | print("The image size needs to be a multiple of 4. " 156 | "The loaded image size was (%d, %d), so it was adjusted to " 157 | "(%d, %d). This adjustment will be done to all images " 158 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 159 | __print_size_warning.has_printed = True 160 | 161 | 162 | 163 | class IEMOData(BaseDataset): 164 | def __init__(self, opt, data_path, set_name): 165 | ''' IEMOCAP dataset reader 166 | set_name in ['trn', 'val', 'tst'] 167 | ''' 168 | super().__init__(opt) 169 | 170 | # record & load basic settings 171 | cvNo = opt.cvNo 172 | self.set_name = set_name 173 | 174 | config = { 175 | "target_root": os.path.join(data_path, "target"), 176 | "feature_root": data_path, 177 | } 178 | 179 | self.norm_method = opt.norm_method 180 | # load feature 181 | self.A_type = opt.A_type 182 | self.all_A = \ 183 | h5py.File(os.path.join(config['feature_root'], 'A', f'{self.A_type}.h5'), 'r') 184 | if self.A_type == 'comparE': 185 | self.mean_std = h5py.File(os.path.join(config['feature_root'], 'A', 'comparE_mean_std.h5'), 'r') 186 | self.mean = torch.from_numpy(self.mean_std[str(cvNo)]['mean'][()]).unsqueeze(0).float() 187 | self.std = torch.from_numpy(self.mean_std[str(cvNo)]['std'][()]).unsqueeze(0).float() 188 | elif self.A_type == 'comparE_raw': 189 | self.mean, self.std = self.calc_mean_std() 190 | 191 | self.V_type = opt.V_type 192 | self.all_V = \ 193 | h5py.File(os.path.join(config['feature_root'], 'V', f'{self.V_type}.h5'), 'r') 194 | self.L_type = opt.L_type 195 | self.all_L = \ 196 | h5py.File(os.path.join(config['feature_root'], 'L', f'{self.L_type}.h5'), 'r') 197 | 198 | # load dataset in memory 199 | if opt.in_mem: 200 | self.all_A = self.h5_to_dict(self.all_A) 201 | self.all_V = self.h5_to_dict(self.all_V) 202 | self.all_L = self.h5_to_dict(self.all_L) 203 | 204 | # load target 205 | label_path = os.path.join(config['target_root'], f'{cvNo}', f"{set_name}_label.npy") 206 | int2name_path = os.path.join(config['target_root'], f'{cvNo}', f"{set_name}_int2name.npy") 207 | self.label = np.load(label_path) 208 | 209 | self.label = np.argmax(self.label, axis=1) 210 | self.int2name = np.load(int2name_path) 211 | self.manual_collate_fn = True 212 | self.tshape = (22, 1024) 213 | self.ashape = (350, 130) 214 | self.vshape = (50, 342) 215 | 216 | def __getitem__(self, index): 217 | int2name = self.int2name[index] 218 | int2name = int2name[0].decode() 219 | label = torch.tensor(self.label[index]) 220 | # process A_feat 221 | A_feat = torch.from_numpy(self.all_A[int2name][()]).float() 222 | if self.A_type == 'comparE' or self.A_type == 'comparE_raw': 223 | A_feat = self.normalize_on_utt(A_feat) if self.norm_method == 'utt' else self.normalize_on_trn(A_feat) 224 | # process V_feat 225 | V_feat = torch.from_numpy(self.all_V[int2name][()]).float() 226 | # process L_feat 227 | L_feat = torch.from_numpy(self.all_L[int2name][()]).float() 228 | X = [L_feat, A_feat, V_feat] 229 | 230 | return X, label 231 | 232 | def __len__(self): 233 | return len(self.label) 234 | 235 | def h5_to_dict(self, h5f): 236 | ret = {} 237 | for key in h5f.keys(): 238 | ret[key] = h5f[key][()] 239 | return ret 240 | 241 | def normalize_on_utt(self, features): 242 | mean_f = torch.mean(features, dim=0).unsqueeze(0).float() 243 | std_f = torch.std(features, dim=0).unsqueeze(0).float() 244 | std_f[std_f == 0.0] = 1.0 245 | features = (features - mean_f) / std_f 246 | return features 247 | 248 | def normalize_on_trn(self, features): 249 | features = (features - self.mean) / self.std 250 | return features 251 | 252 | def get_dim(self): 253 | return [1024, 130, 342] 254 | 255 | def get_seq_len(self): 256 | return [22, 350, 50] 257 | 258 | def calc_mean_std(self): 259 | utt_ids = [utt_id for utt_id in self.all_A.keys()] 260 | feats = np.array([self.all_A[utt_id] for utt_id in utt_ids]) 261 | _feats = feats.reshape(-1, feats.shape[2]) 262 | mean = np.mean(_feats, axis=0) 263 | std = np.std(_feats, axis=0) 264 | std[std == 0.0] = 1.0 265 | return mean, std 266 | 267 | def collate_fn(self, batch): 268 | max_length = 350 269 | A = [torch.cat( 270 | [sample[0][1], torch.zeros((max_length - len(sample[0][1]), sample[0][1].shape[1]), device='cpu')]) for 271 | sample in batch] 272 | V = [sample[0][2] for sample in batch] 273 | L = [sample[0][0] for sample in batch] 274 | A = pad_sequence(A, batch_first=True, padding_value=0) 275 | V = pad_sequence(V, batch_first=True, padding_value=0) 276 | L = pad_sequence(L, batch_first=True, padding_value=0) 277 | label = torch.tensor([sample[1] for sample in batch]) 278 | 279 | s = { 280 | 'text': L, 281 | 'audio': A, 282 | 'vision': V, 283 | 'labels': label 284 | } 285 | 286 | return s -------------------------------------------------------------------------------- /models/segmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from src.eval_metrics import train_eval_seg 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, BatchNorm, input_heads=1): 10 | super(Decoder, self).__init__() 11 | 12 | low_level_inplanes = 256 * input_heads 13 | last_conv_input = 256 * input_heads + 48 14 | 15 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 16 | self.bn1 = BatchNorm(48) 17 | self.relu = nn.ReLU() 18 | self.condconv1 = nn.Conv2d(last_conv_input, 256, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.last_conv = nn.Sequential( 20 | BatchNorm(256), 21 | nn.ReLU(), 22 | nn.Dropout(0.5), 23 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 24 | BatchNorm(256), 25 | nn.ReLU(), 26 | nn.Dropout(0.1), 27 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 28 | self._init_weight() 29 | 30 | def forward(self, x, low_level_feat): 31 | 32 | low_level_feat = self.conv1(low_level_feat) 33 | low_level_feat = self.bn1(low_level_feat) 34 | low_level_feat = self.relu(low_level_feat) 35 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 36 | x = torch.cat((x, low_level_feat), dim=1) 37 | x = self.condconv1(x) 38 | x = self.last_conv(x) 39 | 40 | return x 41 | 42 | def _init_weight(self): 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | torch.nn.init.kaiming_normal_(m.weight) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1) 48 | m.bias.data.zero_() 49 | 50 | 51 | def build_decoder(num_classes, BatchNorm, input_heads=1): 52 | return Decoder(num_classes, BatchNorm, input_heads) 53 | 54 | 55 | class _ASPPModule(nn.Module): 56 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 57 | super(_ASPPModule, self).__init__() 58 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 59 | stride=1, padding=padding, dilation=dilation, bias=False) 60 | self.bn = BatchNorm(planes) 61 | self.relu = nn.ReLU() 62 | 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x = self.atrous_conv(x) 67 | x = self.bn(x) 68 | 69 | return self.relu(x) 70 | 71 | def _init_weight(self): 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | torch.nn.init.kaiming_normal_(m.weight) 75 | elif isinstance(m, nn.BatchNorm2d): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | 79 | class ASPP(nn.Module): 80 | def __init__(self, output_stride, BatchNorm): 81 | super(ASPP, self).__init__() 82 | inplanes = 2048 83 | 84 | if output_stride == 16: 85 | #dilations = [1, 6, 12, 18] 86 | dilations = [1, 2, 4, 6] 87 | elif output_stride == 8: 88 | dilations = [1, 12, 24, 36] 89 | else: 90 | raise NotImplementedError 91 | 92 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 93 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 94 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 95 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 96 | 97 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 98 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 99 | BatchNorm(256), 100 | nn.ReLU()) 101 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 102 | self.bn1 = BatchNorm(256) 103 | self.relu = nn.ReLU() 104 | self.dropout = nn.Dropout(0.5) 105 | self._init_weight() 106 | 107 | def forward(self, x): 108 | x1 = self.aspp1(x) 109 | x2 = self.aspp2(x) 110 | x3 = self.aspp3(x) 111 | x4 = self.aspp4(x) 112 | x5 = self.global_avg_pool(x) 113 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 114 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 115 | 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | 120 | return self.dropout(x) 121 | 122 | def _init_weight(self): 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | torch.nn.init.kaiming_normal_(m.weight) 128 | elif isinstance(m, nn.BatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | 132 | 133 | def build_aspp(output_stride, BatchNorm): 134 | return ASPP(output_stride, BatchNorm) 135 | 136 | 137 | class Bottleneck(nn.Module): 138 | expansion = 4 139 | 140 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 141 | super(Bottleneck, self).__init__() 142 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 143 | self.bn1 = BatchNorm(planes) 144 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 145 | dilation=dilation, padding=dilation, bias=False) 146 | self.bn2 = BatchNorm(planes) 147 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 148 | self.bn3 = BatchNorm(planes * 4) 149 | self.relu = nn.ReLU(inplace=True) 150 | self.downsample = downsample 151 | self.stride = stride 152 | self.dilation = dilation 153 | 154 | def forward(self, x): 155 | residual = x 156 | 157 | out = self.conv1(x) 158 | out = self.bn1(out) 159 | out = self.relu(out) 160 | 161 | out = self.conv2(out) 162 | out = self.bn2(out) 163 | out = self.relu(out) 164 | 165 | out = self.conv3(out) 166 | out = self.bn3(out) 167 | 168 | if self.downsample is not None: 169 | residual = self.downsample(x) 170 | 171 | out += residual 172 | out = self.relu(out) 173 | 174 | return out 175 | 176 | 177 | class ResNet(nn.Module): 178 | 179 | def __init__(self, block, layers, output_stride, BatchNorm, input_dim=1): 180 | self.inplanes = 64 181 | super(ResNet, self).__init__() 182 | blocks = [1, 2, 4] 183 | if output_stride == 16: 184 | strides = [1, 2, 2, 1] 185 | dilations = [1, 1, 1, 2] 186 | elif output_stride == 8: 187 | strides = [1, 2, 1, 1] 188 | dilations = [1, 1, 2, 4] 189 | else: 190 | raise NotImplementedError 191 | # Modules 192 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, 193 | bias=False) 194 | self.bn1 = BatchNorm(64) 195 | self.relu = nn.ReLU(inplace=True) 196 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 197 | 198 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], 199 | BatchNorm=BatchNorm) 200 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], 201 | BatchNorm=BatchNorm) 202 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], 203 | BatchNorm=BatchNorm) 204 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], 205 | BatchNorm=BatchNorm) 206 | self._init_weight() 207 | 208 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 209 | downsample = None 210 | if stride != 1 or self.inplanes != planes * block.expansion: 211 | downsample = nn.Sequential( 212 | nn.Conv2d(self.inplanes, planes * block.expansion, 213 | kernel_size=1, stride=stride, bias=False), 214 | BatchNorm(planes * block.expansion), 215 | ) 216 | 217 | layers = [] 218 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 219 | self.inplanes = planes * block.expansion 220 | for i in range(1, blocks): 221 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 222 | 223 | return nn.Sequential(*layers) 224 | 225 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 226 | downsample = None 227 | if stride != 1 or self.inplanes != planes * block.expansion: 228 | downsample = nn.Sequential( 229 | nn.Conv2d(self.inplanes, planes * block.expansion, 230 | kernel_size=1, stride=stride, bias=False), 231 | BatchNorm(planes * block.expansion), 232 | ) 233 | 234 | layers = [] 235 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation, 236 | downsample=downsample, BatchNorm=BatchNorm)) 237 | self.inplanes = planes * block.expansion 238 | for i in range(1, len(blocks)): 239 | layers.append(block(self.inplanes, planes, stride=1, 240 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm)) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def forward(self, input): 245 | 246 | x = self.conv1(input) 247 | x = self.bn1(x) 248 | x = self.relu(x) 249 | x = self.maxpool(x) 250 | 251 | x = self.layer1(x) 252 | low_level_feat = x 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | x = self.layer4(x) 256 | return x, low_level_feat 257 | 258 | def _init_weight(self): 259 | for m in self.modules(): 260 | if isinstance(m, nn.Conv2d): 261 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 262 | m.weight.data.normal_(0, math.sqrt(2. / n)) 263 | elif isinstance(m, nn.BatchNorm2d): 264 | m.weight.data.fill_(1) 265 | m.bias.data.zero_() 266 | 267 | 268 | def ResNet101(output_stride, BatchNorm, input_dim=1): 269 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, input_dim=input_dim) 270 | return model 271 | 272 | class DeepLabMultiInput(nn.Module): 273 | def __init__(self, output_stride=8, num_classes=4): 274 | super(DeepLabMultiInput, self).__init__() 275 | 276 | BatchNorm = nn.BatchNorm2d 277 | 278 | self.backbone1 = ResNet101(output_stride, BatchNorm) # flair 279 | self.aspp1 = build_aspp(output_stride, BatchNorm) 280 | self.backbone2 = ResNet101(output_stride, BatchNorm) # t1ce 281 | self.aspp2 = build_aspp(output_stride, BatchNorm) 282 | self.backbone3 = ResNet101(output_stride, BatchNorm) # t1 283 | self.aspp3 = build_aspp(output_stride, BatchNorm) 284 | self.backbone4 = ResNet101(output_stride, BatchNorm) # t2 285 | self.aspp4 = build_aspp(output_stride, BatchNorm) 286 | 287 | self.decoder = build_decoder(num_classes, BatchNorm, input_heads=4) 288 | 289 | def forward(self, input1, input2=None, input3=None, input4=None): 290 | x1, low_level_feat1 = self.backbone1(input1) 291 | x1 = self.aspp1(x1) 292 | # flair 293 | if input2 is not None: 294 | x2, low_level_feat2 = self.backbone2(input2) 295 | x2 = self.aspp2(x2) 296 | else: 297 | x2 = torch.zeros_like(x1) 298 | low_level_feat2 = torch.zeros_like(low_level_feat1) 299 | # DoLP 300 | if input3 is not None: 301 | x3, low_level_feat3 = self.backbone3(input3) 302 | x3 = self.aspp3(x3) 303 | else: 304 | x3 = torch.zeros_like(x1) 305 | low_level_feat3 = torch.zeros_like(low_level_feat1) 306 | # NIR 307 | if input4 is not None: 308 | x4, low_level_feat4 = self.backbone4(input4) 309 | x4 = self.aspp4(x4) 310 | else: 311 | x4 = torch.zeros_like(x1) 312 | low_level_feat4 = torch.zeros_like(low_level_feat1) 313 | 314 | x = torch.cat([x1, x2, x3, x4], dim=1) 315 | hf = [x1, x2, x3, x4] 316 | lf = [low_level_feat1, low_level_feat2, low_level_feat3, low_level_feat4] 317 | low_level_feat = torch.cat([low_level_feat1, low_level_feat2, low_level_feat3, low_level_feat4], dim=1) 318 | x = self.decoder(x, low_level_feat) 319 | x = F.interpolate(x, size=input1.size()[2:], mode='bilinear', align_corners=True) 320 | 321 | hff, lff = [], [] 322 | for h in hf: 323 | hff.append(h.clone().detach()) 324 | for l in lf: 325 | lff.append(l.clone().detach()) 326 | 327 | return x, hff, lff 328 | 329 | 330 | class Unimodal(nn.Module): 331 | def __init__(self, output_stride=8, num_classes=4): 332 | super(Unimodal, self).__init__() 333 | 334 | BatchNorm = nn.BatchNorm2d 335 | 336 | self.backbone1 = ResNet101(output_stride, BatchNorm) # flair 337 | self.aspp1 = build_aspp(output_stride, BatchNorm) 338 | self.decoder = build_decoder(num_classes, BatchNorm, input_heads=1) 339 | 340 | def forward(self, input1): 341 | x1, low_level_feat1 = self.backbone1(input1) 342 | x1 = self.aspp1(x1) 343 | 344 | x = self.decoder(x1, low_level_feat1) 345 | x = F.interpolate(x, size=input1.size()[2:], mode='bilinear', align_corners=True) 346 | return x 347 | 348 | 349 | class SegClassifier(nn.Module): 350 | def __init__(self, num_classes=4): 351 | super(SegClassifier, self).__init__() 352 | self.decoder1 = build_decoder(num_classes, nn.BatchNorm2d, input_heads=1) 353 | self.decoder2 = build_decoder(num_classes, nn.BatchNorm2d, input_heads=1) 354 | self.decoder3 = build_decoder(num_classes, nn.BatchNorm2d, input_heads=1) 355 | self.decoder4 = build_decoder(num_classes, nn.BatchNorm2d, input_heads=1) 356 | 357 | def cal_coeff(self, cls_res, label): 358 | acc_list = list() 359 | for r in cls_res: 360 | acc = train_eval_seg(r, label) 361 | acc_list.append(acc) 362 | return acc_list 363 | 364 | def forward(self, hf, lf): 365 | x1 = self.decoder1(hf[0], lf[0]) 366 | x2 = self.decoder2(hf[1], lf[1]) 367 | x3 = self.decoder3(hf[2], lf[2]) 368 | x4 = self.decoder4(hf[3], lf[3]) 369 | x1 = F.interpolate(x1, size=(160, 160), mode='bilinear', align_corners=True) 370 | x2 = F.interpolate(x2, size=(160, 160), mode='bilinear', align_corners=True) 371 | x3 = F.interpolate(x3, size=(160, 160), mode='bilinear', align_corners=True) 372 | x4 = F.interpolate(x4, size=(160, 160), mode='bilinear', align_corners=True) 373 | res = [x1, x2, x3, x4] 374 | return res --------------------------------------------------------------------------------