├── README.md ├── attention.py ├── contrast_loss.py ├── eval.py ├── figures └── framework.png ├── src ├── .DS_Store ├── __init__.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── helpers.py │ └── vocab.py ├── models │ ├── __init__.py │ ├── bert.py │ ├── image.py │ ├── late_fusion.py │ └── tool.py └── utils │ ├── __init__.py │ ├── logger.py │ └── utils.py ├── train.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This repository is the implementation code for our paper entitled _"Embracing Unimodal Aleatoric Uncertainty for Robust Multimodal Fusion"_ (CVPR 2024) and its extension version (under review). 4 | 5 | 6 | ![model framework](figures/framework.png) 7 | 8 | 9 | # Model Training 10 | Please download the BERT pre-trained weights and put corresponding files into the path _prebert/_. 11 | Moreover, the MVSA-Single dataset is organized as follows: 12 | 13 | 14 | ``` 15 | data 16 | |-- MVSA_Single 17 | | |-- train.jsonl 18 | | |-- dev.jsonl 19 | | |-- test.jsonl 20 | | |-- labelResultAll.txt 21 | ``` 22 | 23 | Run the following scripts to train `URMF` on the MVSA-Single dataset. 24 | 25 | ``` 26 | python train.py 27 | ``` 28 | 29 | The complete implementation on other datasets will be released after the peer review of our extended manuscript. 30 | 31 | # Acknowledgements 32 | The codes are modified from [QMF](https://github.com/QingyangZhang/QMF/tree/main) and [OGM-GE](https://github.com/GeWu-Lab/OGM-GE_CVPR2022/tree/main). 33 | If you find our code is helpful, please consider cite: 34 | ``` 35 | @inproceedings{gao2024embracing, 36 | title={Embracing Unimodal Aleatoric Uncertainty for Robust Multimodal Fusion}, 37 | author={Gao, Zixian and Jiang, Xun and Xu, Xing and Shen, Fumin and Li, Yujie and Shen, Heng Tao}, 38 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 39 | pages={26876--26885}, 40 | year={2024} 41 | } 42 | 43 | @inproceedings{zhang2023provable, 44 | title={Provable dynamic fusion for low-quality multimodal data}, 45 | author={Zhang, Qingyang and Wu, Haitao and Zhang, Changqing and Hu, Qinghua and Fu, Huazhu and Zhou, Joey Tianyi and Peng, Xi}, 46 | booktitle={International conference on machine learning}, 47 | pages={41753--41769}, 48 | year={2023}, 49 | } 50 | 51 | @inproceedings{peng2022balanced, 52 | title={Balanced multimodal learning via on-the-fly gradient modulation}, 53 | author={Peng, Xiaokang and Wei, Yake and Deng, Andong and Wang, Dong and Hu, Di}, 54 | booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, 55 | pages={8238--8247}, 56 | year={2022} 57 | } 58 | 59 | 60 | ``` -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | from torch.nn import init 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return x.view(x.size(0), -1) 9 | 10 | 11 | class ChannelGate(nn.Module): 12 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 13 | super(ChannelGate, self).__init__() 14 | self.gate_channels = gate_channels 15 | # self.mlp = nn.Sequential( 16 | # Flatten(), 17 | # nn.Linear(gate_channels, gate_channels // reduction_ratio), 18 | # nn.ReLU(), 19 | # nn.Linear(gate_channels // reduction_ratio, gate_channels) 20 | # ) 21 | self.pool_types = pool_types 22 | 23 | self.con=nn.Conv1d(2,2,kernel_size=3,stride=1,padding=1,bias=False) 24 | 25 | # init.kaiming_uniform_(self.mlp[1].weight, mode='fan_in', nonlinearity='relu') 26 | # init.kaiming_uniform_(self.mlp[3].weight, mode='fan_in', nonlinearity='relu') 27 | 28 | # init.xavier_uniform_(self.mlp[1].weight, gain=init.calculate_gain('relu')) 29 | # init.xavier_uniform_(self.mlp[3].weight, gain=init.calculate_gain('relu')) 30 | def forward(self, x): 31 | channel_att_sum = None 32 | # for pool_type in self.pool_types: 33 | if self.pool_types == 'avg': 34 | # Calculate average pool along the feature_dim dimension 35 | avg_pool = x.mean(dim=2) 36 | avg_pool =avg_pool.unsqueeze(2) 37 | channel_att_raw=self.con(avg_pool) 38 | # channel_att_raw = self.mlp(avg_pool) 39 | elif self.pool_types == 'max': 40 | # Calculate max pool along the feature_dim dimension 41 | max_pool = x.max(dim=2) 42 | channel_att_raw = self.mlp(max_pool) 43 | elif self.pool_types == 'lp': 44 | # Calculate Lp pool along the feature_dim dimension 45 | lp_pool = x.norm(2, dim=2) 46 | channel_att_raw = self.mlp(lp_pool) 47 | elif self.pool_types == 'lse': 48 | # LSE pool only 49 | lse_pool = F.logsumexp_1d(x, dim=2) 50 | channel_att_raw = self.mlp(lse_pool) 51 | 52 | if channel_att_sum is None: 53 | channel_att_sum = channel_att_raw 54 | else: 55 | channel_att_sum = channel_att_sum + channel_att_raw 56 | 57 | # scale = torch.sigmoid(channel_att_sum).unsqueeze(2).expand_as(x) 58 | score=avg_pool+channel_att_raw 59 | # x=x+x*scale 60 | x=x*score+x 61 | # fuison_output=torch.cat((x[:,0,:],x[:,1,:],x[:,2,:]),dim=1) 62 | fuison_output=(x[:,0,:]+x[:,1,:]) 63 | 64 | return fuison_output 65 | if __name__ =="__main__": 66 | a=torch.randn(50,3,72) 67 | fusion=ChannelGate(3,3,'avg') 68 | b=fusion(a) 69 | print(b) -------------------------------------------------------------------------------- /contrast_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Contrastive_loss(nn.Module): 6 | def __init__(self,tau): 7 | super(Contrastive_loss,self).__init__() 8 | self.tau=tau 9 | 10 | def sim(self,z1:torch.Tensor,z2:torch.Tensor): 11 | z1 = F.normalize(z1) 12 | z2 = F.normalize(z2) 13 | return torch.mm(z1,z2.t()) 14 | 15 | def semi_loss(self,z1:torch.Tensor,z2:torch.Tensor): 16 | f=lambda x: torch.exp(x/self.tau) 17 | refl_sim = f(self.sim(z1,z2)) 18 | between_sim=f(self.sim(z1,z2)) 19 | 20 | return -torch.log(between_sim.diag()/(refl_sim.sum(1)+between_sim.sum(1)-refl_sim.diag())) 21 | 22 | def forward(self,z1:torch.Tensor,z2:torch.Tensor,mean:bool=True): 23 | l1=self.semi_loss(z1,z2) 24 | l2=self.semi_loss(z2,z1) 25 | ret=(l1+l2)*0.5 26 | ret=ret.mean() if mean else ret.sum() 27 | return ret -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | 11 | import argparse 12 | import os.path 13 | 14 | from sklearn.metrics import f1_score, accuracy_score 15 | from tqdm import tqdm 16 | 17 | import json 18 | import torch 19 | import torch.nn as nn 20 | import torch.optim as optim 21 | from pytorch_pretrained_bert import BertAdam 22 | import torch.nn.functional as F 23 | from contrast_loss import Contrastive_loss 24 | from sklearn.metrics import f1_score, accuracy_score 25 | device = torch.device("cuda:0") 26 | torch.cuda.set_device(device) 27 | from src.data.helpers import get_data_loaders 28 | from src.models import get_model 29 | from src.utils.logger import create_logger 30 | from src.utils.utils import * 31 | 32 | recoreds=[] 33 | 34 | def get_args(parser): 35 | parser.add_argument("--batch_sz", type=int, default=16) 36 | parser.add_argument("--bert_model", type=str, default="./prebert") 37 | parser.add_argument("--data_path", type=str, default="./datasets/MVSA_Single/") 38 | parser.add_argument("--drop_img_percent", type=float, default=0.0) 39 | parser.add_argument("--dropout", type=float, default=0.1) 40 | parser.add_argument("--embed_sz", type=int, default=300) 41 | parser.add_argument("--freeze_img", type=int, default=3) 42 | parser.add_argument("--freeze_txt", type=int, default=5) 43 | parser.add_argument("--glove_path", type=str, default="datasets/glove_embeds/glove.840B.300d.txt") 44 | parser.add_argument("--gradient_accumulation_steps", type=int, default=40) 45 | parser.add_argument("--hidden", nargs="*", type=int, default=[]) 46 | parser.add_argument("--hidden_sz", type=int, default=768) 47 | parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"]) 48 | parser.add_argument("--img_hidden_sz", type=int, default=2048) 49 | parser.add_argument("--include_bn", type=int, default=True) 50 | parser.add_argument("--lr", type=float, default=5e-05) 51 | parser.add_argument("--lr_factor", type=float, default=0.5) 52 | parser.add_argument("--lr_patience", type=int, default=2) 53 | parser.add_argument("--max_epochs", type=int, default=100) 54 | parser.add_argument("--max_seq_len", type=int, default=512) 55 | parser.add_argument("--model", type=str, default="latefusion", choices=["bow", "img", "bert", "concatbow", "concatbert", "mmbt","latefusion"]) 56 | parser.add_argument("--n_workers", type=int, default=12) 57 | parser.add_argument("--name", type=str, default="MVSA_Single_latefusion_model_run_df_1") 58 | parser.add_argument("--num_image_embeds", type=int, default=3) 59 | parser.add_argument("--patience", type=int, default=5) 60 | parser.add_argument("--savedir", type=str, default="./saved/MVSA_Single") 61 | parser.add_argument("--seed", type=int, default=1701) 62 | parser.add_argument("--task", type=str, default="MVSA_Single", choices=["mmimdb", "vsnli", "food101","MVSA_Single"]) 63 | parser.add_argument("--task_type", type=str, default="classification", choices=["multilabel", "classification"]) 64 | parser.add_argument("--warmup", type=float, default=0.1) 65 | parser.add_argument("--weight_classes", type=int, default=1) 66 | parser.add_argument("--df", type=bool, default=True) 67 | parser.add_argument("--noise", type=float, default=5) 68 | parser.add_argument("--noise_type", type=str, default="Salt") 69 | 70 | def get_criterion(args): 71 | if args.task_type == "multilabel": 72 | if args.weight_classes: 73 | freqs = [args.label_freqs[l] for l in args.labels] 74 | label_weights = (torch.FloatTensor(freqs) / args.train_data_len) ** -1 75 | criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights.cuda()) 76 | else: 77 | criterion = nn.BCEWithLogitsLoss() 78 | else: 79 | criterion = nn.CrossEntropyLoss() 80 | 81 | return criterion 82 | 83 | 84 | def model_eval(i_epoch, data, model, args, criterion, store_preds=False): 85 | with torch.no_grad(): 86 | losses, preds, tgts = [], [], [] 87 | for batch in data: 88 | loss, out, tgt = model_forward(i_epoch, model, args, criterion, batch,mode='eval') 89 | losses.append(loss.item()) 90 | 91 | if args.task_type == "multilabel": 92 | pred = torch.sigmoid(out).cpu().detach().numpy() > 0.5 93 | else: 94 | pred = torch.nn.functional.softmax(out, dim=1).argmax(dim=1).cpu().detach().numpy() 95 | 96 | preds.append(pred) 97 | tgt = tgt.cpu().detach().numpy() 98 | tgts.append(tgt) 99 | 100 | metrics = {"loss": np.mean(losses)} 101 | if args.task_type == "multilabel": 102 | tgts = np.vstack(tgts) 103 | preds = np.vstack(preds) 104 | metrics["macro_f1"] = f1_score(tgts, preds, average="macro") 105 | metrics["micro_f1"] = f1_score(tgts, preds, average="micro") 106 | else: 107 | tgts = [l for sl in tgts for l in sl] 108 | preds = [l for sl in preds for l in sl] 109 | metrics["acc"] = accuracy_score(tgts, preds) 110 | metrics["micro_f1"] = f1_score(tgts, preds, average="micro") 111 | 112 | 113 | if store_preds: 114 | store_preds_to_disk(tgts, preds, args) 115 | 116 | return metrics 117 | 118 | def get_optimizer(model, args): 119 | if args.model in ["bert", "concatbert", "mmbt"]: 120 | total_steps = ( 121 | args.train_data_len 122 | / args.batch_sz 123 | / args.gradient_accumulation_steps 124 | * args.max_epochs 125 | ) 126 | param_optimizer = list(model.named_parameters()) 127 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 128 | optimizer_grouped_parameters = [ 129 | {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], "weight_decay": 0.01}, 130 | {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0,}, 131 | ] 132 | optimizer = BertAdam( 133 | optimizer_grouped_parameters, 134 | lr=args.lr, 135 | warmup=args.warmup, 136 | t_total=total_steps, 137 | ) 138 | else: 139 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 140 | 141 | return optimizer 142 | 143 | 144 | def get_scheduler(optimizer, args): 145 | return optim.lr_scheduler.ReduceLROnPlateau( 146 | optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor 147 | ) 148 | 149 | 150 | def totolloss(txt_img_logits, txt_logits,tgt,img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z): 151 | txt_kl_loss = -(1 + txt_logvar - txt_mu.pow(2) - txt_logvar.exp()) / 2 152 | txt_kl_loss = txt_kl_loss.sum(dim=1).mean() 153 | 154 | img_kl_loss = -(1 + img_logvar - img_mu.pow(2) - img_logvar.exp()) / 2 155 | img_kl_loss = img_kl_loss.sum(dim=1).mean() 156 | 157 | kl_loss = -(1 + logvar - mu.pow(2) - logvar.exp()) / 2 158 | kl_loss = kl_loss.sum(dim=1).mean() 159 | IB_loss=F.cross_entropy(z,tgt) 160 | 161 | fusion_cls_loss=F.cross_entropy(txt_img_logits,tgt) 162 | 163 | 164 | totol_loss=fusion_cls_loss+1e-3*kl_loss+1e-3*txt_kl_loss+1e-3*img_kl_loss+1e-3*IB_loss 165 | return totol_loss 166 | 167 | def KL_regular(mu_1,logvar_1,mu_2,logvar_2): 168 | var_1=torch.exp(logvar_1) 169 | var_2=torch.exp(logvar_2) 170 | KL_loss=logvar_2-logvar_1+((var_1.pow(2)+(mu_1-mu_2).pow(2))/(2*var_2.pow(2)))-0.5 171 | KL_loss=KL_loss.sum(dim=1).mean() 172 | return KL_loss 173 | 174 | def reparameterise(mu, std): 175 | """ 176 | mu : [batch_size,z_dim] 177 | std : [batch_size,z_dim] 178 | """ 179 | # get epsilon from standard normal 180 | eps = torch.randn_like(std) 181 | return mu + std*eps 182 | 183 | def con_loss(txt_mu,txt_logvar,img_mu,img_logvar): 184 | Conloss=Contrastive_loss(0.5) 185 | while True: 186 | t_z1 = reparameterise(txt_mu, txt_logvar) 187 | t_z2 = reparameterise(txt_mu, txt_logvar) 188 | 189 | if not np.array_equal(t_z1, t_z2): 190 | break 191 | while True: 192 | i_z1=reparameterise(img_mu,img_logvar) 193 | i_z2=reparameterise(img_mu,img_logvar) 194 | 195 | if not np.array_equal(t_z1, t_z2): 196 | break 197 | 198 | 199 | loss_t=Conloss(t_z1,t_z2) 200 | loss_i=Conloss(i_z1,i_z2) 201 | 202 | return loss_t+loss_i 203 | 204 | def model_forward(i_epoch, model, args, criterion, batch,txt_history=None,img_history=None,mode='eval'): 205 | txt, segment, mask, img, tgt,_ = batch 206 | txt, img = txt.to(device), img.to(device) 207 | mask, segment = mask.to(device), segment.to(device) 208 | txt_img_logits, txt_logits, img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z=model(txt, mask, segment, img) 209 | 210 | 211 | tgt = tgt.to(device) 212 | 213 | conloss=con_loss(txt_mu,torch.exp(txt_logvar),img_mu,torch.exp(img_logvar)) 214 | loss=totolloss(txt_img_logits, txt_logits,tgt,img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z) 215 | loss=loss+1e-5*KL_regular(txt_mu,txt_logvar,img_mu,img_logvar)+conloss*1e-3 216 | 217 | return loss,txt_img_logits,tgt 218 | 219 | 220 | def train(args): 221 | 222 | set_seed(args.seed) 223 | args.savedir = os.path.join(args.savedir, args.name) 224 | os.makedirs(args.savedir, exist_ok=True) 225 | 226 | train_loader, val_loader, test_loaders = get_data_loaders(args) 227 | 228 | logger = create_logger("%s/eval_logfile.log" % args.savedir, args) 229 | 230 | model=torch.load(os.path.join(args.savedir, "model_best.pth")) 231 | model=model['state_dict'] 232 | model.cuda() 233 | 234 | model.eval() 235 | 236 | 237 | accList=[] 238 | for test_name, test_loader in test_loaders.items(): 239 | test_metrics = model_eval( 240 | np.inf, test_loader, model, args, None, store_preds=True 241 | ) 242 | 243 | log_metrics(f"Test - {test_name}", test_metrics, args, logger) 244 | accList.append(test_metrics['acc']) 245 | 246 | info = f"name:{args.name} seed:{args.seed} noise:{args.noise} test_acc: {accList[0]:0.5f}\n" 247 | 248 | 249 | result_json={ 250 | 'name':args.name, 251 | 'method': args.model+'_df', 252 | 'seed':args.seed, 253 | 'noise':args.noise, 254 | 'test_acc':accList[0], 255 | } 256 | 257 | path = f"eval_data/{args.task}_result_{args.noise_type}.json" 258 | if os.path.exists(path): 259 | with open(path, encoding='utf-8') as f: 260 | exist_json = json.load(f) 261 | else: 262 | exist_json = [] 263 | 264 | 265 | def cli_main(): 266 | parser = argparse.ArgumentParser(description="Train Models") 267 | get_args(parser) 268 | args, remaining_args = parser.parse_known_args() 269 | assert remaining_args == [], remaining_args 270 | train(args) 271 | 272 | 273 | if __name__ == "__main__": 274 | import warnings 275 | 276 | warnings.filterwarnings("ignore") 277 | 278 | cli_main() 279 | -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CFM-MSG/Code_URMF/a9d1dcaffc21382682809a7953f674914c07cf00/figures/framework.png -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CFM-MSG/Code_URMF/a9d1dcaffc21382682809a7953f674914c07cf00/src/.DS_Store -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CFM-MSG/Code_URMF/a9d1dcaffc21382682809a7953f674914c07cf00/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CFM-MSG/Code_URMF/a9d1dcaffc21382682809a7953f674914c07cf00/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import json 11 | import numpy as np 12 | import os 13 | from PIL import Image 14 | import sys 15 | sys.path.append('src') 16 | import torch 17 | from torch.utils.data import Dataset 18 | from utils.utils import truncate_seq_pair, numpy_seed 19 | 20 | import random 21 | class JsonlDataset(Dataset): 22 | def __init__(self, data_path, tokenizer, transforms, vocab, args): 23 | self.data = [json.loads(l) for l in open(data_path)] 24 | self.data_dir = os.path.dirname(data_path) 25 | self.tokenizer = tokenizer 26 | self.args = args 27 | self.vocab = vocab 28 | self.n_classes = len(args.labels) 29 | self.text_start_token = ["[CLS]"] if args.model != "mmbt" else ["[SEP]"] 30 | 31 | with numpy_seed(0): 32 | for row in self.data: 33 | if np.random.random() < args.drop_img_percent: 34 | row["img"] = None 35 | 36 | self.max_seq_len = args.max_seq_len 37 | if args.model == "mmbt": 38 | self.max_seq_len -= args.num_image_embeds 39 | 40 | self.transforms = transforms 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def __getitem__(self, index): 46 | 47 | 48 | 49 | 50 | if self.args.task == "vsnli": 51 | sent1 = self.tokenizer(self.data[index]["sentence1"]) 52 | sent2 = self.tokenizer(self.data[index]["sentence2"]) 53 | truncate_seq_pair(sent1, sent2, self.args.max_seq_len - 3) 54 | sentence = self.text_start_token + sent1 + ["[SEP]"] + sent2 + ["[SEP]"] 55 | segment = torch.cat( 56 | [torch.zeros(2 + len(sent1)), torch.ones(len(sent2) + 1)] 57 | ) 58 | else: 59 | 60 | _ = self.tokenizer(self.data[index]["text"]) 61 | if self.args.noise > 0.0: 62 | p = [0.5, 0.5] 63 | flag = np.random.choice([0, 1], p=p) 64 | if flag: 65 | wordlist=self.data[index]["text"].split(' ') 66 | for i in range(len(wordlist)): 67 | replace_p=1/10*self.args.noise 68 | replace_flag = np.random.choice([0, 1], p=[1-replace_p, replace_p]) 69 | if replace_flag: 70 | # pass 71 | wordlist[i]='_' 72 | _=' '.join(wordlist) 73 | _=self.tokenizer(_) 74 | 75 | sentence = ( 76 | self.text_start_token 77 | + _[:(self.args.max_seq_len - 1)] 78 | ) 79 | segment = torch.zeros(len(sentence)) 80 | 81 | sentence = torch.LongTensor( #ids 82 | [ 83 | self.vocab.stoi[w] if w in self.vocab.stoi else self.vocab.stoi["[UNK]"] 84 | for w in sentence 85 | ] 86 | ) 87 | 88 | 89 | if self.args.task_type == "multilabel": 90 | label = torch.zeros(self.n_classes) 91 | label[ 92 | [self.args.labels.index(tgt) for tgt in self.data[index]["label"]] 93 | ] = 1 94 | else: 95 | label = torch.LongTensor( 96 | [self.args.labels.index(self.data[index]["label"])] 97 | ) 98 | 99 | image = None 100 | if self.args.model in ["img", "concatbow", "concatbert", "mmbt","latefusion","tmc"]: 101 | if self.data[index]["img"]: 102 | image = Image.open( 103 | os.path.join(self.data_dir, self.data[index]["img"]) 104 | ).convert("RGB") 105 | else: 106 | image = Image.fromarray(128 * np.ones((256, 256, 3), dtype=np.uint8)) 107 | image = self.transforms(image) 108 | if self.args.model == "mmbt": 109 | # The first SEP is part of Image Token. 110 | segment = segment[1:] 111 | sentence = sentence[1:] 112 | # The first segment (0) is of images. 113 | segment += 1 114 | 115 | # print(image) 116 | 117 | return sentence, segment, image, label,torch.LongTensor([index]) 118 | 119 | class AddGaussianNoise(object): 120 | 121 | ''' 122 | mean:均值 123 | variance:方差 124 | amplitude:幅值 125 | ''' 126 | def __init__(self, mean=0.0, variance=1.0, amplitude=1.0): 127 | 128 | self.mean = mean 129 | self.variance = variance 130 | self.amplitude = amplitude 131 | 132 | def __call__(self, img): 133 | 134 | img = np.array(img) 135 | h, w, c = img.shape 136 | np.random.seed(0) 137 | N = self.amplitude * np.random.normal(loc=self.mean, scale=self.variance, size=(h, w, 1)) 138 | N = np.repeat(N, c, axis=2) 139 | img = N + img 140 | img[img > 255] = 255 141 | img = Image.fromarray(img.astype('uint8')).convert('RGB') 142 | return img 143 | 144 | class AddSaltPepperNoise(object): 145 | 146 | def __init__(self, density=0,p=0.5): 147 | self.density = density 148 | self.p = p 149 | 150 | def __call__(self, img): 151 | if random.uniform(0, 1) < self.p: 152 | img = np.array(img) 153 | h, w, c = img.shape 154 | Nd = self.density 155 | Sd = 1 - Nd 156 | mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[Nd / 2.0, Nd / 2.0, Sd]) 157 | mask = np.repeat(mask, c, axis=2) 158 | img[mask == 0] = 0 159 | img[mask == 1] = 255 160 | img = Image.fromarray(img.astype('uint8')).convert('RGB') 161 | return img 162 | else: 163 | return img -------------------------------------------------------------------------------- /src/data/helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import functools 11 | import json 12 | import os 13 | from collections import Counter 14 | 15 | import torch 16 | import torchvision.transforms as transforms 17 | from pytorch_pretrained_bert import BertTokenizer 18 | from torch.utils.data import DataLoader 19 | import sys 20 | sys.path.append('src/data') 21 | from dataset import JsonlDataset,AddGaussianNoise,AddSaltPepperNoise 22 | from vocab import Vocab 23 | 24 | 25 | def get_transforms(): 26 | return transforms.Compose( 27 | [ 28 | transforms.Resize(256), 29 | transforms.CenterCrop(224), 30 | transforms.ToTensor(), 31 | transforms.Normalize( 32 | mean=[0.46777044, 0.44531429, 0.40661017], 33 | std=[0.12221994, 0.12145835, 0.14380469], 34 | ), 35 | ] 36 | ) 37 | 38 | def get_GaussianNoisetransforms(rgb_severity): 39 | return transforms.Compose( 40 | [ 41 | transforms.Resize(256), 42 | transforms.RandomApply([AddGaussianNoise(amplitude=rgb_severity * 10)], p=0.5), 43 | transforms.CenterCrop(224), 44 | transforms.ToTensor(), 45 | transforms.Normalize( 46 | mean=[0.46777044, 0.44531429, 0.40661017], 47 | std=[0.12221994, 0.12145835, 0.14380469], 48 | ), 49 | ] 50 | ) 51 | 52 | def get_SaltNoisetransforms(rgb_severity): 53 | return transforms.Compose( 54 | [ 55 | transforms.Resize(256), 56 | transforms.RandomApply([AddSaltPepperNoise(density=0.1, p=rgb_severity/10)], p=0.5), 57 | transforms.CenterCrop(224), 58 | transforms.ToTensor(), 59 | transforms.Normalize( 60 | mean=[0.46777044, 0.44531429, 0.40661017], 61 | std=[0.12221994, 0.12145835, 0.14380469], 62 | ), 63 | ] 64 | ) 65 | def get_labels_and_frequencies(path): 66 | label_freqs = Counter() 67 | data_labels = [json.loads(line)["label"] for line in open(path)] 68 | if type(data_labels[0]) == list: 69 | for label_row in data_labels: 70 | label_freqs.update(label_row) 71 | else: 72 | label_freqs.update(data_labels) 73 | 74 | return list(label_freqs.keys()), label_freqs 75 | 76 | 77 | def get_glove_words(path): 78 | word_list = [] 79 | for line in open(path): 80 | w, _ = line.split(" ", 1) 81 | word_list.append(w) 82 | return word_list 83 | 84 | 85 | def get_vocab(args): 86 | vocab = Vocab() 87 | if args.model in ["bert", "mmbt", "concatbert","latefusion",'tmc']: 88 | bert_tokenizer = BertTokenizer.from_pretrained( 89 | './prebert/', do_lower_case=True 90 | ) 91 | vocab.stoi = bert_tokenizer.vocab 92 | vocab.itos = bert_tokenizer.ids_to_tokens 93 | vocab.vocab_sz = len(vocab.itos) 94 | 95 | else: 96 | word_list = get_glove_words(args.glove_path) 97 | vocab.add(word_list) 98 | 99 | return vocab 100 | 101 | 102 | def collate_fn(batch, args): 103 | lens = [len(row[0]) for row in batch] 104 | bsz, max_seq_len = len(batch), max(lens) 105 | 106 | mask_tensor = torch.zeros(bsz, max_seq_len).long() 107 | text_tensor = torch.zeros(bsz, max_seq_len).long() 108 | segment_tensor = torch.zeros(bsz, max_seq_len).long() 109 | 110 | img_tensor = None 111 | if args.model in ["img", "concatbow", "concatbert", "mmbt","latefusion",'tmc']: 112 | img_tensor = torch.stack([row[2] for row in batch]) 113 | 114 | if args.task_type == "multilabel": 115 | # Multilabel case 116 | tgt_tensor = torch.stack([row[3] for row in batch]) 117 | else: 118 | # Single Label case 119 | tgt_tensor = torch.cat([row[3] for row in batch]).long() 120 | 121 | for i_batch, (input_row, length) in enumerate(zip(batch, lens)): 122 | tokens, segment = input_row[:2] 123 | text_tensor[i_batch, :length] = tokens 124 | segment_tensor[i_batch, :length] = segment 125 | mask_tensor[i_batch, :length] = 1 126 | 127 | idx=torch.cat([row[4] for row in batch]).long() 128 | return text_tensor, segment_tensor, mask_tensor, img_tensor, tgt_tensor,idx 129 | 130 | 131 | def get_data_loaders(args): 132 | tokenizer = ( 133 | # BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True).tokenize 134 | BertTokenizer.from_pretrained('./prebert/', do_lower_case=True).tokenize 135 | if args.model in ["bert", "mmbt", "concatbert","latefusion","tmc"] #tmc 136 | else str.split 137 | ) 138 | 139 | transforms = get_transforms() 140 | 141 | args.labels, args.label_freqs = get_labels_and_frequencies( 142 | os.path.join(args.data_path, "train.jsonl") 143 | ) 144 | vocab = get_vocab(args) 145 | args.vocab = vocab 146 | args.vocab_sz = vocab.vocab_sz 147 | args.n_classes = len(args.labels) 148 | 149 | train = JsonlDataset( 150 | os.path.join(args.data_path, "train.jsonl"), 151 | tokenizer, 152 | transforms, 153 | vocab, 154 | args, 155 | ) 156 | args.train_data_len = len(train) 157 | 158 | dev = JsonlDataset( 159 | os.path.join(args.data_path, "dev.jsonl"), 160 | tokenizer, 161 | transforms, 162 | vocab, 163 | args, 164 | ) 165 | 166 | collate = functools.partial(collate_fn, args=args) 167 | train_loader = DataLoader( 168 | train, 169 | batch_size=args.batch_sz, 170 | shuffle=True, 171 | num_workers=args.n_workers, 172 | collate_fn=collate, 173 | ) 174 | 175 | val_loader = DataLoader( #batchsize=128 176 | dev, 177 | batch_size=args.batch_sz, 178 | shuffle=False, 179 | num_workers=args.n_workers, 180 | collate_fn=collate, 181 | ) 182 | 183 | if args.noise>0.0: 184 | if args.noise_type=='Gaussian': 185 | print('Gaussian') 186 | test_transforms=get_GaussianNoisetransforms(args.noise) 187 | elif args.noise_type=='Salt': 188 | print("Salt") 189 | test_transforms = get_SaltNoisetransforms(args.noise) 190 | else: 191 | test_transforms=transforms 192 | 193 | 194 | test_set = JsonlDataset( 195 | os.path.join(args.data_path, "test.jsonl"), 196 | tokenizer, 197 | test_transforms, 198 | vocab, 199 | args, 200 | ) 201 | 202 | test_loader = DataLoader( 203 | test_set, 204 | batch_size=args.batch_sz, 205 | shuffle=False, 206 | num_workers=args.n_workers, 207 | collate_fn=collate, 208 | ) 209 | 210 | if args.task == "vsnli": 211 | test_hard = JsonlDataset( 212 | os.path.join(args.data_path, args.task, "test_hard.jsonl"), 213 | tokenizer, 214 | transforms, 215 | vocab, 216 | args, 217 | ) 218 | 219 | test_hard_loader = DataLoader( 220 | test_hard, 221 | batch_size=args.batch_sz, 222 | shuffle=False, 223 | num_workers=args.n_workers, 224 | collate_fn=collate, 225 | ) 226 | 227 | test = {"test": test_loader, "test_hard": test_hard_loader} 228 | elif args.task == "MVSA_Single": 229 | test = {"test": test_loader} 230 | 231 | elif args.task == "food101": 232 | test = {"test": test_loader} 233 | else: 234 | test_gt = JsonlDataset( 235 | os.path.join(args.data_path, args.task, "test_hard_gt.jsonl"), 236 | tokenizer, 237 | test_transforms, 238 | vocab, 239 | args, 240 | ) 241 | 242 | test_gt_loader = DataLoader( 243 | test_gt, 244 | batch_size=args.batch_sz, 245 | shuffle=False, 246 | num_workers=args.n_workers, 247 | collate_fn=collate, 248 | ) 249 | 250 | 251 | test = { 252 | "test": test_loader, 253 | "test_gt": test_gt_loader, 254 | } 255 | 256 | return train_loader, val_loader, test 257 | -------------------------------------------------------------------------------- /src/data/vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | 11 | class Vocab(object): 12 | def __init__(self, emptyInit=False): 13 | if emptyInit: 14 | self.stoi, self.itos, self.vocab_sz = {}, [], 0 15 | else: 16 | self.stoi = { #pad:0 unk:1 cls:2 sep:3 mask:4 17 | w: i 18 | for i, w in enumerate(["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]) 19 | } 20 | self.itos = [w for w in self.stoi] 21 | self.vocab_sz = len(self.itos) 22 | 23 | def add(self, words): 24 | cnt = len(self.itos) 25 | for w in words: 26 | if w in self.stoi: 27 | continue 28 | self.stoi[w] = cnt 29 | self.itos.append(w) 30 | cnt += 1 31 | self.vocab_sz = len(self.itos) 32 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | from models.bert import BertClf 11 | 12 | from models.image import ImageClf 13 | 14 | from models.late_fusion import MultimodalLateFusionClf 15 | 16 | MODELS = { 17 | "bert": BertClf, 18 | "img": ImageClf, 19 | 'latefusion':MultimodalLateFusionClf, 20 | } 21 | 22 | 23 | def get_model(args): 24 | # print(args.model) 25 | return MODELS[args.model](args) 26 | -------------------------------------------------------------------------------- /src/models/bert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | import torch 10 | import torch.nn as nn 11 | from pytorch_pretrained_bert.modeling import BertModel 12 | 13 | 14 | class BertEncoder(nn.Module): 15 | def __init__(self, args): 16 | super(BertEncoder, self).__init__() 17 | self.args = args 18 | self.bert = BertModel.from_pretrained(args.bert_model) 19 | 20 | def forward(self, txt, mask, segment): 21 | _, out = self.bert( 22 | txt, 23 | token_type_ids=segment, 24 | attention_mask=mask, 25 | output_all_encoded_layers=False, 26 | ) 27 | return out #16*768 28 | 29 | class Flatten(nn.Module): 30 | def forward(self, input): 31 | return input.view(input.size(0), -1) 32 | 33 | class BertClf(nn.Module): 34 | def __init__(self, args): 35 | super(BertClf, self).__init__() 36 | self.args = args 37 | self.enc = BertEncoder(args) 38 | self.clf = nn.Linear(128, args.n_classes) 39 | self.mu=nn.Sequential( 40 | # nn.BatchNorm1d(args.hidden_sz, eps=2e-5,affine=False), 41 | # nn.Dropout(p=0.4), 42 | # Flatten(), 43 | nn.Linear(args.hidden_sz, 128)) 44 | # nn.BatchNorm1d(128,eps=2e-5)) 45 | self.logvar=nn.Sequential( 46 | # nn.BatchNorm1d(args.hidden_sz, eps=2e-5,affine=False), 47 | # nn.Dropout(p=0.4), 48 | # Flatten(), 49 | nn.Linear(args.hidden_sz, 128)) 50 | # nn.BatchNorm1d(128,eps=2e-5)) 51 | self.clf.apply(self.enc.bert.init_bert_weights) 52 | 53 | def forward(self, txt, mask, segment): 54 | x = self.enc(txt, mask, segment) #x.shape=batch_size*768 55 | mu=self.mu(x) #batch_size*200 56 | logvar=self.logvar(x) #batch_size*200 57 | x=self._reparameterize(mu,logvar) 58 | out=self.clf(x) 59 | return mu,logvar,out 60 | 61 | def _reparameterize(self, mu, logvar): 62 | std = torch.exp(logvar).sqrt() 63 | epsilon = torch.randn_like(std) 64 | sampler = epsilon * std 65 | return mu + sampler 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/models/image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torchvision 13 | 14 | 15 | class ImageEncoder(nn.Module): 16 | def __init__(self, args): 17 | super(ImageEncoder, self).__init__() 18 | self.args = args 19 | model = torchvision.models.resnet152(pretrained=True) 20 | modules = list(model.children())[:-2] 21 | self.model = nn.Sequential(*modules) 22 | 23 | pool_func = ( 24 | nn.AdaptiveAvgPool2d 25 | if args.img_embed_pool_type == "avg" 26 | else nn.AdaptiveMaxPool2d 27 | ) 28 | 29 | if args.num_image_embeds in [1, 2, 3, 5, 7]: 30 | self.pool = pool_func((args.num_image_embeds, 1)) 31 | elif args.num_image_embeds == 4: 32 | self.pool = pool_func((2, 2)) 33 | elif args.num_image_embeds == 6: 34 | self.pool = pool_func((3, 2)) 35 | elif args.num_image_embeds == 8: 36 | self.pool = pool_func((4, 2)) 37 | elif args.num_image_embeds == 9: 38 | self.pool = pool_func((3, 3)) 39 | 40 | def forward(self, x): 41 | # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048 42 | out = self.pool(self.model(x)) 43 | out = torch.flatten(out, start_dim=2) #batchsize*2048*3 44 | out = out.transpose(1, 2).contiguous() #batchsize*3*2048 45 | return out # BxNx2048 46 | 47 | class Flatten(nn.Module): 48 | def forward(self, input): 49 | return input.view(input.size(0), -1) 50 | 51 | class ImageClf(nn.Module): 52 | def __init__(self, args): 53 | super(ImageClf, self).__init__() 54 | self.args = args 55 | self.img_encoder = ImageEncoder(args) 56 | self.clf=nn.Linear(128,args.n_classes) 57 | 58 | self.mu=nn.Sequential( 59 | # nn.BatchNorm1d(args.img_hidden_sz* args.num_image_embeds, eps=2e-5, affine=False), 60 | # nn.Dropout(p=0.4), 61 | # Flatten(), 62 | nn.Linear(args.img_hidden_sz* args.num_image_embeds,128)) 63 | # nn.BatchNorm1d(128,eps=2e-5)) 64 | self.logvar=nn.Sequential( 65 | # nn.BatchNorm1d(args.img_hidden_sz* args.num_image_embeds, eps=2e-5,affine=False), 66 | # nn.Dropout(p=0.4), 67 | # Flatten(), 68 | nn.Linear(args.img_hidden_sz* args.num_image_embeds, 128)) 69 | # nn.BatchNorm1d(128,eps=2e-5)) 70 | 71 | def forward(self, x): 72 | x = self.img_encoder(x) 73 | x = torch.flatten(x, start_dim=1) 74 | mu=self.mu(x) #batch_size 75 | logvar=self.logvar(x) #batch_size 76 | x=self._reparameterize(mu,logvar) 77 | out=self.clf(x) 78 | return mu,logvar,out 79 | 80 | def _reparameterize(self, mu, logvar): 81 | std = torch.exp(logvar).sqrt() 82 | epsilon = torch.randn_like(std) 83 | sampler = epsilon * std 84 | return mu + sampler -------------------------------------------------------------------------------- /src/models/late_fusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from attention import ChannelGate 14 | from models.bert import BertEncoder,BertClf 15 | from models.image import ImageEncoder,ImageClf 16 | from src.models.tool import cog_uncertainty_normal, cog_uncertainty_sample 17 | 18 | 19 | 20 | def reparameterise(mu, std): 21 | """ 22 | mu : [batch_size,z_dim] 23 | std : [batch_size,z_dim] 24 | """ 25 | # get epsilon from standard normal 26 | eps = torch.randn_like(std) 27 | return mu + std*eps 28 | 29 | 30 | class MultimodalLateFusionClf(nn.Module): 31 | def __init__(self, args): 32 | super(MultimodalLateFusionClf, self).__init__() 33 | self.args = args 34 | 35 | self.fusion=ChannelGate(3,3,'avg') 36 | self.txtclf = BertClf(args) 37 | self.imgclf= ImageClf(args) 38 | self.mu=nn.Linear(128,128) 39 | self.logvar=nn.Linear(128,128) 40 | self.IB_classfier=nn.Linear(128,3) 41 | self.fc_fusion1=nn.Sequential(nn.Linear(128,3)) 42 | 43 | 44 | def forward(self, txt, mask, segment, img): 45 | txt_mu,txt_logvar,txt_out = self.txtclf(txt, mask, segment) 46 | img_mu,img_logvar,img_out = self.imgclf(img) 47 | 48 | txt_var=torch.exp(txt_logvar) 49 | img_var=torch.exp(img_logvar) 50 | 51 | def get_supp_mod(key): 52 | if key == "l": 53 | return img_mu 54 | elif key == "v": 55 | return txt_mu 56 | else: 57 | raise KeyError 58 | 59 | 60 | l_sample, v_sample = cog_uncertainty_sample(txt_mu, txt_var, img_mu, img_logvar, sample_times=10) 61 | sample_dict = { 62 | "l": l_sample, 63 | "v": v_sample 64 | } 65 | cog_uncertainty_dict = {} 66 | with torch.no_grad(): 67 | for key, sample_tensor in sample_dict.items(): 68 | bsz, sample_times, dim = sample_tensor.shape 69 | sample_tensor = sample_tensor.reshape(bsz * sample_times, dim) 70 | sample_tensor = sample_tensor.unsqueeze(1) 71 | supp_mod = get_supp_mod(key) 72 | supp_mod = supp_mod.unsqueeze(1) 73 | supp_mod = supp_mod.unsqueeze(1).repeat(1, sample_times, 1, 1) 74 | supp_mod = supp_mod.reshape(bsz * sample_times, 1, dim) 75 | feature = torch.cat([supp_mod, sample_tensor], dim=1) 76 | 77 | feature_fusion=self.fusion(feature) 78 | mu=self.mu(feature_fusion) 79 | logvar=self.logvar(feature_fusion) 80 | z=reparameterise(mu,torch.exp(logvar)) 81 | z=self.IB_classfier(z) 82 | txt_img_out=self.fc_fusion1(mu) 83 | 84 | cog_un = torch.var(txt_img_out, dim=-1) 85 | cog_uncertainty_dict[key] = cog_un 86 | 87 | cog_uncertainty_dict = cog_uncertainty_normal(cog_uncertainty_dict) 88 | 89 | 90 | weight=torch.softmax(torch.stack([img_var, txt_var]), dim=0) 91 | img_w=weight[1] 92 | txt_w=weight[0] 93 | 94 | feature_txt=txt_mu*txt_w 95 | feature_img=img_mu*img_w 96 | 97 | 98 | feature=torch.stack((feature_txt,feature_img),dim=1) 99 | feature_fusion=self.fusion(feature) 100 | mu=self.mu(feature_fusion) 101 | logvar=self.logvar(feature_fusion) 102 | z=reparameterise(mu,torch.exp(logvar)) 103 | z=self.IB_classfier(z) 104 | txt_img_out=self.fc_fusion1(mu) 105 | 106 | 107 | 108 | return [txt_img_out,txt_out,img_out,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z,cog_uncertainty_dict] -------------------------------------------------------------------------------- /src/models/tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from contrast_loss import Contrastive_loss 5 | import numpy as np 6 | import pdb 7 | 8 | def KL_regular(mu_1, var_1, mu_2, var_2, mu_3, var_3): 9 | 10 | KL_loss_1=var_2.log()-var_1.log()+((var_1.pow(2)+(mu_1-mu_2).pow(2))/(2*var_2.pow(2)))-0.5 11 | KL_loss_1=KL_loss_1.sum(dim=1).mean() 12 | 13 | KL_loss_2=var_3.log()-var_1.log()+((var_1.pow(2)+(mu_1-mu_3).pow(2))/(2*var_3.pow(2)))-0.5 14 | KL_loss_2=KL_loss_2.sum(dim=1).mean() 15 | 16 | sub_kl_loss_1 = -(1 + var_1.log() - mu_1.pow(2) - var_1) / 2 17 | sub_kl_loss_1 = sub_kl_loss_1.sum(dim=1).mean() 18 | 19 | sub_kl_loss_2 = -(1 + var_2.log() - mu_2.pow(2) - var_2) / 2 20 | sub_kl_loss_2 = sub_kl_loss_2.sum(dim=1).mean() 21 | 22 | sub_kl_loss_3 = -(1 + var_3.log() - mu_3.pow(2) - var_3) / 2 23 | sub_kl_loss_3 = sub_kl_loss_3.sum(dim=1).mean() 24 | 25 | sub_kl_loss=sub_kl_loss_1 + sub_kl_loss_2 + sub_kl_loss_3 26 | 27 | return KL_loss_1 + KL_loss_2 + sub_kl_loss*1e3 28 | 29 | def reparameterise(mu, std): 30 | """ 31 | mu : [batch_size,z_dim] 32 | std : [batch_size,z_dim] 33 | """ 34 | # get epsilon from standard normal 35 | eps = torch.randn_like(std) 36 | return mu + std*eps 37 | 38 | def con_loss(txt_mu, txt_logvar, img_mu, img_logvar, aou_mu, aou_logvar): 39 | Conloss=Contrastive_loss(0.5) 40 | 41 | 42 | while True: 43 | t_z1 = reparameterise(txt_mu, txt_logvar) 44 | t_z2 = reparameterise(txt_mu, txt_logvar) 45 | 46 | if not np.array_equal(t_z1, t_z2): 47 | break 48 | 49 | 50 | while True: 51 | i_z1=reparameterise(img_mu,img_logvar) 52 | i_z2=reparameterise(img_mu,img_logvar) 53 | 54 | if not np.array_equal(i_z1, i_z2): 55 | break 56 | 57 | 58 | 59 | while True: 60 | a_z1=reparameterise(aou_mu,aou_logvar) 61 | a_z2=reparameterise(aou_mu,aou_logvar) 62 | 63 | if not np.array_equal(a_z1, a_z2): 64 | break 65 | 66 | loss_t=Conloss(t_z1,t_z2) 67 | loss_i=Conloss(i_z1,i_z2) 68 | loss_a=Conloss(a_z1,a_z2) 69 | 70 | return loss_t + loss_i + loss_a 71 | 72 | 73 | def cog_uncertainty_sample(mu_l, var_l, mu_v, var_v, sample_times=10): 74 | 75 | l_list = [] 76 | for _ in range(sample_times): 77 | l_list.append(reparameterise(mu_l, var_l)) 78 | l_sample = torch.stack(l_list, dim=1) 79 | 80 | v_list = [] 81 | for _ in range(sample_times): 82 | v_list.append(reparameterise(mu_v, var_v)) 83 | v_sample = torch.stack(v_list, dim=1) 84 | 85 | return l_sample, v_sample 86 | 87 | 88 | def cog_uncertainty_normal(unc_dict, normal_type="None"): 89 | 90 | key_list = [k for k, _ in unc_dict.items()] 91 | comb_list = [t for _, t in unc_dict.items()] 92 | comb_t = torch.stack(comb_list, dim=1) 93 | mat = torch.exp(torch.reciprocal(comb_t)) 94 | mat_sum = mat.sum(dim=-1, keepdim=True) 95 | weight = mat / mat_sum 96 | 97 | if normal_type == "minmax": 98 | weight = weight / torch.max(weight, dim=1)[0].unsqueeze(-1) # [bsz, mod_num] 99 | for i, key in enumerate(key_list): 100 | unc_dict[key] = weight[:, i] 101 | else: 102 | pass 103 | # raise TypeError("Unsupported Operations at cog_uncertainty_normal!") 104 | 105 | return unc_dict 106 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CFM-MSG/Code_URMF/a9d1dcaffc21382682809a7953f674914c07cf00/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import logging 11 | import time 12 | from datetime import timedelta 13 | 14 | 15 | class LogFormatter: 16 | def __init__(self): 17 | self.start_time = time.time() 18 | 19 | def format(self, record): 20 | elapsed_seconds = round(record.created - self.start_time) 21 | 22 | prefix = "%s - %s - %s" % ( 23 | record.levelname, 24 | time.strftime("%x %X"), 25 | timedelta(seconds=elapsed_seconds), 26 | ) 27 | message = record.getMessage() 28 | message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) 29 | return "%s - %s" % (prefix, message) 30 | 31 | 32 | def create_logger(filepath, args): 33 | # create log formatter 34 | log_formatter = LogFormatter() 35 | 36 | # create file handler and set level to debug 37 | file_handler = logging.FileHandler(filepath, "a") 38 | file_handler.setLevel(logging.DEBUG) 39 | file_handler.setFormatter(log_formatter) 40 | 41 | # create console handler and set level to info 42 | console_handler = logging.StreamHandler() 43 | console_handler.setLevel(logging.INFO) 44 | console_handler.setFormatter(log_formatter) 45 | 46 | # create logger and set level to debug 47 | logger = logging.getLogger() 48 | logger.handlers = [] 49 | logger.setLevel(logging.INFO) 50 | logger.propagate = False 51 | logger.addHandler(file_handler) 52 | logger.addHandler(console_handler) 53 | 54 | # reset logger elapsed time 55 | def reset_time(): 56 | log_formatter.start_time = time.time() 57 | 58 | logger.reset_time = reset_time 59 | 60 | logger.info( 61 | "\n".join( 62 | "%s: %s" % (k, str(v)) 63 | for k, v in sorted(dict(vars(args)).items(), key=lambda x: x[0]) 64 | ) 65 | ) 66 | 67 | return logger 68 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import contextlib 11 | import numpy as np 12 | import random 13 | import shutil 14 | import os 15 | 16 | import torch 17 | 18 | 19 | def set_seed(seed): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | 29 | def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"): 30 | # filename = os.path.join(checkpoint_path, "self"+filename) 31 | # torch.save(state, filename) 32 | if is_best: 33 | # shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt")) 34 | torch.save(state,os.path.join(checkpoint_path, "model_best.pth")) 35 | 36 | 37 | def load_checkpoint(model, path): 38 | best_checkpoint = torch.load(path) 39 | model.load_state_dict(best_checkpoint["state_dict"]) 40 | 41 | 42 | def truncate_seq_pair(tokens_a, tokens_b, max_length): 43 | """Truncates a sequence pair in place to the maximum length. 44 | Copied from https://github.com/huggingface/pytorch-pretrained-BERT 45 | """ 46 | while True: 47 | total_length = len(tokens_a) + len(tokens_b) 48 | if total_length <= max_length: 49 | break 50 | if len(tokens_a) > len(tokens_b): 51 | tokens_a.pop() 52 | else: 53 | tokens_b.pop() 54 | 55 | 56 | def store_preds_to_disk(tgts, preds, args): 57 | if args.task_type == "multilabel": 58 | with open(os.path.join(args.savedir, "test_labels_pred.txt"), "w") as fw: 59 | fw.write( 60 | "\n".join([" ".join(["1" if x else "0" for x in p]) for p in preds]) 61 | ) 62 | with open(os.path.join(args.savedir, "test_labels_gold.txt"), "w") as fw: 63 | fw.write( 64 | "\n".join([" ".join(["1" if x else "0" for x in t]) for t in tgts]) 65 | ) 66 | with open(os.path.join(args.savedir, "test_labels.txt"), "w") as fw: 67 | fw.write(" ".join([l for l in args.labels])) 68 | 69 | else: 70 | with open(os.path.join(args.savedir, "test_labels_pred.txt"), "w") as fw: 71 | fw.write("\n".join([str(x) for x in preds])) 72 | with open(os.path.join(args.savedir, "test_labels_gold.txt"), "w") as fw: 73 | fw.write("\n".join([str(x) for x in tgts])) 74 | with open(os.path.join(args.savedir, "test_labels.txt"), "w") as fw: 75 | fw.write(" ".join([str(l) for l in args.labels])) 76 | 77 | 78 | def log_metrics(set_name, metrics, args, logger): 79 | if args.task_type == "multilabel": 80 | logger.info( 81 | "{}: Loss: {:.5f} | Macro F1 {:.5f} | Micro F1: {:.5f}".format( 82 | set_name, metrics["loss"], metrics["macro_f1"], metrics["micro_f1"] 83 | ) 84 | ) 85 | else: 86 | logger.info( 87 | "{}: Loss: {:.5f} | Acc: {:.5f} | F1: {:.5f}".format( 88 | set_name, metrics["loss"], metrics["acc"],metrics["micro_f1"] 89 | ) 90 | ) 91 | 92 | 93 | @contextlib.contextmanager 94 | def numpy_seed(seed, *addl_seeds): 95 | """Context manager which seeds the NumPy PRNG with the specified seed and 96 | restores the state afterward""" 97 | if seed is None: 98 | yield 99 | return 100 | if len(addl_seeds) > 0: 101 | seed = int(hash((seed, *addl_seeds)) % 1e6) 102 | state = np.random.get_state() 103 | np.random.seed(seed) 104 | try: 105 | yield 106 | finally: 107 | np.random.set_state(state) 108 | 109 | import numpy as np 110 | import torch 111 | 112 | class History(object): 113 | def __init__(self, n_data): 114 | self.correctness = np.zeros((n_data)) 115 | self.confidence = np.zeros((n_data)) 116 | self.max_correctness = 1 117 | 118 | # correctness update 119 | def correctness_update(self, data_idx, correctness, confidence): 120 | #probs = torch.nn.functional.softmax(output, dim=1) 121 | #confidence, _ = probs.max(dim=1) 122 | data_idx = data_idx.cpu().numpy() 123 | 124 | self.correctness[data_idx] += correctness.cpu().numpy() 125 | self.confidence[data_idx] = confidence.cpu().detach().numpy() 126 | 127 | # max correctness update 128 | def max_correctness_update(self, epoch): 129 | if epoch > 1: 130 | self.max_correctness += 1 131 | 132 | # correctness normalize (0 ~ 1) range 133 | def correctness_normalize(self, data): 134 | data_min = self.correctness.min() 135 | #data_max = float(self.max_correctness) 136 | data_max = float(self.correctness.max()) 137 | 138 | return (data - data_min) / (data_max - data_min) 139 | 140 | # get target & margin 141 | def get_target_margin(self, data_idx1, data_idx2): 142 | data_idx1 = data_idx1.cpu().numpy() 143 | cum_correctness1 = self.correctness[data_idx1] #对应样本的损失值 144 | cum_correctness2 = self.correctness[data_idx2] 145 | # normalize correctness values 146 | cum_correctness1 = self.correctness_normalize(cum_correctness1) 147 | cum_correctness2 = self.correctness_normalize(cum_correctness2) 148 | # make target pair 149 | n_pair = len(data_idx1) 150 | target1 = cum_correctness1[:n_pair] 151 | target2 = cum_correctness2[:n_pair] 152 | # calc target 153 | greater = np.array(target1 > target2, dtype='float') 154 | less = np.array(target1 < target2, dtype='float') * (-1) 155 | 156 | target = greater + less 157 | target = torch.from_numpy(target).float().cuda() 158 | # calc margin 159 | margin = abs(target1 - target2) 160 | margin = torch.from_numpy(margin).float().cuda() 161 | 162 | return target, margin 163 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | 11 | import argparse 12 | from sklearn.metrics import f1_score, accuracy_score 13 | from tqdm import tqdm 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | from pytorch_pretrained_bert import BertAdam 19 | import torch.nn.functional as F 20 | 21 | device = torch.device("cuda:0") 22 | torch.cuda.set_device(device) 23 | 24 | from util import Contrastive_loss,totolloss,con_loss,KL_regular 25 | from src.data.helpers import get_data_loaders 26 | from src.models import get_model 27 | from src.utils.logger import create_logger 28 | from src.utils.utils import * 29 | 30 | import time 31 | 32 | def get_args(parser): 33 | parser.add_argument("--batch_sz", type=int, default=8) 34 | parser.add_argument("--bert_model", type=str, default="./prebert") 35 | parser.add_argument("--data_path", type=str, default="./datasets/MVSA_Single/") 36 | parser.add_argument("--drop_img_percent", type=float, default=0.0) 37 | parser.add_argument("--dropout", type=float, default=0.1) 38 | parser.add_argument("--embed_sz", type=int, default=300) 39 | parser.add_argument("--freeze_img", type=int, default=3) 40 | parser.add_argument("--freeze_txt", type=int, default=5) 41 | parser.add_argument("--glove_path", type=str, default="./datasets/glove_embeds/glove.840B.300d.txt") 42 | parser.add_argument("--gradient_accumulation_steps", type=int, default=40) 43 | parser.add_argument("--hidden", nargs="*", type=int, default=[]) 44 | parser.add_argument("--hidden_sz", type=int, default=768) 45 | parser.add_argument("--img_embed_pool_type", type=str, default="avg", choices=["max", "avg"]) 46 | parser.add_argument("--img_hidden_sz", type=int, default=2048) 47 | parser.add_argument("--include_bn", type=int, default=True) 48 | parser.add_argument("--lr", type=float, default=5e-05) 49 | parser.add_argument("--lr_factor", type=float, default=0.5) 50 | parser.add_argument("--lr_patience", type=int, default=2) 51 | parser.add_argument("--max_epochs", type=int, default=100) 52 | parser.add_argument("--max_seq_len", type=int, default=512) 53 | parser.add_argument("--model", type=str, default="latefusion", choices=["bow", "img", "bert", "concatbow", "concatbert", "mmbt","latefusion"]) 54 | parser.add_argument("--n_workers", type=int, default=4) 55 | parser.add_argument("--name", type=str, default="URMF") 56 | parser.add_argument("--num_image_embeds", type=int, default=3) 57 | parser.add_argument("--patience", type=int, default=5) 58 | parser.add_argument("--savedir", type=str, default="./saved/MVSA_Single") 59 | parser.add_argument("--seed", type=int, default=1699) 60 | parser.add_argument("--task", type=str, default="MVSA_Single", choices=["MVSA_Single"]) 61 | parser.add_argument("--task_type", type=str, default="classification", choices=["multilabel", "classification"]) 62 | parser.add_argument("--warmup", type=float, default=0.1) 63 | parser.add_argument("--weight_classes", type=int, default=1) 64 | parser.add_argument("--df", type=bool, default=True) 65 | parser.add_argument("--noise", type=float, default=0.0) 66 | parser.add_argument("--log_marker", type=str, default='') 67 | parser.add_argument("--modulation_starts", type=int, default=5) 68 | parser.add_argument("--modulation_ends", type=int, default=100) 69 | parser.add_argument("--zeta", type=float, default=0.01) 70 | 71 | 72 | 73 | def get_criterion(args): 74 | 75 | criterion = nn.CrossEntropyLoss() 76 | 77 | return criterion 78 | 79 | 80 | def get_optimizer(model, args): 81 | 82 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 83 | 84 | return optimizer 85 | 86 | 87 | def get_scheduler(optimizer, args): 88 | return optim.lr_scheduler.ReduceLROnPlateau( 89 | optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor 90 | ) 91 | 92 | 93 | def model_eval(i_epoch, data, model, args, criterion, store_preds=False): 94 | with torch.no_grad(): 95 | losses, preds, tgts = [], [], [] 96 | for batch in data: 97 | loss, out, tgt, cog_un = model_forward(i_epoch, model, args, criterion, batch,mode='eval') 98 | losses.append(loss.item()) 99 | 100 | if args.task_type == "multilabel": 101 | pred = torch.sigmoid(out).cpu().detach().numpy() > 0.5 102 | else: 103 | pred = torch.nn.functional.softmax(out, dim=1).argmax(dim=1).cpu().detach().numpy() 104 | 105 | preds.append(pred) 106 | tgt = tgt.cpu().detach().numpy() 107 | tgts.append(tgt) 108 | 109 | metrics = {"loss": np.mean(losses)} 110 | 111 | tgts = [l for sl in tgts for l in sl] 112 | preds = [l for sl in preds for l in sl] 113 | metrics["acc"] = accuracy_score(tgts, preds) 114 | metrics["micro_f1"] = f1_score(tgts, preds, average="weighted") 115 | 116 | 117 | if store_preds: 118 | store_preds_to_disk(tgts, preds, args) 119 | 120 | return metrics 121 | 122 | 123 | def model_forward(i_epoch, model, args, criterion, batch,txt_history=None,img_history=None,mode='eval'): 124 | txt, segment, mask, img, tgt,idx = batch 125 | # print(txt) 126 | # print(img) 127 | freeze_img = i_epoch < args.freeze_img 128 | freeze_txt = i_epoch < args.freeze_txt 129 | 130 | txt, img = txt.to(device), img.to(device) 131 | mask, segment = mask.to(device), segment.to(device) 132 | # out = model(txt, mask, segment, img) 133 | txt_img_logits, txt_logits, img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z, cog_un=model(txt, mask, segment, img) 134 | 135 | 136 | tgt = tgt.to(device) 137 | # loss = criterion(out, tgt) 138 | 139 | conloss=con_loss(txt_mu,torch.exp(txt_logvar),img_mu,torch.exp(img_logvar)) 140 | loss=totolloss(txt_img_logits, txt_logits,tgt,img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z) 141 | loss=loss+1e-5*KL_regular(txt_mu,txt_logvar,img_mu,img_logvar)+conloss*1e-3 142 | # loss=loss+0.01*conloss 143 | 144 | return loss,txt_img_logits,tgt, cog_un 145 | 146 | def train(args): 147 | 148 | set_seed(args.seed) 149 | args.savedir = os.path.join(args.savedir, args.name) 150 | os.makedirs(args.savedir, exist_ok=True) 151 | 152 | train_loader, val_loader, test_loaders = get_data_loaders(args) 153 | 154 | # print('adsasd') 155 | model = get_model(args) 156 | # print('1341232') 157 | criterion = get_criterion(args) 158 | optimizer = get_optimizer(model, args) 159 | scheduler = get_scheduler(optimizer, args) 160 | 161 | 162 | current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) 163 | dataset = "MVSA" 164 | if not os.path.exists(f"./log/{dataset}"): 165 | os.makedirs(f"./log/{dataset}") 166 | if args.log_marker != "": 167 | log_name = str(current_time)+"_"+args.log_marker 168 | else: 169 | log_name = str(current_time) 170 | log_path = os.path.join(f"./log/{dataset}", f"{log_name}.log") 171 | 172 | logger = create_logger(log_path, args) 173 | 174 | argsDict = args.__dict__ 175 | for eachArg, value in argsDict.items(): 176 | logger.info('{}:{}'.format(eachArg, value)) 177 | logger.info("==============================================") 178 | 179 | # logger.info(model) 180 | model.cuda() 181 | 182 | 183 | torch.save(args, os.path.join(args.savedir, "args.pt")) 184 | 185 | start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf 186 | 187 | if os.path.exists(os.path.join(args.savedir, "checkpoint.pt")): 188 | checkpoint = torch.load(os.path.join(args.savedir, "checkpoint.pt")) 189 | start_epoch = checkpoint["epoch"] 190 | n_no_improve = checkpoint["n_no_improve"] 191 | best_metric = checkpoint["best_metric"] 192 | model.load_state_dict(checkpoint["state_dict"]) 193 | optimizer.load_state_dict(checkpoint["optimizer"]) 194 | scheduler.load_state_dict(checkpoint["scheduler"]) 195 | 196 | logger.info("Training..") 197 | txt_history = History(len(train_loader.dataset)) 198 | img_history = History(len(train_loader.dataset)) 199 | 200 | for i_epoch in range(start_epoch, args.max_epochs): 201 | train_losses = [] 202 | model.train() 203 | optimizer.zero_grad() 204 | preds, tgts = [], [] 205 | for batch in tqdm(train_loader,total=len(train_loader)): 206 | # for batch in train_loader: 207 | loss, out, tgt, cog_uncertainty_dict = model_forward(i_epoch, model, args, criterion, batch,txt_history,img_history,mode='train') 208 | train_losses.append(loss.item()) 209 | 210 | 211 | 212 | loss.backward() 213 | 214 | if args.modulation_starts <= i_epoch <= args.modulation_ends: # bug fixed 215 | coeff_l = args.zeta * cog_uncertainty_dict['l'].mean() 216 | coeff_v = args.zeta * cog_uncertainty_dict['v'].mean() 217 | for name, parms in model.named_parameters(): 218 | if parms.grad == None: 219 | continue 220 | if any( _ in name for _ in ["txtclf"]): 221 | parms.grad = parms.grad * (1+coeff_v) 222 | if any( _ in name for _ in ["imgclf"]): 223 | parms.grad = parms.grad * (1+coeff_l) 224 | else: 225 | pass 226 | 227 | 228 | global_step += 1 229 | if global_step % args.gradient_accumulation_steps == 0: 230 | optimizer.step() 231 | optimizer.zero_grad() 232 | 233 | pred = torch.nn.functional.softmax(out, dim=1).argmax(dim=1).cpu().detach().numpy() 234 | 235 | preds.append(pred) 236 | tgt = tgt.cpu().detach().numpy() 237 | tgts.append(tgt) 238 | 239 | tgts = [l for sl in tgts for l in sl] 240 | preds = [l for sl in preds for l in sl] 241 | train_acc = accuracy_score(tgts, preds) 242 | 243 | 244 | model.eval() 245 | metrics = model_eval(i_epoch, val_loader, model, args, criterion) 246 | logger.info("Train Loss: {:.5f} Acc: {:.5f}".format(np.mean(train_losses),train_acc)) 247 | log_metrics("Val", metrics, args, logger) 248 | 249 | tuning_metric = ( 250 | metrics["micro_f1"] if args.task_type == "multilabel" else metrics["acc"] 251 | ) 252 | scheduler.step(tuning_metric) 253 | is_improvement = tuning_metric > best_metric 254 | if is_improvement: 255 | best_metric = tuning_metric 256 | n_no_improve = 0 257 | else: 258 | n_no_improve += 1 259 | 260 | save_checkpoint( 261 | { 262 | "epoch": i_epoch + 1, 263 | "state_dict": model, 264 | "optimizer": optimizer.state_dict(), 265 | "scheduler": scheduler.state_dict(), 266 | "n_no_improve": n_no_improve, 267 | "best_metric": best_metric, 268 | }, 269 | is_improvement, 270 | args.savedir, 271 | ) 272 | 273 | model.eval() 274 | for test_name, test_loader in test_loaders.items(): 275 | test_metrics = model_eval( 276 | np.inf, test_loader, model, args, criterion, store_preds=True 277 | ) 278 | log_metrics(f"Test - {test_name}", test_metrics, args, logger) 279 | 280 | 281 | def cli_main(): 282 | parser = argparse.ArgumentParser(description="Train Models") 283 | get_args(parser) 284 | args, remaining_args = parser.parse_known_args() 285 | assert remaining_args == [], remaining_args 286 | train(args) 287 | 288 | 289 | if __name__ == "__main__": 290 | import warnings 291 | 292 | warnings.filterwarnings("ignore") 293 | 294 | cli_main() 295 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | class Contrastive_loss(nn.Module): 6 | def __init__(self,tau): 7 | super(Contrastive_loss,self).__init__() 8 | self.tau=tau 9 | 10 | def sim(self,z1:torch.Tensor,z2:torch.Tensor): 11 | z1 = F.normalize(z1) 12 | z2 = F.normalize(z2) 13 | return torch.mm(z1,z2.t()) 14 | 15 | def semi_loss(self,z1:torch.Tensor,z2:torch.Tensor): 16 | f=lambda x: torch.exp(x/self.tau) 17 | refl_sim = f(self.sim(z1,z2)) 18 | between_sim=f(self.sim(z1,z2)) 19 | 20 | return -torch.log(between_sim.diag()/(refl_sim.sum(1)+between_sim.sum(1)-refl_sim.diag())) 21 | 22 | def forward(self,z1:torch.Tensor,z2:torch.Tensor,mean:bool=True): 23 | l1=self.semi_loss(z1,z2) 24 | l2=self.semi_loss(z2,z1) 25 | ret=(l1+l2)*0.5 26 | ret=ret.mean() if mean else ret.sum() 27 | return ret 28 | 29 | def totolloss(txt_img_logits, txt_logits,tgt,img_logits,txt_mu,txt_logvar,img_mu,img_logvar,mu,logvar,z): 30 | 31 | txt_kl_loss = -(1 + txt_logvar - txt_mu.pow(2) - txt_logvar.exp()) / 2 32 | txt_kl_loss = txt_kl_loss.sum(dim=1).mean() 33 | 34 | img_kl_loss = -(1 + img_logvar - img_mu.pow(2) - img_logvar.exp()) / 2 35 | img_kl_loss = img_kl_loss.sum(dim=1).mean() 36 | 37 | kl_loss = -(1 + logvar - mu.pow(2) - logvar.exp()) / 2 38 | kl_loss = kl_loss.sum(dim=1).mean() 39 | IB_loss=F.cross_entropy(z,tgt) 40 | 41 | fusion_cls_loss=F.cross_entropy(txt_img_logits,tgt) 42 | 43 | totol_loss=fusion_cls_loss+1e-3*kl_loss+1e-3*txt_kl_loss+1e-3*img_kl_loss+1e-3*IB_loss 44 | return totol_loss 45 | 46 | def KL_regular(mu_1,logvar_1,mu_2,logvar_2): 47 | var_1=torch.exp(logvar_1) 48 | var_2=torch.exp(logvar_2) 49 | KL_loss=logvar_2-logvar_1+((var_1.pow(2)+(mu_1-mu_2).pow(2))/(2*var_2.pow(2)))-0.5 50 | KL_loss=KL_loss.sum(dim=1).mean() 51 | return KL_loss 52 | 53 | def reparameterise(mu, std): 54 | """ 55 | mu : [batch_size,z_dim] 56 | std : [batch_size,z_dim] 57 | """ 58 | # get epsilon from standard normal 59 | eps = torch.randn_like(std) 60 | return mu + std*eps 61 | 62 | def con_loss(txt_mu,txt_logvar,img_mu,img_logvar): 63 | Conloss=Contrastive_loss(0.5) 64 | while True: 65 | t_z1 = reparameterise(txt_mu, txt_logvar) 66 | t_z2 = reparameterise(txt_mu, txt_logvar) 67 | 68 | if not np.array_equal(t_z1, t_z2): 69 | break 70 | while True: 71 | i_z1=reparameterise(img_mu,img_logvar) 72 | i_z2=reparameterise(img_mu,img_logvar) 73 | 74 | if not np.array_equal(t_z1, t_z2): 75 | break 76 | 77 | 78 | loss_t=Conloss(t_z1,t_z2) 79 | loss_i=Conloss(i_z1,i_z2) 80 | 81 | return loss_t+loss_i --------------------------------------------------------------------------------