├── README.md ├── dataset.py ├── figs ├── Acadian_Flycatcher_0008_795599.jpg ├── American_Goldfinch_0092_32910.jpg ├── Canada_Warbler_0117_162394.jpg ├── Elegant_Tern_0085_151091.jpg ├── European_Goldfinch_0025_794647.jpg ├── Florida_Jay_0008_64482.jpg ├── Fox_Sparrow_0025_114555.jpg ├── Grasshopper_Sparrow_0053_115991.jpg ├── Grasshopper_Sparrow_0107_116286.jpg ├── Gray_Crowned_Rosy_Finch_0036_797287.jpg ├── Vesper_Sparrow_0090_125690.jpg ├── Western_Gull_0058_53882.jpg ├── White_Throated_Sparrow_0128_128956.jpg ├── Winter_Wren_0118_189805.jpg ├── Yellow_Breasted_Chat_0044_22106.jpg ├── pipeline.png ├── tsne_awa2_seen.png ├── tsne_awa2_unseen.png ├── tsne_cub_seen.png ├── tsne_cub_unseen.png ├── tsne_sun_seen.png ├── tsne_sun_unseen.png └── tzpp.png ├── helper_func.py ├── model_tzpp.py ├── preprocessing.py ├── requirements.txt ├── train_awa2.py ├── train_cub.py ├── train_sun.py ├── w2v ├── AWA2_attribute.pkl ├── CUB_attribute.pkl └── SUN_attribute.pkl └── wandb_config ├── awa2_czsl.yaml ├── awa2_gzsl.yaml ├── cub_czsl.yaml ├── cub_gzsl.yaml ├── sun_czsl.yaml └── sun_gzsl.yaml /README.md: -------------------------------------------------------------------------------- 1 | # TransZero++ 2 | 3 | This repository contains the training and testing code for the paper "***TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning***" accepted to TPAMI. 4 | 5 | ![](figs/tzpp.png) 6 | 7 | ## Running Environment 8 | The implementation of **TransZero++** is mainly based on Python 3.8.8 and [PyTorch](https://pytorch.org/) 1.8.0. To install all required dependencies: 9 | ``` 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | We use [Weights & Biases](https://wandb.ai/site) (W&B) to keep track and organize the results of experiments. You may need to follow the [online documentation](https://docs.wandb.ai/quickstart) of W&B to quickstart. 14 | 15 | 16 | ## Download Dataset 17 | 18 | We trained the model on three popular ZSL benchmarks: [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [SUN](http://cs.brown.edu/~gmpatter/sunattributes.html) and [AWA2](http://cvml.ist.ac.at/AwA2/) following the data split of [xlsa17](http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip). In order to train the **TransZero++**, you should firstly download these datasets as well as the xlsa17. Then decompress and organize them as follows: 19 | ``` 20 | . 21 | ├── data 22 | │ ├── CUB/CUB_200_2011/... 23 | │ ├── SUN/images/... 24 | │ ├── AWA2/Animals_with_Attributes2/... 25 | │ └── xlsa17/data/... 26 | └── ··· 27 | ``` 28 | 29 | 30 | ## Visual Features Preprocessing 31 | 32 | In this step, you should run the following commands to extract the visual features of three datasets: 33 | 34 | ``` 35 | $ python preprocessing.py --dataset CUB --compression --device cuda:0 36 | $ python preprocessing.py --dataset SUN --compression --device cuda:0 37 | $ python preprocessing.py --dataset AWA2 --compression --device cuda:0 38 | ``` 39 | 40 | ## Training TransZero++ from Scratch 41 | In `./wandb_config`, we provide our parameters setting of conventional ZSL (CZSL) and generalized ZSL (GZSL) tasks for CUB, SUN, and AWA2. Please run the following commands to train the **TransZero++** from scratch: 42 | 43 | ``` 44 | $ python train_cub.py # CUB 45 | $ python train_sun.py # SUN 46 | $ python train_awa2.py # AWA2 47 | ``` 48 | **Note**: Please load the corresponding setting when aiming at the CZSL task. 49 | 50 | 51 | 52 | ## Results 53 | 54 | We also provide trained models ([Google Drive](https://drive.google.com/drive/folders/1rNHCglaSD_Q5se1rs5qIh6QNtMDCZokc?usp=sharing)) on CUB/SUN/AWA2. You can download these `.pth` files and validate the results in our paper. Please refer to the [here](https://github.com/shiming-chen/TransZero_pp/tree/f8251f2991c31775d6eb0367986321a681429713) for testing codes and usage. 55 | Following table shows the results of our released models using various evaluation protocols on three datasets, both in the CZSL and GZSL settings: 56 | 57 | | Dataset | Acc(CZSL) | U(GZSL) | S(GZSL) | H(GZSL) | 58 | | :-----: | :-----: | :-----: | :-----: | :-----: | 59 | | CUB | 78.3 | 67.5 | 73.6 | 70.4 | 60 | | SUN | 67.6 | 48.6 | 37.8 | 42.5 | 61 | | AWA2 | 72.6 | 64.6 | 82.7 | 72.5 | 62 | 63 | **Note**: The training of our models and all of the above results are run on a server with an AMD Ryzen 7 5800X CPU, 128GB memory, and an NVIDIA RTX A6000 GPU (48GB). 64 | 65 | ## Citation 66 | If this work is helpful for you, please cite our paper. 67 | 68 | ``` 69 | @article{Chen2022TransZeropp, 70 | author = {Chen, Shiming and Hong, Ziming and Hou, Wenjin and Xie, Guo-Sen and Song, Yibing and Zhao, Jian and You, Xinge and Yan, Shuicheng and Shao, Ling}, 71 | title = {TransZero++: Cross Attribute-guided Transformer for Zero-Shot Learning}, 72 | booktitle = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 73 | year = {2022} 74 | } 75 | ``` 76 | 77 | ## References 78 | Parts of our codes based on: 79 | * [hbdat/cvpr20_DAZLE](https://github.com/hbdat/cvpr20_DAZLE) 80 | * [zhangxuying1004/RSTNet](https://github.com/zhangxuying1004/RSTNet) 81 | * [shiming-chen/TransZero](https://github.com/shiming-chen/TransZero) 82 | 83 | ## Contact 84 | If you have any questions about codes, please don't hesitate to contact us by gchenshiming@gmail.com or hoongzm@gmail.com. 85 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys, torch, pickle, h5py 3 | import numpy as np 4 | import scipy.io as sio 5 | import pandas as pd 6 | from PIL import Image 7 | from sklearn import preprocessing 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, Subset, DataLoader 10 | 11 | 12 | class BaseDataset(Dataset): 13 | def __init__(self, dataset_path, image_files, labels, transform=None): 14 | super(BaseDataset, self).__init__() 15 | self.dataset_path = dataset_path 16 | self.image_files = image_files 17 | self.labels = labels 18 | self.transform = transform 19 | 20 | def __len__(self): 21 | return len(self.image_files) 22 | 23 | def __getitem__(self, idx): 24 | label = self.labels[idx] 25 | image_file = self.image_files[idx] 26 | image_file = os.path.join(self.dataset_path, image_file) 27 | image = Image.open(image_file) 28 | if image.mode != 'RGB': 29 | image = image.convert('RGB') 30 | if self.transform: 31 | image = self.transform(image) 32 | return image, label 33 | 34 | 35 | class UNIDataloader(): 36 | def __init__(self, config): 37 | self.config = config 38 | with open(config.pkl_path, 'rb') as f: 39 | self.info = pickle.load(f) 40 | 41 | self.seenclasses = self.info['seenclasses'].to(config.device) 42 | self.unseenclasses = self.info['unseenclasses'].to(config.device) 43 | 44 | (self.train_set, 45 | self.test_seen_set, 46 | self.test_unseen_set) = self.torch_dataset() 47 | 48 | self.train_loader = DataLoader(self.train_set, 49 | batch_size=config.batch_size, 50 | shuffle=True, 51 | num_workers=config.num_workers) 52 | self.test_seen_loader = DataLoader(self.test_seen_set, 53 | batch_size=config.batch_size, 54 | shuffle=False, 55 | num_workers=config.num_workers) 56 | self.test_unseen_loader = DataLoader(self.test_unseen_set, 57 | batch_size=config.batch_size, 58 | shuffle=False, 59 | num_workers=config.num_workers) 60 | 61 | def torch_dataset(self): 62 | data_transforms = transforms.Compose([ 63 | transforms.Resize(self.config.img_size), 64 | transforms.CenterCrop(self.config.img_size), 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 67 | baseset = BaseDataset(self.config.dataset_path, 68 | self.info['image_files'], 69 | self.info['labels'], 70 | data_transforms) 71 | 72 | train_set = Subset(baseset, self.info['trainval_loc']) 73 | test_seen_set = Subset(baseset, self.info['test_seen_loc']) 74 | test_unseen_set = Subset(baseset, self.info['test_unseen_loc']) 75 | 76 | return train_set, test_seen_set, test_unseen_set 77 | 78 | 79 | class CUBDataLoader(): 80 | def __init__(self, data_path, device, is_scale=False, 81 | is_unsupervised_attr=False, is_balance=True): 82 | print(data_path) 83 | sys.path.append(data_path) 84 | self.data_path = data_path 85 | self.device = device 86 | self.dataset = 'CUB' 87 | # print('$'*30) 88 | # print(self.dataset) 89 | # print('$'*30) 90 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 91 | self.index_in_epoch = 0 92 | self.epochs_completed = 0 93 | self.is_scale = is_scale 94 | self.is_balance = is_balance 95 | if self.is_balance: 96 | print('Balance dataloader') 97 | self.is_unsupervised_attr = is_unsupervised_attr 98 | self.read_matdataset() 99 | self.get_idx_classes() 100 | 101 | def next_batch(self, batch_size): 102 | if self.is_balance: 103 | idx = [] 104 | n_samples_class = max(batch_size //self.ntrain_class,1) 105 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 106 | for i_c in sampled_idx_c: 107 | idxs = self.idxs_list[i_c] 108 | idx.append(np.random.choice(idxs,n_samples_class)) 109 | idx = np.concatenate(idx) 110 | idx = torch.from_numpy(idx) 111 | else: 112 | idx = torch.randperm(self.ntrain)[0:batch_size] 113 | 114 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 115 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 116 | batch_att = self.att[batch_label].to(self.device) 117 | return batch_label, batch_feature, batch_att 118 | 119 | def get_idx_classes(self): 120 | n_classes = self.seenclasses.size(0) 121 | self.idxs_list = [] 122 | train_label = self.data['train_seen']['labels'] 123 | for i in range(n_classes): 124 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 125 | idx_c = np.squeeze(idx_c) 126 | self.idxs_list.append(idx_c) 127 | return self.idxs_list 128 | 129 | def read_matdataset(self): 130 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 131 | print('_____') 132 | print(path) 133 | # tic = time.time() 134 | hf = h5py.File(path, 'r') 135 | features = np.array(hf.get('feature_map')) 136 | # shape = features.shape 137 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 138 | # pdb.set_trace() 139 | labels = np.array(hf.get('labels')) 140 | trainval_loc = np.array(hf.get('trainval_loc')) 141 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 142 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 143 | test_seen_loc = np.array(hf.get('test_seen_loc')) 144 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 145 | 146 | if self.is_unsupervised_attr: 147 | print('Unsupervised Attr') 148 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 149 | with open(class_path,'rb') as f: 150 | w2v_class = pickle.load(f) 151 | temp = np.array(hf.get('att')) 152 | print(w2v_class.shape,temp.shape) 153 | # assert w2v_class.shape == temp.shape 154 | w2v_class = torch.tensor(w2v_class).float() 155 | 156 | U, s, V = torch.svd(w2v_class) 157 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 158 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 159 | 160 | print('shape U:{} V:{}'.format(U.size(),V.size())) 161 | print('s: {}'.format(s)) 162 | 163 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 164 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 165 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 166 | 167 | else: 168 | print('Expert Attr') 169 | att = np.array(hf.get('att')) 170 | self.att = torch.from_numpy(att).float().to(self.device) 171 | 172 | original_att = np.array(hf.get('original_att')) 173 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 174 | 175 | w2v_att = np.array(hf.get('w2v_att')) 176 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 177 | 178 | self.normalize_att = self.original_att/100 179 | 180 | train_feature = features[trainval_loc] 181 | test_seen_feature = features[test_seen_loc] 182 | test_unseen_feature = features[test_unseen_loc] 183 | if self.is_scale: 184 | scaler = preprocessing.MinMaxScaler() 185 | 186 | train_feature = scaler.fit_transform(train_feature) 187 | test_seen_feature = scaler.fit_transform(test_seen_feature) 188 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 189 | 190 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 191 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 192 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 193 | 194 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 195 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 196 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 197 | 198 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 199 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 200 | self.ntrain = train_feature.size()[0] 201 | self.ntrain_class = self.seenclasses.size(0) 202 | self.ntest_class = self.unseenclasses.size(0) 203 | self.train_class = self.seenclasses.clone() 204 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 205 | 206 | self.data = {} 207 | self.data['train_seen'] = {} 208 | self.data['train_seen']['resnet_features'] = train_feature 209 | self.data['train_seen']['labels']= train_label 210 | 211 | self.data['train_unseen'] = {} 212 | self.data['train_unseen']['resnet_features'] = None 213 | self.data['train_unseen']['labels'] = None 214 | 215 | self.data['test_seen'] = {} 216 | self.data['test_seen']['resnet_features'] = test_seen_feature 217 | self.data['test_seen']['labels'] = test_seen_label 218 | 219 | self.data['test_unseen'] = {} 220 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 221 | self.data['test_unseen']['labels'] = test_unseen_label 222 | 223 | 224 | class SUNDataLoader(): 225 | def __init__(self, data_path, device, is_scale=False, 226 | is_unsupervised_attr=False, is_balance=True): 227 | print(data_path) 228 | sys.path.append(data_path) 229 | self.data_path = data_path 230 | self.device = device 231 | self.dataset = 'SUN' 232 | print('$'*30) 233 | print(self.dataset) 234 | print('$'*30) 235 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 236 | self.index_in_epoch = 0 237 | self.epochs_completed = 0 238 | self.is_scale = is_scale 239 | self.is_balance = is_balance 240 | if self.is_balance: 241 | print('Balance dataloader') 242 | self.is_unsupervised_attr = is_unsupervised_attr 243 | self.read_matdataset() 244 | self.get_idx_classes() 245 | self.I = torch.eye(self.allclasses.size(0)).to(device) 246 | 247 | def next_batch(self, batch_size): 248 | if self.is_balance: 249 | idx = [] 250 | n_samples_class = max(batch_size //self.ntrain_class,1) 251 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 252 | for i_c in sampled_idx_c: 253 | idxs = self.idxs_list[i_c] 254 | idx.append(np.random.choice(idxs,n_samples_class)) 255 | idx = np.concatenate(idx) 256 | idx = torch.from_numpy(idx) 257 | else: 258 | idx = torch.randperm(self.ntrain)[0:batch_size] 259 | 260 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 261 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 262 | batch_att = self.att[batch_label].to(self.device) 263 | return batch_label, batch_feature, batch_att 264 | 265 | def get_idx_classes(self): 266 | n_classes = self.seenclasses.size(0) 267 | self.idxs_list = [] 268 | train_label = self.data['train_seen']['labels'] 269 | for i in range(n_classes): 270 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 271 | idx_c = np.squeeze(idx_c) 272 | self.idxs_list.append(idx_c) 273 | return self.idxs_list 274 | 275 | def read_matdataset(self): 276 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 277 | 278 | print('_____') 279 | print(path) 280 | # tic = time.time() 281 | hf = h5py.File(path, 'r') 282 | features = np.array(hf.get('feature_map')) 283 | labels = np.array(hf.get('labels')) 284 | trainval_loc = np.array(hf.get('trainval_loc')) 285 | test_seen_loc = np.array(hf.get('test_seen_loc')) 286 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 287 | 288 | if self.is_unsupervised_attr: 289 | print('Unsupervised Attr') 290 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 291 | with open(class_path,'rb') as f: 292 | w2v_class = pickle.load(f) 293 | assert w2v_class.shape == (50,300) 294 | w2v_class = torch.tensor(w2v_class).float() 295 | 296 | U, s, V = torch.svd(w2v_class) 297 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 298 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 299 | 300 | print('shape U:{} V:{}'.format(U.size(),V.size())) 301 | print('s: {}'.format(s)) 302 | 303 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 304 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 305 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 306 | 307 | else: 308 | print('Expert Attr') 309 | att = np.array(hf.get('att')) 310 | self.att = torch.from_numpy(att).float().to(self.device) 311 | 312 | original_att = np.array(hf.get('original_att')) 313 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 314 | 315 | w2v_att = np.array(hf.get('w2v_att')) 316 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 317 | 318 | self.normalize_att = self.original_att/100 319 | 320 | train_feature = features[trainval_loc] 321 | test_seen_feature = features[test_seen_loc] 322 | test_unseen_feature = features[test_unseen_loc] 323 | if self.is_scale: 324 | scaler = preprocessing.MinMaxScaler() 325 | 326 | train_feature = scaler.fit_transform(train_feature) 327 | test_seen_feature = scaler.fit_transform(test_seen_feature) 328 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 329 | 330 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 331 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 332 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 333 | 334 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 335 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 336 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 337 | 338 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 339 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 340 | self.ntrain = train_feature.size()[0] 341 | self.ntrain_class = self.seenclasses.size(0) 342 | self.ntest_class = self.unseenclasses.size(0) 343 | self.train_class = self.seenclasses.clone() 344 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 345 | 346 | self.data = {} 347 | self.data['train_seen'] = {} 348 | self.data['train_seen']['resnet_features'] = train_feature 349 | self.data['train_seen']['labels']= train_label 350 | 351 | self.data['train_unseen'] = {} 352 | self.data['train_unseen']['resnet_features'] = None 353 | self.data['train_unseen']['labels'] = None 354 | 355 | self.data['test_seen'] = {} 356 | self.data['test_seen']['resnet_features'] = test_seen_feature 357 | self.data['test_seen']['labels'] = test_seen_label 358 | 359 | self.data['test_unseen'] = {} 360 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 361 | self.data['test_unseen']['labels'] = test_unseen_label 362 | 363 | 364 | class AWA2DataLoader(): 365 | def __init__(self, data_path, device, is_scale=False, 366 | is_unsupervised_attr=False, is_balance=True): 367 | print(data_path) 368 | sys.path.append(data_path) 369 | self.data_path = data_path 370 | self.device = device 371 | self.dataset = 'AWA2' 372 | print('$'*30) 373 | print(self.dataset) 374 | print('$'*30) 375 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 376 | self.index_in_epoch = 0 377 | self.epochs_completed = 0 378 | self.is_scale = is_scale 379 | self.is_balance = is_balance 380 | if self.is_balance: 381 | print('Balance dataloader') 382 | self.is_unsupervised_attr = is_unsupervised_attr 383 | self.read_matdataset() 384 | self.get_idx_classes() 385 | 386 | def next_batch(self, batch_size): 387 | if self.is_balance: 388 | idx = [] 389 | n_samples_class = max(batch_size //self.ntrain_class,1) 390 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 391 | for i_c in sampled_idx_c: 392 | idxs = self.idxs_list[i_c] 393 | idx.append(np.random.choice(idxs,n_samples_class)) 394 | idx = np.concatenate(idx) 395 | idx = torch.from_numpy(idx) 396 | else: 397 | idx = torch.randperm(self.ntrain)[0:batch_size] 398 | 399 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 400 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 401 | batch_att = self.att[batch_label].to(self.device) 402 | return batch_label, batch_feature, batch_att 403 | 404 | def get_idx_classes(self): 405 | n_classes = self.seenclasses.size(0) 406 | self.idxs_list = [] 407 | train_label = self.data['train_seen']['labels'] 408 | for i in range(n_classes): 409 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 410 | idx_c = np.squeeze(idx_c) 411 | self.idxs_list.append(idx_c) 412 | return self.idxs_list 413 | 414 | def read_matdataset(self): 415 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 416 | print('_____') 417 | print(path) 418 | # tic = time.clock() 419 | hf = h5py.File(path, 'r') 420 | features = np.array(hf.get('feature_map')) 421 | labels = np.array(hf.get('labels')) 422 | trainval_loc = np.array(hf.get('trainval_loc')) 423 | test_seen_loc = np.array(hf.get('test_seen_loc')) 424 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 425 | 426 | if self.is_unsupervised_attr: 427 | print('Unsupervised Attr') 428 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 429 | with open(class_path,'rb') as f: 430 | w2v_class = pickle.load(f) 431 | assert w2v_class.shape == (50,300) 432 | w2v_class = torch.tensor(w2v_class).float() 433 | 434 | U, s, V = torch.svd(w2v_class) 435 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 436 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 437 | 438 | print('shape U:{} V:{}'.format(U.size(),V.size())) 439 | print('s: {}'.format(s)) 440 | 441 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 442 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 443 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 444 | else: 445 | print('Expert Attr') 446 | att = np.array(hf.get('att')) 447 | 448 | print("threshold at zero attribute with negative value") 449 | att[att<0]=0 450 | 451 | self.att = torch.from_numpy(att).float().to(self.device) 452 | 453 | original_att = np.array(hf.get('original_att')) 454 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 455 | 456 | w2v_att = np.array(hf.get('w2v_att')) 457 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 458 | 459 | self.normalize_att = self.original_att/100 460 | 461 | train_feature = features[trainval_loc] 462 | test_seen_feature = features[test_seen_loc] 463 | test_unseen_feature = features[test_unseen_loc] 464 | if self.is_scale: 465 | scaler = preprocessing.MinMaxScaler() 466 | 467 | train_feature = scaler.fit_transform(train_feature) 468 | test_seen_feature = scaler.fit_transform(test_seen_feature) 469 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 470 | 471 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 472 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 473 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 474 | 475 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 476 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 477 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 478 | 479 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 480 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 481 | self.ntrain = train_feature.size()[0] 482 | self.ntrain_class = self.seenclasses.size(0) 483 | self.ntest_class = self.unseenclasses.size(0) 484 | self.train_class = self.seenclasses.clone() 485 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 486 | 487 | self.data = {} 488 | self.data['train_seen'] = {} 489 | self.data['train_seen']['resnet_features'] = train_feature 490 | self.data['train_seen']['labels']= train_label 491 | 492 | self.data['train_unseen'] = {} 493 | self.data['train_unseen']['resnet_features'] = None 494 | self.data['train_unseen']['labels'] = None 495 | 496 | self.data['test_seen'] = {} 497 | self.data['test_seen']['resnet_features'] = test_seen_feature 498 | self.data['test_seen']['labels'] = test_seen_label 499 | 500 | self.data['test_unseen'] = {} 501 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 502 | self.data['test_unseen']['labels'] = test_unseen_label -------------------------------------------------------------------------------- /figs/Acadian_Flycatcher_0008_795599.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Acadian_Flycatcher_0008_795599.jpg -------------------------------------------------------------------------------- /figs/American_Goldfinch_0092_32910.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/American_Goldfinch_0092_32910.jpg -------------------------------------------------------------------------------- /figs/Canada_Warbler_0117_162394.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Canada_Warbler_0117_162394.jpg -------------------------------------------------------------------------------- /figs/Elegant_Tern_0085_151091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Elegant_Tern_0085_151091.jpg -------------------------------------------------------------------------------- /figs/European_Goldfinch_0025_794647.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/European_Goldfinch_0025_794647.jpg -------------------------------------------------------------------------------- /figs/Florida_Jay_0008_64482.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Florida_Jay_0008_64482.jpg -------------------------------------------------------------------------------- /figs/Fox_Sparrow_0025_114555.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Fox_Sparrow_0025_114555.jpg -------------------------------------------------------------------------------- /figs/Grasshopper_Sparrow_0053_115991.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Grasshopper_Sparrow_0053_115991.jpg -------------------------------------------------------------------------------- /figs/Grasshopper_Sparrow_0107_116286.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Grasshopper_Sparrow_0107_116286.jpg -------------------------------------------------------------------------------- /figs/Gray_Crowned_Rosy_Finch_0036_797287.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Gray_Crowned_Rosy_Finch_0036_797287.jpg -------------------------------------------------------------------------------- /figs/Vesper_Sparrow_0090_125690.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Vesper_Sparrow_0090_125690.jpg -------------------------------------------------------------------------------- /figs/Western_Gull_0058_53882.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Western_Gull_0058_53882.jpg -------------------------------------------------------------------------------- /figs/White_Throated_Sparrow_0128_128956.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/White_Throated_Sparrow_0128_128956.jpg -------------------------------------------------------------------------------- /figs/Winter_Wren_0118_189805.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Winter_Wren_0118_189805.jpg -------------------------------------------------------------------------------- /figs/Yellow_Breasted_Chat_0044_22106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/Yellow_Breasted_Chat_0044_22106.jpg -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/pipeline.png -------------------------------------------------------------------------------- /figs/tsne_awa2_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_awa2_seen.png -------------------------------------------------------------------------------- /figs/tsne_awa2_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_awa2_unseen.png -------------------------------------------------------------------------------- /figs/tsne_cub_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_cub_seen.png -------------------------------------------------------------------------------- /figs/tsne_cub_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_cub_unseen.png -------------------------------------------------------------------------------- /figs/tsne_sun_seen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_sun_seen.png -------------------------------------------------------------------------------- /figs/tsne_sun_unseen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tsne_sun_unseen.png -------------------------------------------------------------------------------- /figs/tzpp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/figs/tzpp.png -------------------------------------------------------------------------------- /helper_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def val_gzsl(test_X, test_label, target_classes,in_package,bias = 0): 6 | 7 | batch_size = in_package['batch_size'] 8 | model = in_package['model'] 9 | device = in_package['device'] 10 | with torch.no_grad(): 11 | start = 0 12 | ntest = test_X.size()[0] 13 | predicted_label = torch.LongTensor(test_label.size()) 14 | for i in range(0, ntest, batch_size): 15 | 16 | end = min(ntest, start+batch_size) 17 | 18 | input = test_X[start:end].to(device) 19 | 20 | out_package = model(input) 21 | 22 | output = out_package['S_pp'] 23 | output[:,target_classes] = output[:,target_classes]+bias 24 | predicted_label[start:end] = torch.argmax(output.data, 1) 25 | 26 | start = end 27 | 28 | acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package) 29 | return acc 30 | 31 | 32 | def map_label(label, classes): 33 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 34 | for i in range(classes.size(0)): 35 | mapped_label[label==classes[i]] = i 36 | 37 | return mapped_label 38 | 39 | 40 | def val_zs_gzsl(test_X, test_label, unseen_classes,in_package,bias = 0): 41 | batch_size = in_package['batch_size'] 42 | model = in_package['model'] 43 | device = in_package['device'] 44 | with torch.no_grad(): 45 | start = 0 46 | ntest = test_X.size()[0] 47 | predicted_label_gzsl = torch.LongTensor(test_label.size()) 48 | predicted_label_zsl = torch.LongTensor(test_label.size()) 49 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 50 | for i in range(0, ntest, batch_size): 51 | 52 | end = min(ntest, start+batch_size) 53 | 54 | input = test_X[start:end].to(device) 55 | 56 | out_package = model(input) 57 | output = out_package['S_pp'] 58 | 59 | output_t = output.clone() 60 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 61 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 62 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 63 | 64 | output[:,unseen_classes] = output[:,unseen_classes]+bias 65 | predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 66 | 67 | 68 | start = end 69 | acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package) 70 | acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package) 71 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t, unseen_classes.size(0)) 72 | 73 | return acc_gzsl,acc_zs_t 74 | 75 | 76 | def compute_per_class_acc(test_label, predicted_label, nclass): 77 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 78 | for i in range(nclass): 79 | idx = (test_label == i) 80 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 81 | return acc_per_class.mean().item() 82 | 83 | 84 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 85 | 86 | device = in_package['device'] 87 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 88 | 89 | predicted_label = predicted_label.to(device) 90 | 91 | for i in range(target_classes.size()[0]): 92 | 93 | is_class = test_label == target_classes[i] 94 | 95 | per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float()) 96 | return per_class_accuracies.mean().item() 97 | 98 | 99 | def eval_zs_gzsl(dataloader,model,device,bias_seen=0, bias_unseen=0, batch_size=50): 100 | model.eval() 101 | # print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 102 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 103 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 104 | 105 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 106 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 107 | 108 | seenclasses = dataloader.seenclasses 109 | unseenclasses = dataloader.unseenclasses 110 | 111 | batch_size = batch_size 112 | 113 | in_package = {'model':model,'device':device, 'batch_size':batch_size} 114 | 115 | with torch.no_grad(): 116 | acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen) 117 | acc_novel,acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen) 118 | 119 | if (acc_seen+acc_novel)>0: 120 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 121 | else: 122 | H = 0 123 | 124 | return acc_seen, acc_novel, H, acc_zs 125 | 126 | 127 | def val_gzsl_k(k,test_X, test_label, target_classes,in_package,bias = 0,is_detect=False): 128 | batch_size = in_package['batch_size'] 129 | model = in_package['model'] 130 | device = in_package['device'] 131 | n_classes = in_package["num_class"] 132 | 133 | with torch.no_grad(): 134 | start = 0 135 | ntest = test_X.size()[0] 136 | test_label = F.one_hot(test_label, num_classes=n_classes) 137 | predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device) 138 | for i in range(0, ntest, batch_size): 139 | 140 | end = min(ntest, start+batch_size) 141 | 142 | input = test_X[start:end].to(device) 143 | 144 | out_package = model(input) 145 | 146 | output = out_package['S_pp'] 147 | output[:,target_classes] = output[:,target_classes]+bias 148 | _,idx_k = torch.topk(output,k,dim=1) 149 | if is_detect: 150 | assert k == 1 151 | detection_mask=in_package["detection_mask"] 152 | predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)] 153 | else: 154 | predicted_label[start:end] = predicted_label[start:end].scatter_(1,idx_k,1) 155 | start = end 156 | 157 | acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package) 158 | return acc 159 | 160 | 161 | def val_zs_gzsl_k(k,test_X, test_label, unseen_classes,in_package,bias = 0,is_detect=False): 162 | batch_size = in_package['batch_size'] 163 | model = in_package['model'] 164 | device = in_package['device'] 165 | n_classes = in_package["num_class"] 166 | with torch.no_grad(): 167 | start = 0 168 | ntest = test_X.size()[0] 169 | 170 | test_label_gzsl = F.one_hot(test_label, num_classes=n_classes) 171 | predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device) 172 | 173 | predicted_label_zsl = torch.LongTensor(test_label.size()) 174 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 175 | for i in range(0, ntest, batch_size): 176 | 177 | end = min(ntest, start+batch_size) 178 | 179 | input = test_X[start:end].to(device) 180 | 181 | out_package = model(input) 182 | output = out_package['S_pp'] 183 | 184 | output_t = output.clone() 185 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 186 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 187 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 188 | 189 | output[:,unseen_classes] = output[:,unseen_classes]+bias 190 | _,idx_k = torch.topk(output,k,dim=1) 191 | if is_detect: 192 | assert k == 1 193 | detection_mask=in_package["detection_mask"] 194 | predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)] 195 | else: 196 | predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1,idx_k,1) 197 | 198 | start = end 199 | 200 | acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package) 201 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 202 | return acc_gzsl,-1 203 | 204 | 205 | def compute_per_class_acc_k(test_label, predicted_label, nclass): 206 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 207 | for i in range(nclass): 208 | idx = (test_label == i) 209 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 210 | return acc_per_class.mean().item() 211 | 212 | 213 | def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package): 214 | device = in_package['device'] 215 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 216 | 217 | predicted_label = predicted_label.to(device) 218 | 219 | hit = test_label*predicted_label 220 | for i in range(target_classes.size()[0]): 221 | 222 | target = target_classes[i] 223 | n_pos = torch.sum(hit[:,target]) 224 | n_gt = torch.sum(test_label[:,target]) 225 | per_class_accuracies[i] = torch.div(n_pos.float(),n_gt.float()) 226 | #pdb.set_trace() 227 | return per_class_accuracies.mean().item() 228 | 229 | 230 | def eval_zs_gzsl_k(k,dataloader,model,device,bias_seen,bias_unseen,is_detect=False): 231 | model.eval() 232 | print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 233 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 234 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 235 | 236 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 237 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 238 | 239 | seenclasses = dataloader.seenclasses 240 | unseenclasses = dataloader.unseenclasses 241 | 242 | batch_size = 100 243 | n_classes = dataloader.ntrain_class+dataloader.ntest_class 244 | in_package = {'model':model,'device':device, 'batch_size':batch_size,'num_class':n_classes} 245 | 246 | if is_detect: 247 | print("Measure novelty detection k: {}".format(k)) 248 | 249 | detection_mask = torch.zeros((n_classes,n_classes)).long().to(dataloader.device) 250 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 251 | detect_label[seenclasses]=1 252 | detection_mask[seenclasses,:] = detect_label 253 | 254 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 255 | detect_label[unseenclasses]=1 256 | detection_mask[unseenclasses,:]=detect_label 257 | in_package["detection_mask"]=detection_mask 258 | 259 | with torch.no_grad(): 260 | acc_seen = val_gzsl_k(k,test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen,is_detect=is_detect) 261 | acc_novel,acc_zs = val_zs_gzsl_k(k,test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen,is_detect=is_detect) 262 | 263 | if (acc_seen+acc_novel)>0: 264 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 265 | else: 266 | H = 0 267 | 268 | return acc_seen, acc_novel, H, acc_zs 269 | -------------------------------------------------------------------------------- /model_tzpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class TransZeroPP(nn.Module): 8 | def __init__(self, config, att, init_w2v_att, seenclass, unseenclass, 9 | is_bias=True, bias=1, is_conservative=True): 10 | super(TransZeroPP, self).__init__() 11 | self.config = config 12 | self.dim_f = config.dim_f 13 | self.dim_v = config.dim_v 14 | self.nclass = config.num_class 15 | self.seenclass = seenclass 16 | self.unseenclass = unseenclass 17 | self.is_bias = is_bias 18 | self.is_conservative = is_conservative 19 | # class-level semantic vectors 20 | self.att = nn.Parameter(F.normalize(att), requires_grad=False) 21 | # GloVe features for attributes 22 | self.V = nn.Parameter(F.normalize(init_w2v_att), requires_grad=True) 23 | # self-calibration 24 | self.bias = nn.Parameter(torch.tensor(bias), requires_grad=False) 25 | mask_bias = np.ones((1, self.nclass)) 26 | mask_bias[:, self.seenclass.cpu().numpy()] *= -1 27 | self.mask_bias = nn.Parameter(torch.tensor( 28 | mask_bias, dtype=torch.float), requires_grad=False) 29 | # mapping 30 | # s2v 31 | self.W_1_s2v = nn.Parameter(nn.init.normal_(torch.empty( 32 | self.dim_v, config.tf_common_dim)), requires_grad=True) 33 | self.W_3_s2v = nn.Parameter(nn.init.zeros_(torch.empty( 34 | self.dim_v, config.tf_common_dim)), requires_grad=True) 35 | # v2s 36 | if config.tf_v2s_init == 'zeros_': 37 | self.W_1_v2s = nn.Parameter(nn.init.zeros_(torch.empty( 38 | config.tf_common_dim, config.tf_common_dim)), requires_grad=True) 39 | elif config.tf_v2s_init == 'normal_': 40 | self.W_1_v2s = nn.Parameter(nn.init.normal_(torch.empty( 41 | config.tf_common_dim, config.tf_common_dim)), requires_grad=True) 42 | self.W_3_v2s = nn.Parameter(nn.init.zeros_(torch.empty( 43 | self.dim_f, config.tf_common_dim)), requires_grad=True) 44 | self.W_4_v2s = nn.Parameter(nn.init.normal_(torch.empty( 45 | config.tf_common_dim, config.tf_common_dim)), requires_grad=True) 46 | # transformer semantic -> visual 47 | self.transformer_s2v = TransformerPP( 48 | ec_layer=config.tf_ec_layer, 49 | dc_layer=config.tf_dc_layer, 50 | dim_com=config.tf_common_dim, 51 | dim_feedforward=config.tf_dim_feedforward, 52 | dropout=config.tf_dropout, 53 | SAtt=config.tf_SAtt, 54 | heads=config.tf_heads, 55 | aux_embed=config.tf_aux_embed) 56 | # transformer visual -> semantic 57 | self.transformer_v2s = TransformerPP( 58 | ec_layer=config.tf_ec_layer, 59 | dc_layer=config.tf_dc_layer, 60 | dim_com=config.tf_common_dim, 61 | dim_feedforward=config.tf_dim_feedforward, 62 | dropout=config.tf_dropout, 63 | SAtt=config.tf_SAtt, 64 | heads=config.tf_heads, 65 | aux_embed=config.tf_aux_embed) 66 | # loss 67 | self.log_softmax_func = nn.LogSoftmax(dim=1) 68 | self.weight_ce = nn.Parameter(torch.eye(self.nclass), requires_grad=False) 69 | 70 | def forward(self, input, from_img=False): 71 | Fs = self.resnet101(input) if from_img else input 72 | # transformer-based visual-to-semantic embedding 73 | embed_s2v, embed_v2s = self.forward_feature_transformer(Fs) 74 | # classification 75 | package_s2v = {} 76 | package_s2v['pred'] = self.forward_attribute(embed_s2v) 77 | package_s2v['embed'] = embed_s2v 78 | package_v2s = {} 79 | package_v2s['pred'] = self.forward_attribute(embed_v2s) 80 | package_v2s['embed'] = embed_v2s 81 | out_package = {} 82 | out_package['package_s2v'] = package_s2v 83 | out_package['package_v2s'] = package_v2s 84 | out_package['pred'] = self.config.weight_s2v * package_s2v['pred'] + \ 85 | (1 - self.config.weight_s2v) * package_v2s['pred'] 86 | out_package['S_pp'] = out_package['pred'] 87 | return out_package 88 | 89 | def forward_feature_transformer(self, Fs): 90 | if len(Fs.shape) == 4: 91 | shape = Fs.shape 92 | Fs = Fs.reshape(shape[0], shape[1], shape[2] * shape[3]) 93 | Fs = F.normalize(Fs, dim=1) 94 | Fs_pmt = Fs.permute(0, 2, 1) 95 | V_n = F.normalize(self.V) if self.config.normalize_V else self.V 96 | V_n_batch = V_n.unsqueeze(0).repeat(shape[0], 1, 1) 97 | # semantic-2-visual 98 | memory_s2v, _, emb_att_s2v = self.transformer_s2v.forward_encoder( 99 | Fs_pmt, V_n_batch) 100 | F_p_s2v = self.transformer_s2v.forward_decoder( 101 | memory_s2v, emb_att_s2v, type='s2v') 102 | S_p_s2v = torch.einsum('biv,vc,bic->bi', V_n_batch, self.W_1_s2v, F_p_s2v) 103 | embed_s2v = S_p_s2v 104 | # visual-2-semantic 105 | memory_v2s, emb_vis_v2s, emb_att_v2s = self.transformer_v2s.forward_encoder( 106 | Fs_pmt, V_n_batch) 107 | F_p_v2s = self.transformer_v2s.forward_decoder( 108 | memory_v2s, emb_att_v2s, type='v2s') 109 | S_p_v2s = torch.einsum('rbf,fc,brc->br', memory_v2s, self.W_1_v2s, F_p_v2s) 110 | E_v2s = torch.einsum('brc,cc,bic->bir', emb_vis_v2s, self.W_4_v2s, emb_att_v2s) 111 | embed_v2s = torch.einsum('bir,br->bi', E_v2s, S_p_v2s) 112 | return embed_s2v, embed_v2s 113 | 114 | def forward_attribute(self, embed): 115 | embed = torch.einsum('ki,bi->bk', self.att, embed) 116 | self.vec_bias = self.mask_bias*self.bias 117 | embed = embed + self.vec_bias 118 | return embed 119 | 120 | def compute_loss_Self_Calibrate(self, in_package): 121 | S_pp = in_package['pred'] 122 | Prob_all = F.softmax(S_pp, dim=-1) 123 | Prob_unseen = Prob_all[:, self.unseenclass] 124 | assert Prob_unseen.size(1) == len(self.unseenclass) 125 | mass_unseen = torch.sum(Prob_unseen, dim=1) 126 | loss_pmp = -torch.log(torch.mean(mass_unseen)) 127 | return loss_pmp 128 | 129 | def compute_aug_cross_entropy(self, in_package): 130 | Labels = in_package['batch_label'] 131 | S_pp = in_package['pred'] 132 | 133 | if self.is_bias: 134 | S_pp = S_pp - self.vec_bias 135 | 136 | if not self.is_conservative: 137 | S_pp = S_pp[:, self.seenclass] 138 | Labels = Labels[:, self.seenclass] 139 | assert S_pp.size(1) == len(self.seenclass) 140 | 141 | Prob = self.log_softmax_func(S_pp) 142 | 143 | loss = -torch.einsum('bk,bk->b', Prob, Labels) 144 | loss = torch.mean(loss) 145 | return loss 146 | 147 | def compute_reg_loss(self, in_package): 148 | tgt = torch.matmul(in_package['batch_label'], self.att) 149 | embed = in_package['embed'] 150 | loss_reg = F.mse_loss(embed, tgt, reduction='mean') 151 | return loss_reg 152 | 153 | def compute_loss(self, in_package): 154 | if len(in_package['batch_label'].size()) == 1: 155 | in_package['batch_label'] = self.weight_ce[in_package['batch_label']] 156 | 157 | loss_CE = self.compute_aug_cross_entropy(in_package) 158 | loss_cal = self.compute_loss_Self_Calibrate(in_package) 159 | loss_reg = self.compute_reg_loss(in_package) 160 | 161 | loss = loss_CE + self.config.lambda_ * \ 162 | loss_cal + self.config.lambda_reg * loss_reg 163 | out_package = {'loss': loss, 'loss_CE': loss_CE, 164 | 'loss_cal': loss_cal, 'loss_reg': loss_reg} 165 | return out_package 166 | 167 | def WeightedL2(self, pred, gt): 168 | loss = F.mse_loss(pred, gt, reduction='mean') 169 | return loss 170 | 171 | def compute_contrastive_loss(self, in_package1, in_package2): 172 | reg_func = self.WeightedL2 173 | loss_att = reg_func(in_package1['embed'], in_package2['embed']) 174 | loss_cls = reg_func(in_package1['pred'], in_package2['pred']) 175 | return loss_att, loss_cls 176 | 177 | 178 | class Transformer(nn.Module): 179 | def __init__(self, ec_layer=1, dc_layer=1, dim_com=300, 180 | dim_feedforward=2048, dropout=0.1, heads=1, 181 | in_dim_cv=2048, in_dim_attr=300, SAtt=True, 182 | aux_embed=True): 183 | super(Transformer, self).__init__() 184 | # input embedding 185 | self.embed_cv = nn.Sequential(nn.Linear(in_dim_cv, dim_com)) 186 | if aux_embed: 187 | self.embed_cv_aux = nn.Sequential(nn.Linear(in_dim_cv, dim_com)) 188 | self.embed_attr = nn.Sequential(nn.Linear(in_dim_attr, dim_com)) 189 | # transformer encoder 190 | self.transformer_encoder = MultiLevelEncoder_woPad(N=ec_layer, 191 | d_model=dim_com, 192 | h=1, 193 | d_k=dim_com, 194 | d_v=dim_com, 195 | d_ff=dim_feedforward, 196 | dropout=dropout) 197 | # transformer decoder 198 | decoder_layer = TransformerDecoderLayer(d_model=dim_com, 199 | nhead=heads, 200 | dim_feedforward=dim_feedforward, 201 | dropout=dropout, 202 | SAtt=SAtt) 203 | self.transformer_decoder = nn.TransformerDecoder( 204 | decoder_layer, num_layers=dc_layer) 205 | 206 | def forward(self, f_cv, f_attr): 207 | # linearly map to common dim 208 | h_cv = self.embed_cv(f_cv.permute(0, 2, 1)) 209 | h_attr = self.embed_attr(f_attr) 210 | h_attr_batch = h_attr.unsqueeze(0).repeat(f_cv.shape[0], 1, 1) 211 | # visual encoder 212 | memory = self.transformer_encoder(h_cv).permute(1, 0, 2) 213 | # attribute-visual decoder 214 | out = self.transformer_decoder(h_attr_batch.permute(1, 0, 2), memory) 215 | return out.permute(1, 0, 2) 216 | 217 | 218 | class TransformerPP(Transformer): 219 | def __init__(self, ec_layer=1, dc_layer=1, dim_com=300, 220 | dim_feedforward=2048, dropout=0.1, heads=1, 221 | in_dim_cv=2048, in_dim_attr=300, SAtt=True, 222 | aux_embed=True): 223 | 224 | super(TransformerPP, self).__init__( 225 | ec_layer, dc_layer, dim_com, dim_feedforward, dropout, 226 | heads, in_dim_cv, in_dim_attr, SAtt, aux_embed) 227 | 228 | def forward_encoder(self, f_cv, f_attr, pre_embed=True, is_enc=True): 229 | if pre_embed: 230 | h_cv = self.embed_cv(f_cv) 231 | h_attr = self.embed_attr(f_attr) 232 | else: 233 | h_cv = f_cv 234 | h_attr = f_attr 235 | if is_enc: 236 | memory = self.transformer_encoder(h_cv).permute(1, 0, 2) 237 | return memory, h_cv, h_attr 238 | 239 | def forward_decoder(self, memory, h_attr, type='s2v'): 240 | if type == 's2v': 241 | out = self.transformer_decoder(h_attr.permute(1, 0, 2), memory) 242 | elif type == 'v2s': 243 | out = self.transformer_decoder(memory, h_attr.permute(1, 0, 2)) 244 | return out.permute(1, 0, 2) 245 | 246 | 247 | class EncoderLayer(nn.Module): 248 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 249 | dropout=.1, identity_map_reordering=False, 250 | attention_module=None, attention_module_kwargs=None): 251 | super(EncoderLayer, self).__init__() 252 | self.identity_map_reordering = identity_map_reordering 253 | self.mhatt = MultiHeadGeometryAttention(d_model, d_k, d_v, h, dropout, 254 | identity_map_reordering=identity_map_reordering, 255 | attention_module=attention_module, 256 | attention_module_kwargs=attention_module_kwargs) 257 | self.dropout = nn.Dropout(dropout) 258 | self.lnorm = nn.LayerNorm(d_model) 259 | self.pwff = PositionWiseFeedForward( 260 | d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 261 | 262 | def forward(self, queries, keys, values, relative_geometry_weights, 263 | attention_mask=None, attention_weights=None, pos=None): 264 | q, k = (queries + pos, keys + 265 | pos) if pos is not None else (queries, keys) 266 | att = self.mhatt(q, k, values, relative_geometry_weights, 267 | attention_mask, attention_weights) 268 | att = self.lnorm(queries + self.dropout(att)) 269 | ff = self.pwff(att) 270 | return ff 271 | 272 | 273 | class MultiLevelEncoder_woPad(nn.Module): 274 | def __init__(self, N, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 275 | dropout=.1, identity_map_reordering=False, 276 | attention_module=None, attention_module_kwargs=None): 277 | super(MultiLevelEncoder_woPad, self).__init__() 278 | self.d_model = d_model 279 | self.dropout = dropout 280 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 281 | identity_map_reordering=identity_map_reordering, 282 | attention_module=attention_module, 283 | attention_module_kwargs=attention_module_kwargs) 284 | for _ in range(N)]) 285 | 286 | self.WGs = nn.ModuleList( 287 | [nn.Linear(64, 1, bias=True) for _ in range(h)]) 288 | 289 | def forward(self, input, attention_mask=None, attention_weights=None, pos=None): 290 | relative_geometry_embeddings = BoxRelationalEmbedding( 291 | input, grid_size=(14, 14)) 292 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view( 293 | -1, 64) 294 | box_size_per_head = list(relative_geometry_embeddings.shape[:3]) 295 | box_size_per_head.insert(1, 1) 296 | relative_geometry_weights_per_head = [layer( 297 | flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs] 298 | relative_geometry_weights = torch.cat( 299 | (relative_geometry_weights_per_head), 1) 300 | relative_geometry_weights = F.relu(relative_geometry_weights) 301 | out = input 302 | for layer in self.layers: 303 | out = layer(out, out, out, relative_geometry_weights, 304 | attention_mask, attention_weights, pos=pos) 305 | return out 306 | 307 | 308 | class TransformerDecoderLayer(nn.TransformerDecoderLayer): 309 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 310 | activation="relu", SAtt=True): 311 | super(TransformerDecoderLayer, self).__init__(d_model, nhead, 312 | dim_feedforward=dim_feedforward, 313 | dropout=dropout, 314 | activation=activation) 315 | self.SAtt = SAtt 316 | 317 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 318 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 319 | if self.SAtt: 320 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 321 | key_padding_mask=tgt_key_padding_mask)[0] 322 | tgt = self.norm1(tgt + self.dropout1(tgt2)) 323 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 324 | key_padding_mask=memory_key_padding_mask)[0] 325 | tgt = tgt + self.dropout2(tgt2) 326 | tgt = self.norm2(tgt) 327 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 328 | tgt = tgt + self.dropout3(tgt2) 329 | tgt = self.norm3(tgt) 330 | return tgt 331 | 332 | 333 | def get_relative_pos(x, batch_size, norm_len): 334 | x = x.view(1, -1, 1).expand(batch_size, -1, -1) 335 | return x / norm_len 336 | 337 | 338 | def get_grids_pos(batch_size, seq_len, grid_size=(7, 7)): 339 | assert seq_len == grid_size[0] * grid_size[1] 340 | x = torch.arange(0, grid_size[0]).float().cuda() 341 | y = torch.arange(0, grid_size[1]).float().cuda() 342 | px_min = x.view(-1, 1).expand(-1, grid_size[0]).contiguous().view(-1) 343 | py_min = y.view(1, -1).expand(grid_size[1], -1).contiguous().view(-1) 344 | px_max = px_min + 1 345 | py_max = py_min + 1 346 | rpx_min = get_relative_pos(px_min, batch_size, grid_size[0]) 347 | rpy_min = get_relative_pos(py_min, batch_size, grid_size[1]) 348 | rpx_max = get_relative_pos(px_max, batch_size, grid_size[0]) 349 | rpy_max = get_relative_pos(py_max, batch_size, grid_size[1]) 350 | return rpx_min, rpy_min, rpx_max, rpy_max 351 | 352 | 353 | def BoxRelationalEmbedding(f_g, dim_g=64, wave_len=1000, trignometric_embedding=True, 354 | grid_size=(7, 7)): 355 | batch_size, seq_len = f_g.size(0), f_g.size(1) 356 | x_min, y_min, x_max, y_max = get_grids_pos(batch_size, seq_len, grid_size) 357 | cx = (x_min + x_max) * 0.5 358 | cy = (y_min + y_max) * 0.5 359 | w = (x_max - x_min) + 1. 360 | h = (y_max - y_min) + 1. 361 | delta_x = cx - cx.view(batch_size, 1, -1) 362 | delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3) 363 | delta_x = torch.log(delta_x) 364 | delta_y = cy - cy.view(batch_size, 1, -1) 365 | delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3) 366 | delta_y = torch.log(delta_y) 367 | delta_w = torch.log(w / w.view(batch_size, 1, -1)) 368 | delta_h = torch.log(h / h.view(batch_size, 1, -1)) 369 | matrix_size = delta_h.size() 370 | delta_x = delta_x.view(batch_size, matrix_size[1], matrix_size[2], 1) 371 | delta_y = delta_y.view(batch_size, matrix_size[1], matrix_size[2], 1) 372 | delta_w = delta_w.view(batch_size, matrix_size[1], matrix_size[2], 1) 373 | delta_h = delta_h.view(batch_size, matrix_size[1], matrix_size[2], 1) 374 | position_mat = torch.cat((delta_x, delta_y, delta_w, delta_h), -1) 375 | if trignometric_embedding == True: 376 | feat_range = torch.arange(dim_g / 8).cuda() 377 | dim_mat = feat_range / (dim_g / 8) 378 | dim_mat = 1. / (torch.pow(wave_len, dim_mat)) 379 | dim_mat = dim_mat.view(1, 1, 1, -1) 380 | position_mat = position_mat.view( 381 | batch_size, matrix_size[1], matrix_size[2], 4, -1) 382 | position_mat = 100. * position_mat 383 | mul_mat = position_mat * dim_mat 384 | mul_mat = mul_mat.view(batch_size, matrix_size[1], matrix_size[2], -1) 385 | sin_mat = torch.sin(mul_mat) 386 | cos_mat = torch.cos(mul_mat) 387 | embedding = torch.cat((sin_mat, cos_mat), -1) 388 | else: 389 | embedding = position_mat 390 | return (embedding) 391 | 392 | 393 | class ScaledDotProductGeometryAttention(nn.Module): 394 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None): 395 | super(ScaledDotProductGeometryAttention, self).__init__() 396 | self.fc_q = nn.Linear(d_model, h * d_k) 397 | self.fc_k = nn.Linear(d_model, h * d_k) 398 | self.fc_v = nn.Linear(d_model, h * d_v) 399 | self.fc_o = nn.Linear(h * d_v, d_model) 400 | self.dropout = nn.Dropout(dropout) 401 | self.d_model = d_model 402 | self.d_k = d_k 403 | self.d_v = d_v 404 | self.h = h 405 | self.init_weights() 406 | self.comment = comment 407 | 408 | def init_weights(self): 409 | nn.init.xavier_uniform_(self.fc_q.weight) 410 | nn.init.xavier_uniform_(self.fc_k.weight) 411 | nn.init.xavier_uniform_(self.fc_v.weight) 412 | nn.init.xavier_uniform_(self.fc_o.weight) 413 | nn.init.constant_(self.fc_q.bias, 0) 414 | nn.init.constant_(self.fc_k.bias, 0) 415 | nn.init.constant_(self.fc_v.bias, 0) 416 | nn.init.constant_(self.fc_o.bias, 0) 417 | 418 | def forward(self, queries, keys, values, box_relation_embed_matrix, 419 | attention_mask=None, attention_weights=None): 420 | b_s, nq = queries.shape[:2] 421 | nk = keys.shape[1] 422 | q = self.fc_q(queries).view(b_s, nq, self.h, 423 | self.d_k).permute(0, 2, 1, 3) 424 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) 425 | v = self.fc_v(values).view(b_s, nk, self.h, 426 | self.d_v).permute(0, 2, 1, 3) 427 | att = torch.matmul(q, k) / np.sqrt(self.d_k) 428 | if attention_weights is not None: 429 | att = att * attention_weights 430 | if attention_mask is not None: 431 | att = att.masked_fill(attention_mask, -np.inf) 432 | w_g = box_relation_embed_matrix 433 | w_a = att 434 | w_mn = - w_g + w_a 435 | w_mn = torch.softmax(w_mn, -1) 436 | att = self.dropout(w_mn) 437 | out = torch.matmul(att, v).permute( 438 | 0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) 439 | out = self.fc_o(out) 440 | return out 441 | 442 | 443 | class MultiHeadGeometryAttention(nn.Module): 444 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, 445 | can_be_stateful=False, attention_module=None, 446 | attention_module_kwargs=None, comment=None): 447 | super(MultiHeadGeometryAttention, self).__init__() 448 | self.identity_map_reordering = identity_map_reordering 449 | self.attention = ScaledDotProductGeometryAttention( 450 | d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment) 451 | self.dropout = nn.Dropout(p=dropout) 452 | self.layer_norm = nn.LayerNorm(d_model) 453 | self.can_be_stateful = can_be_stateful 454 | if self.can_be_stateful: 455 | self.register_state('running_keys', torch.zeros((0, d_model))) 456 | self.register_state('running_values', torch.zeros((0, d_model))) 457 | 458 | def forward(self, queries, keys, values, relative_geometry_weights, 459 | attention_mask=None, attention_weights=None): 460 | if self.can_be_stateful and self._is_stateful: 461 | self.running_keys = torch.cat([self.running_keys, keys], 1) 462 | keys = self.running_keys 463 | self.running_values = torch.cat([self.running_values, values], 1) 464 | values = self.running_values 465 | if self.identity_map_reordering: 466 | q_norm = self.layer_norm(queries) 467 | k_norm = self.layer_norm(keys) 468 | v_norm = self.layer_norm(values) 469 | out = self.attention(q_norm, k_norm, v_norm, relative_geometry_weights, 470 | attention_mask, attention_weights) 471 | out = queries + self.dropout(torch.relu(out)) 472 | else: 473 | out = self.attention(queries, keys, values, relative_geometry_weights, 474 | attention_mask, attention_weights) 475 | out = self.dropout(out) 476 | out = self.layer_norm(queries + out) 477 | return out 478 | 479 | 480 | class PositionWiseFeedForward(nn.Module): 481 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 482 | super(PositionWiseFeedForward, self).__init__() 483 | self.identity_map_reordering = identity_map_reordering 484 | self.fc1 = nn.Linear(d_model, d_ff) 485 | self.fc2 = nn.Linear(d_ff, d_model) 486 | self.dropout = nn.Dropout(p=dropout) 487 | self.dropout_2 = nn.Dropout(p=dropout) 488 | self.layer_norm = nn.LayerNorm(d_model) 489 | 490 | def forward(self, input): 491 | if self.identity_map_reordering: 492 | out = self.layer_norm(input) 493 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 494 | out = input + self.dropout(torch.relu(out)) 495 | else: 496 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 497 | out = self.dropout(out) 498 | out = self.layer_norm(input + out) 499 | return out 500 | 501 | 502 | if __name__ == '__main__': 503 | pass 504 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | import argparse 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import scipy.io as sio 9 | import torchvision.models.resnet as models 10 | from torchvision import datasets, transforms 11 | from torch.utils.data import Dataset, DataLoader 12 | from PIL import Image 13 | 14 | 15 | class CustomedDataset(Dataset): 16 | def __init__(self, dataset, img_dir, file_paths, transform=None): 17 | self.dataset = dataset 18 | self.matcontent = sio.loadmat(file_paths) 19 | self.image_files = np.squeeze(self.matcontent['image_files']) 20 | self.img_dir = img_dir 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return len(self.image_files) 25 | 26 | def __getitem__(self, idx): 27 | image_file = self.image_files[idx][0] 28 | if self.dataset == 'CUB': 29 | split_idx = 6 30 | elif self.dataset == 'SUN': 31 | split_idx = 7 32 | elif self.dataset == 'AWA2': 33 | split_idx = 5 34 | image_file = os.path.join(self.img_dir, 35 | '/'.join(image_file.split('/')[split_idx:])) 36 | image = Image.open(image_file) 37 | if image.mode != 'RGB': 38 | image = image.convert('RGB') 39 | if self.transform: 40 | image = self.transform(image) 41 | return image 42 | 43 | 44 | def extract_features(config): 45 | 46 | img_dir = f'data/{config.dataset}' 47 | file_paths = f'data/xlsa17/data/{config.dataset}/res101.mat' 48 | save_path = f'data/{config.dataset}/feature_map_ResNet_101_{config.dataset}.hdf5' 49 | attribute_path = f'w2v/{config.dataset}_attribute.pkl' 50 | 51 | # region feature extractor 52 | resnet101 = models.resnet101(pretrained=True).to(config.device) 53 | resnet101 = nn.Sequential(*list(resnet101.children())[:-2]).eval() 54 | 55 | data_transforms = transforms.Compose([ 56 | transforms.Resize(448), 57 | transforms.CenterCrop(448), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 60 | 61 | Dataset = CustomedDataset(config.dataset, img_dir, file_paths, data_transforms) 62 | dataset_loader = torch.utils.data.DataLoader(Dataset, 63 | batch_size=config.batch_size, 64 | shuffle=False, 65 | num_workers=config.nun_workers) 66 | 67 | with torch.no_grad(): 68 | all_features = [] 69 | for _, imgs in enumerate(dataset_loader): 70 | imgs = imgs.to(config.device) 71 | features = resnet101(imgs) 72 | all_features.append(features.cpu().numpy()) 73 | all_features = np.concatenate(all_features, axis=0) 74 | 75 | # get remaining metadata 76 | matcontent = Dataset.matcontent 77 | labels = matcontent['labels'].astype(int).squeeze() - 1 78 | 79 | split_path = os.path.join(f'data/xlsa17/data/{config.dataset}/att_splits.mat') 80 | matcontent = sio.loadmat(split_path) 81 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 82 | # train_loc = matcontent['train_loc'].squeeze() - 1 83 | # val_unseen_loc = matcontent['val_loc'].squeeze() - 1 84 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 85 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 86 | att = matcontent['att'].T 87 | original_att = matcontent['original_att'].T 88 | 89 | # construct attribute w2v 90 | with open(attribute_path,'rb') as f: 91 | w2v_att = pickle.load(f) 92 | if config.dataset == 'CUB': 93 | assert w2v_att.shape == (312,300) 94 | elif config.dataset == 'SUN': 95 | assert w2v_att.shape == (102,300) 96 | elif config.dataset == 'AWA2': 97 | assert w2v_att.shape == (85,300) 98 | 99 | compression = 'gzip' if config.compression else None 100 | f = h5py.File(save_path, 'w') 101 | f.create_dataset('feature_map', data=all_features,compression=compression) 102 | f.create_dataset('labels', data=labels,compression=compression) 103 | f.create_dataset('trainval_loc', data=trainval_loc,compression=compression) 104 | # f.create_dataset('train_loc', data=train_loc,compression=compression) 105 | # f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression=compression) 106 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression=compression) 107 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression=compression) 108 | f.create_dataset('att', data=att,compression=compression) 109 | f.create_dataset('original_att', data=original_att,compression=compression) 110 | f.create_dataset('w2v_att', data=w2v_att,compression=compression) 111 | f.close() 112 | 113 | 114 | if __name__=='__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--dataset', '-d', type=str, default='AWA2') 117 | parser.add_argument('--compression', '-c', action='store_true', default=False) 118 | parser.add_argument('--batch_size', '-b', type=int, default=200) 119 | parser.add_argument('--device', '-g', type=str, default='cuda:0') 120 | parser.add_argument('--nun_workers', '-n', type=int, default='16') 121 | config = parser.parse_args() 122 | extract_features(config) 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.9.0 2 | h5py==3.2.1 3 | torch==1.8.0 4 | numpy==1.20.2 5 | pandas==1.2.3 6 | scipy==1.6.2 7 | Pillow==10.0.0 8 | scikit_learn==1.3.0 9 | -------------------------------------------------------------------------------- /train_awa2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from model_tzpp import TransZeroPP 5 | from dataset import AWA2DataLoader 6 | from helper_func import eval_zs_gzsl 7 | import numpy as np 8 | import wandb 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZeroPP', config='wandb_config/awa2_gzsl.yaml') 12 | config = wandb.config 13 | print('Config file from wandb:', config) 14 | 15 | # load dataset 16 | dataloader = AWA2DataLoader('.', config.device) 17 | 18 | # set random seed 19 | seed = config.random_seed 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | init_w2v_att = dataloader.w2v_att 25 | att = dataloader.att 26 | att[att<0] = 0 27 | normalize_att = dataloader.normalize_att 28 | 29 | # TransZero model 30 | model = TransZeroPP(config, dataloader.att, dataloader.w2v_att, 31 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 32 | optimizer = optim.Adam(model.parameters(), lr=0.00001, weight_decay=0.0001) 33 | 34 | # main loop 35 | niters = dataloader.ntrain * config.epochs//config.batch_size 36 | report_interval = niters//config.epochs 37 | best_performance = [0, 0, 0, 0] 38 | best_performance_zsl = 0 39 | for i in range(0, niters): 40 | model.train() 41 | optimizer.zero_grad() 42 | 43 | batch_label, batch_feature, batch_att = dataloader.next_batch( 44 | config.batch_size) 45 | out_package = model(batch_feature) 46 | 47 | in_package1 = out_package['package_s2v'] 48 | in_package2 = out_package['package_v2s'] 49 | in_package1['batch_label'] = batch_label 50 | in_package2['batch_label'] = batch_label 51 | 52 | out_package1=model.compute_loss(in_package1) 53 | out_package2=model.compute_loss(in_package2) 54 | 55 | loss = out_package1['loss'] + config.lambda_v2s * out_package2['loss'] 56 | loss_CE = out_package1['loss_CE'] + out_package2['loss_CE'] 57 | loss_cal = out_package1['loss_cal'] + out_package2['loss_cal'] 58 | loss_reg = out_package1['loss_reg'] + out_package2['loss_reg'] 59 | 60 | loss_att, loss_cls = model.compute_contrastive_loss( 61 | in_package1, in_package2) 62 | loss += config.lambda_cst_reg_att * loss_att 63 | loss += config.lambda_cst_reg_cls * loss_cls 64 | 65 | loss.backward() 66 | optimizer.step() 67 | 68 | # report result 69 | if i % report_interval == 0: 70 | print('-'*30) 71 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 72 | dataloader, model, config.device, batch_size=config.batch_size) 73 | 74 | if H > best_performance[2]: 75 | best_performance = [acc_novel, acc_seen, H, acc_zs] 76 | if acc_zs > best_performance_zsl: 77 | best_performance_zsl = acc_zs 78 | 79 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 80 | 'loss_reg=%.3f, loss_cst_att=%.3f, loss_cst_cls=%.3f | ' % ( 81 | i, int(i//report_interval), 82 | loss.item(), loss_CE.item(), loss_cal.item(), 83 | loss_reg.item(), loss_att.item(), loss_cls.item())) 84 | print('Current GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | Current CZSL: acc_zs=%.3f' % ( 85 | acc_novel, acc_seen, H, acc_zs)) 86 | print('BEST GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f' 87 | ' | BEST CZSL: acc_zs=%.3f' % ( 88 | best_performance[0], best_performance[1], 89 | best_performance[2], best_performance[3], 90 | best_performance_zsl)) 91 | 92 | wandb.log({ 93 | 'iter': i, 94 | 'loss': loss.item(), 95 | 'loss_CE': loss_CE.item(), 96 | 'loss_cal': loss_cal.item(), 97 | 'loss_reg': loss_reg.item(), 98 | 'loss_att': loss_att.item(), 99 | 'loss_cls': loss_cls.item(), 100 | 'acc_unseen': acc_novel, 101 | 'acc_seen': acc_seen, 102 | 'H': H, 103 | 'acc_zs': acc_zs, 104 | 'best_acc_unseen': best_performance[0], 105 | 'best_acc_seen': best_performance[1], 106 | 'best_H': best_performance[2], 107 | 'best_acc_zs': best_performance_zsl 108 | }) 109 | -------------------------------------------------------------------------------- /train_cub.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from model_tzpp import TransZeroPP 5 | from dataset import CUBDataLoader 6 | from helper_func import eval_zs_gzsl 7 | import numpy as np 8 | import wandb 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZeroPP', config='wandb_config/cub_gzsl.yaml') 12 | config = wandb.config 13 | print('Config file from wandb:', config) 14 | 15 | # load dataset 16 | dataloader = CUBDataLoader('.', config.device, is_balance=False) 17 | 18 | # set random seed 19 | seed = config.random_seed 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | # TransZero model 25 | model = TransZeroPP(config, dataloader.att, dataloader.w2v_att, 26 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 27 | optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001) 28 | 29 | # main loop 30 | niters = dataloader.ntrain * config.epochs//config.batch_size 31 | report_interval = niters//config.epochs 32 | best_performance = [0, 0, 0, 0] 33 | best_performance_zsl = 0 34 | for i in range(0, niters): 35 | model.train() 36 | optimizer.zero_grad() 37 | 38 | batch_label, batch_feature, batch_att = dataloader.next_batch( 39 | config.batch_size) 40 | out_package = model(batch_feature) 41 | 42 | in_package1 = out_package['package_s2v'] 43 | in_package2 = out_package['package_v2s'] 44 | in_package1['batch_label'] = batch_label 45 | in_package2['batch_label'] = batch_label 46 | 47 | out_package1=model.compute_loss(in_package1) 48 | out_package2=model.compute_loss(in_package2) 49 | 50 | loss = out_package1['loss'] + config.lambda_v2s * out_package2['loss'] 51 | loss_CE = out_package1['loss_CE'] + out_package2['loss_CE'] 52 | loss_cal = out_package1['loss_cal'] + out_package2['loss_cal'] 53 | loss_reg = out_package1['loss_reg'] + out_package2['loss_reg'] 54 | 55 | loss_att, loss_cls = model.compute_contrastive_loss( 56 | in_package1, in_package2) 57 | loss += config.lambda_cst_reg_att * loss_att 58 | loss += config.lambda_cst_reg_cls * loss_cls 59 | 60 | loss.backward() 61 | optimizer.step() 62 | 63 | # report result 64 | if i % report_interval == 0: 65 | print('-'*30) 66 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 67 | dataloader, model, config.device, batch_size=config.batch_size) 68 | 69 | if H > best_performance[2]: 70 | best_performance = [acc_novel, acc_seen, H, acc_zs] 71 | if acc_zs > best_performance_zsl: 72 | best_performance_zsl = acc_zs 73 | 74 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 75 | 'loss_reg=%.3f, loss_cst_att=%.3f, loss_cst_cls=%.3f | ' % ( 76 | i, int(i//report_interval), 77 | loss.item(), loss_CE.item(), loss_cal.item(), 78 | loss_reg.item(), loss_att.item(), loss_cls.item())) 79 | print('Current GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | Current CZSL: acc_zs=%.3f' % ( 80 | acc_novel, acc_seen, H, acc_zs)) 81 | print('BEST GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f' 82 | ' | BEST CZSL: acc_zs=%.3f' % ( 83 | best_performance[0], best_performance[1], 84 | best_performance[2], best_performance[3], 85 | best_performance_zsl)) 86 | 87 | wandb.log({ 88 | 'iter': i, 89 | 'loss': loss.item(), 90 | 'loss_CE': loss_CE.item(), 91 | 'loss_cal': loss_cal.item(), 92 | 'loss_reg': loss_reg.item(), 93 | 'loss_att': loss_att.item(), 94 | 'loss_cls': loss_cls.item(), 95 | 'acc_unseen': acc_novel, 96 | 'acc_seen': acc_seen, 97 | 'H': H, 98 | 'acc_zs': acc_zs, 99 | 'best_acc_unseen': best_performance[0], 100 | 'best_acc_seen': best_performance[1], 101 | 'best_H': best_performance[2], 102 | 'best_acc_zs': best_performance_zsl 103 | }) 104 | -------------------------------------------------------------------------------- /train_sun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from model_tzpp import TransZeroPP 5 | from dataset import SUNDataLoader 6 | from helper_func import eval_zs_gzsl 7 | import numpy as np 8 | import wandb 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZeroPP', config='wandb_config/sun_gzsl.yaml') 12 | # wandb.init(project='TransZeroPP', config='wandb_config/sun_czsl.yaml') 13 | config = wandb.config 14 | print('Config file from wandb:', config) 15 | 16 | # load dataset 17 | dataloader = SUNDataLoader('.', config.device) 18 | 19 | # set random seed 20 | seed = config.random_seed 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | 25 | # TransZero model 26 | model = TransZeroPP(config, dataloader.att, dataloader.w2v_att, 27 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 28 | optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001) 29 | 30 | # main loop 31 | niters = dataloader.ntrain * config.epochs//config.batch_size 32 | report_interval = niters//config.epochs 33 | best_performance = [0, 0, 0, 0] 34 | best_performance_zsl = 0 35 | for i in range(0, niters): 36 | model.train() 37 | optimizer.zero_grad() 38 | 39 | batch_label, batch_feature, batch_att = dataloader.next_batch( 40 | config.batch_size) 41 | out_package = model(batch_feature) 42 | 43 | in_package1 = out_package['package_s2v'] 44 | in_package2 = out_package['package_v2s'] 45 | in_package1['batch_label'] = batch_label 46 | in_package2['batch_label'] = batch_label 47 | 48 | out_package1=model.compute_loss(in_package1) 49 | out_package2=model.compute_loss(in_package2) 50 | 51 | loss = out_package1['loss'] + config.lambda_v2s * out_package2['loss'] 52 | loss_CE = out_package1['loss_CE'] + out_package2['loss_CE'] 53 | loss_cal = out_package1['loss_cal'] + out_package2['loss_cal'] 54 | loss_reg = out_package1['loss_reg'] + out_package2['loss_reg'] 55 | 56 | loss_att, loss_cls = model.compute_contrastive_loss( 57 | in_package1, in_package2) 58 | loss += config.lambda_cst_reg_att * loss_att 59 | loss += config.lambda_cst_reg_cls * loss_cls 60 | 61 | loss.backward() 62 | optimizer.step() 63 | 64 | # report result 65 | if i % report_interval == 0: 66 | print('-'*30) 67 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 68 | dataloader, model, config.device, batch_size=config.batch_size) 69 | 70 | if H > best_performance[2]: 71 | best_performance = [acc_novel, acc_seen, H, acc_zs] 72 | if acc_zs > best_performance_zsl: 73 | best_performance_zsl = acc_zs 74 | 75 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 76 | 'loss_reg=%.3f, loss_cst_att=%.3f, loss_cst_cls=%.3f | ' % ( 77 | i, int(i//report_interval), 78 | loss.item(), loss_CE.item(), loss_cal.item(), 79 | loss_reg.item(), loss_att.item(), loss_cls.item())) 80 | print('Current GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | Current CZSL: acc_zs=%.3f' % ( 81 | acc_novel, acc_seen, H, acc_zs)) 82 | print('BEST GZSL: acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f' 83 | ' | BEST CZSL: acc_zs=%.3f' % ( 84 | best_performance[0], best_performance[1], 85 | best_performance[2], best_performance[3], 86 | best_performance_zsl)) 87 | 88 | wandb.log({ 89 | 'iter': i, 90 | 'loss': loss.item(), 91 | 'loss_CE': loss_CE.item(), 92 | 'loss_cal': loss_cal.item(), 93 | 'loss_reg': loss_reg.item(), 94 | 'loss_att': loss_att.item(), 95 | 'loss_cls': loss_cls.item(), 96 | 'acc_unseen': acc_novel, 97 | 'acc_seen': acc_seen, 98 | 'H': H, 99 | 'acc_zs': acc_zs, 100 | 'best_acc_unseen': best_performance[0], 101 | 'best_acc_seen': best_performance[1], 102 | 'best_H': best_performance[2], 103 | 'best_acc_zs': best_performance_zsl 104 | }) 105 | -------------------------------------------------------------------------------- /w2v/AWA2_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/w2v/AWA2_attribute.pkl -------------------------------------------------------------------------------- /w2v/CUB_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/w2v/CUB_attribute.pkl -------------------------------------------------------------------------------- /w2v/SUN_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero_pp/c045f73d5133d14774393b9c2ab17d890a5eb35f/w2v/SUN_attribute.pkl -------------------------------------------------------------------------------- /wandb_config/awa2_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: AWA2 3 | num_class: 4 | value: 50 5 | num_attribute: 6 | value: 85 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 8746 25 | normalize_V: 26 | # value: False 27 | value: True 28 | tf_SAtt: 29 | value: True 30 | tf_ec_layer: 31 | value: 1 32 | tf_dc_layer: 33 | value: 1 34 | tf_heads: 35 | value: 1 36 | tf_common_dim: 37 | value: 300 38 | tf_aux_embed: 39 | value: True 40 | tf_dim_feedforward: 41 | value: 512 42 | tf_dropout: 43 | value: 0.5 44 | tf_v2s_init: 45 | value: zeros_ 46 | weight_s2v: 47 | value: 0.5 48 | lambda_: 49 | value: 2 50 | lambda_reg: 51 | value: 0.00005 52 | lambda_v2s: 53 | value: 0.01 54 | lambda_cst_reg_att: 55 | value: 0.00001 56 | lambda_cst_reg_cls: 57 | value: 0.00001 -------------------------------------------------------------------------------- /wandb_config/awa2_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: AWA2 3 | num_class: 4 | value: 50 5 | num_attribute: 6 | value: 85 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 8746 25 | normalize_V: 26 | # value: False 27 | value: True 28 | tf_SAtt: 29 | value: True 30 | tf_ec_layer: 31 | value: 1 32 | tf_dc_layer: 33 | value: 1 34 | tf_heads: 35 | value: 1 36 | tf_common_dim: 37 | value: 300 38 | tf_aux_embed: 39 | value: True 40 | tf_dim_feedforward: 41 | value: 512 42 | tf_dropout: 43 | value: 0.5 44 | tf_v2s_init: 45 | value: zeros_ 46 | weight_s2v: 47 | value: 0.6 48 | lambda_: 49 | value: 2 50 | lambda_reg: 51 | value: 0.00005 52 | lambda_v2s: 53 | value: 0.01 54 | lambda_cst_reg_att: 55 | value: 0.00001 56 | lambda_cst_reg_cls: 57 | value: 0.00001 -------------------------------------------------------------------------------- /wandb_config/cub_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: CUB 3 | num_class: 4 | value: 200 5 | num_attribute: 6 | value: 312 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 3423 25 | normalize_V: 26 | value: False 27 | tf_SAtt: 28 | value: True 29 | tf_ec_layer: 30 | value: 1 31 | tf_dc_layer: 32 | value: 1 33 | tf_heads: 34 | value: 1 35 | tf_common_dim: 36 | value: 300 37 | tf_aux_embed: 38 | value: True 39 | tf_dim_feedforward: 40 | value: 512 41 | tf_dropout: 42 | value: 0.4 43 | tf_v2s_init: 44 | value: zeros_ 45 | weight_s2v: 46 | value: 0.9 47 | lambda_: 48 | value: 0.2 49 | lambda_reg: 50 | value: 0.0001 51 | lambda_v2s: 52 | value: 0.1 53 | lambda_cst_reg_att: 54 | value: 0.001 55 | lambda_cst_reg_cls: 56 | value: 0.01 -------------------------------------------------------------------------------- /wandb_config/cub_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: CUB 3 | num_class: 4 | value: 200 5 | num_attribute: 6 | value: 312 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 4590 25 | normalize_V: 26 | value: False 27 | tf_SAtt: 28 | value: True 29 | tf_ec_layer: 30 | value: 1 31 | tf_dc_layer: 32 | value: 1 33 | tf_heads: 34 | value: 1 35 | tf_common_dim: 36 | value: 300 37 | tf_aux_embed: 38 | value: True 39 | tf_dim_feedforward: 40 | value: 512 41 | tf_dropout: 42 | value: 0.4 43 | tf_v2s_init: 44 | value: zeros_ 45 | weight_s2v: 46 | value: 0.9 47 | lambda_: 48 | value: 0.2 49 | lambda_reg: 50 | value: 0.0001 51 | lambda_v2s: 52 | value: 0.1 53 | lambda_cst_reg_att: 54 | value: 0.001 55 | lambda_cst_reg_cls: 56 | value: 0.01 -------------------------------------------------------------------------------- /wandb_config/sun_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: SUN 3 | num_class: 4 | value: 717 5 | num_attribute: 6 | value: 102 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 5348 25 | normalize_V: 26 | value: False 27 | tf_SAtt: 28 | value: False 29 | tf_ec_layer: 30 | value: 1 31 | tf_dc_layer: 32 | value: 1 33 | tf_heads: 34 | value: 1 35 | tf_common_dim: 36 | value: 128 37 | tf_aux_embed: 38 | value: False 39 | tf_dim_feedforward: 40 | value: 2048 41 | tf_dropout: 42 | value: 0.3 43 | tf_v2s_init: 44 | value: normal_ 45 | weight_s2v: 46 | value: 0.2 47 | lambda_: 48 | value: 0.1 49 | lambda_reg: 50 | value: 0.01 51 | lambda_v2s: 52 | value: 0.1 53 | lambda_cst_reg_att: 54 | value: 0.01 55 | lambda_cst_reg_cls: 56 | value: 0. -------------------------------------------------------------------------------- /wandb_config/sun_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: SUN 3 | num_class: 4 | value: 717 5 | num_attribute: 6 | value: 102 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 1089 25 | normalize_V: 26 | value: False 27 | tf_SAtt: 28 | value: False 29 | tf_ec_layer: 30 | value: 1 31 | tf_dc_layer: 32 | value: 1 33 | tf_heads: 34 | value: 1 35 | tf_common_dim: 36 | value: 128 37 | tf_aux_embed: 38 | value: False 39 | tf_dim_feedforward: 40 | value: 2048 41 | tf_dropout: 42 | value: 0.3 43 | tf_v2s_init: 44 | value: normal_ 45 | weight_s2v: 46 | value: 0.6 47 | lambda_: 48 | value: 0.1 49 | lambda_reg: 50 | value: 0.01 51 | lambda_v2s: 52 | value: 1.0 53 | lambda_cst_reg_att: 54 | value: 0.01 55 | lambda_cst_reg_cls: 56 | value: 0.001 --------------------------------------------------------------------------------