├── MSDN.py ├── MSDN_awa2.py ├── MSDN_cub.py ├── MSDN_sun.py ├── README.md ├── Test_AWA2.py ├── Test_CUB.py ├── Test_SUN.py ├── config ├── test_AWA2.json ├── test_CUB.json └── test_SUN.json ├── core ├── 1 ├── AWA2DataLoader.py ├── CUBDataLoader.py ├── MSDN.py ├── SUNDataLoader.py ├── helper_MSDN_AWA2.py ├── helper_MSDN_CUB.py └── helper_MSDN_SUN.py ├── data ├── AWA2 ├── AWA2.pkl ├── CUB ├── CUB.pkl ├── SUN └── SUN.pkl ├── dataset.py ├── global_setting.py ├── images ├── t-v │ ├── Acadian_Flycatcher_0008_795599.jpg │ ├── American_Goldfinch_0092_32910.jpg │ ├── Canada_Warbler_0117_162394.jpg │ ├── Carolina_Wren_0006_186742.jpg │ ├── Elegant_Tern_0085_151091.jpg │ ├── European_Goldfinch_0025_794647.jpg │ ├── Florida_Jay_0008_64482.jpg │ ├── Fox_Sparrow_0025_114555.jpg │ ├── Grasshopper_Sparrow_0053_115991.jpg │ ├── Grasshopper_Sparrow_0107_116286.jpg │ ├── Gray_Crowned_Rosy_Finch_0036_797287.jpg │ ├── Vesper_Sparrow_0090_125690.jpg │ ├── Western_Gull_0058_53882.jpg │ ├── White_Throated_Sparrow_0128_128956.jpg │ ├── Winter_Wren_0118_189805.jpg │ └── Yellow_Breasted_Chat_0044_22106.jpg ├── tsne │ ├── awa2_tsne_test_unseen.png │ ├── awa2_tsne_train_seen.png │ ├── cub_tsne_test_unseen.png │ ├── cub_tsne_train_seen.png │ ├── sun_tsne_test_unseen.png │ └── sun_tsne_train_seen.png └── v-t │ ├── 1 │ ├── Acadian_Flycatcher_0008_795599.jpg │ ├── American_Goldfinch_0092_32910.jpg │ ├── Canada_Warbler_0117_162394.jpg │ ├── Carolina_Wren_0006_186742.jpg │ ├── Elegant_Tern_0085_151091.jpg │ ├── European_Goldfinch_0025_794647.jpg │ ├── Vesper_Sparrow_0090_125690.jpg │ ├── Western_Gull_0058_53882.jpg │ ├── White_Throated_Sparrow_0128_128956.jpg │ ├── Winter_Wren_0118_189805.jpg │ └── Yellow_Breasted_Chat_0044_22106.jpg ├── requirements.txt └── utils.py /MSDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | 7 | 8 | class MSDN(nn.Module): 9 | ##### 10 | # einstein sum notation 11 | # b: Batch size \ f: dim feature=2048 \ v: dim w2v=300 \ r: number of region=196 \ k: number of classes 12 | # i: number of attribute=312 13 | ##### 14 | def __init__(self, config, normalize_V = False, normalize_F = False, is_conservative = False, 15 | prob_prune=0,uniform_att_1 = False,uniform_att_2 = False, is_conv = False, 16 | is_bias = False,bias = 1,non_linear_act=False, 17 | loss_type = 'CE',non_linear_emb = False, 18 | is_sigmoid = False): 19 | super(MSDN, self).__init__() 20 | self.config = config 21 | self.dim_f = config.dim_f 22 | self.dim_v = config.dim_v 23 | self.nclass = config.num_class 24 | self.dim_att = config.num_attribute 25 | self.hidden = self.dim_att//2 26 | self.non_linear_act = non_linear_act 27 | self.loss_type = loss_type 28 | self.w1 = config.w1 29 | self.w2 = config.w2 30 | 31 | self.att = nn.Parameter(torch.empty( 32 | self.nclass, self.dim_att), requires_grad=False) 33 | self.V = nn.Parameter(torch.empty( 34 | self.dim_att, self.dim_v), requires_grad=True) 35 | 36 | self.W_1 = nn.Parameter(nn.init.normal_( 37 | torch.empty(self.dim_v, self.dim_f)), requires_grad=True) 38 | self.W_2 = nn.Parameter(nn.init.zeros_( 39 | torch.empty(self.dim_v, self.dim_f)), requires_grad=True) 40 | self.W_3 = nn.Parameter(nn.init.zeros_( 41 | torch.empty(self.dim_v, self.dim_f)), requires_grad=True) 42 | 43 | self.W_1_1 = nn.Parameter(nn.init.zeros_( 44 | torch.empty(self.dim_f, self.dim_v)), requires_grad=True) 45 | self.W_2_1 = nn.Parameter(nn.init.zeros_( 46 | torch.empty(self.dim_v, self.dim_f)), requires_grad=True) 47 | self.W_3_1 = nn.Parameter(nn.init.zeros_( 48 | torch.empty(self.dim_f, self.dim_v)), requires_grad=True) 49 | 50 | self.normalize_V = normalize_V 51 | self.normalize_F = normalize_F 52 | self.is_conservative = is_conservative 53 | self.is_conv = is_conv 54 | self.is_bias = is_bias 55 | 56 | if is_bias: 57 | self.bias = nn.Parameter(torch.tensor(1), requires_grad=False) 58 | self.mask_bias = nn.Parameter(torch.empty( 59 | 1, self.nclass), requires_grad=False) 60 | 61 | self.prob_prune = nn.Parameter(torch.tensor(prob_prune),requires_grad = False) 62 | 63 | self.uniform_att_1 = uniform_att_1 64 | self.uniform_att_2 = uniform_att_2 65 | 66 | self.non_linear_emb = non_linear_emb 67 | if self.non_linear_emb: 68 | self.emb_func = torch.nn.Sequential( 69 | torch.nn.Linear(self.dim_att, self.dim_att//2), 70 | torch.nn.ReLU(), 71 | torch.nn.Linear(self.dim_att//2, 1),) 72 | self.is_sigmoid = is_sigmoid 73 | 74 | # bakcbone 75 | resnet101 = torchvision.models.resnet101(pretrained=True) 76 | self.resnet101 = nn.Sequential(*list(resnet101.children())[:-2]) 77 | 78 | 79 | def compute_V(self): 80 | if self.normalize_V: 81 | V_n = F.normalize(self.V) 82 | else: 83 | V_n = self.V 84 | return V_n 85 | 86 | def get_global_feature(self, x): 87 | 88 | N, C, W, H = x.shape 89 | global_feat = F.avg_pool2d(x, kernel_size=(W, H)) 90 | global_feat = global_feat.view(N, C) 91 | 92 | return global_feat 93 | 94 | 95 | def forward(self, imgs): 96 | 97 | Fs = self.resnet101(imgs) 98 | 99 | if self.is_conv: 100 | Fs = self.conv1(Fs) 101 | Fs = self.conv1_bn(Fs) 102 | Fs = F.relu(Fs) 103 | 104 | shape = Fs.shape 105 | 106 | visualf_ori = self.get_global_feature(Fs) 107 | 108 | 109 | Fs = Fs.reshape(shape[0],shape[1],shape[2]*shape[3]) # batch x 2048 x 49 110 | 111 | R = Fs.size(2) # 49 112 | B = Fs.size(0) # batch 113 | V_n = self.compute_V() # 312x300 114 | 115 | if self.normalize_F and not self.is_conv: # true 116 | Fs = F.normalize(Fs,dim = 1) 117 | 118 | 119 | ##########################Text-Image################################ 120 | 121 | ## Compute attribute score on each image region 122 | S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1,Fs) # batchx312x49 123 | 124 | if self.is_sigmoid: 125 | S=torch.sigmoid(S) 126 | 127 | ## Ablation setting 128 | A_b = Fs.new_full((B,self.dim_att,R),1/R) 129 | A_b_p = self.att.new_full((B,self.dim_att),fill_value = 1) 130 | S_b_p = torch.einsum('bir,bir->bi',A_b,S) 131 | S_b_pp = torch.einsum('ki,bi,bi->bk',self.att,A_b_p,S_b_p) 132 | ## 133 | 134 | ## compute Dense Attention 135 | A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2,Fs) 136 | A = F.softmax(A,dim = -1) 137 | 138 | F_p = torch.einsum('bir,bfr->bif',A,Fs) 139 | if self.uniform_att_1: # false 140 | S_p = torch.einsum('bir,bir->bi',A_b,S) 141 | else: 142 | S_p = torch.einsum('bir,bir->bi',A,S) 143 | 144 | if self.non_linear_act: # false 145 | S_p = F.relu(S_p) 146 | ## 147 | 148 | ## compute Attention over Attribute 149 | A_p = torch.einsum('iv,vf,bif->bi',V_n,self.W_3,F_p) #eq. 6 150 | A_p = torch.sigmoid(A_p) 151 | ## 152 | 153 | if self.uniform_att_2: # true 154 | S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_b_p,S_p) 155 | else: 156 | # S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_p,S_p) 157 | S_pp = torch.einsum('ki,bi->bik',self.att,S_p) 158 | 159 | S_attr = torch.einsum('bi,bi->bi',A_b_p,S_p) 160 | 161 | if self.non_linear_emb: 162 | S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik] 163 | S_pp = self.emb_func(S_pp) #[bk1] <== [bki] 164 | S_pp = S_pp[:,:,0] #[bk] <== [bk1] 165 | else: 166 | S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik] 167 | 168 | # augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes 169 | if self.is_bias: 170 | self.vec_bias = self.mask_bias*self.bias 171 | S_pp = S_pp + self.vec_bias 172 | 173 | ## spatial attention supervision 174 | Pred_att = torch.einsum('iv,vf,bif->bi',V_n,self.W_1,F_p) 175 | package1 = {'S_pp':S_pp,'Pred_att':Pred_att,'S_p':S_p,'S_b_pp':S_b_pp,'A_p':A_p,'A':A,'S_attr':S_attr,'visualf_ori':visualf_ori,'visualf_a_v':F_p} 176 | 177 | ##########################Image-Text################################ 178 | 179 | ## Compute attribute score on each image region 180 | 181 | S = torch.einsum('bfr,fv,iv->bri',Fs,self.W_1_1,V_n) # batchx49x312 182 | # S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1_1,Fs) 183 | if self.is_sigmoid: 184 | S=torch.sigmoid(S) 185 | 186 | 187 | 188 | ## compute Dense Attention 189 | A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2_1,Fs) 190 | A = F.softmax(A,dim = 1) 191 | 192 | v_a = torch.einsum('bir,iv->brv',A,V_n) 193 | 194 | S_p = torch.einsum('bir,bri->bi',A,S) 195 | 196 | if self.non_linear_act: # false 197 | S_p = F.relu(S_p) 198 | 199 | 200 | 201 | S_pp = torch.einsum('ki,bi->bik',self.att,S_p) 202 | 203 | S_attr = 0#torch.einsum('bi,bi->bi',A_b_p,S_p) 204 | 205 | if self.non_linear_emb: 206 | S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik] 207 | S_pp = self.emb_func(S_pp) #[bk1] <== [bki] 208 | S_pp = S_pp[:,:,0] #[bk] <== [bk1] 209 | else: 210 | S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik] 211 | 212 | # augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes 213 | if self.is_bias: 214 | self.vec_bias = self.mask_bias*self.bias 215 | S_pp = S_pp + self.vec_bias 216 | 217 | ## spatial attention supervision 218 | package2 = {'S_pp':S_pp,'visualf_v_a':v_a, 'S_p':S_p, 'A':A} 219 | 220 | package = {'embed': self.w1 * package1['S_pp']+self.w2 * package2['S_pp']} 221 | 222 | return package 223 | -------------------------------------------------------------------------------- /MSDN_awa2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import pandas as pd 5 | from core.MSDN import MSDN 6 | from core.AWA2DataLoader import AWA2DataLoader 7 | from core.helper_MSDN_AWA2 import eval_zs_gzsl 8 | from global_setting import NFS_path 9 | import importlib 10 | import pdb 11 | import numpy as np 12 | 13 | idx_GPU = 0 14 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 15 | dataloader = AWA2DataLoader(NFS_path,device) 16 | torch.backends.cudnn.benchmark = True 17 | 18 | def get_lr(optimizer): 19 | lr = [] 20 | for param_group in optimizer.param_groups: 21 | lr.append(param_group['lr']) 22 | return lr 23 | 24 | seed = 87778 25 | # seed = 6379 # for czsl 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | np.random.seed(seed) 29 | 30 | batch_size = 50 31 | nepoches = 50 32 | niters = dataloader.ntrain * nepoches//batch_size 33 | dim_f = 2048 34 | dim_v = 300 35 | init_w2v_att = dataloader.w2v_att 36 | att = dataloader.att#dataloader.normalize_att# 37 | att[att<0] = 0 38 | normalize_att = dataloader.normalize_att 39 | #assert (att.min().item() == 0 and att.max().item() == 1) 40 | 41 | trainable_w2v = True 42 | lambda_ = 0.12#0.1 ,0.12 for T-I in GZSL, 0.3 for T-I in CZSL, 0.13 for I-T,0.3 for baseline 43 | bias = 0 44 | prob_prune = 0 45 | uniform_att_1 = False 46 | uniform_att_2 = False 47 | 48 | seenclass = dataloader.seenclasses 49 | unseenclass = dataloader.unseenclasses 50 | desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0)) 51 | report_interval = niters//nepoches#10000//batch_size# 52 | 53 | model = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att, 54 | seenclass,unseenclass, 55 | lambda_, 56 | trainable_w2v,normalize_V=True,normalize_F=True,is_conservative=True, 57 | uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2, 58 | prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False, 59 | is_bias=True) 60 | model.to(device) 61 | 62 | setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8}, 63 | 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}} 64 | print(setup) 65 | 66 | params_to_update = [] 67 | params_names = [] 68 | for name,param in model.named_parameters(): 69 | if param.requires_grad == True: 70 | params_to_update.append(param) 71 | params_names.append(name) 72 | print("\t",name) 73 | #%% 74 | lr = 0.0001 75 | weight_decay = 0.0001#0.000#0.# 76 | momentum = 0.#0.# 77 | #%% 78 | lr_seperator = 1 79 | lr_factor = 1 80 | print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:])) 81 | optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum) 82 | print('-'*30) 83 | print('learing rate {}'.format(lr)) 84 | print('trainable V {}'.format(trainable_w2v)) 85 | print('lambda_ {}'.format(lambda_)) 86 | print('optimized seen only') 87 | print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay)) 88 | print('-'*30) 89 | 90 | best_performance = [0,0,0] 91 | best_acc = 0 92 | for i in range(0,niters): 93 | model.train() 94 | optimizer.zero_grad() 95 | 96 | batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size) 97 | 98 | out_package1, out_package2= model(batch_feature) 99 | 100 | in_package1 = out_package1 101 | in_package2 = out_package2 102 | in_package1['batch_label'] = batch_label 103 | in_package2['batch_label'] = batch_label 104 | 105 | 106 | out_package1=model.compute_loss(in_package1) 107 | out_package2=model.compute_loss(in_package2) 108 | loss,loss_CE,loss_cal = out_package1['loss']+out_package2['loss'],out_package1['loss_CE']+out_package2['loss_CE'],out_package1['loss_cal']+out_package2['loss_cal'] 109 | constrastive_loss=model.compute_contrastive_loss(in_package1, in_package2) 110 | 111 | loss=loss + 0.0000001*constrastive_loss 112 | 113 | loss.backward() 114 | optimizer.step() 115 | if i%report_interval==0: 116 | print('-'*30) 117 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias) 118 | 119 | if H > best_performance[2]: 120 | best_performance = [acc_novel, acc_seen, H] 121 | if acc_zs > best_acc: 122 | best_acc = acc_zs 123 | print('iter=%d, loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(i,loss.item(),loss_CE.item(),loss_cal.item(),best_performance[0],best_performance[1],best_performance[2],best_acc)) 124 | -------------------------------------------------------------------------------- /MSDN_cub.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import pandas as pd 5 | from core.MSDN import MSDN 6 | from core.CUBDataLoader import CUBDataLoader 7 | from core.helper_MSDN_CUB import eval_zs_gzsl 8 | # from global_setting import NFS_path 9 | import importlib 10 | import pdb 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | NFS_path = './' 15 | 16 | idx_GPU = 0 17 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 18 | dataloader = CUBDataLoader(NFS_path,device,is_unsupervised_attr=False,is_balance=False) 19 | torch.backends.cudnn.benchmark = True 20 | 21 | def get_lr(optimizer): 22 | lr = [] 23 | for param_group in optimizer.param_groups: 24 | lr.append(param_group['lr']) 25 | return lr 26 | 27 | seed = 214#215# 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | np.random.seed(seed) 31 | 32 | batch_size = 50 33 | nepoches = 30#22 34 | niters = dataloader.ntrain * nepoches//batch_size 35 | dim_f = 2048 36 | dim_v = 300 37 | init_w2v_att = dataloader.w2v_att 38 | att = dataloader.att 39 | normalize_att = dataloader.normalize_att 40 | 41 | trainable_w2v = True 42 | lambda_ = 0.1#0.1 for GZSL, 0.18 for CZSL 43 | bias = 0 44 | prob_prune = 0 45 | uniform_att_1 = False 46 | uniform_att_2 = False 47 | 48 | seenclass = dataloader.seenclasses 49 | unseenclass = dataloader.unseenclasses 50 | desired_mass = 1 51 | report_interval = niters//nepoches 52 | 53 | model = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att, 54 | seenclass,unseenclass, 55 | lambda_, 56 | trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True, 57 | uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2, 58 | prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False, 59 | is_bias=True) 60 | model.to(device) 61 | 62 | setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8}, 63 | 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}} 64 | print(setup) 65 | #scheduler = Scheduler(model,niters,batch_size,report_interval,setup) 66 | 67 | params_to_update = [] 68 | params_names = [] 69 | for name,param in model.named_parameters(): 70 | if param.requires_grad == True: 71 | params_to_update.append(param) 72 | params_names.append(name) 73 | print("\t",name) 74 | #%% 75 | lr = 0.0001 76 | weight_decay = 0.0001#0.000#0.# 77 | momentum = 0.9#0.# 78 | #%% 79 | lr_seperator = 1 80 | lr_factor = 1 81 | print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:])) 82 | optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum) 83 | 84 | print('-'*30) 85 | print('learing rate {}'.format(lr)) 86 | print('trainable V {}'.format(trainable_w2v)) 87 | print('lambda_ {}'.format(lambda_)) 88 | print('optimized seen only') 89 | print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay)) 90 | print('-'*30) 91 | 92 | iter_x = [] 93 | best_H = [] 94 | best_ACC =[] 95 | 96 | best_performance = [0,0,0] 97 | best_acc = 0 98 | for i in range(0,niters): 99 | model.train() 100 | optimizer.zero_grad() 101 | 102 | batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size) 103 | 104 | out_package1, out_package2= model(batch_feature) 105 | 106 | in_package1 = out_package1 107 | in_package2 = out_package2 108 | in_package1['batch_label'] = batch_label 109 | in_package2['batch_label'] = batch_label 110 | 111 | out_package1=model.compute_loss(in_package1) 112 | out_package2=model.compute_loss(in_package2) 113 | loss,loss_CE,loss_cal = out_package1['loss']+out_package2['loss'],out_package1['loss_CE']+out_package2['loss_CE'],out_package1['loss_cal']+out_package2['loss_cal'] 114 | constrastive_loss1=model.compute_contrastive_loss(in_package1, in_package2) 115 | 116 | loss=loss + 0.001*constrastive_loss1##0.001 117 | 118 | 119 | loss.backward() 120 | optimizer.step() 121 | if i%report_interval==0: 122 | print('-'*30) 123 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias) 124 | 125 | if H > best_performance[2]: 126 | best_performance = [acc_novel, acc_seen, H] 127 | if acc_zs > best_acc: 128 | best_acc = acc_zs 129 | print('iter=%d, loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(i,loss.item(),loss_CE.item(),loss_cal.item(),best_performance[0],best_performance[1],best_performance[2],best_acc)) -------------------------------------------------------------------------------- /MSDN_sun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | import pandas as pd 5 | from core.MSDN import MSDN 6 | from core.SUNDataLoader import SUNDataLoader 7 | from core.helper_MSDN_SUN import eval_zs_gzsl 8 | from global_setting import NFS_path 9 | import importlib 10 | import pdb 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | idx_GPU = 0 16 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 17 | dataloader = SUNDataLoader(NFS_path,device,is_scale=False,is_balance = True) 18 | torch.backends.cudnn.benchmark = True 19 | 20 | seed = 2339#214 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | 25 | print('Randomize seed {}'.format(seed)) 26 | #%% 27 | batch_size = 50 28 | nepoches = 70 29 | niters = dataloader.ntrain * nepoches//batch_size 30 | dim_f = 2048 31 | dim_v = 300 32 | init_w2v_att = dataloader.w2v_att 33 | att = dataloader.att#dataloader.normalize_att# 34 | normalize_att = dataloader.normalize_att 35 | #assert (att.min().item() == 0 and att.max().item() == 1) 36 | 37 | trainable_w2v = True 38 | lambda_ = 0.0001 #0.0 39 | bias = 0. 40 | prob_prune = 0 41 | uniform_att_1 = False 42 | uniform_att_2 = True 43 | 44 | seenclass = dataloader.seenclasses 45 | unseenclass = dataloader.unseenclasses 46 | desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0)) 47 | report_interval = niters//nepoches 48 | #%% 49 | model = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att, 50 | seenclass,unseenclass, 51 | lambda_, 52 | trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True, 53 | uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2, 54 | prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False, 55 | is_bias=True,non_linear_act=False) 56 | model.to(device) 57 | #%% 58 | params_to_update = [] 59 | for name,param in model.named_parameters(): 60 | if param.requires_grad == True: 61 | params_to_update.append(param) 62 | print("\t",name) 63 | #%% 64 | lr = 0.0001 65 | weight_decay = 0.0001#0.000#0.# 66 | momentum = 0.9#0.# 67 | optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum) 68 | #%% 69 | print('-'*30) 70 | print('learing rate {}'.format(lr)) 71 | print('trainable V {}'.format(trainable_w2v)) 72 | print('lambda_ {}'.format(lambda_)) 73 | print('optimized seen only') 74 | print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay)) 75 | print('-'*30) 76 | 77 | iter_x = [] 78 | best_H = [] 79 | best_ACC =[] 80 | best_performance = [0,0,0] 81 | best_acc = 0 82 | for i in range(0,niters): 83 | model.train() 84 | optimizer.zero_grad() 85 | 86 | batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size) 87 | 88 | out_package1, out_package2= model(batch_feature) 89 | 90 | in_package1 = out_package1 91 | in_package2 = out_package2 92 | in_package1['batch_label'] = batch_label 93 | in_package2['batch_label'] = batch_label 94 | 95 | out_package1=model.compute_loss(in_package1) 96 | out_package2=model.compute_loss(in_package2) 97 | loss,loss_CE,loss_cal = out_package1['loss']+out_package2['loss'],out_package1['loss_CE']+out_package2['loss_CE'],out_package1['loss_cal']+out_package2['loss_cal'] 98 | constrastive_loss=model.compute_contrastive_loss(in_package1, in_package2) 99 | loss=loss + 0.01*constrastive_loss 100 | 101 | loss.backward() 102 | optimizer.step() 103 | if i%report_interval==0: 104 | print('-'*30) 105 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias) 106 | 107 | if H > best_performance[2]: 108 | best_performance = [acc_novel, acc_seen, H] 109 | if acc_zs > best_acc: 110 | best_acc = acc_zs 111 | print('iter=%d, loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(i,loss.item(),loss_CE.item(),loss_cal.item(),best_performance[0],best_performance[1],best_performance[2],best_acc)) 112 | if i%500==0: 113 | iter_x.append(i) 114 | best_H.append(best_performance[2]) 115 | best_ACC.append(best_acc) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSDN 2 | 3 | This is the total codes of paper "**MSDN: Mutually Semantic Distillation Network for Zero-Shot Learning**" accepted to *CVPR'22*. This website includes the following materials for testing and checking our results reported in our paper: 4 | 5 | 1. The training codes 6 | 1. The testing codes 7 | 2. The trained model 8 | 9 | ### Requirements 10 | The code implementation of **MSDN** mainly based on [PyTorch](https://pytorch.org/). All of our experiments run and test in Python 3.8.8. To install all required dependencies: 11 | ``` 12 | $ pip install -r requirements.txt 13 | ``` 14 | 15 | ## Training 16 | 17 | We trained the model on three popular ZSL benchmarks: [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [SUN](http://cs.brown.edu/~gmpatter/sunattributes.html) and [AWA2](http://cvml.ist.ac.at/AwA2/) following the data split of [xlsa17](http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip). 18 | Please follow [TransZero](https://github.com/shiming-chen/TransZero) to prepare datasets and extract visual features. 19 | 20 | ### Training Script 21 | 22 | ``` 23 | $ python MSDN_cub.py 24 | $ python MSDN_sun.py 25 | $ python MSDN_awa2.py 26 | ``` 27 | **Note**: Please load the corresponding setting when aiming at the CZSL task. 28 | 29 | ### Results 30 | We also upload trained models in [test branch](https://github.com/shiming-chen/MSDN). Results of our released models using various evaluation protocols on three datasets, both in the conventional ZSL (CZSL) and generalized ZSL (GZSL) settings. 31 | 32 | | Dataset | Acc(CZSL) | U(GZSL) | S(GZSL) | H(GZSL) | 33 | | :-----: | :-----: | :-----: | :-----: | :-----: | 34 | | CUB | 76.1 | 68.7 | 67.5 | 68.1 | 35 | | SUN | 65.8 | 52.2 | 34.2 | 41.3 | 36 | | AWA2 | 70.1 | 62.0 | 74.5 | 67.7 | 37 | 38 | **Note**: All of above results are run on a server with an AMD Ryzen 7 5800X CPU and a NVIDIA RTX A6000 GPU. The training codes will be released soon. 39 | 40 | ## Testing 41 | 42 | ### Preparing Dataset and Model 43 | 44 | We provide trained models ([Google Drive](https://drive.google.com/drive/folders/1IBGfPXleu4E2BLTI4TlUL1jYSuwahbYC?usp=sharing)) on three different datasets: [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [SUN](http://cs.brown.edu/~gmpatter/sunattributes.html), [AWA2](http://cvml.ist.ac.at/AwA2/) in the CZSL/GZSL setting. You can download model files as well as corresponding datasets, and organize them as follows: 45 | ``` 46 | . 47 | ├── saved_model 48 | │ ├── CUB_MSDN_CZSL.pth 49 | │ ├── CUB_MSDN_GZSL.pth 50 | │ ├── SUN_MSDN_CZSL.pth 51 | │ ├── SUN_MSDN_GZSL.pth 52 | │ ├── AWA2_MSDN_CZSL.pth 53 | │ └── AWA2_MSDN_GZSL.pth 54 | ├── data 55 | │ ├── CUB/ 56 | │ ├── SUN/ 57 | │ └── AWA2/ 58 | └── ··· 59 | ``` 60 | 61 | 62 | ## Testing Script 63 | Runing following commands and testing **MSDN** on different dataset: 64 | 65 | CUB Dataset: 66 | ``` 67 | $ python Test_CUB.py 68 | ``` 69 | SUN Dataset: 70 | ``` 71 | $ python Test_SUN.py 72 | ``` 73 | AWA2 Dataset: 74 | ``` 75 | $ python Test_AWA2.py 76 | ``` 77 | 78 | ### Results 79 | Results of our released models using various evaluation protocols on three datasets, both in the conventional ZSL (CZSL) and generalized ZSL (GZSL) settings. 80 | 81 | | Dataset | Acc(CZSL) | U(GZSL) | S(GZSL) | H(GZSL) | 82 | | :-----: | :-----: | :-----: | :-----: | :-----: | 83 | | CUB | 76.1 | 68.7 | 67.5 | 68.1 | 84 | | SUN | 65.8 | 52.2 | 34.2 | 41.3 | 85 | | AWA2 | 70.1 | 62.0 | 74.5 | 67.7 | 86 | 87 | **Note**: All of above results are run on a server with an AMD Ryzen 7 5800X CPU and a NVIDIA RTX A6000 GPU. The training codes will be released soon. 88 | 89 | ## Citation 90 | If this work is helpful for you, please cite our paper. 91 | 92 | ``` 93 | @InProceedings{Chen2022MSDN, 94 | author = {Chen, Shiming and Hong, Ziming and Xie, Guo-Sen and Yang, Wenhan and Peng, Qinmu and Wang, Kai and Zhao, Jian and You, Xinge}, 95 | title = {MSDN: Mutually Semantic Distillation Network for Zero-Shot Learning}, 96 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition ( CVPR )}, 97 | year = {2022} 98 | } 99 | ``` 100 | 101 | 102 | ## References 103 | Parts of our codes based on: 104 | * [hbdat/cvpr20_DAZLE](https://github.com/hbdat/cvpr20_DAZLE) 105 | * [shiming-chen/TransZero](https://github.com/shiming-chen/TransZero) 106 | 161 | -------------------------------------------------------------------------------- /Test_AWA2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from MSDN import MSDN 3 | from dataset import UNIDataloader 4 | import argparse 5 | import json 6 | from utils import evaluation 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--config', type=str, default='config/test_AWA2.json') 11 | config = parser.parse_args() 12 | with open(config.config, 'r') as f: 13 | config.__dict__ = json.load(f) 14 | 15 | dataloader = UNIDataloader(config) 16 | 17 | model_gzsl = MSDN(config, normalize_V=True, normalize_F=True, is_conservative=True, 18 | uniform_att_1=False, uniform_att_2=False, 19 | is_conv=False, is_bias=True).to(config.device) 20 | model_dict = model_gzsl.state_dict() 21 | saved_dict = torch.load('saved_model/AWA2_MSDN_GZSL.pth') 22 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 23 | model_dict.update(saved_dict) 24 | model_gzsl.load_state_dict(model_dict) 25 | 26 | model_czsl = MSDN(config, normalize_V=True, normalize_F=True, is_conservative=True, 27 | uniform_att_1=False, uniform_att_2=False, 28 | is_conv=False, is_bias=True).to(config.device) 29 | model_dict = model_czsl.state_dict() 30 | saved_dict = torch.load('saved_model/AWA2_MSDN_CZSL.pth') 31 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 32 | model_dict.update(saved_dict) 33 | model_czsl.load_state_dict(model_dict) 34 | 35 | evaluation(config.batch_size, config.device, 36 | dataloader, model_gzsl, model_czsl) 37 | -------------------------------------------------------------------------------- /Test_CUB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from MSDN import MSDN 3 | from dataset import UNIDataloader 4 | import argparse 5 | import json 6 | from utils import evaluation 7 | 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--config', type=str, default='config/test_CUB.json') 11 | config = parser.parse_args() 12 | with open(config.config, 'r') as f: 13 | config.__dict__ = json.load(f) 14 | 15 | dataloader = UNIDataloader(config) 16 | 17 | model_gzsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True, 18 | uniform_att_1=False, uniform_att_2=False, 19 | is_conv=False, is_bias=True).to(config.device) 20 | model_dict = model_gzsl.state_dict() 21 | saved_dict = torch.load('saved_model/CUB_MSDN_GZSL.pth') 22 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 23 | model_dict.update(saved_dict) 24 | model_gzsl.load_state_dict(model_dict) 25 | 26 | model_czsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True, 27 | uniform_att_1=False, uniform_att_2=False, 28 | is_conv=False, is_bias=True).to(config.device) 29 | model_dict = model_czsl.state_dict() 30 | saved_dict = torch.load('saved_model/CUB_MSDN_CZSL.pth') 31 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 32 | model_dict.update(saved_dict) 33 | model_czsl.load_state_dict(model_dict) 34 | 35 | evaluation(config.batch_size, config.device, 36 | dataloader, model_gzsl, model_czsl) 37 | -------------------------------------------------------------------------------- /Test_SUN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from MSDN import MSDN 3 | from dataset import UNIDataloader 4 | import argparse 5 | import json 6 | from utils import evaluation 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--config', type=str, default='config/test_SUN.json') 10 | config = parser.parse_args() 11 | with open(config.config, 'r') as f: 12 | config.__dict__ = json.load(f) 13 | 14 | dataloader = UNIDataloader(config) 15 | 16 | model_gzsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True, 17 | uniform_att_1=False, uniform_att_2=True, 18 | is_conv=False, is_bias=True, non_linear_act=False).to(config.device) 19 | model_dict = model_gzsl.state_dict() 20 | saved_dict = torch.load('saved_model/SUN_MSDN_GZSL.pth') 21 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 22 | model_dict.update(saved_dict) 23 | model_gzsl.load_state_dict(model_dict) 24 | 25 | model_czsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True, 26 | uniform_att_1=False, uniform_att_2=True, 27 | is_conv=False, is_bias=True, non_linear_act=False).to(config.device) 28 | model_dict = model_czsl.state_dict() 29 | saved_dict = torch.load('saved_model/SUN_MSDN_CZSL.pth') 30 | saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict} 31 | model_dict.update(saved_dict) 32 | model_czsl.load_state_dict(model_dict) 33 | 34 | evaluation(config.batch_size, config.device, 35 | dataloader, model_gzsl, model_czsl) 36 | -------------------------------------------------------------------------------- /config/test_AWA2.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "AWA2", 3 | "dataset_path": "./data/AWA2", 4 | "pkl_path": "./data/AWA2.pkl", 5 | "device": "cuda:0", 6 | "num_workers": 16, 7 | "batch_size": 50, 8 | "num_attribute": 85, 9 | "num_class": 50, 10 | "resnet_region": 196, 11 | "dim_f": 2048, 12 | "dim_v": 300, 13 | "img_size": 448, 14 | "w1": 1.0, 15 | "w2": 0.0 16 | } -------------------------------------------------------------------------------- /config/test_CUB.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "CUB", 3 | "dataset_path": "./data/CUB", 4 | "pkl_path": "./data/CUB.pkl", 5 | "device": "cuda:0", 6 | "num_workers": 16, 7 | "batch_size": 50, 8 | "num_attribute": 312, 9 | "num_class": 200, 10 | "resnet_region": 196, 11 | "dim_f": 2048, 12 | "dim_v": 300, 13 | "img_size": 448, 14 | "w1": 0.9, 15 | "w2": 0.1 16 | } -------------------------------------------------------------------------------- /config/test_SUN.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": "SUN", 3 | "dataset_path": "./data/SUN", 4 | "pkl_path": "./data/SUN.pkl", 5 | "device": "cuda:0", 6 | "num_workers": 16, 7 | "batch_size": 50, 8 | "num_attribute": 102, 9 | "num_class": 717, 10 | "resnet_region": 196, 11 | "dim_f": 2048, 12 | "dim_v": 300, 13 | "img_size": 448, 14 | "w1": 0.7, 15 | "w2": 0.3 16 | } -------------------------------------------------------------------------------- /core/1: -------------------------------------------------------------------------------- 1 | 1111 2 | -------------------------------------------------------------------------------- /core/AWA2DataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jul 20 21:23:18 2019 4 | 5 | @author: badat 6 | """ 7 | 8 | import os,sys 9 | #import scipy.io as sio 10 | import torch 11 | import numpy as np 12 | import h5py 13 | import time 14 | import pickle 15 | from sklearn import preprocessing 16 | from global_setting import NFS_path 17 | #%% 18 | import scipy.io as sio 19 | import pandas as pd 20 | #%% 21 | import pdb 22 | #%% 23 | dataset = 'AWA2' 24 | img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset)) 25 | mat_path = os.path.join(NFS_path,'data/xlsa17/data/{}/res101.mat'.format(dataset)) 26 | attr_path = '/data2/shimingchen/BCA/attribute/{}/new_des.csv'.format(dataset) 27 | 28 | 29 | class AWA2DataLoader(): 30 | def __init__(self, data_path, device, is_scale = False, is_unsupervised_attr = False,is_balance =True): 31 | 32 | print(data_path) 33 | sys.path.append(data_path) 34 | 35 | self.data_path = data_path 36 | self.device = device 37 | self.dataset = 'AWA2' 38 | print('$'*30) 39 | print(self.dataset) 40 | print('$'*30) 41 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 42 | self.index_in_epoch = 0 43 | self.epochs_completed = 0 44 | self.is_scale = is_scale 45 | self.is_balance = is_balance 46 | if self.is_balance: 47 | print('Balance dataloader') 48 | self.is_unsupervised_attr = is_unsupervised_attr 49 | self.read_matdataset() 50 | self.get_idx_classes() 51 | 52 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 53 | features = None 54 | labels = None 55 | img_files = None 56 | if class_id in self.seenclasses: 57 | if is_trainset: 58 | features = self.data['train_seen']['resnet_features'] 59 | labels = self.data['train_seen']['labels'] 60 | img_files = self.data['train_seen']['img_path'] 61 | else: 62 | features = self.data['test_seen']['resnet_features'] 63 | labels = self.data['test_seen']['labels'] 64 | img_files = self.data['test_seen']['img_path'] 65 | elif class_id in self.unseenclasses: 66 | features = self.data['test_unseen']['resnet_features'] 67 | labels = self.data['test_unseen']['labels'] 68 | img_files = self.data['test_unseen']['img_path'] 69 | else: 70 | raise Exception("Cannot find this class {}".format(class_id)) 71 | 72 | #note that img_files is numpy type !!!!! 73 | 74 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 75 | 76 | features = features[idx_c] 77 | labels = labels[idx_c] 78 | img_files = img_files[idx_c.cpu().numpy()] 79 | 80 | batch_label = labels[:batch_size].to(self.device) 81 | batch_feature = features[:batch_size].to(self.device) 82 | batch_files = img_files[:batch_size] 83 | batch_att = self.att[batch_label].to(self.device) 84 | 85 | return batch_label, batch_feature,batch_files, batch_att 86 | 87 | def next_batch(self, batch_size): 88 | if self.is_balance: 89 | idx = [] 90 | n_samples_class = max(batch_size //self.ntrain_class,1) 91 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 92 | for i_c in sampled_idx_c: 93 | idxs = self.idxs_list[i_c] 94 | idx.append(np.random.choice(idxs,n_samples_class)) 95 | idx = np.concatenate(idx) 96 | idx = torch.from_numpy(idx) 97 | else: 98 | idx = torch.randperm(self.ntrain)[0:batch_size] 99 | 100 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 101 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 102 | batch_att = self.att[batch_label].to(self.device) 103 | return batch_label, batch_feature, batch_att 104 | 105 | def get_idx_classes(self): 106 | n_classes = self.seenclasses.size(0) 107 | self.idxs_list = [] 108 | train_label = self.data['train_seen']['labels'] 109 | for i in range(n_classes): 110 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 111 | idx_c = np.squeeze(idx_c) 112 | self.idxs_list.append(idx_c) 113 | return self.idxs_list 114 | 115 | def read_matdataset(self): 116 | 117 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 118 | print('_____') 119 | print(path) 120 | # tic = time.clock() 121 | hf = h5py.File(path, 'r') 122 | features = np.array(hf.get('feature_map')) 123 | # shape = features.shape 124 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 125 | labels = np.array(hf.get('labels')) 126 | trainval_loc = np.array(hf.get('trainval_loc')) 127 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 128 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 129 | test_seen_loc = np.array(hf.get('test_seen_loc')) 130 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 131 | 132 | if self.is_unsupervised_attr: 133 | print('Unsupervised Attr') 134 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 135 | with open(class_path,'rb') as f: 136 | w2v_class = pickle.load(f) 137 | assert w2v_class.shape == (50,300) 138 | w2v_class = torch.tensor(w2v_class).float() 139 | 140 | U, s, V = torch.svd(w2v_class) 141 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 142 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 143 | 144 | print('shape U:{} V:{}'.format(U.size(),V.size())) 145 | print('s: {}'.format(s)) 146 | 147 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 148 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 149 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 150 | 151 | else: 152 | print('Expert Attr') 153 | att = np.array(hf.get('att')) 154 | 155 | print("threshold at zero attribute with negative value") 156 | att[att<0]=0 157 | 158 | self.att = torch.from_numpy(att).float().to(self.device) 159 | 160 | original_att = np.array(hf.get('original_att')) 161 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 162 | 163 | w2v_att = np.array(hf.get('w2v_att')) 164 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 165 | 166 | self.normalize_att = self.original_att/100 167 | 168 | # print('Finish loading data in ',time.clock()-tic) 169 | 170 | train_feature = features[trainval_loc] 171 | test_seen_feature = features[test_seen_loc] 172 | test_unseen_feature = features[test_unseen_loc] 173 | if self.is_scale: 174 | scaler = preprocessing.MinMaxScaler() 175 | 176 | train_feature = scaler.fit_transform(train_feature) 177 | test_seen_feature = scaler.fit_transform(test_seen_feature) 178 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 179 | 180 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 181 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 182 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 183 | 184 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 185 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 186 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 187 | 188 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 189 | 190 | 191 | 192 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 193 | self.ntrain = train_feature.size()[0] 194 | self.ntrain_class = self.seenclasses.size(0) 195 | self.ntest_class = self.unseenclasses.size(0) 196 | self.train_class = self.seenclasses.clone() 197 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 198 | 199 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 200 | 201 | self.data = {} 202 | self.data['train_seen'] = {} 203 | self.data['train_seen']['resnet_features'] = train_feature 204 | self.data['train_seen']['labels']= train_label 205 | 206 | 207 | self.data['train_unseen'] = {} 208 | self.data['train_unseen']['resnet_features'] = None 209 | self.data['train_unseen']['labels'] = None 210 | 211 | self.data['test_seen'] = {} 212 | self.data['test_seen']['resnet_features'] = test_seen_feature 213 | self.data['test_seen']['labels'] = test_seen_label 214 | 215 | self.data['test_unseen'] = {} 216 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 217 | self.data['test_unseen']['labels'] = test_unseen_label 218 | -------------------------------------------------------------------------------- /core/CUBDataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 4 11:53:09 2019 4 | 5 | @author: badat 6 | """ 7 | import os,sys 8 | #import scipy.io as sio 9 | import torch 10 | import numpy as np 11 | import h5py 12 | import time 13 | import pickle 14 | import pdb 15 | from sklearn import preprocessing 16 | from global_setting import NFS_path 17 | #%% 18 | import scipy.io as sio 19 | import pandas as pd 20 | #%% 21 | #import pdb 22 | #%% 23 | 24 | img_dir = os.path.join(NFS_path,'data/CUB/') 25 | 26 | class CUBDataLoader(): 27 | def __init__(self, data_path, device, is_scale = False,is_unsupervised_attr = False,is_balance=True): 28 | 29 | print(data_path) 30 | sys.path.append(data_path) 31 | 32 | self.data_path = data_path 33 | self.device = device 34 | self.dataset = 'CUB' 35 | print('$'*30) 36 | print(self.dataset) 37 | print('$'*30) 38 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 39 | self.index_in_epoch = 0 40 | self.epochs_completed = 0 41 | self.is_scale = is_scale 42 | self.is_balance = is_balance 43 | if self.is_balance: 44 | print('Balance dataloader') 45 | self.is_unsupervised_attr = is_unsupervised_attr 46 | self.read_matdataset() 47 | self.get_idx_classes() 48 | 49 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 50 | features = None 51 | labels = None 52 | img_files = None 53 | if class_id in self.seenclasses: 54 | if is_trainset: 55 | features = self.data['train_seen']['resnet_features'] 56 | labels = self.data['train_seen']['labels'] 57 | img_files = self.data['train_seen']['img_path'] 58 | else: 59 | features = self.data['test_seen']['resnet_features'] 60 | labels = self.data['test_seen']['labels'] 61 | img_files = self.data['test_seen']['img_path'] 62 | elif class_id in self.unseenclasses: 63 | features = self.data['test_unseen']['resnet_features'] 64 | labels = self.data['test_unseen']['labels'] 65 | img_files = self.data['test_unseen']['img_path'] 66 | else: 67 | raise Exception("Cannot find this class {}".format(class_id)) 68 | 69 | #note that img_files is numpy type !!!!! 70 | 71 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 72 | 73 | features = features[idx_c] 74 | labels = labels[idx_c] 75 | img_files = img_files[idx_c.cpu().numpy()] 76 | 77 | batch_label = labels[:batch_size].to(self.device) 78 | batch_feature = features[:batch_size].to(self.device) 79 | batch_files = img_files[:batch_size] 80 | batch_att = self.att[batch_label].to(self.device) 81 | 82 | return batch_label, batch_feature,batch_files, batch_att 83 | 84 | 85 | def next_batch(self, batch_size): 86 | if self.is_balance: 87 | idx = [] 88 | n_samples_class = max(batch_size //self.ntrain_class,1) 89 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 90 | for i_c in sampled_idx_c: 91 | idxs = self.idxs_list[i_c] 92 | idx.append(np.random.choice(idxs,n_samples_class)) 93 | idx = np.concatenate(idx) 94 | idx = torch.from_numpy(idx) 95 | else: 96 | idx = torch.randperm(self.ntrain)[0:batch_size] 97 | 98 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 99 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 100 | batch_att = self.att[batch_label].to(self.device) 101 | return batch_label, batch_feature, batch_att 102 | 103 | def get_idx_classes(self): 104 | n_classes = self.seenclasses.size(0) 105 | self.idxs_list = [] 106 | train_label = self.data['train_seen']['labels'] 107 | for i in range(n_classes): 108 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 109 | idx_c = np.squeeze(idx_c) 110 | self.idxs_list.append(idx_c) 111 | return self.idxs_list 112 | 113 | 114 | def read_matdataset(self): 115 | 116 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 117 | print('_____') 118 | print(path) 119 | # tic = time.time() 120 | hf = h5py.File(path, 'r') 121 | features = np.array(hf.get('feature_map')) 122 | # shape = features.shape 123 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 124 | # pdb.set_trace() 125 | labels = np.array(hf.get('labels')) 126 | trainval_loc = np.array(hf.get('trainval_loc')) 127 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 128 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 129 | test_seen_loc = np.array(hf.get('test_seen_loc')) 130 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 131 | 132 | if self.is_unsupervised_attr: 133 | print('Unsupervised Attr') 134 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 135 | with open(class_path,'rb') as f: 136 | w2v_class = pickle.load(f) 137 | temp = np.array(hf.get('att')) 138 | print(w2v_class.shape,temp.shape) 139 | # assert w2v_class.shape == temp.shape 140 | w2v_class = torch.tensor(w2v_class).float() 141 | 142 | U, s, V = torch.svd(w2v_class) 143 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 144 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 145 | 146 | print('shape U:{} V:{}'.format(U.size(),V.size())) 147 | print('s: {}'.format(s)) 148 | 149 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 150 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 151 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 152 | 153 | else: 154 | print('Expert Attr') 155 | att = np.array(hf.get('att')) 156 | self.att = torch.from_numpy(att).float().to(self.device) 157 | 158 | original_att = np.array(hf.get('original_att')) 159 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 160 | 161 | w2v_att = np.array(hf.get('w2v_att')) 162 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 163 | 164 | self.normalize_att = self.original_att/100 165 | 166 | # print('Finish loading data in ',time.time()-tic) 167 | 168 | train_feature = features[trainval_loc] 169 | test_seen_feature = features[test_seen_loc] 170 | test_unseen_feature = features[test_unseen_loc] 171 | if self.is_scale: 172 | scaler = preprocessing.MinMaxScaler() 173 | 174 | train_feature = scaler.fit_transform(train_feature) 175 | test_seen_feature = scaler.fit_transform(test_seen_feature) 176 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 177 | 178 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 179 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 180 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 181 | 182 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 183 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 184 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 185 | 186 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 187 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 188 | self.ntrain = train_feature.size()[0] 189 | self.ntrain_class = self.seenclasses.size(0) 190 | self.ntest_class = self.unseenclasses.size(0) 191 | self.train_class = self.seenclasses.clone() 192 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 193 | 194 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 195 | 196 | self.data = {} 197 | self.data['train_seen'] = {} 198 | self.data['train_seen']['resnet_features'] = train_feature 199 | self.data['train_seen']['labels']= train_label 200 | 201 | 202 | self.data['train_unseen'] = {} 203 | self.data['train_unseen']['resnet_features'] = None 204 | self.data['train_unseen']['labels'] = None 205 | 206 | self.data['test_seen'] = {} 207 | self.data['test_seen']['resnet_features'] = test_seen_feature 208 | self.data['test_seen']['labels'] = test_seen_label 209 | 210 | self.data['test_unseen'] = {} 211 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 212 | self.data['test_unseen']['labels'] = test_unseen_label 213 | -------------------------------------------------------------------------------- /core/MSDN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 4 17:39:45 2019 4 | 5 | @author: badat 6 | """ 7 | import tensorflow as tf 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | #%% 13 | 14 | class MSDN(nn.Module): 15 | ##### 16 | # einstein sum notation 17 | # b: Batch size \ f: dim feature=2048 \ v: dim w2v=300 \ r: number of region=49 \ k: number of classes 18 | # i: number of attribute=312 \ h : hidden attention dim 19 | ##### 20 | def __init__(self,dim_f,dim_v, 21 | init_w2v_att,att,normalize_att, 22 | seenclass,unseenclass, 23 | lambda_, 24 | trainable_w2v = False, normalize_V = False, normalize_F = False, is_conservative = False, 25 | prob_prune=0,desired_mass = -1,uniform_att_1 = False,uniform_att_2 = False, is_conv = False, 26 | is_bias = False,bias = 1,non_linear_act=False, 27 | loss_type = 'CE',non_linear_emb = False, 28 | is_sigmoid = False): 29 | super(MSDN, self).__init__() 30 | self.dim_f = dim_f 31 | self.dim_v = dim_v 32 | self.dim_att = att.shape[1] 33 | self.nclass = att.shape[0] 34 | self.hidden = self.dim_att//2 35 | self.init_w2v_att = init_w2v_att 36 | self.non_linear_act = non_linear_act 37 | self.loss_type = loss_type 38 | if is_conv: 39 | r_dim = dim_f//2 40 | self.conv1 = nn.Conv2d(dim_f, r_dim, 2) #[2x2] kernel with same input and output dims 41 | print('***Reduce dim {} -> {}***'.format(self.dim_f,r_dim)) 42 | self.dim_f = r_dim 43 | self.conv1_bn = nn.BatchNorm2d(self.dim_f) 44 | 45 | 46 | if init_w2v_att is None: 47 | self.V = nn.Parameter(nn.init.normal_(torch.empty(self.dim_att,self.dim_v)),requires_grad = True) 48 | else: 49 | self.init_w2v_att = F.normalize(torch.tensor(init_w2v_att)) 50 | self.V = nn.Parameter(self.init_w2v_att.clone(),requires_grad = trainable_w2v) 51 | 52 | self.att = nn.Parameter(F.normalize(torch.tensor(att)),requires_grad = False) 53 | 54 | self.W_1 = nn.Parameter(nn.init.normal_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) #nn.utils.weight_norm(nn.Linear(self.dim_v,self.dim_f,bias=False))# 55 | self.W_2 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) #nn.utils.weight_norm(nn.Linear(self.dim_v,self.dim_f,bias=False))# 56 | ## second layer attenion conditioned on image features 57 | self.W_3 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) 58 | 59 | self.W_1_1 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_f,self.dim_v)),requires_grad = True)#nn.utils.weight_norm(nn.Linear(self.dim_v,self.dim_f,bias=False))# 60 | self.W_2_1 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) 61 | self.W_3_1 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_f,self.dim_v)),requires_grad = True) 62 | 63 | ## Compute the similarity between classes 64 | self.P = torch.mm(self.att,torch.transpose(self.att,1,0)) 65 | assert self.P.size(1)==self.P.size(0) and self.P.size(0)==self.nclass 66 | self.weight_ce = nn.Parameter(torch.eye(self.nclass).float(),requires_grad = False)#nn.Parameter(torch.tensor(weight_ce).float(),requires_grad = False) 67 | 68 | self.normalize_V = normalize_V 69 | self.normalize_F = normalize_F 70 | self.is_conservative = is_conservative 71 | self.is_conv = is_conv 72 | self.is_bias = is_bias 73 | 74 | self.seenclass = seenclass 75 | self.unseenclass = unseenclass 76 | self.normalize_att = normalize_att 77 | 78 | if is_bias: 79 | self.bias = nn.Parameter(torch.tensor(bias),requires_grad = False) 80 | mask_bias = np.ones((1,self.nclass)) 81 | mask_bias[:,self.seenclass.cpu().numpy()] *= -1 82 | self.mask_bias = nn.Parameter(torch.tensor(mask_bias).float(),requires_grad = False) 83 | 84 | if desired_mass == -1: 85 | self.desired_mass = self.unseenclass.size(0)/self.nclass#nn.Parameter(torch.tensor(self.unseenclass.size(0)/self.nclass),requires_grad = False)#nn.Parameter(torch.tensor(0.1),requires_grad = False)# 86 | else: 87 | self.desired_mass = desired_mass#nn.Parameter(torch.tensor(desired_mass),requires_grad = False)#nn.Parameter(torch.tensor(self.unseenclass.size(0)/self.nclass),requires_grad = False)# 88 | self.prob_prune = nn.Parameter(torch.tensor(prob_prune),requires_grad = False) 89 | 90 | self.lambda_ = lambda_ 91 | self.loss_att_func = nn.BCEWithLogitsLoss() 92 | self.log_softmax_func = nn.LogSoftmax(dim=1) 93 | self.uniform_att_1 = uniform_att_1 94 | self.uniform_att_2 = uniform_att_2 95 | 96 | self.non_linear_emb = non_linear_emb 97 | 98 | 99 | print('-'*30) 100 | print('Configuration') 101 | 102 | print('loss_type {}'.format(loss_type)) 103 | 104 | if self.is_conv: 105 | print('Learn CONV layer correct') 106 | 107 | if self.normalize_V: 108 | print('normalize V') 109 | else: 110 | print('no constraint V') 111 | 112 | if self.normalize_F: 113 | print('normalize F') 114 | else: 115 | print('no constraint F') 116 | 117 | if self.is_conservative: 118 | print('training to exclude unseen class [seen upperbound]') 119 | if init_w2v_att is None: 120 | print('Learning word2vec from scratch with dim {}'.format(self.V.size())) 121 | else: 122 | print('Init word2vec') 123 | 124 | if self.non_linear_act: 125 | print('Non-linear relu model') 126 | else: 127 | print('Linear model') 128 | 129 | print('loss_att {}'.format(self.loss_att_func)) 130 | print('Bilinear attention module') 131 | print('*'*30) 132 | print('Measure w2v deviation') 133 | if self.uniform_att_1: 134 | print('WARNING: UNIFORM ATTENTION LEVEL 1') 135 | if self.uniform_att_2: 136 | print('WARNING: UNIFORM ATTENTION LEVEL 2') 137 | print('Compute Pruning loss {}'.format(self.prob_prune)) 138 | if self.is_bias: 139 | print('Add one smoothing') 140 | print('Second layer attenion conditioned on image features') 141 | print('-'*30) 142 | 143 | if self.non_linear_emb: 144 | print('non_linear embedding') 145 | self.emb_func = torch.nn.Sequential( 146 | torch.nn.Linear(self.dim_att, self.dim_att//2), 147 | torch.nn.ReLU(), 148 | torch.nn.Linear(self.dim_att//2, 1), 149 | ) 150 | 151 | self.is_sigmoid = is_sigmoid 152 | if self.is_sigmoid: 153 | print("Sigmoid on attr score!!!") 154 | else: 155 | print("No sigmoid on attr score") 156 | 157 | 158 | def compute_loss_rank(self,in_package): 159 | # this is pairwise ranking loss 160 | batch_label = in_package['batch_label'] 161 | S_pp = in_package['S_pp'] 162 | 163 | batch_label_idx = torch.argmax(batch_label,dim = 1) 164 | 165 | s_c = torch.gather(S_pp,1,batch_label_idx.view(-1,1)) 166 | if self.is_conservative: 167 | S_seen = S_pp 168 | else: 169 | S_seen = S_pp[:,self.seenclass] 170 | assert S_seen.size(1) == len(self.seenclass) 171 | 172 | margin = 1-(s_c-S_seen) 173 | loss_rank = torch.max(margin,torch.zeros_like(margin)) 174 | loss_rank = torch.mean(loss_rank) 175 | return loss_rank 176 | 177 | def compute_loss_Self_Calibrate(self,in_package): 178 | S_pp = in_package['S_pp'] 179 | Prob_all = F.softmax(S_pp,dim=-1) 180 | Prob_unseen = Prob_all[:,self.unseenclass] 181 | assert Prob_unseen.size(1) == len(self.unseenclass) 182 | mass_unseen = torch.sum(Prob_unseen,dim=1) 183 | loss_pmp = -torch.log(torch.mean(mass_unseen)) 184 | return loss_pmp 185 | 186 | def compute_V(self): 187 | if self.normalize_V: 188 | V_n = F.normalize(self.V) 189 | else: 190 | V_n = self.V 191 | return V_n 192 | 193 | def compute_aug_cross_entropy(self,in_package): 194 | batch_label = in_package['batch_label'] 195 | S_pp = in_package['S_pp'] 196 | 197 | Labels = batch_label 198 | 199 | if self.is_bias: 200 | S_pp = S_pp - self.vec_bias # remove the margin +1/-1 from prediction scores 201 | 202 | if not self.is_conservative: 203 | S_pp = S_pp[:,self.seenclass] 204 | Labels = Labels[:,self.seenclass] 205 | assert S_pp.size(1) == len(self.seenclass) 206 | 207 | Prob = self.log_softmax_func(S_pp) 208 | 209 | loss = -torch.einsum('bk,bk->b',Prob,Labels) 210 | loss = torch.mean(loss) 211 | return loss 212 | 213 | def compute_loss(self,in_package): 214 | 215 | if len(in_package['batch_label'].size()) == 1: 216 | in_package['batch_label'] = self.weight_ce[in_package['batch_label']] 217 | 218 | ## loss rank 219 | if self.loss_type == 'CE': 220 | loss_CE = self.compute_aug_cross_entropy(in_package) 221 | elif self.loss_type == 'rank': 222 | loss_CE = self.compute_loss_rank(in_package) 223 | else: 224 | raise Exception('Unknown loss type') 225 | 226 | ## loss self-calibration 227 | loss_cal = self.compute_loss_Self_Calibrate(in_package) 228 | 229 | ## total loss 230 | loss = loss_CE + self.lambda_*loss_cal 231 | 232 | out_package = {'loss':loss,'loss_CE':loss_CE, 233 | 'loss_cal':loss_cal} 234 | 235 | return out_package 236 | 237 | def compute_contrastive_loss(self, in_package1, in_package2): 238 | S_pp1,S_pp2=in_package1['S_pp'], in_package2['S_pp'] 239 | wt = (S_pp1-S_pp2).pow(2) 240 | wt /= wt.sum(1).sqrt().unsqueeze(1).expand(wt.size(0),wt.size(1)) 241 | loss = wt * (S_pp1-S_pp2).abs() 242 | loss= (loss.sum()/loss.size(0)) 243 | 244 | #JSD 245 | KLDivLoss = nn.KLDivLoss(reduction='batchmean') 246 | p_output = F.softmax(S_pp1) 247 | q_output = F.softmax(S_pp2) 248 | log_mean_output = ((p_output + q_output )/2).log() 249 | loss+=(KLDivLoss(log_mean_output, q_output) + KLDivLoss(log_mean_output, p_output))/2 250 | 251 | return loss 252 | 253 | 254 | 255 | 256 | 257 | 258 | def get_global_feature(self, x): 259 | 260 | N, C, W, H = x.shape 261 | global_feat = F.avg_pool2d(x, kernel_size=(W, H)) 262 | global_feat = global_feat.view(N, C) 263 | 264 | return global_feat 265 | 266 | 267 | def forward(self,Fs): 268 | 269 | if self.is_conv: 270 | Fs = self.conv1(Fs) 271 | Fs = self.conv1_bn(Fs) 272 | Fs = F.relu(Fs) 273 | 274 | shape = Fs.shape 275 | 276 | visualf_ori = self.get_global_feature(Fs) 277 | 278 | ##########################base-model################################ 279 | # global_feature = self.get_global_feature(Fs) 280 | # temp_norm = torch.norm(self.att, p=2, dim=1).unsqueeze(1).expand_as(self.att) 281 | # seen_att_normalized = self.att.div(temp_norm + 1e-5) 282 | 283 | # S_pp = torch.einsum('bi,ki->bk', global_feature, seen_att_normalized) 284 | # package0 = {'S_pp':S_pp} 285 | 286 | # return package0 287 | ##########################base-model################################ 288 | 289 | Fs = Fs.reshape(shape[0],shape[1],shape[2]*shape[3]) # batch x 2048 x 49 290 | 291 | R = Fs.size(2) # 49 292 | B = Fs.size(0) # batch 293 | V_n = self.compute_V() # 312x300 294 | 295 | if self.normalize_F and not self.is_conv: # true 296 | Fs = F.normalize(Fs,dim = 1) 297 | 298 | 299 | ##########################Text-Image################################ 300 | 301 | ## Compute attribute score on each image region 302 | S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1,Fs) # batchx312x49 303 | 304 | if self.is_sigmoid: 305 | S=torch.sigmoid(S) 306 | 307 | ## Ablation setting 308 | A_b = Fs.new_full((B,self.dim_att,R),1/R) 309 | A_b_p = self.att.new_full((B,self.dim_att),fill_value = 1) 310 | S_b_p = torch.einsum('bir,bir->bi',A_b,S) 311 | S_b_pp = torch.einsum('ki,bi,bi->bk',self.att,A_b_p,S_b_p) 312 | ## 313 | 314 | ## compute Dense Attention 315 | A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2,Fs) # batchx312x49 316 | A = F.softmax(A,dim = -1) # compute an attention map for each attribute 317 | 318 | F_p = torch.einsum('bir,bfr->bif',A,Fs) # compute attribute-based features 319 | if self.uniform_att_1: # false 320 | S_p = torch.einsum('bir,bir->bi',A_b,S) # ablation: compute attribute score using average image region features 321 | else: 322 | S_p = torch.einsum('bir,bir->bi',A,S) # compute attribute scores from attribute attention maps 323 | 324 | if self.non_linear_act: # false 325 | S_p = F.relu(S_p) 326 | ## 327 | 328 | ## compute Attention over Attribute 329 | A_p = torch.einsum('iv,vf,bif->bi',V_n,self.W_3,F_p) #eq. 6 330 | A_p = torch.sigmoid(A_p) 331 | ## 332 | 333 | if self.uniform_att_2: # true 334 | S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_b_p,S_p) # ablation: setting attention over attribute to 1 335 | else: 336 | # S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_p,S_p) # compute the final prediction as the product of semantic scores, attribute scores, and attention over attribute scores 337 | S_pp = torch.einsum('ki,bi->bik',self.att,S_p) 338 | 339 | S_attr = torch.einsum('bi,bi->bi',A_b_p,S_p) 340 | 341 | if self.non_linear_emb: 342 | S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik] 343 | S_pp = self.emb_func(S_pp) #[bk1] <== [bki] 344 | S_pp = S_pp[:,:,0] #[bk] <== [bk1] 345 | else: 346 | S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik] 347 | 348 | # augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes 349 | if self.is_bias: 350 | self.vec_bias = self.mask_bias*self.bias 351 | S_pp = S_pp + self.vec_bias 352 | 353 | ## spatial attention supervision 354 | Pred_att = torch.einsum('iv,vf,bif->bi',V_n,self.W_1,F_p) 355 | package1 = {'S_pp':S_pp,'Pred_att':Pred_att,'S_b_pp':S_b_pp,'A_p':A_p,'A':A,'S_attr':S_attr,'visualf_ori':visualf_ori,'a_v':F_p} 356 | 357 | ##########################Image-Text################################ 358 | 359 | ## Compute attribute score on each image region 360 | 361 | S = torch.einsum('bfr,fv,iv->bri',Fs,self.W_1_1,V_n) # batchx49x312 362 | # S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1_1,Fs) 363 | if self.is_sigmoid: 364 | S=torch.sigmoid(S) 365 | 366 | ## Ablation setting 367 | # A_b = Fs.new_full((B,self.dim_att,R),1/R) 368 | # A_b_p = self.att.new_full((B,self.dim_att),fill_value = 1) 369 | # S_b_p = torch.einsum('bir,bir->bi',A_b,S) 370 | # S_b_pp = torch.einsum('ki,bi,bi->bk',self.att,A_b_p,S_b_p) 371 | ## 372 | 373 | ## compute Dense Attention 374 | # A = torch.einsum('bfr,fv,iv->bri',Fs,self.W_1_1,V_n) # batchx49x312 375 | A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2_1,Fs) 376 | A = F.softmax(A,dim = 1) # compute an attention map for each attribute 377 | 378 | v_a = torch.einsum('bir,iv->brv',A,V_n) # compute attribute-based features 379 | 380 | S_p = torch.einsum('bir,bri->bi',A,S) # compute attribute scores from attribute attention maps 381 | 382 | if self.non_linear_act: # false 383 | S_p = F.relu(S_p) 384 | ## 385 | 386 | ## compute Attention over Attribute 387 | # A_p = torch.einsum('bfr,fv,brv->br',Fs,self.W_3_1,F_p) #eq. 6 388 | # A_p = torch.sigmoid(A_p) 389 | ## 390 | 391 | 392 | # S_pp = torch.einsum('ki,br,br->brk',self.att,A_p,S_p) # compute the final prediction as the product of semantic scores, attribute scores, and attention over attribute scores 393 | S_pp = torch.einsum('ki,bi->bik',self.att,S_p) # compute the final prediction as the product of semantic scores, attribute scores, and attention over attribute scores 394 | 395 | S_attr = 0#torch.einsum('bi,bi->bi',A_b_p,S_p) 396 | 397 | if self.non_linear_emb: 398 | S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik] 399 | S_pp = self.emb_func(S_pp) #[bk1] <== [bki] 400 | S_pp = S_pp[:,:,0] #[bk] <== [bk1] 401 | else: 402 | S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik] 403 | 404 | # augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes 405 | if self.is_bias: 406 | self.vec_bias = self.mask_bias*self.bias 407 | S_pp = S_pp + self.vec_bias 408 | 409 | ## spatial attention supervision 410 | Pred_att = 0#torch.einsum('brv,fv,iv->br',F_p,self.W_1_1,V_n) 411 | package2 = {'S_pp':S_pp,'v_a':v_a} 412 | 413 | return package1, package2 414 | 415 | # %% 416 | # -------------------------------------------------------------------------------- /core/SUNDataLoader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Aug 1 12:11:40 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | #import scipy.io as sio 11 | import torch 12 | import numpy as np 13 | import h5py 14 | import time 15 | import pickle 16 | from sklearn import preprocessing 17 | #%% 18 | import pdb 19 | #%% 20 | 21 | class SUNDataLoader(): 22 | def __init__(self, data_path, device, is_scale = False, is_unsupervised_attr = False,is_balance=True): 23 | 24 | print(data_path) 25 | sys.path.append(data_path) 26 | 27 | self.data_path = data_path 28 | self.device = device 29 | self.dataset = 'SUN' 30 | print('$'*30) 31 | print(self.dataset) 32 | print('$'*30) 33 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 34 | self.index_in_epoch = 0 35 | self.epochs_completed = 0 36 | self.is_scale = is_scale 37 | self.is_balance = is_balance 38 | if self.is_balance: 39 | print('Balance dataloader') 40 | self.is_unsupervised_attr = is_unsupervised_attr 41 | self.read_matdataset() 42 | self.get_idx_classes() 43 | self.I = torch.eye(self.allclasses.size(0)).to(device) 44 | 45 | def next_batch(self, batch_size): 46 | if self.is_balance: 47 | idx = [] 48 | n_samples_class = max(batch_size //self.ntrain_class,1) 49 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 50 | for i_c in sampled_idx_c: 51 | idxs = self.idxs_list[i_c] 52 | idx.append(np.random.choice(idxs,n_samples_class)) 53 | idx = np.concatenate(idx) 54 | idx = torch.from_numpy(idx) 55 | else: 56 | idx = torch.randperm(self.ntrain)[0:batch_size] 57 | 58 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 59 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 60 | batch_att = self.att[batch_label].to(self.device) 61 | return batch_label, batch_feature, batch_att 62 | 63 | def get_idx_classes(self): 64 | n_classes = self.seenclasses.size(0) 65 | self.idxs_list = [] 66 | train_label = self.data['train_seen']['labels'] 67 | for i in range(n_classes): 68 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 69 | idx_c = np.squeeze(idx_c) 70 | self.idxs_list.append(idx_c) 71 | return self.idxs_list 72 | 73 | # def next_batch_mix_up(self,batch_size): 74 | # Y1,S1,_=self.next_batch(batch_size) 75 | # Y2,S2,_=self.next_batch(batch_size) 76 | # S,Y=mix_up(S1,S2,Y1,Y2) 77 | # return Y,S,None 78 | 79 | def read_matdataset(self): 80 | 81 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 82 | print('_____') 83 | print(path) 84 | # tic = time.time() 85 | hf = h5py.File(path, 'r') 86 | features = np.array(hf.get('feature_map')) 87 | # shape = features.shape 88 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 89 | labels = np.array(hf.get('labels')) 90 | trainval_loc = np.array(hf.get('trainval_loc')) 91 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 92 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 93 | test_seen_loc = np.array(hf.get('test_seen_loc')) 94 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 95 | 96 | if self.is_unsupervised_attr: 97 | print('Unsupervised Attr') 98 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 99 | with open(class_path,'rb') as f: 100 | w2v_class = pickle.load(f) 101 | assert w2v_class.shape == (50,300) 102 | w2v_class = torch.tensor(w2v_class).float() 103 | 104 | U, s, V = torch.svd(w2v_class) 105 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 106 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 107 | 108 | print('shape U:{} V:{}'.format(U.size(),V.size())) 109 | print('s: {}'.format(s)) 110 | 111 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 112 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 113 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 114 | 115 | else: 116 | print('Expert Attr') 117 | att = np.array(hf.get('att')) 118 | self.att = torch.from_numpy(att).float().to(self.device) 119 | 120 | original_att = np.array(hf.get('original_att')) 121 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 122 | 123 | w2v_att = np.array(hf.get('w2v_att')) 124 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 125 | 126 | self.normalize_att = self.original_att/100 127 | 128 | # print('Finish loading data in ',time.time()-tic) 129 | 130 | train_feature = features[trainval_loc] 131 | test_seen_feature = features[test_seen_loc] 132 | test_unseen_feature = features[test_unseen_loc] 133 | if self.is_scale: 134 | scaler = preprocessing.MinMaxScaler() 135 | 136 | train_feature = scaler.fit_transform(train_feature) 137 | test_seen_feature = scaler.fit_transform(test_seen_feature) 138 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 139 | 140 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 141 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 142 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 143 | 144 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 145 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 146 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 147 | 148 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 149 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 150 | self.ntrain = train_feature.size()[0] 151 | self.ntrain_class = self.seenclasses.size(0) 152 | self.ntest_class = self.unseenclasses.size(0) 153 | self.train_class = self.seenclasses.clone() 154 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 155 | 156 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 157 | 158 | self.data = {} 159 | self.data['train_seen'] = {} 160 | self.data['train_seen']['resnet_features'] = train_feature 161 | self.data['train_seen']['labels']= train_label 162 | 163 | 164 | self.data['train_unseen'] = {} 165 | self.data['train_unseen']['resnet_features'] = None 166 | self.data['train_unseen']['labels'] = None 167 | 168 | self.data['test_seen'] = {} 169 | self.data['test_seen']['resnet_features'] = test_seen_feature 170 | self.data['test_seen']['labels'] = test_seen_label 171 | 172 | self.data['test_unseen'] = {} 173 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 174 | self.data['test_unseen']['labels'] = test_unseen_label 175 | -------------------------------------------------------------------------------- /core/helper_MSDN_AWA2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | #%% visualization package 5 | from scipy import ndimage 6 | from torchvision import transforms 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | import skimage.transform 10 | import torch.nn.functional as F 11 | #%% 12 | import pandas as pd 13 | #%% 14 | import pdb 15 | #%% 16 | 17 | def mix_up(S1,S2,Y1,Y2): # S: bdwh Y: bk 18 | device = S1.device 19 | n = S1.size(0) 20 | m = torch.empty(n).uniform_().to(device) 21 | S = torch.einsum('bdwh,b-> bdwh',S1,m) + torch.einsum('bdwh,b-> bdwh',S2,1-m) 22 | Y = torch.einsum('bk,b-> bk',Y1,m) + torch.einsum('bk,b-> bk',Y2,1-m) 23 | return S,Y 24 | 25 | def val_gzsl(test_X, test_label, target_classes,in_package,bias = 0): 26 | 27 | batch_size = in_package['batch_size'] 28 | model = in_package['model'] 29 | device = in_package['device'] 30 | with torch.no_grad(): 31 | start = 0 32 | ntest = test_X.size()[0] 33 | predicted_label = torch.LongTensor(test_label.size()) 34 | for i in range(0, ntest, batch_size): 35 | 36 | end = min(ntest, start+batch_size) 37 | 38 | input = test_X[start:end].to(device) 39 | 40 | out_package1, out_package2= model(input) 41 | 42 | # if type(output) == tuple: # if model return multiple output, take the first one 43 | # output = output[0] 44 | #output = out_package1['S_pp'] 45 | output = out_package1['S_pp'] 46 | output[:,target_classes] = output[:,target_classes]+bias 47 | predicted_label[start:end] = torch.argmax(output.data, 1) 48 | 49 | start = end 50 | 51 | acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package) 52 | return acc 53 | 54 | def map_label(label, classes): 55 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 56 | for i in range(classes.size(0)): 57 | mapped_label[label==classes[i]] = i 58 | 59 | return mapped_label 60 | 61 | def val_zs_gzsl(test_X, test_label, unseen_classes,in_package,bias = 0): 62 | batch_size = in_package['batch_size'] 63 | model = in_package['model'] 64 | device = in_package['device'] 65 | with torch.no_grad(): 66 | start = 0 67 | ntest = test_X.size()[0] 68 | predicted_label_gzsl = torch.LongTensor(test_label.size()) 69 | predicted_label_zsl = torch.LongTensor(test_label.size()) 70 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 71 | for i in range(0, ntest, batch_size): 72 | 73 | end = min(ntest, start+batch_size) 74 | 75 | input = test_X[start:end].to(device) 76 | 77 | out_package1,out_package2 = model(input) 78 | 79 | # if type(output) == tuple: # if model return multiple output, take the first one 80 | # output = output[0] 81 | # 82 | #output = out_package1['S_pp'] 83 | output = out_package1['S_pp'] 84 | 85 | output_t = output.clone() 86 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 87 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 88 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 89 | 90 | output[:,unseen_classes] = output[:,unseen_classes]+bias 91 | predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 92 | 93 | 94 | start = end 95 | acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package) 96 | acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package) 97 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t, unseen_classes.size(0)) 98 | 99 | # assert np.abs(acc_zs - acc_zs_t) < 0.001 100 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 101 | return acc_gzsl,acc_zs_t 102 | 103 | def compute_per_class_acc(test_label, predicted_label, nclass): 104 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 105 | for i in range(nclass): 106 | idx = (test_label == i) 107 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 108 | return acc_per_class.mean().item() 109 | 110 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 111 | 112 | device = in_package['device'] 113 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 114 | 115 | predicted_label = predicted_label.to(device) 116 | 117 | for i in range(target_classes.size()[0]): 118 | 119 | is_class = test_label == target_classes[i] 120 | 121 | per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float()) 122 | # pdb.set_trace() 123 | return per_class_accuracies.mean().item() 124 | 125 | def eval_zs_gzsl(dataloader,model,device,bias_seen,bias_unseen): 126 | model.eval() 127 | # print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 128 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 129 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 130 | 131 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 132 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 133 | 134 | seenclasses = dataloader.seenclasses 135 | unseenclasses = dataloader.unseenclasses 136 | 137 | batch_size = 100 138 | 139 | in_package = {'model':model,'device':device, 'batch_size':batch_size} 140 | 141 | with torch.no_grad(): 142 | acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen) 143 | acc_novel,acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen) 144 | 145 | if (acc_seen+acc_novel)>0: 146 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 147 | else: 148 | H = 0 149 | 150 | return acc_seen, acc_novel, H, acc_zs 151 | 152 | def get_heatmap(dataloader,model,device): 153 | model.eval() 154 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 155 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 156 | 157 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 158 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 159 | 160 | seenclasses = dataloader.seenclasses 161 | unseenclasses = dataloader.unseenclasses 162 | 163 | eval_size = 100 164 | n_classes = model.nclass 165 | n_atts = model.dim_att 166 | 167 | heatmap_seen = torch.zeros((n_classes,n_atts)) 168 | heatmap_unseen = torch.zeros((n_classes,n_atts)) 169 | 170 | with torch.no_grad(): 171 | for c in seenclasses: 172 | idx_c = torch.squeeze(torch.nonzero(test_seen_label == c))[:eval_size] 173 | 174 | batch_c_samples = test_seen_feature[idx_c].to(device) 175 | out_package = model(batch_c_samples) 176 | A_p = out_package['A_p'] 177 | heatmap_seen[c] += torch.mean(A_p,dim=0).cpu() 178 | 179 | for c in unseenclasses: 180 | idx_c = torch.squeeze(torch.nonzero(test_unseen_label == c))[:eval_size] 181 | 182 | batch_c_samples = test_unseen_feature[idx_c].to(device) 183 | out_package = model(batch_c_samples) 184 | A_p = out_package['A_p'] 185 | heatmap_unseen[c] += torch.mean(A_p,dim=0).cpu() 186 | 187 | return heatmap_seen.cpu().numpy(),heatmap_unseen.cpu().numpy() 188 | 189 | def val_gzsl_k(k,test_X, test_label, target_classes,in_package,bias = 0,is_detect=False): 190 | batch_size = in_package['batch_size'] 191 | model = in_package['model'] 192 | device = in_package['device'] 193 | n_classes = in_package["num_class"] 194 | 195 | with torch.no_grad(): 196 | start = 0 197 | ntest = test_X.size()[0] 198 | test_label = F.one_hot(test_label, num_classes=n_classes) 199 | predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device) 200 | for i in range(0, ntest, batch_size): 201 | 202 | end = min(ntest, start+batch_size) 203 | 204 | input = test_X[start:end].to(device) 205 | 206 | out_package1, out_package2= model(input) 207 | 208 | # if type(output) == tuple: # if model return multiple output, take the first one 209 | # output = output[0] 210 | # 211 | #output = out_package1['S_pp'] 212 | output = out_package1['S_pp'] 213 | output[:,target_classes] = output[:,target_classes]+bias 214 | # predicted_label[start:end] = torch.argmax(output.data, 1) 215 | _,idx_k = torch.topk(output,k,dim=1) 216 | if is_detect: 217 | assert k == 1 218 | detection_mask=in_package["detection_mask"] 219 | predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)] 220 | else: 221 | predicted_label[start:end] = predicted_label[start:end].scatter_(1,idx_k,1) 222 | start = end 223 | 224 | acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package) 225 | return acc 226 | 227 | def val_zs_gzsl_k(k,test_X, test_label, unseen_classes,in_package,bias = 0,is_detect=False): 228 | batch_size = in_package['batch_size'] 229 | model = in_package['model'] 230 | device = in_package['device'] 231 | n_classes = in_package["num_class"] 232 | with torch.no_grad(): 233 | start = 0 234 | ntest = test_X.size()[0] 235 | 236 | test_label_gzsl = F.one_hot(test_label, num_classes=n_classes) 237 | predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device) 238 | 239 | predicted_label_zsl = torch.LongTensor(test_label.size()) 240 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 241 | for i in range(0, ntest, batch_size): 242 | 243 | end = min(ntest, start+batch_size) 244 | 245 | input = test_X[start:end].to(device) 246 | 247 | out_package1,out_package2 = model(input) 248 | 249 | # if type(output) == tuple: # if model return multiple output, take the first one 250 | # output = output[0] 251 | # 252 | #output = out_package1['S_pp'] 253 | output = out_package1['S_pp'] 254 | output_t = output.clone() 255 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 256 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 257 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 258 | 259 | output[:,unseen_classes] = output[:,unseen_classes]+bias 260 | # predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 261 | _,idx_k = torch.topk(output,k,dim=1) 262 | if is_detect: 263 | assert k == 1 264 | detection_mask=in_package["detection_mask"] 265 | predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)] 266 | else: 267 | predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1,idx_k,1) 268 | 269 | start = end 270 | 271 | acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package) 272 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 273 | return acc_gzsl,-1 274 | 275 | def compute_per_class_acc_k(test_label, predicted_label, nclass): 276 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 277 | for i in range(nclass): 278 | idx = (test_label == i) 279 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 280 | return acc_per_class.mean().item() 281 | 282 | def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package): 283 | device = in_package['device'] 284 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 285 | 286 | predicted_label = predicted_label.to(device) 287 | 288 | hit = test_label*predicted_label 289 | for i in range(target_classes.size()[0]): 290 | 291 | # is_class = test_label == target_classes[i] 292 | target = target_classes[i] 293 | n_pos = torch.sum(hit[:,target]) 294 | n_gt = torch.sum(test_label[:,target]) 295 | per_class_accuracies[i] = torch.div(n_pos.float(),n_gt.float()) 296 | #pdb.set_trace() 297 | return per_class_accuracies.mean().item() 298 | 299 | def eval_zs_gzsl_k(k,dataloader,model,device,bias_seen,bias_unseen,is_detect=False): 300 | model.eval() 301 | print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 302 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 303 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 304 | 305 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 306 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 307 | 308 | seenclasses = dataloader.seenclasses 309 | unseenclasses = dataloader.unseenclasses 310 | 311 | batch_size = 100 312 | n_classes = dataloader.ntrain_class+dataloader.ntest_class 313 | in_package = {'model':model,'device':device, 'batch_size':batch_size,'num_class':n_classes} 314 | 315 | if is_detect: 316 | print("Measure novelty detection k: {}".format(k)) 317 | 318 | detection_mask = torch.zeros((n_classes,n_classes)).long().to(dataloader.device) 319 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 320 | detect_label[seenclasses]=1 321 | detection_mask[seenclasses,:] = detect_label 322 | 323 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 324 | detect_label[unseenclasses]=1 325 | detection_mask[unseenclasses,:]=detect_label 326 | in_package["detection_mask"]=detection_mask 327 | 328 | with torch.no_grad(): 329 | acc_seen = val_gzsl_k(k,test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen,is_detect=is_detect) 330 | acc_novel,acc_zs = val_zs_gzsl_k(k,test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen,is_detect=is_detect) 331 | 332 | if (acc_seen+acc_novel)>0: 333 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 334 | else: 335 | H = 0 336 | 337 | return acc_seen, acc_novel, H, acc_zs 338 | 339 | def compute_entropy(V): 340 | eps = 1e-7 341 | mass = torch.sum(V,dim = 1, keepdim = True) 342 | att_n = torch.div(V,mass) 343 | e = att_n * torch.log(att_n+eps) 344 | e = -1.0 * torch.sum(e,dim=1) 345 | # e = torch.mean(e) 346 | return e 347 | 348 | def get_lr(optimizer): 349 | lr = [] 350 | for param_group in optimizer.param_groups: 351 | lr.append(param_group['lr']) 352 | return lr 353 | 354 | input_size = 224 355 | data_transforms = transforms.Compose([ 356 | transforms.Resize(input_size), 357 | transforms.CenterCrop(input_size), 358 | transforms.ToTensor() 359 | ]) 360 | 361 | def visualize_attention(img_ids,alphas_1,alphas_2,S,n_top_attr,attr_name,attr,save_path=None,is_top=True): #alphas_1: [bir] alphas_2: [bi] 362 | n = img_ids.shape[0] 363 | image_size = 14*16 #one side of the img 364 | assert alphas_1.shape[1] == alphas_2.shape[1] == len(attr_name) 365 | r = alphas_1.shape[2] 366 | h = w = int(np.sqrt(r)) 367 | for i in range(n): 368 | fig=plt.figure(i,figsize=(20, 10)) 369 | file_path=img_ids[i]#.decode('utf-8') 370 | img_name = file_path.split("/")[-1] 371 | # file_path = img_path+str_id+'.jpg' 372 | alpha_1 = alphas_1[i] #[ir] 373 | alpha_2 = alphas_2[i] #[i] 374 | score = S[i] 375 | # Plot original image 376 | image = Image.open(file_path) 377 | if image.mode == 'L': 378 | image=image.convert('RGB') 379 | image = data_transforms(image) 380 | image = image.permute(1,2,0) #[224,244,3] <== [3,224,224] 381 | ax = plt.subplot(4, 5, 1) 382 | plt.imshow(image) 383 | ax.set_title(img_name,{'fontsize': 10}) 384 | # plt.axis('off') 385 | 386 | if is_top: 387 | idxs_top=np.argsort(-alpha_2)[:n_top_attr] 388 | else: 389 | idxs_top=np.argsort(alpha_2)[:n_top_attr] 390 | 391 | #pdb.set_trace() 392 | for idx_ctxt,idx_attr in enumerate(idxs_top): 393 | ax=plt.subplot(4, 5, idx_ctxt+2) 394 | plt.imshow(image) 395 | alp_curr = alpha_1[idx_attr,:].reshape(7,7) 396 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=image_size/h, sigma=10,multichannel=False) 397 | plt.imshow(alp_img, alpha=0.7) 398 | ax.set_title("{}\n{}\n{}-{}".format(attr_name[idx_attr],alpha_2[idx_attr],score[idx_attr],attr[idx_attr]),{'fontsize': 10}) 399 | # plt.axis('off') 400 | fig.tight_layout() 401 | if save_path is not None: 402 | plt.savefig(save_path+img_name,dpi=500) 403 | plt.close() 404 | 405 | class Logger: 406 | def __init__(self,filename,cols,is_save=True): 407 | self.df = pd.DataFrame() 408 | self.cols = cols 409 | self.filename=filename 410 | self.is_save=is_save 411 | def add(self,values): 412 | self.df=self.df.append(pd.DataFrame([values],columns=self.cols),ignore_index=True) 413 | def save(self): 414 | if self.is_save: 415 | self.df.to_csv(self.filename) 416 | def get_max(self,col): 417 | return np.max(self.df[col]) 418 | 419 | def is_max(self,col): 420 | return self.df[col].iloc[-1] >= np.max(self.df[col]) 421 | 422 | def get_attr_entropy(att): #the lower the more discriminative it is 423 | eps = 1e-8 424 | mass=np.sum(att,axis = 0,keepdims=True) 425 | att_n = np.divide(att,mass+eps) 426 | entropy = np.sum(-att_n*np.log(att_n+eps),axis=0) 427 | assert len(entropy.shape)==1 428 | return entropy -------------------------------------------------------------------------------- /core/helper_MSDN_CUB.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | #%% visualization package 5 | from scipy import ndimage 6 | from torchvision import transforms 7 | from PIL import Image 8 | import matplotlib.pyplot as plt 9 | import skimage.transform 10 | import torch.nn.functional as F 11 | #%% 12 | import pandas as pd 13 | #%% 14 | import pdb 15 | #%% 16 | 17 | def mix_up(S1,S2,Y1,Y2): # S: bdwh Y: bk 18 | device = S1.device 19 | n = S1.size(0) 20 | m = torch.empty(n).uniform_().to(device) 21 | S = torch.einsum('bdwh,b-> bdwh',S1,m) + torch.einsum('bdwh,b-> bdwh',S2,1-m) 22 | Y = torch.einsum('bk,b-> bk',Y1,m) + torch.einsum('bk,b-> bk',Y2,1-m) 23 | return S,Y 24 | 25 | def val_gzsl(test_X, test_label, target_classes,in_package,bias = 0): 26 | 27 | batch_size = in_package['batch_size'] 28 | model = in_package['model'] 29 | device = in_package['device'] 30 | with torch.no_grad(): 31 | start = 0 32 | ntest = test_X.size()[0] 33 | predicted_label = torch.LongTensor(test_label.size()) 34 | for i in range(0, ntest, batch_size): 35 | 36 | end = min(ntest, start+batch_size) 37 | 38 | input = test_X[start:end].to(device) 39 | 40 | out_package1, out_package2= model(input) 41 | 42 | # if type(output) == tuple: # if model return multiple output, take the first one 43 | # output = output[0] 44 | #output = out_package1['S_pp'] 45 | output = 0.9*out_package1['S_pp']+0.1*out_package2['S_pp'] 46 | output[:,target_classes] = output[:,target_classes]+bias 47 | predicted_label[start:end] = torch.argmax(output.data, 1) 48 | 49 | start = end 50 | 51 | acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package) 52 | return acc 53 | 54 | def map_label(label, classes): 55 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 56 | for i in range(classes.size(0)): 57 | mapped_label[label==classes[i]] = i 58 | 59 | return mapped_label 60 | 61 | def val_zs_gzsl(test_X, test_label, unseen_classes,in_package,bias = 0): 62 | batch_size = in_package['batch_size'] 63 | model = in_package['model'] 64 | device = in_package['device'] 65 | with torch.no_grad(): 66 | start = 0 67 | ntest = test_X.size()[0] 68 | predicted_label_gzsl = torch.LongTensor(test_label.size()) 69 | predicted_label_zsl = torch.LongTensor(test_label.size()) 70 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 71 | for i in range(0, ntest, batch_size): 72 | 73 | end = min(ntest, start+batch_size) 74 | 75 | input = test_X[start:end].to(device) 76 | 77 | out_package1,out_package2 = model(input) 78 | 79 | # if type(output) == tuple: # if model return multiple output, take the first one 80 | # output = output[0] 81 | # 82 | #output = out_package1['S_pp'] 83 | output = 0.9*out_package1['S_pp']+0.1*out_package2['S_pp'] 84 | 85 | output_t = output.clone() 86 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 87 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 88 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 89 | 90 | output[:,unseen_classes] = output[:,unseen_classes]+bias 91 | predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 92 | 93 | 94 | start = end 95 | acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package) 96 | acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package) 97 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t, unseen_classes.size(0)) 98 | 99 | # assert np.abs(acc_zs - acc_zs_t) < 0.001 100 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 101 | return acc_gzsl,acc_zs_t 102 | 103 | def compute_per_class_acc(test_label, predicted_label, nclass): 104 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 105 | for i in range(nclass): 106 | idx = (test_label == i) 107 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 108 | return acc_per_class.mean().item() 109 | 110 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 111 | 112 | device = in_package['device'] 113 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 114 | 115 | predicted_label = predicted_label.to(device) 116 | 117 | for i in range(target_classes.size()[0]): 118 | 119 | is_class = test_label == target_classes[i] 120 | 121 | per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float()) 122 | # pdb.set_trace() 123 | return per_class_accuracies.mean().item() 124 | 125 | def eval_zs_gzsl(dataloader,model,device,bias_seen,bias_unseen): 126 | model.eval() 127 | # print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 128 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 129 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 130 | 131 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 132 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 133 | 134 | seenclasses = dataloader.seenclasses 135 | unseenclasses = dataloader.unseenclasses 136 | 137 | batch_size = 100 138 | 139 | in_package = {'model':model,'device':device, 'batch_size':batch_size} 140 | 141 | with torch.no_grad(): 142 | acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen) 143 | acc_novel,acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen) 144 | 145 | if (acc_seen+acc_novel)>0: 146 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 147 | else: 148 | H = 0 149 | 150 | return acc_seen, acc_novel, H, acc_zs 151 | 152 | def get_heatmap(dataloader,model,device): 153 | model.eval() 154 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 155 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 156 | 157 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 158 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 159 | 160 | seenclasses = dataloader.seenclasses 161 | unseenclasses = dataloader.unseenclasses 162 | 163 | eval_size = 100 164 | n_classes = model.nclass 165 | n_atts = model.dim_att 166 | 167 | heatmap_seen = torch.zeros((n_classes,n_atts)) 168 | heatmap_unseen = torch.zeros((n_classes,n_atts)) 169 | 170 | with torch.no_grad(): 171 | for c in seenclasses: 172 | idx_c = torch.squeeze(torch.nonzero(test_seen_label == c))[:eval_size] 173 | 174 | batch_c_samples = test_seen_feature[idx_c].to(device) 175 | out_package = model(batch_c_samples) 176 | A_p = out_package['A_p'] 177 | heatmap_seen[c] += torch.mean(A_p,dim=0).cpu() 178 | 179 | for c in unseenclasses: 180 | idx_c = torch.squeeze(torch.nonzero(test_unseen_label == c))[:eval_size] 181 | 182 | batch_c_samples = test_unseen_feature[idx_c].to(device) 183 | out_package = model(batch_c_samples) 184 | A_p = out_package['A_p'] 185 | heatmap_unseen[c] += torch.mean(A_p,dim=0).cpu() 186 | 187 | return heatmap_seen.cpu().numpy(),heatmap_unseen.cpu().numpy() 188 | 189 | def val_gzsl_k(k,test_X, test_label, target_classes,in_package,bias = 0,is_detect=False): 190 | batch_size = in_package['batch_size'] 191 | model = in_package['model'] 192 | device = in_package['device'] 193 | n_classes = in_package["num_class"] 194 | 195 | with torch.no_grad(): 196 | start = 0 197 | ntest = test_X.size()[0] 198 | test_label = F.one_hot(test_label, num_classes=n_classes) 199 | predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device) 200 | for i in range(0, ntest, batch_size): 201 | 202 | end = min(ntest, start+batch_size) 203 | 204 | input = test_X[start:end].to(device) 205 | 206 | out_package1, out_package2= model(input) 207 | 208 | # if type(output) == tuple: # if model return multiple output, take the first one 209 | # output = output[0] 210 | # 211 | #output = out_package1['S_pp'] 212 | output = 0.9*out_package1['S_pp']+0.1*out_package2['S_pp'] 213 | output[:,target_classes] = output[:,target_classes]+bias 214 | # predicted_label[start:end] = torch.argmax(output.data, 1) 215 | _,idx_k = torch.topk(output,k,dim=1) 216 | if is_detect: 217 | assert k == 1 218 | detection_mask=in_package["detection_mask"] 219 | predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)] 220 | else: 221 | predicted_label[start:end] = predicted_label[start:end].scatter_(1,idx_k,1) 222 | start = end 223 | 224 | acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package) 225 | return acc 226 | 227 | def val_zs_gzsl_k(k,test_X, test_label, unseen_classes,in_package,bias = 0,is_detect=False): 228 | batch_size = in_package['batch_size'] 229 | model = in_package['model'] 230 | device = in_package['device'] 231 | n_classes = in_package["num_class"] 232 | with torch.no_grad(): 233 | start = 0 234 | ntest = test_X.size()[0] 235 | 236 | test_label_gzsl = F.one_hot(test_label, num_classes=n_classes) 237 | predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device) 238 | 239 | predicted_label_zsl = torch.LongTensor(test_label.size()) 240 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 241 | for i in range(0, ntest, batch_size): 242 | 243 | end = min(ntest, start+batch_size) 244 | 245 | input = test_X[start:end].to(device) 246 | 247 | out_package1,out_package2 = model(input) 248 | 249 | # if type(output) == tuple: # if model return multiple output, take the first one 250 | # output = output[0] 251 | # 252 | #output = out_package1['S_pp'] 253 | output = 0.9*out_package1['S_pp']+0.1*out_package2['S_pp'] 254 | output_t = output.clone() 255 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 256 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 257 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 258 | 259 | output[:,unseen_classes] = output[:,unseen_classes]+bias 260 | # predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 261 | _,idx_k = torch.topk(output,k,dim=1) 262 | if is_detect: 263 | assert k == 1 264 | detection_mask=in_package["detection_mask"] 265 | predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)] 266 | else: 267 | predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1,idx_k,1) 268 | 269 | start = end 270 | 271 | acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package) 272 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 273 | return acc_gzsl,-1 274 | 275 | def compute_per_class_acc_k(test_label, predicted_label, nclass): 276 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 277 | for i in range(nclass): 278 | idx = (test_label == i) 279 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 280 | return acc_per_class.mean().item() 281 | 282 | def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package): 283 | device = in_package['device'] 284 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 285 | 286 | predicted_label = predicted_label.to(device) 287 | 288 | hit = test_label*predicted_label 289 | for i in range(target_classes.size()[0]): 290 | 291 | # is_class = test_label == target_classes[i] 292 | target = target_classes[i] 293 | n_pos = torch.sum(hit[:,target]) 294 | n_gt = torch.sum(test_label[:,target]) 295 | per_class_accuracies[i] = torch.div(n_pos.float(),n_gt.float()) 296 | #pdb.set_trace() 297 | return per_class_accuracies.mean().item() 298 | 299 | def eval_zs_gzsl_k(k,dataloader,model,device,bias_seen,bias_unseen,is_detect=False): 300 | model.eval() 301 | print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 302 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 303 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 304 | 305 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 306 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 307 | 308 | seenclasses = dataloader.seenclasses 309 | unseenclasses = dataloader.unseenclasses 310 | 311 | batch_size = 100 312 | n_classes = dataloader.ntrain_class+dataloader.ntest_class 313 | in_package = {'model':model,'device':device, 'batch_size':batch_size,'num_class':n_classes} 314 | 315 | if is_detect: 316 | print("Measure novelty detection k: {}".format(k)) 317 | 318 | detection_mask = torch.zeros((n_classes,n_classes)).long().to(dataloader.device) 319 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 320 | detect_label[seenclasses]=1 321 | detection_mask[seenclasses,:] = detect_label 322 | 323 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 324 | detect_label[unseenclasses]=1 325 | detection_mask[unseenclasses,:]=detect_label 326 | in_package["detection_mask"]=detection_mask 327 | 328 | with torch.no_grad(): 329 | acc_seen = val_gzsl_k(k,test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen,is_detect=is_detect) 330 | acc_novel,acc_zs = val_zs_gzsl_k(k,test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen,is_detect=is_detect) 331 | 332 | if (acc_seen+acc_novel)>0: 333 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 334 | else: 335 | H = 0 336 | 337 | return acc_seen, acc_novel, H, acc_zs 338 | 339 | def compute_entropy(V): 340 | eps = 1e-7 341 | mass = torch.sum(V,dim = 1, keepdim = True) 342 | att_n = torch.div(V,mass) 343 | e = att_n * torch.log(att_n+eps) 344 | e = -1.0 * torch.sum(e,dim=1) 345 | # e = torch.mean(e) 346 | return e 347 | 348 | def get_lr(optimizer): 349 | lr = [] 350 | for param_group in optimizer.param_groups: 351 | lr.append(param_group['lr']) 352 | return lr 353 | 354 | input_size = 224 355 | data_transforms = transforms.Compose([ 356 | transforms.Resize(input_size), 357 | transforms.CenterCrop(input_size), 358 | transforms.ToTensor() 359 | ]) 360 | 361 | def visualize_attention(img_ids,alphas_1,alphas_2,S,n_top_attr,attr_name,attr,save_path=None,is_top=True): #alphas_1: [bir] alphas_2: [bi] 362 | n = img_ids.shape[0] 363 | image_size = 14*16 #one side of the img 364 | assert alphas_1.shape[1] == alphas_2.shape[1] == len(attr_name) 365 | r = alphas_1.shape[2] 366 | h = w = int(np.sqrt(r)) 367 | for i in range(n): 368 | fig=plt.figure(i,figsize=(20, 10)) 369 | file_path=img_ids[i]#.decode('utf-8') 370 | img_name = file_path.split("/")[-1] 371 | # file_path = img_path+str_id+'.jpg' 372 | alpha_1 = alphas_1[i] #[ir] 373 | alpha_2 = alphas_2[i] #[i] 374 | score = S[i] 375 | # Plot original image 376 | image = Image.open(file_path) 377 | if image.mode == 'L': 378 | image=image.convert('RGB') 379 | image = data_transforms(image) 380 | image = image.permute(1,2,0) #[224,244,3] <== [3,224,224] 381 | ax = plt.subplot(4, 5, 1) 382 | plt.imshow(image) 383 | ax.set_title(img_name,{'fontsize': 10}) 384 | # plt.axis('off') 385 | 386 | if is_top: 387 | idxs_top=np.argsort(-alpha_2)[:n_top_attr] 388 | else: 389 | idxs_top=np.argsort(alpha_2)[:n_top_attr] 390 | 391 | #pdb.set_trace() 392 | for idx_ctxt,idx_attr in enumerate(idxs_top): 393 | ax=plt.subplot(4, 5, idx_ctxt+2) 394 | plt.imshow(image) 395 | alp_curr = alpha_1[idx_attr,:].reshape(7,7) 396 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=image_size/h, sigma=10,multichannel=False) 397 | plt.imshow(alp_img, alpha=0.7) 398 | ax.set_title("{}\n{}\n{}-{}".format(attr_name[idx_attr],alpha_2[idx_attr],score[idx_attr],attr[idx_attr]),{'fontsize': 10}) 399 | # plt.axis('off') 400 | fig.tight_layout() 401 | if save_path is not None: 402 | plt.savefig(save_path+img_name,dpi=500) 403 | plt.close() 404 | 405 | class Logger: 406 | def __init__(self,filename,cols,is_save=True): 407 | self.df = pd.DataFrame() 408 | self.cols = cols 409 | self.filename=filename 410 | self.is_save=is_save 411 | def add(self,values): 412 | self.df=self.df.append(pd.DataFrame([values],columns=self.cols),ignore_index=True) 413 | def save(self): 414 | if self.is_save: 415 | self.df.to_csv(self.filename) 416 | def get_max(self,col): 417 | return np.max(self.df[col]) 418 | 419 | def is_max(self,col): 420 | return self.df[col].iloc[-1] >= np.max(self.df[col]) 421 | 422 | def get_attr_entropy(att): #the lower the more discriminative it is 423 | eps = 1e-8 424 | mass=np.sum(att,axis = 0,keepdims=True) 425 | att_n = np.divide(att,mass+eps) 426 | entropy = np.sum(-att_n*np.log(att_n+eps),axis=0) 427 | assert len(entropy.shape)==1 428 | return entropy -------------------------------------------------------------------------------- /core/helper_MSDN_SUN.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import numpy as np 5 | #%% visualization package 6 | from scipy import ndimage 7 | from torchvision import transforms 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | import skimage.transform 11 | import torch.nn.functional as F 12 | #%% 13 | import pandas as pd 14 | #%% 15 | import pdb 16 | #%% 17 | 18 | def mix_up(S1,S2,Y1,Y2): # S: bdwh Y: bk 19 | device = S1.device 20 | n = S1.size(0) 21 | m = torch.empty(n).uniform_().to(device) 22 | S = torch.einsum('bdwh,b-> bdwh',S1,m) + torch.einsum('bdwh,b-> bdwh',S2,1-m) 23 | Y = torch.einsum('bk,b-> bk',Y1,m) + torch.einsum('bk,b-> bk',Y2,1-m) 24 | return S,Y 25 | 26 | def val_gzsl(test_X, test_label, target_classes,in_package,bias = 0): 27 | 28 | batch_size = in_package['batch_size'] 29 | model = in_package['model'] 30 | device = in_package['device'] 31 | with torch.no_grad(): 32 | start = 0 33 | ntest = test_X.size()[0] 34 | predicted_label = torch.LongTensor(test_label.size()) 35 | for i in range(0, ntest, batch_size): 36 | 37 | end = min(ntest, start+batch_size) 38 | 39 | input = test_X[start:end].to(device) 40 | 41 | out_package1, out_package2= model(input) 42 | 43 | # if type(output) == tuple: # if model return multiple output, take the first one 44 | # output = output[0] 45 | #output = out_package1['S_pp'] 46 | output = 0.7*out_package1['S_pp']+0.3*out_package2['S_pp'] 47 | output[:,target_classes] = output[:,target_classes]+bias 48 | predicted_label[start:end] = torch.argmax(output.data, 1) 49 | 50 | start = end 51 | 52 | acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package) 53 | return acc 54 | 55 | def map_label(label, classes): 56 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 57 | for i in range(classes.size(0)): 58 | mapped_label[label==classes[i]] = i 59 | 60 | return mapped_label 61 | 62 | def val_zs_gzsl(test_X, test_label, unseen_classes,in_package,bias = 0): 63 | batch_size = in_package['batch_size'] 64 | model = in_package['model'] 65 | device = in_package['device'] 66 | with torch.no_grad(): 67 | start = 0 68 | ntest = test_X.size()[0] 69 | predicted_label_gzsl = torch.LongTensor(test_label.size()) 70 | predicted_label_zsl = torch.LongTensor(test_label.size()) 71 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 72 | for i in range(0, ntest, batch_size): 73 | 74 | end = min(ntest, start+batch_size) 75 | 76 | input = test_X[start:end].to(device) 77 | 78 | out_package1,out_package2 = model(input) 79 | 80 | # if type(output) == tuple: # if model return multiple output, take the first one 81 | # output = output[0] 82 | # 83 | #output = out_package1['S_pp'] 84 | output = 0.7*out_package1['S_pp']+0.3*out_package2['S_pp'] 85 | 86 | output_t = output.clone() 87 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 88 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 89 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 90 | 91 | output[:,unseen_classes] = output[:,unseen_classes]+bias 92 | predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 93 | 94 | 95 | start = end 96 | acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package) 97 | acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package) 98 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t, unseen_classes.size(0)) 99 | 100 | # assert np.abs(acc_zs - acc_zs_t) < 0.001 101 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 102 | return acc_gzsl,acc_zs_t 103 | 104 | def compute_per_class_acc(test_label, predicted_label, nclass): 105 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 106 | for i in range(nclass): 107 | idx = (test_label == i) 108 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 109 | return acc_per_class.mean().item() 110 | 111 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 112 | 113 | device = in_package['device'] 114 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 115 | 116 | predicted_label = predicted_label.to(device) 117 | 118 | for i in range(target_classes.size()[0]): 119 | 120 | is_class = test_label == target_classes[i] 121 | 122 | per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float()) 123 | # pdb.set_trace() 124 | return per_class_accuracies.mean().item() 125 | 126 | def eval_zs_gzsl(dataloader,model,device,bias_seen,bias_unseen): 127 | model.eval() 128 | # print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 129 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 130 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 131 | 132 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 133 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 134 | 135 | seenclasses = dataloader.seenclasses 136 | unseenclasses = dataloader.unseenclasses 137 | 138 | batch_size = 100 139 | 140 | in_package = {'model':model,'device':device, 'batch_size':batch_size} 141 | 142 | with torch.no_grad(): 143 | acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen) 144 | acc_novel,acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen) 145 | 146 | if (acc_seen+acc_novel)>0: 147 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 148 | else: 149 | H = 0 150 | 151 | return acc_seen, acc_novel, H, acc_zs 152 | 153 | def get_heatmap(dataloader,model,device): 154 | model.eval() 155 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 156 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 157 | 158 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 159 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 160 | 161 | seenclasses = dataloader.seenclasses 162 | unseenclasses = dataloader.unseenclasses 163 | 164 | eval_size = 100 165 | n_classes = model.nclass 166 | n_atts = model.dim_att 167 | 168 | heatmap_seen = torch.zeros((n_classes,n_atts)) 169 | heatmap_unseen = torch.zeros((n_classes,n_atts)) 170 | 171 | with torch.no_grad(): 172 | for c in seenclasses: 173 | idx_c = torch.squeeze(torch.nonzero(test_seen_label == c))[:eval_size] 174 | 175 | batch_c_samples = test_seen_feature[idx_c].to(device) 176 | out_package = model(batch_c_samples) 177 | A_p = out_package['A_p'] 178 | heatmap_seen[c] += torch.mean(A_p,dim=0).cpu() 179 | 180 | for c in unseenclasses: 181 | idx_c = torch.squeeze(torch.nonzero(test_unseen_label == c))[:eval_size] 182 | 183 | batch_c_samples = test_unseen_feature[idx_c].to(device) 184 | out_package = model(batch_c_samples) 185 | A_p = out_package['A_p'] 186 | heatmap_unseen[c] += torch.mean(A_p,dim=0).cpu() 187 | 188 | return heatmap_seen.cpu().numpy(),heatmap_unseen.cpu().numpy() 189 | 190 | def val_gzsl_k(k,test_X, test_label, target_classes,in_package,bias = 0,is_detect=False): 191 | batch_size = in_package['batch_size'] 192 | model = in_package['model'] 193 | device = in_package['device'] 194 | n_classes = in_package["num_class"] 195 | 196 | with torch.no_grad(): 197 | start = 0 198 | ntest = test_X.size()[0] 199 | test_label = F.one_hot(test_label, num_classes=n_classes) 200 | predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device) 201 | for i in range(0, ntest, batch_size): 202 | 203 | end = min(ntest, start+batch_size) 204 | 205 | input = test_X[start:end].to(device) 206 | 207 | out_package1, out_package2= model(input) 208 | 209 | # if type(output) == tuple: # if model return multiple output, take the first one 210 | # output = output[0] 211 | # 212 | #output = out_package1['S_pp'] 213 | output = 0.7*out_package1['S_pp']+0.3*out_package2['S_pp'] 214 | output[:,target_classes] = output[:,target_classes]+bias 215 | # predicted_label[start:end] = torch.argmax(output.data, 1) 216 | _,idx_k = torch.topk(output,k,dim=1) 217 | if is_detect: 218 | assert k == 1 219 | detection_mask=in_package["detection_mask"] 220 | predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)] 221 | else: 222 | predicted_label[start:end] = predicted_label[start:end].scatter_(1,idx_k,1) 223 | start = end 224 | 225 | acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package) 226 | return acc 227 | 228 | def val_zs_gzsl_k(k,test_X, test_label, unseen_classes,in_package,bias = 0,is_detect=False): 229 | batch_size = in_package['batch_size'] 230 | model = in_package['model'] 231 | device = in_package['device'] 232 | n_classes = in_package["num_class"] 233 | with torch.no_grad(): 234 | start = 0 235 | ntest = test_X.size()[0] 236 | 237 | test_label_gzsl = F.one_hot(test_label, num_classes=n_classes) 238 | predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device) 239 | 240 | predicted_label_zsl = torch.LongTensor(test_label.size()) 241 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 242 | for i in range(0, ntest, batch_size): 243 | 244 | end = min(ntest, start+batch_size) 245 | 246 | input = test_X[start:end].to(device) 247 | 248 | out_package1,out_package2 = model(input) 249 | 250 | # if type(output) == tuple: # if model return multiple output, take the first one 251 | # output = output[0] 252 | # 253 | #output = out_package1['S_pp'] 254 | output = 0.5*out_package1['S_pp']+0.5*out_package2['S_pp'] 255 | output_t = output.clone() 256 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 257 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 258 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 259 | 260 | output[:,unseen_classes] = output[:,unseen_classes]+bias 261 | # predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 262 | _,idx_k = torch.topk(output,k,dim=1) 263 | if is_detect: 264 | assert k == 1 265 | detection_mask=in_package["detection_mask"] 266 | predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)] 267 | else: 268 | predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1,idx_k,1) 269 | 270 | start = end 271 | 272 | acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package) 273 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 274 | return acc_gzsl,-1 275 | 276 | def compute_per_class_acc_k(test_label, predicted_label, nclass): 277 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 278 | for i in range(nclass): 279 | idx = (test_label == i) 280 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 281 | return acc_per_class.mean().item() 282 | 283 | def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package): 284 | device = in_package['device'] 285 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 286 | 287 | predicted_label = predicted_label.to(device) 288 | 289 | hit = test_label*predicted_label 290 | for i in range(target_classes.size()[0]): 291 | 292 | # is_class = test_label == target_classes[i] 293 | target = target_classes[i] 294 | n_pos = torch.sum(hit[:,target]) 295 | n_gt = torch.sum(test_label[:,target]) 296 | per_class_accuracies[i] = torch.div(n_pos.float(),n_gt.float()) 297 | #pdb.set_trace() 298 | return per_class_accuracies.mean().item() 299 | 300 | def eval_zs_gzsl_k(k,dataloader,model,device,bias_seen,bias_unseen,is_detect=False): 301 | model.eval() 302 | print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 303 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 304 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 305 | 306 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 307 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 308 | 309 | seenclasses = dataloader.seenclasses 310 | unseenclasses = dataloader.unseenclasses 311 | 312 | batch_size = 100 313 | n_classes = dataloader.ntrain_class+dataloader.ntest_class 314 | in_package = {'model':model,'device':device, 'batch_size':batch_size,'num_class':n_classes} 315 | 316 | if is_detect: 317 | print("Measure novelty detection k: {}".format(k)) 318 | 319 | detection_mask = torch.zeros((n_classes,n_classes)).long().to(dataloader.device) 320 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 321 | detect_label[seenclasses]=1 322 | detection_mask[seenclasses,:] = detect_label 323 | 324 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 325 | detect_label[unseenclasses]=1 326 | detection_mask[unseenclasses,:]=detect_label 327 | in_package["detection_mask"]=detection_mask 328 | 329 | with torch.no_grad(): 330 | acc_seen = val_gzsl_k(k,test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen,is_detect=is_detect) 331 | acc_novel,acc_zs = val_zs_gzsl_k(k,test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen,is_detect=is_detect) 332 | 333 | if (acc_seen+acc_novel)>0: 334 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 335 | else: 336 | H = 0 337 | 338 | return acc_seen, acc_novel, H, acc_zs 339 | 340 | def compute_entropy(V): 341 | eps = 1e-7 342 | mass = torch.sum(V,dim = 1, keepdim = True) 343 | att_n = torch.div(V,mass) 344 | e = att_n * torch.log(att_n+eps) 345 | e = -1.0 * torch.sum(e,dim=1) 346 | # e = torch.mean(e) 347 | return e 348 | 349 | def get_lr(optimizer): 350 | lr = [] 351 | for param_group in optimizer.param_groups: 352 | lr.append(param_group['lr']) 353 | return lr 354 | 355 | input_size = 224 356 | data_transforms = transforms.Compose([ 357 | transforms.Resize(input_size), 358 | transforms.CenterCrop(input_size), 359 | transforms.ToTensor() 360 | ]) 361 | 362 | def visualize_attention(img_ids,alphas_1,alphas_2,S,n_top_attr,attr_name,attr,save_path=None,is_top=True): #alphas_1: [bir] alphas_2: [bi] 363 | n = img_ids.shape[0] 364 | image_size = 14*16 #one side of the img 365 | assert alphas_1.shape[1] == alphas_2.shape[1] == len(attr_name) 366 | r = alphas_1.shape[2] 367 | h = w = int(np.sqrt(r)) 368 | for i in range(n): 369 | fig=plt.figure(i,figsize=(20, 10)) 370 | file_path=img_ids[i]#.decode('utf-8') 371 | img_name = file_path.split("/")[-1] 372 | # file_path = img_path+str_id+'.jpg' 373 | alpha_1 = alphas_1[i] #[ir] 374 | alpha_2 = alphas_2[i] #[i] 375 | score = S[i] 376 | # Plot original image 377 | image = Image.open(file_path) 378 | if image.mode == 'L': 379 | image=image.convert('RGB') 380 | image = data_transforms(image) 381 | image = image.permute(1,2,0) #[224,244,3] <== [3,224,224] 382 | ax = plt.subplot(4, 5, 1) 383 | plt.imshow(image) 384 | ax.set_title(img_name,{'fontsize': 10}) 385 | # plt.axis('off') 386 | 387 | if is_top: 388 | idxs_top=np.argsort(-alpha_2)[:n_top_attr] 389 | else: 390 | idxs_top=np.argsort(alpha_2)[:n_top_attr] 391 | 392 | #pdb.set_trace() 393 | for idx_ctxt,idx_attr in enumerate(idxs_top): 394 | ax=plt.subplot(4, 5, idx_ctxt+2) 395 | plt.imshow(image) 396 | alp_curr = alpha_1[idx_attr,:].reshape(7,7) 397 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=image_size/h, sigma=10,multichannel=False) 398 | plt.imshow(alp_img, alpha=0.7) 399 | ax.set_title("{}\n{}\n{}-{}".format(attr_name[idx_attr],alpha_2[idx_attr],score[idx_attr],attr[idx_attr]),{'fontsize': 10}) 400 | # plt.axis('off') 401 | fig.tight_layout() 402 | if save_path is not None: 403 | plt.savefig(save_path+img_name,dpi=500) 404 | plt.close() 405 | 406 | class Logger: 407 | def __init__(self,filename,cols,is_save=True): 408 | self.df = pd.DataFrame() 409 | self.cols = cols 410 | self.filename=filename 411 | self.is_save=is_save 412 | def add(self,values): 413 | self.df=self.df.append(pd.DataFrame([values],columns=self.cols),ignore_index=True) 414 | def save(self): 415 | if self.is_save: 416 | self.df.to_csv(self.filename) 417 | def get_max(self,col): 418 | return np.max(self.df[col]) 419 | 420 | def is_max(self,col): 421 | return self.df[col].iloc[-1] >= np.max(self.df[col]) 422 | 423 | def get_attr_entropy(att): #the lower the more discriminative it is 424 | eps = 1e-8 425 | mass=np.sum(att,axis = 0,keepdims=True) 426 | att_n = np.divide(att,mass+eps) 427 | entropy = np.sum(-att_n*np.log(att_n+eps),axis=0) 428 | assert len(entropy.shape)==1 429 | return entropy -------------------------------------------------------------------------------- /data/AWA2: -------------------------------------------------------------------------------- 1 | ../../../data/AWA2 -------------------------------------------------------------------------------- /data/AWA2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/data/AWA2.pkl -------------------------------------------------------------------------------- /data/CUB: -------------------------------------------------------------------------------- 1 | ../../../data/CUB -------------------------------------------------------------------------------- /data/CUB.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/data/CUB.pkl -------------------------------------------------------------------------------- /data/SUN: -------------------------------------------------------------------------------- 1 | ../../../data/SUN -------------------------------------------------------------------------------- /data/SUN.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/data/SUN.pkl -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from torchvision import transforms 4 | from torch.utils.data import Dataset, Subset, DataLoader 5 | from PIL import Image 6 | 7 | 8 | class BaseDataset(Dataset): 9 | def __init__(self, dataset_path, image_files, labels, transform=None): 10 | super(BaseDataset, self).__init__() 11 | self.dataset_path = dataset_path 12 | self.image_files = image_files 13 | self.labels = labels 14 | self.transform = transform 15 | 16 | def __len__(self): 17 | return len(self.image_files) 18 | 19 | def __getitem__(self, idx): 20 | label = self.labels[idx] 21 | image_file = self.image_files[idx] 22 | image_file = os.path.join(self.dataset_path, image_file) 23 | image = Image.open(image_file) 24 | if image.mode != 'RGB': 25 | image = image.convert('RGB') 26 | if self.transform: 27 | image = self.transform(image) 28 | return image, label 29 | 30 | 31 | class UNIDataloader(): 32 | def __init__(self, config): 33 | self.config = config 34 | with open(config.pkl_path, 'rb') as f: 35 | self.info = pickle.load(f) 36 | 37 | self.seenclasses = self.info['seenclasses'].to(config.device) 38 | self.unseenclasses = self.info['unseenclasses'].to(config.device) 39 | 40 | (self.train_set, 41 | self.test_seen_set, 42 | self.test_unseen_set) = self.torch_dataset() 43 | 44 | self.train_loader = DataLoader(self.train_set, 45 | batch_size=config.batch_size, 46 | shuffle=True, 47 | num_workers=config.num_workers) 48 | self.test_seen_loader = DataLoader(self.test_seen_set, 49 | batch_size=config.batch_size, 50 | shuffle=False, 51 | num_workers=config.num_workers) 52 | self.test_unseen_loader = DataLoader(self.test_unseen_set, 53 | batch_size=config.batch_size, 54 | shuffle=False, 55 | num_workers=config.num_workers) 56 | 57 | def torch_dataset(self): 58 | data_transforms = transforms.Compose([ 59 | transforms.Resize(self.config.img_size), 60 | transforms.CenterCrop(self.config.img_size), 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 63 | baseset = BaseDataset(self.config.dataset_path, 64 | self.info['image_files'], 65 | self.info['labels'], 66 | data_transforms) 67 | 68 | train_set = Subset(baseset, self.info['trainval_loc']) 69 | test_seen_set = Subset(baseset, self.info['test_seen_loc']) 70 | test_unseen_set = Subset(baseset, self.info['test_unseen_loc']) 71 | 72 | return train_set, test_seen_set, test_unseen_set 73 | -------------------------------------------------------------------------------- /global_setting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jul 3 18:59:49 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | NFS_path = './' -------------------------------------------------------------------------------- /images/t-v/Acadian_Flycatcher_0008_795599.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Acadian_Flycatcher_0008_795599.jpg -------------------------------------------------------------------------------- /images/t-v/American_Goldfinch_0092_32910.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/American_Goldfinch_0092_32910.jpg -------------------------------------------------------------------------------- /images/t-v/Canada_Warbler_0117_162394.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Canada_Warbler_0117_162394.jpg -------------------------------------------------------------------------------- /images/t-v/Carolina_Wren_0006_186742.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Carolina_Wren_0006_186742.jpg -------------------------------------------------------------------------------- /images/t-v/Elegant_Tern_0085_151091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Elegant_Tern_0085_151091.jpg -------------------------------------------------------------------------------- /images/t-v/European_Goldfinch_0025_794647.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/European_Goldfinch_0025_794647.jpg -------------------------------------------------------------------------------- /images/t-v/Florida_Jay_0008_64482.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Florida_Jay_0008_64482.jpg -------------------------------------------------------------------------------- /images/t-v/Fox_Sparrow_0025_114555.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Fox_Sparrow_0025_114555.jpg -------------------------------------------------------------------------------- /images/t-v/Grasshopper_Sparrow_0053_115991.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Grasshopper_Sparrow_0053_115991.jpg -------------------------------------------------------------------------------- /images/t-v/Grasshopper_Sparrow_0107_116286.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Grasshopper_Sparrow_0107_116286.jpg -------------------------------------------------------------------------------- /images/t-v/Gray_Crowned_Rosy_Finch_0036_797287.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Gray_Crowned_Rosy_Finch_0036_797287.jpg -------------------------------------------------------------------------------- /images/t-v/Vesper_Sparrow_0090_125690.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Vesper_Sparrow_0090_125690.jpg -------------------------------------------------------------------------------- /images/t-v/Western_Gull_0058_53882.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Western_Gull_0058_53882.jpg -------------------------------------------------------------------------------- /images/t-v/White_Throated_Sparrow_0128_128956.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/White_Throated_Sparrow_0128_128956.jpg -------------------------------------------------------------------------------- /images/t-v/Winter_Wren_0118_189805.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Winter_Wren_0118_189805.jpg -------------------------------------------------------------------------------- /images/t-v/Yellow_Breasted_Chat_0044_22106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/t-v/Yellow_Breasted_Chat_0044_22106.jpg -------------------------------------------------------------------------------- /images/tsne/awa2_tsne_test_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/awa2_tsne_test_unseen.png -------------------------------------------------------------------------------- /images/tsne/awa2_tsne_train_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/awa2_tsne_train_seen.png -------------------------------------------------------------------------------- /images/tsne/cub_tsne_test_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/cub_tsne_test_unseen.png -------------------------------------------------------------------------------- /images/tsne/cub_tsne_train_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/cub_tsne_train_seen.png -------------------------------------------------------------------------------- /images/tsne/sun_tsne_test_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/sun_tsne_test_unseen.png -------------------------------------------------------------------------------- /images/tsne/sun_tsne_train_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/tsne/sun_tsne_train_seen.png -------------------------------------------------------------------------------- /images/v-t/1: -------------------------------------------------------------------------------- 1 | 11 2 | -------------------------------------------------------------------------------- /images/v-t/Acadian_Flycatcher_0008_795599.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Acadian_Flycatcher_0008_795599.jpg -------------------------------------------------------------------------------- /images/v-t/American_Goldfinch_0092_32910.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/American_Goldfinch_0092_32910.jpg -------------------------------------------------------------------------------- /images/v-t/Canada_Warbler_0117_162394.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Canada_Warbler_0117_162394.jpg -------------------------------------------------------------------------------- /images/v-t/Carolina_Wren_0006_186742.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Carolina_Wren_0006_186742.jpg -------------------------------------------------------------------------------- /images/v-t/Elegant_Tern_0085_151091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Elegant_Tern_0085_151091.jpg -------------------------------------------------------------------------------- /images/v-t/European_Goldfinch_0025_794647.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/European_Goldfinch_0025_794647.jpg -------------------------------------------------------------------------------- /images/v-t/Vesper_Sparrow_0090_125690.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Vesper_Sparrow_0090_125690.jpg -------------------------------------------------------------------------------- /images/v-t/Western_Gull_0058_53882.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Western_Gull_0058_53882.jpg -------------------------------------------------------------------------------- /images/v-t/White_Throated_Sparrow_0128_128956.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/White_Throated_Sparrow_0128_128956.jpg -------------------------------------------------------------------------------- /images/v-t/Winter_Wren_0118_189805.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Winter_Wren_0118_189805.jpg -------------------------------------------------------------------------------- /images/v-t/Yellow_Breasted_Chat_0044_22106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/MSDN/ec3598bc6639732dfc46205de0e4cb4958774a1e/images/v-t/Yellow_Breasted_Chat_0044_22106.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.2 2 | torchvision==0.9.0 3 | torch==1.8.0 4 | Pillow==8.3.2 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | 6 | def get_gpu_info(): 7 | gpuinfolist = os.popen('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free').readlines() 8 | freemem = [int(gpuinfo.split()[2]) for gpuinfo in gpuinfolist] 9 | gpuidx = len(freemem) - 1 - np.argmax(list(reversed(freemem))) 10 | return f'cuda:{gpuidx}' 11 | 12 | 13 | def map_label(label, classes): 14 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 15 | for i in range(classes.size(0)): 16 | mapped_label[label == classes[i]] = i 17 | return mapped_label 18 | 19 | 20 | def compute_per_class_acc(test_label, predicted_label, nclass): 21 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 22 | for i in range(nclass): 23 | idx = (test_label == i) 24 | acc_per_class[i] = torch.sum( 25 | test_label[idx] == predicted_label[idx]).float() / torch.sum(idx).float() 26 | return acc_per_class.mean().item() 27 | 28 | 29 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 30 | device = in_package['device'] 31 | per_class_accuracies = torch.zeros( 32 | target_classes.size()[0]).float().to(device).detach() 33 | predicted_label = predicted_label.to(device) 34 | for i in range(target_classes.size()[0]): 35 | is_class = test_label == target_classes[i] 36 | per_class_accuracies[i] = torch.div( 37 | (predicted_label[is_class] == test_label[is_class]).sum().float(), 38 | is_class.sum().float()) 39 | return per_class_accuracies.mean().item() 40 | 41 | 42 | def val_gzsl(test_seen_loader, target_classes, in_package, bias=0): 43 | batch_size = in_package['batch_size'] 44 | model = in_package['model'] 45 | device = in_package['device'] 46 | test_label = [] 47 | predicted_label = [] 48 | with torch.no_grad(): 49 | for batch, (imgs, labels) in enumerate(test_seen_loader): 50 | imgs, labels = imgs.to(device), labels.to(device) 51 | out_package = model(imgs) 52 | output = out_package['embed'] 53 | output[:, target_classes] = output[:, target_classes]+bias 54 | predicted_label.append(torch.argmax(output.data, 1)) 55 | test_label.append(labels) 56 | test_label = torch.cat(test_label, dim=0) 57 | predicted_label = torch.cat(predicted_label, dim=0) 58 | acc = compute_per_class_acc_gzsl( 59 | test_label, predicted_label, target_classes, in_package) 60 | return acc 61 | 62 | 63 | def val_zs_gzsl(test_unseen_loader, unseen_classes, in_package, bias=0): 64 | batch_size = in_package['batch_size'] 65 | model = in_package['model'] 66 | device = in_package['device'] 67 | test_label = [] 68 | predicted_label_gzsl = [] 69 | predicted_label_zsl = [] 70 | predicted_label_zsl_t = [] 71 | with torch.no_grad(): 72 | for batch, (imgs, labels) in enumerate(test_unseen_loader): 73 | imgs, labels = imgs.to(device), labels.to(device) 74 | out_package = model(imgs) 75 | output = out_package['embed'] 76 | output_t = output.clone() 77 | output_t[:, unseen_classes] = output_t[:, 78 | unseen_classes] + torch.max(output) + 1 79 | predicted_label_zsl.append(torch.argmax(output_t.data, 1)) 80 | predicted_label_zsl_t.append( 81 | torch.argmax(output.data[:, unseen_classes], 1)) 82 | output[:, unseen_classes] = output[:, unseen_classes]+bias 83 | predicted_label_gzsl.append(torch.argmax(output.data, 1)) 84 | test_label.append(labels) 85 | test_label = torch.cat(test_label, dim=0) 86 | predicted_label_gzsl = torch.cat(predicted_label_gzsl, dim=0) 87 | predicted_label_zsl = torch.cat(predicted_label_zsl, dim=0) 88 | predicted_label_zsl_t = torch.cat(predicted_label_zsl_t, dim=0) 89 | acc_gzsl = compute_per_class_acc_gzsl( 90 | test_label, predicted_label_gzsl, unseen_classes, in_package) 91 | acc_zs = compute_per_class_acc_gzsl( 92 | test_label, predicted_label_zsl, unseen_classes, in_package) 93 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes).to( 94 | device), predicted_label_zsl_t, unseen_classes.size(0)) 95 | return acc_gzsl, acc_zs_t 96 | 97 | 98 | def eval_zs_gzsl(batch_size, device, zsl_task, dataloader, model, bias_seen, bias_unseen): 99 | model.eval() 100 | test_seen_loader = dataloader.test_seen_loader 101 | test_unseen_loader = dataloader.test_unseen_loader 102 | seenclasses = dataloader.seenclasses 103 | unseenclasses = dataloader.unseenclasses 104 | in_package = {'model': model, 'device': device, 'batch_size': batch_size} 105 | if zsl_task == 'CZSL': 106 | with torch.no_grad(): 107 | _, acc_zs = val_zs_gzsl( 108 | test_unseen_loader, unseenclasses, in_package, bias=bias_unseen) 109 | return acc_zs 110 | elif zsl_task == 'GZSL': 111 | with torch.no_grad(): 112 | acc_seen = val_gzsl(test_seen_loader, seenclasses, 113 | in_package, bias=bias_seen) 114 | acc_novel, _ = val_zs_gzsl( 115 | test_unseen_loader, unseenclasses, in_package, bias=bias_unseen) 116 | if (acc_seen+acc_novel) > 0: 117 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 118 | else: 119 | H = 0 120 | return acc_seen, acc_novel, H 121 | 122 | 123 | def evaluation(batch_size, device, dataloader, model_gzsl, model_czsl, bias_seen=0, bias_unseen=0): 124 | acc_zs = eval_zs_gzsl(batch_size, device, 'CZSL', 125 | dataloader, model_czsl, bias_seen, bias_unseen) 126 | print('CZSL Results: Acc_ZSL={:.3f}'.format(acc_zs)) 127 | acc_seen, acc_novel, H = eval_zs_gzsl(batch_size, device, 'GZSL', 128 | dataloader, model_gzsl, bias_seen, bias_unseen) 129 | print('GZSL Results: Acc_Unseen={:.3f}, Acc_Seen={:.3f}, H={:.3f}'.format( 130 | acc_novel, acc_seen, H)) 131 | return 0 132 | 133 | --------------------------------------------------------------------------------