├── LICENSE ├── README.md ├── attention_map.py ├── dataset.py ├── figs └── pipeline.png ├── helper_func.py ├── model.py ├── preprocessing.py ├── requirements.txt ├── train_awa2.py ├── train_cub.py ├── train_sun.py ├── w2v ├── AWA2_attribute.pkl ├── CUB_attribute.pkl └── SUN_attribute.pkl └── wandb_config ├── awa2_czsl.yaml ├── awa2_gzsl.yaml ├── cub_czsl.yaml ├── cub_gzsl.yaml ├── sun_czsl.yaml └── sun_gzsl.yaml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shiming Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransZero [[arXiv]](https://arxiv.org/pdf/2112.01683.pdf) 2 | 3 | 4 | This repository contains the training and [testing](https://github.com/shiming-chen/TransZero/tree/test) code for the paper "***TransZero: Attribute-guided Transformer for Zero-Shot Learning***" accepted to AAAI 2022. 5 | 6 | ![](figs/pipeline.png) 7 | 8 | 9 | ## Running Environment 10 | The implementation of **TransZero** is mainly based on Python 3.8.8 and [PyTorch](https://pytorch.org/) 1.8.0. To install all required dependencies: 11 | ``` 12 | $ pip install -r requirements.txt 13 | ``` 14 | 15 | Additionally, we use [Weights & Biases](https://wandb.ai/site) (W&B) to keep track and organize the results of experiments. You may need to follow the [online documentation](https://docs.wandb.ai/quickstart) of W&B to quickstart. To run these codes, [sign up](https://app.wandb.ai/login?signup=true) an online account to track experiments or create a [local wandb server](https://hub.docker.com/r/wandb/local) using docker (recommended). 16 | 17 | 18 | ## Download Dataset 19 | 20 | We trained the model on three popular ZSL benchmarks: [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [SUN](http://cs.brown.edu/~gmpatter/sunattributes.html) and [AWA2](http://cvml.ist.ac.at/AwA2/) following the data split of [xlsa17](http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip). In order to train the **TransZero**, you should firstly download these datasets as well as the xlsa17. Then decompress and organize them as follows: 21 | ``` 22 | . 23 | ├── data 24 | │ ├── CUB/CUB_200_2011/... 25 | │ ├── SUN/images/... 26 | │ ├── AWA2/Animals_with_Attributes2/... 27 | │ └── xlsa17/data/... 28 | └── ··· 29 | ``` 30 | 31 | 32 | ## Visual Features Preprocessing 33 | 34 | In this step, you should run the following commands to extract the visual features of three datasets: 35 | 36 | ``` 37 | $ python preprocessing.py --dataset CUB --compression --device cuda:0 38 | $ python preprocessing.py --dataset SUN --compression --device cuda:0 39 | $ python preprocessing.py --dataset AWA2 --compression --device cuda:0 40 | ``` 41 | 42 | ## Training TransZero from Scratch 43 | In `./wandb_config`, we provide our parameters setting of conventional ZSL (CZSL) and generalized ZSL (GZSL) tasks for CUB, SUN, and AWA2. You can run the following commands to train the **TransZero** from scratch: 44 | 45 | ``` 46 | $ python train_cub.py # CUB 47 | $ python train_sun.py # SUN 48 | $ python train_awa2.py # AWA2 49 | ``` 50 | **Note**: Please load the corresponding setting when aiming at the CZSL task. 51 | 52 | ## Results 53 | 54 | We also provide trained models ([Google Drive](https://drive.google.com/drive/folders/1WK9pm2eX2Rl4rWqXqe_EZiAM8wWB8yqG?usp=sharing)) on three datasets. You can download these `.pth` files and validate the results in our paper. Please refer to the [test branch](https://github.com/shiming-chen/TransZero/tree/test) for testing codes and usage. 55 | Following table shows the results of our released models using various evaluation protocols on three datasets, both in the CZSL and GZSL settings: 56 | 57 | | Dataset | Acc(CZSL) | U(GZSL) | S(GZSL) | H(GZSL) | 58 | | :-----: | :-----: | :-----: | :-----: | :-----: | 59 | | CUB | 76.8 | 69.3 | 68.3 | 68.8 | 60 | | SUN | 65.6 | 52.6 | 33.4 | 40.8 | 61 | | AWA2 | 70.1 | 61.3 | 82.3 | 70.2 | 62 | 63 | **Note**: The training of our models and all of the above results are run on a server with an AMD Ryzen 7 5800X CPU, 128GB memory, and an NVIDIA RTX A6000 GPU (48GB). 64 | 65 | ## Citation 66 | If this work is helpful for you, please cite our paper. 67 | 68 | ``` 69 | @InProceedings{Chen2022TransZero, 70 | author = {Chen, Shiming and Hong, Ziming and Liu, Yang and Xie, Guo-Sen and Sun, Baigui and Li, Hao and Peng, Qinmu and Lu, Ke and You, Xinge}, 71 | title = {TransZero: Attribute-guided Transformer for Zero-Shot Learning}, 72 | booktitle = {Proceedings of the Thirty-Sixth AAAI Conference on Artificial Intelligence (AAAI)}, 73 | year = {2022} 74 | } 75 | ``` 76 | 77 | ## References 78 | Parts of our codes based on: 79 | * [hbdat/cvpr20_DAZLE](https://github.com/hbdat/cvpr20_DAZLE) 80 | * [zhangxuying1004/RSTNet](https://github.com/zhangxuying1004/RSTNet) 81 | 82 | ## Contact 83 | If you have any questions about codes, please don't hesitate to contact us by gchenshiming@gmail.com or hoongzm@gmail.com. 84 | -------------------------------------------------------------------------------- /attention_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | from core.DAZLE_plot import DAZLE 6 | from core.CUBDataLoader import CUBDataLoader 7 | from core.helper_func import eval_zs_gzsl 8 | from global_setting import NFS_path 9 | import numpy as np 10 | import wandb 11 | from get_gpu_info import get_gpu_info 12 | from PIL import Image 13 | import matplotlib.pyplot as plt 14 | import skimage 15 | from sklearn.manifold import TSNE 16 | from torchvision import transforms 17 | 18 | 19 | data_transforms = transforms.Compose([ 20 | transforms.Resize(448), 21 | transforms.CenterCrop(448), 22 | transforms.ToTensor()]) 23 | 24 | 25 | def dazle_visualize_attention_np_global_448(img_ids,alphas_1,alphas_2,attr_name,save_path=None): 26 | # alphas_1: [bir] alphas_2: [bi] 27 | n = img_ids.shape[0] 28 | image_size = 448 #one side of the img 29 | assert alphas_1.shape[1] == alphas_2.shape[1] == len(attr_name) 30 | r = alphas_1.shape[2] 31 | h = w = int(np.sqrt(r)) 32 | for i in range(n): 33 | fig=plt.figure(i,figsize=(33, 5)) 34 | file_path=img_ids[i]#.decode('utf-8') 35 | img_name = file_path.split("/")[-1] 36 | alpha_1 = alphas_1[i] #[ir] 37 | alpha_2 = alphas_2[i] #[i] 38 | # score = S[i] 39 | # Plot original image 40 | image = Image.open(file_path) 41 | if image.mode == 'L': 42 | image=image.convert('RGB') 43 | image = data_transforms(image) 44 | image = image.permute(1,2,0) #[224,244,3] <== [3,224,224] 45 | idx = 1 46 | ax = plt.subplot(1, 11, 1) 47 | idx += 1 48 | plt.imshow(image) 49 | # ax.set_title(os.path.splitext(img_name)[0],{'fontsize': 13}) 50 | plt.axis('off') 51 | 52 | idxs_top_p=np.argsort(-alpha_2)[:10] 53 | idxs_top_g=np.argsort(-alpha_2)[:200] 54 | # idxs_top_n=np.argsort(alpha_2)[:3] 55 | 56 | #pdb.set_trace() 57 | for idx_ctxt,idx_attr in enumerate(idxs_top_p): 58 | ax=plt.subplot(1, 11, idx) 59 | idx += 1 60 | plt.imshow(image) 61 | alp_curr = alpha_1[idx_attr,:].reshape(14,14) 62 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=image_size/h, sigma=10,multichannel=False) 63 | plt.imshow(alp_img, alpha=0.5, cmap='jet') 64 | # ax.set_title("{}\n{}\n{}-{}".format(attr_name[idx_attr],alpha_2[idx_attr],score[idx_attr],attr[idx_attr]),{'fontsize': 10}) 65 | # ax.set_title("{}\n(Score = {:.2f})".format(attr_name[idx_attr].title().replace( 66 | # ' ', ''), alpha_2[idx_attr]), {'fontsize': 19}) 67 | ax.set_title("{}\n(Score = {:.1f})".format(' '.join(attr_name[idx_attr].split()[:2]).title( 68 | ) + '\n' + ' '.join(attr_name[idx_attr].split()[2:]).title(), alpha_2[idx_attr]), {'fontsize': 25}) 69 | 70 | plt.axis('off') 71 | 72 | fig.tight_layout() 73 | if save_path is not None: 74 | plt.savefig(save_path+img_name,dpi=200) 75 | plt.close() 76 | 77 | 78 | def dazle_visualize_attention_np_global_448_small(img_ids,alphas_1,alphas_2,attr_name,save_path=None): 79 | # alphas_1: [bir] alphas_2: [bi] 80 | n = img_ids.shape[0] 81 | image_size = 448 #one side of the img 82 | assert alphas_1.shape[1] == alphas_2.shape[1] == len(attr_name) 83 | r = alphas_1.shape[2] 84 | h = w = int(np.sqrt(r)) 85 | for i in range(n): 86 | fig=plt.figure(i,figsize=(33, 4)) 87 | file_path=img_ids[i]#.decode('utf-8') 88 | img_name = file_path.split("/")[-1] 89 | alpha_1 = alphas_1[i] #[ir] 90 | alpha_2 = alphas_2[i] #[i] 91 | # score = S[i] 92 | # Plot original image 93 | image = Image.open(file_path) 94 | if image.mode == 'L': 95 | image=image.convert('RGB') 96 | image = data_transforms(image) 97 | image = image.permute(1,2,0) #[224,244,3] <== [3,224,224] 98 | idx = 1 99 | ax = plt.subplot(1, 11, 1) 100 | idx += 1 101 | plt.imshow(image) 102 | # ax.set_title(os.path.splitext(img_name)[0],{'fontsize': 13}) 103 | plt.axis('off') 104 | 105 | idxs_top_p=np.argsort(-alpha_2)[:10] 106 | idxs_top_g=np.argsort(-alpha_2)[:200] 107 | # idxs_top_n=np.argsort(alpha_2)[:3] 108 | 109 | #pdb.set_trace() 110 | for idx_ctxt,idx_attr in enumerate(idxs_top_p): 111 | ax=plt.subplot(1, 11, idx) 112 | idx += 1 113 | plt.imshow(image) 114 | alp_curr = alpha_1[idx_attr,:].reshape(14,14) 115 | alp_img = skimage.transform.pyramid_expand(alp_curr, upscale=image_size/h, sigma=10,multichannel=False) 116 | plt.imshow(alp_img, alpha=0.5, cmap='jet') 117 | # ax.set_title("{}\n{}\n{}-{}".format(attr_name[idx_attr],alpha_2[idx_attr],score[idx_attr],attr[idx_attr]),{'fontsize': 10}) 118 | ax.set_title("{}\n(Score = {:.2f})".format(attr_name[idx_attr].title().replace( 119 | ' ', ''), alpha_2[idx_attr]), {'fontsize': 18}) 120 | # ax.set_title("{}\n(Score = {:.1f})".format(' '.join(attr_name[idx_attr].split()[:2]).title( 121 | # ) + '\n' + ' '.join(attr_name[idx_attr].split()[2:]).title(), alpha_2[idx_attr]), {'fontsize': 20}) 122 | 123 | plt.axis('off') 124 | 125 | 126 | fig.tight_layout() 127 | if save_path is not None: 128 | plt.savefig(save_path+img_name,dpi=200) 129 | plt.close() 130 | 131 | 132 | def plot_att(config): 133 | model_path = 'saved_model/CUB_weights_H-0.688.pth' 134 | 135 | config.dataset = 'CUB' 136 | config.num_class = 200 137 | config.num_attribute = 312 138 | if config.img_size == 224: config.resnet_region = 49 139 | elif config.img_size == 448: config.resnet_region = 196 140 | 141 | print('Config file from wandb:', config) 142 | 143 | if config.device == 'auto': 144 | device = get_gpu_info() 145 | else: 146 | device = config.device 147 | dataloader = CUBDataLoader(NFS_path, device, 148 | is_unsupervised_attr=False, is_balance=False, 149 | img_size=config.img_size, use_unzip=config.use_unzip) 150 | dataloader.augment_img_path() 151 | torch.backends.cudnn.benchmark = True 152 | 153 | def get_lr(optimizer): 154 | lr = [] 155 | for param_group in optimizer.param_groups: 156 | lr.append(param_group['lr']) 157 | return lr 158 | 159 | seed = config.random_seed 160 | torch.manual_seed(seed) 161 | torch.cuda.manual_seed_all(seed) 162 | np.random.seed(seed) 163 | 164 | batch_size = config.batch_size 165 | nepoches = config.epochs 166 | niters = dataloader.ntrain * nepoches//batch_size 167 | dim_f = 2048 168 | dim_v = 300 169 | init_w2v_att = dataloader.w2v_att 170 | att = dataloader.att 171 | normalize_att = dataloader.normalize_att 172 | 173 | trainable_w2v = config.trainable_w2v 174 | # CE loss和cal loss的超参数 175 | lambda_ = config.lambda_ 176 | bias = 0 177 | prob_prune = 0 178 | # uniform DAZLE attention的选项 179 | uniform_att_1 = False 180 | uniform_att_2 = False 181 | 182 | seenclass = dataloader.seenclasses 183 | unseenclass = dataloader.unseenclasses 184 | desired_mass = 1 185 | report_interval = niters//nepoches 186 | 187 | model = DAZLE(config, dim_f,dim_v,init_w2v_att,att,normalize_att, 188 | seenclass,unseenclass, 189 | lambda_, 190 | trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True, 191 | uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2, 192 | prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False, 193 | is_bias=config.is_bias) 194 | model.load_state_dict(torch.load(model_path)) 195 | model.to(device) 196 | num_parameters = sum([p.numel() for p in model.parameters()]) * 1e-6 197 | print('model parameters: %.3fM' % num_parameters) 198 | 199 | file_list = [ 200 | 'Acadian_Flycatcher_0008_795599', 201 | 'American_Goldfinch_0092_32910', 202 | 'Canada_Warbler_0117_162394', 203 | 'Carolina_Wren_0006_186742', 204 | 'Vesper_Sparrow_0090_125690', 205 | 'Western_Gull_0058_53882', 206 | 'White_Throated_Sparrow_0128_128956', 207 | 'Winter_Wren_0118_189805', 208 | 'Yellow_Breasted_Chat_0044_22106', 209 | 'Elegant_Tern_0085_151091', 210 | 'European_Goldfinch_0025_794647', 211 | 'Florida_Jay_0008_64482', 212 | 'Fox_Sparrow_0025_114555', 213 | 'Grasshopper_Sparrow_0053_115991', 214 | 'Grasshopper_Sparrow_0107_116286', 215 | 'Gray_Crowned_Rosy_Finch_0036_797287' 216 | ] 217 | 218 | for filename in file_list: 219 | for i, id in enumerate(dataloader.seenclasses): 220 | # if i == 5: 221 | # raise Exception 222 | id = id.item() 223 | (batch_label, batch_feature, batch_files, batch_att) = dataloader.next_batch_img( 224 | batch_size=10, class_id=id, is_trainset=False) 225 | 226 | if filename not in str(batch_files): 227 | continue 228 | 229 | idx = [filename in str(f) for f in batch_files] 230 | batch_feature = batch_feature[idx] 231 | batch_files = batch_files[idx] 232 | 233 | model.eval() 234 | with torch.no_grad(): 235 | out_package = model(batch_feature) 236 | 237 | # attention map of DAZLE 238 | dazle_visualize_attention_np_global_448_small(batch_files, 239 | out_package['att'].cpu().numpy(), 240 | out_package['dazle_embed'].cpu().numpy(), 241 | dataloader.attr_name, 242 | 'plot/atten_fig/') 243 | 244 | 245 | 246 | if __name__ == '__main__': 247 | 248 | wandb.init(project='ZSL_DALZE_Transformer_GA', config='config_cub.yaml', allow_val_change=True) 249 | config = wandb.config 250 | plot_att(config) 251 | 252 | 253 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys, torch, pickle, h5py 3 | import numpy as np 4 | import scipy.io as sio 5 | import pandas as pd 6 | from PIL import Image 7 | from sklearn import preprocessing 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, Subset, DataLoader 10 | 11 | 12 | class BaseDataset(Dataset): 13 | def __init__(self, dataset_path, image_files, labels, transform=None): 14 | super(BaseDataset, self).__init__() 15 | self.dataset_path = dataset_path 16 | self.image_files = image_files 17 | self.labels = labels 18 | self.transform = transform 19 | 20 | def __len__(self): 21 | return len(self.image_files) 22 | 23 | def __getitem__(self, idx): 24 | label = self.labels[idx] 25 | image_file = self.image_files[idx] 26 | image_file = os.path.join(self.dataset_path, image_file) 27 | image = Image.open(image_file) 28 | if image.mode != 'RGB': 29 | image = image.convert('RGB') 30 | if self.transform: 31 | image = self.transform(image) 32 | return image, label 33 | 34 | 35 | class UNIDataloader(): 36 | def __init__(self, config): 37 | self.config = config 38 | with open(config.pkl_path, 'rb') as f: 39 | self.info = pickle.load(f) 40 | 41 | self.seenclasses = self.info['seenclasses'].to(config.device) 42 | self.unseenclasses = self.info['unseenclasses'].to(config.device) 43 | 44 | (self.train_set, 45 | self.test_seen_set, 46 | self.test_unseen_set) = self.torch_dataset() 47 | 48 | self.train_loader = DataLoader(self.train_set, 49 | batch_size=config.batch_size, 50 | shuffle=True, 51 | num_workers=config.num_workers) 52 | self.test_seen_loader = DataLoader(self.test_seen_set, 53 | batch_size=config.batch_size, 54 | shuffle=False, 55 | num_workers=config.num_workers) 56 | self.test_unseen_loader = DataLoader(self.test_unseen_set, 57 | batch_size=config.batch_size, 58 | shuffle=False, 59 | num_workers=config.num_workers) 60 | 61 | def torch_dataset(self): 62 | data_transforms = transforms.Compose([ 63 | transforms.Resize(self.config.img_size), 64 | transforms.CenterCrop(self.config.img_size), 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 67 | baseset = BaseDataset(self.config.dataset_path, 68 | self.info['image_files'], 69 | self.info['labels'], 70 | data_transforms) 71 | 72 | train_set = Subset(baseset, self.info['trainval_loc']) 73 | test_seen_set = Subset(baseset, self.info['test_seen_loc']) 74 | test_unseen_set = Subset(baseset, self.info['test_unseen_loc']) 75 | 76 | return train_set, test_seen_set, test_unseen_set 77 | 78 | 79 | class CUBDataLoader(): 80 | def __init__(self, data_path, device, is_scale=False, 81 | is_unsupervised_attr=False, is_balance=True): 82 | print(data_path) 83 | sys.path.append(data_path) 84 | self.data_path = data_path 85 | self.device = device 86 | self.dataset = 'CUB' 87 | # print('$'*30) 88 | # print(self.dataset) 89 | # print('$'*30) 90 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 91 | self.index_in_epoch = 0 92 | self.epochs_completed = 0 93 | self.is_scale = is_scale 94 | self.is_balance = is_balance 95 | if self.is_balance: 96 | print('Balance dataloader') 97 | self.is_unsupervised_attr = is_unsupervised_attr 98 | self.read_matdataset() 99 | self.get_idx_classes() 100 | 101 | def next_batch(self, batch_size): 102 | if self.is_balance: 103 | idx = [] 104 | n_samples_class = max(batch_size //self.ntrain_class,1) 105 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 106 | for i_c in sampled_idx_c: 107 | idxs = self.idxs_list[i_c] 108 | idx.append(np.random.choice(idxs,n_samples_class)) 109 | idx = np.concatenate(idx) 110 | idx = torch.from_numpy(idx) 111 | else: 112 | idx = torch.randperm(self.ntrain)[0:batch_size] 113 | 114 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 115 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 116 | batch_att = self.att[batch_label].to(self.device) 117 | return batch_label, batch_feature, batch_att 118 | 119 | def get_idx_classes(self): 120 | n_classes = self.seenclasses.size(0) 121 | self.idxs_list = [] 122 | train_label = self.data['train_seen']['labels'] 123 | for i in range(n_classes): 124 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 125 | idx_c = np.squeeze(idx_c) 126 | self.idxs_list.append(idx_c) 127 | return self.idxs_list 128 | 129 | def read_matdataset(self): 130 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 131 | print('_____') 132 | print(path) 133 | # tic = time.time() 134 | hf = h5py.File(path, 'r') 135 | features = np.array(hf.get('feature_map')) 136 | # shape = features.shape 137 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 138 | # pdb.set_trace() 139 | labels = np.array(hf.get('labels')) 140 | trainval_loc = np.array(hf.get('trainval_loc')) 141 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 142 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 143 | test_seen_loc = np.array(hf.get('test_seen_loc')) 144 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 145 | 146 | if self.is_unsupervised_attr: 147 | print('Unsupervised Attr') 148 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 149 | with open(class_path,'rb') as f: 150 | w2v_class = pickle.load(f) 151 | temp = np.array(hf.get('att')) 152 | print(w2v_class.shape,temp.shape) 153 | # assert w2v_class.shape == temp.shape 154 | w2v_class = torch.tensor(w2v_class).float() 155 | 156 | U, s, V = torch.svd(w2v_class) 157 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 158 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 159 | 160 | print('shape U:{} V:{}'.format(U.size(),V.size())) 161 | print('s: {}'.format(s)) 162 | 163 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 164 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 165 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 166 | 167 | else: 168 | print('Expert Attr') 169 | att = np.array(hf.get('att')) 170 | self.att = torch.from_numpy(att).float().to(self.device) 171 | 172 | original_att = np.array(hf.get('original_att')) 173 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 174 | 175 | w2v_att = np.array(hf.get('w2v_att')) 176 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 177 | 178 | self.normalize_att = self.original_att/100 179 | 180 | train_feature = features[trainval_loc] 181 | test_seen_feature = features[test_seen_loc] 182 | test_unseen_feature = features[test_unseen_loc] 183 | if self.is_scale: 184 | scaler = preprocessing.MinMaxScaler() 185 | 186 | train_feature = scaler.fit_transform(train_feature) 187 | test_seen_feature = scaler.fit_transform(test_seen_feature) 188 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 189 | 190 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 191 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 192 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 193 | 194 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 195 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 196 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 197 | 198 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 199 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 200 | self.ntrain = train_feature.size()[0] 201 | self.ntrain_class = self.seenclasses.size(0) 202 | self.ntest_class = self.unseenclasses.size(0) 203 | self.train_class = self.seenclasses.clone() 204 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 205 | 206 | self.data = {} 207 | self.data['train_seen'] = {} 208 | self.data['train_seen']['resnet_features'] = train_feature 209 | self.data['train_seen']['labels']= train_label 210 | 211 | self.data['train_unseen'] = {} 212 | self.data['train_unseen']['resnet_features'] = None 213 | self.data['train_unseen']['labels'] = None 214 | 215 | self.data['test_seen'] = {} 216 | self.data['test_seen']['resnet_features'] = test_seen_feature 217 | self.data['test_seen']['labels'] = test_seen_label 218 | 219 | self.data['test_unseen'] = {} 220 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 221 | self.data['test_unseen']['labels'] = test_unseen_label 222 | 223 | 224 | class SUNDataLoader(): 225 | def __init__(self, data_path, device, is_scale=False, 226 | is_unsupervised_attr=False, is_balance=True): 227 | print(data_path) 228 | sys.path.append(data_path) 229 | self.data_path = data_path 230 | self.device = device 231 | self.dataset = 'SUN' 232 | print('$'*30) 233 | print(self.dataset) 234 | print('$'*30) 235 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 236 | self.index_in_epoch = 0 237 | self.epochs_completed = 0 238 | self.is_scale = is_scale 239 | self.is_balance = is_balance 240 | if self.is_balance: 241 | print('Balance dataloader') 242 | self.is_unsupervised_attr = is_unsupervised_attr 243 | self.read_matdataset() 244 | self.get_idx_classes() 245 | self.I = torch.eye(self.allclasses.size(0)).to(device) 246 | 247 | def next_batch(self, batch_size): 248 | if self.is_balance: 249 | idx = [] 250 | n_samples_class = max(batch_size //self.ntrain_class,1) 251 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 252 | for i_c in sampled_idx_c: 253 | idxs = self.idxs_list[i_c] 254 | idx.append(np.random.choice(idxs,n_samples_class)) 255 | idx = np.concatenate(idx) 256 | idx = torch.from_numpy(idx) 257 | else: 258 | idx = torch.randperm(self.ntrain)[0:batch_size] 259 | 260 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 261 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 262 | batch_att = self.att[batch_label].to(self.device) 263 | return batch_label, batch_feature, batch_att 264 | 265 | def get_idx_classes(self): 266 | n_classes = self.seenclasses.size(0) 267 | self.idxs_list = [] 268 | train_label = self.data['train_seen']['labels'] 269 | for i in range(n_classes): 270 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 271 | idx_c = np.squeeze(idx_c) 272 | self.idxs_list.append(idx_c) 273 | return self.idxs_list 274 | 275 | def read_matdataset(self): 276 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 277 | 278 | print('_____') 279 | print(path) 280 | # tic = time.time() 281 | hf = h5py.File(path, 'r') 282 | features = np.array(hf.get('feature_map')) 283 | labels = np.array(hf.get('labels')) 284 | trainval_loc = np.array(hf.get('trainval_loc')) 285 | test_seen_loc = np.array(hf.get('test_seen_loc')) 286 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 287 | 288 | if self.is_unsupervised_attr: 289 | print('Unsupervised Attr') 290 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 291 | with open(class_path,'rb') as f: 292 | w2v_class = pickle.load(f) 293 | assert w2v_class.shape == (50,300) 294 | w2v_class = torch.tensor(w2v_class).float() 295 | 296 | U, s, V = torch.svd(w2v_class) 297 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 298 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 299 | 300 | print('shape U:{} V:{}'.format(U.size(),V.size())) 301 | print('s: {}'.format(s)) 302 | 303 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 304 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 305 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 306 | 307 | else: 308 | print('Expert Attr') 309 | att = np.array(hf.get('att')) 310 | self.att = torch.from_numpy(att).float().to(self.device) 311 | 312 | original_att = np.array(hf.get('original_att')) 313 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 314 | 315 | w2v_att = np.array(hf.get('w2v_att')) 316 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 317 | 318 | self.normalize_att = self.original_att/100 319 | 320 | train_feature = features[trainval_loc] 321 | test_seen_feature = features[test_seen_loc] 322 | test_unseen_feature = features[test_unseen_loc] 323 | if self.is_scale: 324 | scaler = preprocessing.MinMaxScaler() 325 | 326 | train_feature = scaler.fit_transform(train_feature) 327 | test_seen_feature = scaler.fit_transform(test_seen_feature) 328 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 329 | 330 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 331 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 332 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 333 | 334 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 335 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 336 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 337 | 338 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 339 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 340 | self.ntrain = train_feature.size()[0] 341 | self.ntrain_class = self.seenclasses.size(0) 342 | self.ntest_class = self.unseenclasses.size(0) 343 | self.train_class = self.seenclasses.clone() 344 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 345 | 346 | self.data = {} 347 | self.data['train_seen'] = {} 348 | self.data['train_seen']['resnet_features'] = train_feature 349 | self.data['train_seen']['labels']= train_label 350 | 351 | self.data['train_unseen'] = {} 352 | self.data['train_unseen']['resnet_features'] = None 353 | self.data['train_unseen']['labels'] = None 354 | 355 | self.data['test_seen'] = {} 356 | self.data['test_seen']['resnet_features'] = test_seen_feature 357 | self.data['test_seen']['labels'] = test_seen_label 358 | 359 | self.data['test_unseen'] = {} 360 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 361 | self.data['test_unseen']['labels'] = test_unseen_label 362 | 363 | 364 | class AWA2DataLoader(): 365 | def __init__(self, data_path, device, is_scale=False, 366 | is_unsupervised_attr=False, is_balance=True): 367 | print(data_path) 368 | sys.path.append(data_path) 369 | self.data_path = data_path 370 | self.device = device 371 | self.dataset = 'AWA2' 372 | print('$'*30) 373 | print(self.dataset) 374 | print('$'*30) 375 | self.datadir = os.path.join(self.data_path, 'data/{}/'.format(self.dataset)) 376 | self.index_in_epoch = 0 377 | self.epochs_completed = 0 378 | self.is_scale = is_scale 379 | self.is_balance = is_balance 380 | if self.is_balance: 381 | print('Balance dataloader') 382 | self.is_unsupervised_attr = is_unsupervised_attr 383 | self.read_matdataset() 384 | self.get_idx_classes() 385 | 386 | def next_batch(self, batch_size): 387 | if self.is_balance: 388 | idx = [] 389 | n_samples_class = max(batch_size //self.ntrain_class,1) 390 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 391 | for i_c in sampled_idx_c: 392 | idxs = self.idxs_list[i_c] 393 | idx.append(np.random.choice(idxs,n_samples_class)) 394 | idx = np.concatenate(idx) 395 | idx = torch.from_numpy(idx) 396 | else: 397 | idx = torch.randperm(self.ntrain)[0:batch_size] 398 | 399 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 400 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 401 | batch_att = self.att[batch_label].to(self.device) 402 | return batch_label, batch_feature, batch_att 403 | 404 | def get_idx_classes(self): 405 | n_classes = self.seenclasses.size(0) 406 | self.idxs_list = [] 407 | train_label = self.data['train_seen']['labels'] 408 | for i in range(n_classes): 409 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 410 | idx_c = np.squeeze(idx_c) 411 | self.idxs_list.append(idx_c) 412 | return self.idxs_list 413 | 414 | def read_matdataset(self): 415 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 416 | print('_____') 417 | print(path) 418 | # tic = time.clock() 419 | hf = h5py.File(path, 'r') 420 | features = np.array(hf.get('feature_map')) 421 | labels = np.array(hf.get('labels')) 422 | trainval_loc = np.array(hf.get('trainval_loc')) 423 | test_seen_loc = np.array(hf.get('test_seen_loc')) 424 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 425 | 426 | if self.is_unsupervised_attr: 427 | print('Unsupervised Attr') 428 | class_path = './w2v/{}_class.pkl'.format(self.dataset) 429 | with open(class_path,'rb') as f: 430 | w2v_class = pickle.load(f) 431 | assert w2v_class.shape == (50,300) 432 | w2v_class = torch.tensor(w2v_class).float() 433 | 434 | U, s, V = torch.svd(w2v_class) 435 | reconstruct = torch.mm(torch.mm(U,torch.diag(s)),torch.transpose(V,1,0)) 436 | print('sanity check: {}'.format(torch.norm(reconstruct-w2v_class).item())) 437 | 438 | print('shape U:{} V:{}'.format(U.size(),V.size())) 439 | print('s: {}'.format(s)) 440 | 441 | self.w2v_att = torch.transpose(V,1,0).to(self.device) 442 | self.att = torch.mm(U,torch.diag(s)).to(self.device) 443 | self.normalize_att = torch.mm(U,torch.diag(s)).to(self.device) 444 | else: 445 | print('Expert Attr') 446 | att = np.array(hf.get('att')) 447 | 448 | print("threshold at zero attribute with negative value") 449 | att[att<0]=0 450 | 451 | self.att = torch.from_numpy(att).float().to(self.device) 452 | 453 | original_att = np.array(hf.get('original_att')) 454 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 455 | 456 | w2v_att = np.array(hf.get('w2v_att')) 457 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 458 | 459 | self.normalize_att = self.original_att/100 460 | 461 | train_feature = features[trainval_loc] 462 | test_seen_feature = features[test_seen_loc] 463 | test_unseen_feature = features[test_unseen_loc] 464 | if self.is_scale: 465 | scaler = preprocessing.MinMaxScaler() 466 | 467 | train_feature = scaler.fit_transform(train_feature) 468 | test_seen_feature = scaler.fit_transform(test_seen_feature) 469 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 470 | 471 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 472 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 473 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 474 | 475 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 476 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 477 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 478 | 479 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 480 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 481 | self.ntrain = train_feature.size()[0] 482 | self.ntrain_class = self.seenclasses.size(0) 483 | self.ntest_class = self.unseenclasses.size(0) 484 | self.train_class = self.seenclasses.clone() 485 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 486 | 487 | self.data = {} 488 | self.data['train_seen'] = {} 489 | self.data['train_seen']['resnet_features'] = train_feature 490 | self.data['train_seen']['labels']= train_label 491 | 492 | self.data['train_unseen'] = {} 493 | self.data['train_unseen']['resnet_features'] = None 494 | self.data['train_unseen']['labels'] = None 495 | 496 | self.data['test_seen'] = {} 497 | self.data['test_seen']['resnet_features'] = test_seen_feature 498 | self.data['test_seen']['labels'] = test_seen_label 499 | 500 | self.data['test_unseen'] = {} 501 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 502 | self.data['test_unseen']['labels'] = test_unseen_label -------------------------------------------------------------------------------- /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero/7c10cdbe76d36ec209b59127a6b1bb02fbd93852/figs/pipeline.png -------------------------------------------------------------------------------- /helper_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def val_gzsl(test_X, test_label, target_classes,in_package,bias = 0): 6 | 7 | batch_size = in_package['batch_size'] 8 | model = in_package['model'] 9 | device = in_package['device'] 10 | with torch.no_grad(): 11 | start = 0 12 | ntest = test_X.size()[0] 13 | predicted_label = torch.LongTensor(test_label.size()) 14 | for i in range(0, ntest, batch_size): 15 | 16 | end = min(ntest, start+batch_size) 17 | 18 | input = test_X[start:end].to(device) 19 | 20 | out_package = model(input) 21 | 22 | output = out_package['S_pp'] 23 | output[:,target_classes] = output[:,target_classes]+bias 24 | predicted_label[start:end] = torch.argmax(output.data, 1) 25 | 26 | start = end 27 | 28 | acc = compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package) 29 | return acc 30 | 31 | 32 | def map_label(label, classes): 33 | mapped_label = torch.LongTensor(label.size()).fill_(-1) 34 | for i in range(classes.size(0)): 35 | mapped_label[label==classes[i]] = i 36 | 37 | return mapped_label 38 | 39 | 40 | def val_zs_gzsl(test_X, test_label, unseen_classes,in_package,bias = 0): 41 | batch_size = in_package['batch_size'] 42 | model = in_package['model'] 43 | device = in_package['device'] 44 | with torch.no_grad(): 45 | start = 0 46 | ntest = test_X.size()[0] 47 | predicted_label_gzsl = torch.LongTensor(test_label.size()) 48 | predicted_label_zsl = torch.LongTensor(test_label.size()) 49 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 50 | for i in range(0, ntest, batch_size): 51 | 52 | end = min(ntest, start+batch_size) 53 | 54 | input = test_X[start:end].to(device) 55 | 56 | out_package = model(input) 57 | output = out_package['S_pp'] 58 | 59 | output_t = output.clone() 60 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 61 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 62 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 63 | 64 | output[:,unseen_classes] = output[:,unseen_classes]+bias 65 | predicted_label_gzsl[start:end] = torch.argmax(output.data, 1) 66 | 67 | 68 | start = end 69 | acc_gzsl = compute_per_class_acc_gzsl(test_label, predicted_label_gzsl, unseen_classes, in_package) 70 | acc_zs = compute_per_class_acc_gzsl(test_label, predicted_label_zsl, unseen_classes, in_package) 71 | acc_zs_t = compute_per_class_acc(map_label(test_label, unseen_classes), predicted_label_zsl_t, unseen_classes.size(0)) 72 | 73 | return acc_gzsl,acc_zs_t 74 | 75 | 76 | def compute_per_class_acc(test_label, predicted_label, nclass): 77 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 78 | for i in range(nclass): 79 | idx = (test_label == i) 80 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 81 | return acc_per_class.mean().item() 82 | 83 | 84 | def compute_per_class_acc_gzsl(test_label, predicted_label, target_classes, in_package): 85 | 86 | device = in_package['device'] 87 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 88 | 89 | predicted_label = predicted_label.to(device) 90 | 91 | for i in range(target_classes.size()[0]): 92 | 93 | is_class = test_label == target_classes[i] 94 | 95 | per_class_accuracies[i] = torch.div((predicted_label[is_class]==test_label[is_class]).sum().float(),is_class.sum().float()) 96 | return per_class_accuracies.mean().item() 97 | 98 | 99 | def eval_zs_gzsl(dataloader,model,device,bias_seen=0, bias_unseen=0, batch_size=50): 100 | model.eval() 101 | # print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 102 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 103 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 104 | 105 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 106 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 107 | 108 | seenclasses = dataloader.seenclasses 109 | unseenclasses = dataloader.unseenclasses 110 | 111 | batch_size = batch_size 112 | 113 | in_package = {'model':model,'device':device, 'batch_size':batch_size} 114 | 115 | with torch.no_grad(): 116 | acc_seen = val_gzsl(test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen) 117 | acc_novel,acc_zs = val_zs_gzsl(test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen) 118 | 119 | if (acc_seen+acc_novel)>0: 120 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 121 | else: 122 | H = 0 123 | 124 | return acc_seen, acc_novel, H, acc_zs 125 | 126 | 127 | def val_gzsl_k(k,test_X, test_label, target_classes,in_package,bias = 0,is_detect=False): 128 | batch_size = in_package['batch_size'] 129 | model = in_package['model'] 130 | device = in_package['device'] 131 | n_classes = in_package["num_class"] 132 | 133 | with torch.no_grad(): 134 | start = 0 135 | ntest = test_X.size()[0] 136 | test_label = F.one_hot(test_label, num_classes=n_classes) 137 | predicted_label = torch.LongTensor(test_label.size()).fill_(0).to(test_label.device) 138 | for i in range(0, ntest, batch_size): 139 | 140 | end = min(ntest, start+batch_size) 141 | 142 | input = test_X[start:end].to(device) 143 | 144 | out_package = model(input) 145 | 146 | output = out_package['S_pp'] 147 | output[:,target_classes] = output[:,target_classes]+bias 148 | _,idx_k = torch.topk(output,k,dim=1) 149 | if is_detect: 150 | assert k == 1 151 | detection_mask=in_package["detection_mask"] 152 | predicted_label[start:end] = detection_mask[torch.argmax(output.data, 1)] 153 | else: 154 | predicted_label[start:end] = predicted_label[start:end].scatter_(1,idx_k,1) 155 | start = end 156 | 157 | acc = compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package) 158 | return acc 159 | 160 | 161 | def val_zs_gzsl_k(k,test_X, test_label, unseen_classes,in_package,bias = 0,is_detect=False): 162 | batch_size = in_package['batch_size'] 163 | model = in_package['model'] 164 | device = in_package['device'] 165 | n_classes = in_package["num_class"] 166 | with torch.no_grad(): 167 | start = 0 168 | ntest = test_X.size()[0] 169 | 170 | test_label_gzsl = F.one_hot(test_label, num_classes=n_classes) 171 | predicted_label_gzsl = torch.LongTensor(test_label_gzsl.size()).fill_(0).to(test_label.device) 172 | 173 | predicted_label_zsl = torch.LongTensor(test_label.size()) 174 | predicted_label_zsl_t = torch.LongTensor(test_label.size()) 175 | for i in range(0, ntest, batch_size): 176 | 177 | end = min(ntest, start+batch_size) 178 | 179 | input = test_X[start:end].to(device) 180 | 181 | out_package = model(input) 182 | output = out_package['S_pp'] 183 | 184 | output_t = output.clone() 185 | output_t[:,unseen_classes] = output_t[:,unseen_classes]+torch.max(output)+1 186 | predicted_label_zsl[start:end] = torch.argmax(output_t.data, 1) 187 | predicted_label_zsl_t[start:end] = torch.argmax(output.data[:,unseen_classes], 1) 188 | 189 | output[:,unseen_classes] = output[:,unseen_classes]+bias 190 | _,idx_k = torch.topk(output,k,dim=1) 191 | if is_detect: 192 | assert k == 1 193 | detection_mask=in_package["detection_mask"] 194 | predicted_label_gzsl[start:end] = detection_mask[torch.argmax(output.data, 1)] 195 | else: 196 | predicted_label_gzsl[start:end] = predicted_label_gzsl[start:end].scatter_(1,idx_k,1) 197 | 198 | start = end 199 | 200 | acc_gzsl = compute_per_class_acc_gzsl_k(test_label_gzsl, predicted_label_gzsl, unseen_classes, in_package) 201 | #print('acc_zs: {} acc_zs_t: {}'.format(acc_zs,acc_zs_t)) 202 | return acc_gzsl,-1 203 | 204 | 205 | def compute_per_class_acc_k(test_label, predicted_label, nclass): 206 | acc_per_class = torch.FloatTensor(nclass).fill_(0) 207 | for i in range(nclass): 208 | idx = (test_label == i) 209 | acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float() 210 | return acc_per_class.mean().item() 211 | 212 | 213 | def compute_per_class_acc_gzsl_k(test_label, predicted_label, target_classes, in_package): 214 | device = in_package['device'] 215 | per_class_accuracies = torch.zeros(target_classes.size()[0]).float().to(device).detach() 216 | 217 | predicted_label = predicted_label.to(device) 218 | 219 | hit = test_label*predicted_label 220 | for i in range(target_classes.size()[0]): 221 | 222 | target = target_classes[i] 223 | n_pos = torch.sum(hit[:,target]) 224 | n_gt = torch.sum(test_label[:,target]) 225 | per_class_accuracies[i] = torch.div(n_pos.float(),n_gt.float()) 226 | #pdb.set_trace() 227 | return per_class_accuracies.mean().item() 228 | 229 | 230 | def eval_zs_gzsl_k(k,dataloader,model,device,bias_seen,bias_unseen,is_detect=False): 231 | model.eval() 232 | print('bias_seen {} bias_unseen {}'.format(bias_seen,bias_unseen)) 233 | test_seen_feature = dataloader.data['test_seen']['resnet_features'] 234 | test_seen_label = dataloader.data['test_seen']['labels'].to(device) 235 | 236 | test_unseen_feature = dataloader.data['test_unseen']['resnet_features'] 237 | test_unseen_label = dataloader.data['test_unseen']['labels'].to(device) 238 | 239 | seenclasses = dataloader.seenclasses 240 | unseenclasses = dataloader.unseenclasses 241 | 242 | batch_size = 100 243 | n_classes = dataloader.ntrain_class+dataloader.ntest_class 244 | in_package = {'model':model,'device':device, 'batch_size':batch_size,'num_class':n_classes} 245 | 246 | if is_detect: 247 | print("Measure novelty detection k: {}".format(k)) 248 | 249 | detection_mask = torch.zeros((n_classes,n_classes)).long().to(dataloader.device) 250 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 251 | detect_label[seenclasses]=1 252 | detection_mask[seenclasses,:] = detect_label 253 | 254 | detect_label = torch.zeros(n_classes).long().to(dataloader.device) 255 | detect_label[unseenclasses]=1 256 | detection_mask[unseenclasses,:]=detect_label 257 | in_package["detection_mask"]=detection_mask 258 | 259 | with torch.no_grad(): 260 | acc_seen = val_gzsl_k(k,test_seen_feature, test_seen_label, seenclasses, in_package,bias=bias_seen,is_detect=is_detect) 261 | acc_novel,acc_zs = val_zs_gzsl_k(k,test_unseen_feature, test_unseen_label, unseenclasses, in_package,bias = bias_unseen,is_detect=is_detect) 262 | 263 | if (acc_seen+acc_novel)>0: 264 | H = (2*acc_seen*acc_novel) / (acc_seen+acc_novel) 265 | else: 266 | H = 0 267 | 268 | return acc_seen, acc_novel, H, acc_zs 269 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class TransZero(nn.Module): 8 | def __init__(self, config, att, init_w2v_att, seenclass, unseenclass, 9 | is_bias=True, bias=1, is_conservative=True): 10 | super(TransZero, self).__init__() 11 | self.config = config 12 | self.dim_f = config.dim_f 13 | self.dim_v = config.dim_v 14 | self.nclass = config.num_class 15 | self.seenclass = seenclass 16 | self.unseenclass = unseenclass 17 | self.is_bias = is_bias 18 | self.is_conservative = is_conservative 19 | # class-level semantic vectors 20 | self.att = nn.Parameter(F.normalize(att), requires_grad=False) 21 | # GloVe features for attributes name 22 | self.V = nn.Parameter(F.normalize(init_w2v_att), requires_grad=True) 23 | # for self-calibration 24 | self.bias = nn.Parameter(torch.tensor(bias), requires_grad=False) 25 | mask_bias = np.ones((1, self.nclass)) 26 | mask_bias[:, self.seenclass.cpu().numpy()] *= -1 27 | self.mask_bias = nn.Parameter(torch.tensor( 28 | mask_bias, dtype=torch.float), requires_grad=False) 29 | # mapping 30 | self.W_1 = nn.Parameter(nn.init.normal_( 31 | torch.empty(self.dim_v, config.tf_common_dim)), requires_grad=True) 32 | # transformer 33 | self.transformer = Transformer( 34 | ec_layer=config.tf_ec_layer, 35 | dc_layer=config.tf_dc_layer, 36 | dim_com=config.tf_common_dim, 37 | dim_feedforward=config.tf_dim_feedforward, 38 | dropout=config.tf_dropout, 39 | SAtt=config.tf_SAtt, 40 | heads=config.tf_heads, 41 | aux_embed=config.tf_aux_embed) 42 | # for loss computation 43 | self.log_softmax_func = nn.LogSoftmax(dim=1) 44 | self.weight_ce = nn.Parameter(torch.eye(self.nclass), requires_grad=False) 45 | 46 | def forward(self, input, from_img=False): 47 | Fs = self.resnet101(input) if from_img else input 48 | # transformer-based visual-to-semantic embedding 49 | v2s_embed = self.forward_feature_transformer(Fs) 50 | # classification 51 | package = {'pred': self.forward_attribute(v2s_embed), 52 | 'embed': v2s_embed} 53 | package['S_pp'] = package['pred'] 54 | return package 55 | 56 | def forward_feature_transformer(self, Fs): 57 | # visual 58 | if len(Fs.shape) == 4: 59 | shape = Fs.shape 60 | Fs = Fs.reshape(shape[0], shape[1], shape[2] * shape[3]) 61 | Fs = F.normalize(Fs, dim=1) 62 | # attributes 63 | V_n = F.normalize(self.V) if self.config.normalize_V else self.V 64 | # locality-augmented visual features 65 | Trans_out = self.transformer(Fs, V_n) 66 | # embedding to semantic space 67 | embed = torch.einsum('iv,vf,bif->bi', V_n, self.W_1, Trans_out) 68 | return embed 69 | 70 | def forward_attribute(self, embed): 71 | embed = torch.einsum('ki,bi->bk', self.att, embed) 72 | self.vec_bias = self.mask_bias*self.bias 73 | embed = embed + self.vec_bias 74 | return embed 75 | 76 | def compute_loss_Self_Calibrate(self, in_package): 77 | S_pp = in_package['pred'] 78 | Prob_all = F.softmax(S_pp, dim=-1) 79 | Prob_unseen = Prob_all[:, self.unseenclass] 80 | assert Prob_unseen.size(1) == len(self.unseenclass) 81 | mass_unseen = torch.sum(Prob_unseen, dim=1) 82 | loss_pmp = -torch.log(torch.mean(mass_unseen)) 83 | return loss_pmp 84 | 85 | def compute_aug_cross_entropy(self, in_package): 86 | Labels = in_package['batch_label'] 87 | S_pp = in_package['pred'] 88 | 89 | if self.is_bias: 90 | S_pp = S_pp - self.vec_bias 91 | 92 | if not self.is_conservative: 93 | S_pp = S_pp[:, self.seenclass] 94 | Labels = Labels[:, self.seenclass] 95 | assert S_pp.size(1) == len(self.seenclass) 96 | 97 | Prob = self.log_softmax_func(S_pp) 98 | 99 | loss = -torch.einsum('bk,bk->b', Prob, Labels) 100 | loss = torch.mean(loss) 101 | return loss 102 | 103 | def compute_reg_loss(self, in_package): 104 | tgt = torch.matmul(in_package['batch_label'], self.att) 105 | embed = in_package['embed'] 106 | loss_reg = F.mse_loss(embed, tgt, reduction='mean') 107 | return loss_reg 108 | 109 | def compute_loss(self, in_package): 110 | if len(in_package['batch_label'].size()) == 1: 111 | in_package['batch_label'] = self.weight_ce[in_package['batch_label']] 112 | 113 | loss_CE = self.compute_aug_cross_entropy(in_package) 114 | loss_cal = self.compute_loss_Self_Calibrate(in_package) 115 | loss_reg = self.compute_reg_loss(in_package) 116 | 117 | loss = loss_CE + self.config.lambda_ * \ 118 | loss_cal + self.config.lambda_reg * loss_reg 119 | out_package = {'loss': loss, 'loss_CE': loss_CE, 120 | 'loss_cal': loss_cal, 'loss_reg': loss_reg} 121 | return out_package 122 | 123 | 124 | class Transformer(nn.Module): 125 | def __init__(self, ec_layer=1, dc_layer=1, dim_com=300, 126 | dim_feedforward=2048, dropout=0.1, heads=1, 127 | in_dim_cv=2048, in_dim_attr=300, SAtt=True, 128 | aux_embed=True): 129 | super(Transformer, self).__init__() 130 | # input embedding 131 | self.embed_cv = nn.Sequential(nn.Linear(in_dim_cv, dim_com)) 132 | if aux_embed: 133 | self.embed_cv_aux = nn.Sequential(nn.Linear(in_dim_cv, dim_com)) 134 | self.embed_attr = nn.Sequential(nn.Linear(in_dim_attr, dim_com)) 135 | # transformer encoder 136 | self.transformer_encoder = MultiLevelEncoder_woPad(N=ec_layer, 137 | d_model=dim_com, 138 | h=1, 139 | d_k=dim_com, 140 | d_v=dim_com, 141 | d_ff=dim_feedforward, 142 | dropout=dropout) 143 | # transformer decoder 144 | decoder_layer = TransformerDecoderLayer(d_model=dim_com, 145 | nhead=heads, 146 | dim_feedforward=dim_feedforward, 147 | dropout=dropout, 148 | SAtt=SAtt) 149 | self.transformer_decoder = nn.TransformerDecoder( 150 | decoder_layer, num_layers=dc_layer) 151 | 152 | def forward(self, f_cv, f_attr): 153 | # linearly map to common dim 154 | h_cv = self.embed_cv(f_cv.permute(0, 2, 1)) 155 | h_attr = self.embed_attr(f_attr) 156 | h_attr_batch = h_attr.unsqueeze(0).repeat(f_cv.shape[0], 1, 1) 157 | # visual encoder 158 | memory = self.transformer_encoder(h_cv).permute(1, 0, 2) 159 | # attribute-visual decoder 160 | out = self.transformer_decoder(h_attr_batch.permute(1, 0, 2), memory) 161 | return out.permute(1, 0, 2) 162 | 163 | 164 | class EncoderLayer(nn.Module): 165 | def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 166 | dropout=.1, identity_map_reordering=False, 167 | attention_module=None, attention_module_kwargs=None): 168 | super(EncoderLayer, self).__init__() 169 | self.identity_map_reordering = identity_map_reordering 170 | self.mhatt = MultiHeadGeometryAttention(d_model, d_k, d_v, h, dropout, 171 | identity_map_reordering=identity_map_reordering, 172 | attention_module=attention_module, 173 | attention_module_kwargs=attention_module_kwargs) 174 | self.dropout = nn.Dropout(dropout) 175 | self.lnorm = nn.LayerNorm(d_model) 176 | self.pwff = PositionWiseFeedForward( 177 | d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering) 178 | 179 | def forward(self, queries, keys, values, relative_geometry_weights, 180 | attention_mask=None, attention_weights=None, pos=None): 181 | q, k = (queries + pos, keys + 182 | pos) if pos is not None else (queries, keys) 183 | att = self.mhatt(q, k, values, relative_geometry_weights, 184 | attention_mask, attention_weights) 185 | att = self.lnorm(queries + self.dropout(att)) 186 | ff = self.pwff(att) 187 | return ff 188 | 189 | 190 | class MultiLevelEncoder_woPad(nn.Module): 191 | def __init__(self, N, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, 192 | dropout=.1, identity_map_reordering=False, 193 | attention_module=None, attention_module_kwargs=None): 194 | super(MultiLevelEncoder_woPad, self).__init__() 195 | self.d_model = d_model 196 | self.dropout = dropout 197 | self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout, 198 | identity_map_reordering=identity_map_reordering, 199 | attention_module=attention_module, 200 | attention_module_kwargs=attention_module_kwargs) 201 | for _ in range(N)]) 202 | 203 | self.WGs = nn.ModuleList( 204 | [nn.Linear(64, 1, bias=True) for _ in range(h)]) 205 | 206 | def forward(self, input, attention_mask=None, attention_weights=None, pos=None): 207 | relative_geometry_embeddings = BoxRelationalEmbedding( 208 | input, grid_size=(14, 14)) 209 | flatten_relative_geometry_embeddings = relative_geometry_embeddings.view( 210 | -1, 64) 211 | box_size_per_head = list(relative_geometry_embeddings.shape[:3]) 212 | box_size_per_head.insert(1, 1) 213 | relative_geometry_weights_per_head = [layer( 214 | flatten_relative_geometry_embeddings).view(box_size_per_head) for layer in self.WGs] 215 | relative_geometry_weights = torch.cat( 216 | (relative_geometry_weights_per_head), 1) 217 | relative_geometry_weights = F.relu(relative_geometry_weights) 218 | out = input 219 | for layer in self.layers: 220 | out = layer(out, out, out, relative_geometry_weights, 221 | attention_mask, attention_weights, pos=pos) 222 | return out 223 | 224 | 225 | class TransformerDecoderLayer(nn.TransformerDecoderLayer): 226 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 227 | activation="relu", SAtt=True): 228 | super(TransformerDecoderLayer, self).__init__(d_model, nhead, 229 | dim_feedforward=dim_feedforward, 230 | dropout=dropout, 231 | activation=activation) 232 | self.SAtt = SAtt 233 | 234 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 235 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 236 | if self.SAtt: 237 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 238 | key_padding_mask=tgt_key_padding_mask)[0] 239 | tgt = self.norm1(tgt + self.dropout1(tgt2)) 240 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 241 | key_padding_mask=memory_key_padding_mask)[0] 242 | tgt = tgt + self.dropout2(tgt2) 243 | tgt = self.norm2(tgt) 244 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 245 | tgt = tgt + self.dropout3(tgt2) 246 | tgt = self.norm3(tgt) 247 | return tgt 248 | 249 | 250 | def get_relative_pos(x, batch_size, norm_len): 251 | x = x.view(1, -1, 1).expand(batch_size, -1, -1) 252 | return x / norm_len 253 | 254 | 255 | def get_grids_pos(batch_size, seq_len, grid_size=(7, 7)): 256 | assert seq_len == grid_size[0] * grid_size[1] 257 | x = torch.arange(0, grid_size[0]).float().cuda() 258 | y = torch.arange(0, grid_size[1]).float().cuda() 259 | px_min = x.view(-1, 1).expand(-1, grid_size[0]).contiguous().view(-1) 260 | py_min = y.view(1, -1).expand(grid_size[1], -1).contiguous().view(-1) 261 | px_max = px_min + 1 262 | py_max = py_min + 1 263 | rpx_min = get_relative_pos(px_min, batch_size, grid_size[0]) 264 | rpy_min = get_relative_pos(py_min, batch_size, grid_size[1]) 265 | rpx_max = get_relative_pos(px_max, batch_size, grid_size[0]) 266 | rpy_max = get_relative_pos(py_max, batch_size, grid_size[1]) 267 | return rpx_min, rpy_min, rpx_max, rpy_max 268 | 269 | 270 | def BoxRelationalEmbedding(f_g, dim_g=64, wave_len=1000, trignometric_embedding=True, 271 | grid_size=(7, 7)): 272 | batch_size, seq_len = f_g.size(0), f_g.size(1) 273 | x_min, y_min, x_max, y_max = get_grids_pos(batch_size, seq_len, grid_size) 274 | cx = (x_min + x_max) * 0.5 275 | cy = (y_min + y_max) * 0.5 276 | w = (x_max - x_min) + 1. 277 | h = (y_max - y_min) + 1. 278 | delta_x = cx - cx.view(batch_size, 1, -1) 279 | delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3) 280 | delta_x = torch.log(delta_x) 281 | delta_y = cy - cy.view(batch_size, 1, -1) 282 | delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3) 283 | delta_y = torch.log(delta_y) 284 | delta_w = torch.log(w / w.view(batch_size, 1, -1)) 285 | delta_h = torch.log(h / h.view(batch_size, 1, -1)) 286 | matrix_size = delta_h.size() 287 | delta_x = delta_x.view(batch_size, matrix_size[1], matrix_size[2], 1) 288 | delta_y = delta_y.view(batch_size, matrix_size[1], matrix_size[2], 1) 289 | delta_w = delta_w.view(batch_size, matrix_size[1], matrix_size[2], 1) 290 | delta_h = delta_h.view(batch_size, matrix_size[1], matrix_size[2], 1) 291 | position_mat = torch.cat((delta_x, delta_y, delta_w, delta_h), -1) 292 | if trignometric_embedding == True: 293 | feat_range = torch.arange(dim_g / 8).cuda() 294 | dim_mat = feat_range / (dim_g / 8) 295 | dim_mat = 1. / (torch.pow(wave_len, dim_mat)) 296 | dim_mat = dim_mat.view(1, 1, 1, -1) 297 | position_mat = position_mat.view( 298 | batch_size, matrix_size[1], matrix_size[2], 4, -1) 299 | position_mat = 100. * position_mat 300 | mul_mat = position_mat * dim_mat 301 | mul_mat = mul_mat.view(batch_size, matrix_size[1], matrix_size[2], -1) 302 | sin_mat = torch.sin(mul_mat) 303 | cos_mat = torch.cos(mul_mat) 304 | embedding = torch.cat((sin_mat, cos_mat), -1) 305 | else: 306 | embedding = position_mat 307 | return (embedding) 308 | 309 | 310 | class ScaledDotProductGeometryAttention(nn.Module): 311 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, comment=None): 312 | super(ScaledDotProductGeometryAttention, self).__init__() 313 | self.fc_q = nn.Linear(d_model, h * d_k) 314 | self.fc_k = nn.Linear(d_model, h * d_k) 315 | self.fc_v = nn.Linear(d_model, h * d_v) 316 | self.fc_o = nn.Linear(h * d_v, d_model) 317 | self.dropout = nn.Dropout(dropout) 318 | self.d_model = d_model 319 | self.d_k = d_k 320 | self.d_v = d_v 321 | self.h = h 322 | self.init_weights() 323 | self.comment = comment 324 | 325 | def init_weights(self): 326 | nn.init.xavier_uniform_(self.fc_q.weight) 327 | nn.init.xavier_uniform_(self.fc_k.weight) 328 | nn.init.xavier_uniform_(self.fc_v.weight) 329 | nn.init.xavier_uniform_(self.fc_o.weight) 330 | nn.init.constant_(self.fc_q.bias, 0) 331 | nn.init.constant_(self.fc_k.bias, 0) 332 | nn.init.constant_(self.fc_v.bias, 0) 333 | nn.init.constant_(self.fc_o.bias, 0) 334 | 335 | def forward(self, queries, keys, values, box_relation_embed_matrix, 336 | attention_mask=None, attention_weights=None): 337 | b_s, nq = queries.shape[:2] 338 | nk = keys.shape[1] 339 | q = self.fc_q(queries).view(b_s, nq, self.h, 340 | self.d_k).permute(0, 2, 1, 3) 341 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) 342 | v = self.fc_v(values).view(b_s, nk, self.h, 343 | self.d_v).permute(0, 2, 1, 3) 344 | att = torch.matmul(q, k) / np.sqrt(self.d_k) 345 | if attention_weights is not None: 346 | att = att * attention_weights 347 | if attention_mask is not None: 348 | att = att.masked_fill(attention_mask, -np.inf) 349 | w_g = box_relation_embed_matrix 350 | w_a = att 351 | w_mn = - w_g + w_a 352 | w_mn = torch.softmax(w_mn, -1) 353 | att = self.dropout(w_mn) 354 | out = torch.matmul(att, v).permute( 355 | 0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) 356 | out = self.fc_o(out) 357 | return out 358 | 359 | 360 | class MultiHeadGeometryAttention(nn.Module): 361 | def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, 362 | can_be_stateful=False, attention_module=None, 363 | attention_module_kwargs=None, comment=None): 364 | super(MultiHeadGeometryAttention, self).__init__() 365 | self.identity_map_reordering = identity_map_reordering 366 | self.attention = ScaledDotProductGeometryAttention( 367 | d_model=d_model, d_k=d_k, d_v=d_v, h=h, comment=comment) 368 | self.dropout = nn.Dropout(p=dropout) 369 | self.layer_norm = nn.LayerNorm(d_model) 370 | self.can_be_stateful = can_be_stateful 371 | if self.can_be_stateful: 372 | self.register_state('running_keys', torch.zeros((0, d_model))) 373 | self.register_state('running_values', torch.zeros((0, d_model))) 374 | 375 | def forward(self, queries, keys, values, relative_geometry_weights, 376 | attention_mask=None, attention_weights=None): 377 | if self.can_be_stateful and self._is_stateful: 378 | self.running_keys = torch.cat([self.running_keys, keys], 1) 379 | keys = self.running_keys 380 | self.running_values = torch.cat([self.running_values, values], 1) 381 | values = self.running_values 382 | if self.identity_map_reordering: 383 | q_norm = self.layer_norm(queries) 384 | k_norm = self.layer_norm(keys) 385 | v_norm = self.layer_norm(values) 386 | out = self.attention(q_norm, k_norm, v_norm, relative_geometry_weights, 387 | attention_mask, attention_weights) 388 | out = queries + self.dropout(torch.relu(out)) 389 | else: 390 | out = self.attention(queries, keys, values, relative_geometry_weights, 391 | attention_mask, attention_weights) 392 | out = self.dropout(out) 393 | out = self.layer_norm(queries + out) 394 | return out 395 | 396 | 397 | class PositionWiseFeedForward(nn.Module): 398 | def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): 399 | super(PositionWiseFeedForward, self).__init__() 400 | self.identity_map_reordering = identity_map_reordering 401 | self.fc1 = nn.Linear(d_model, d_ff) 402 | self.fc2 = nn.Linear(d_ff, d_model) 403 | self.dropout = nn.Dropout(p=dropout) 404 | self.dropout_2 = nn.Dropout(p=dropout) 405 | self.layer_norm = nn.LayerNorm(d_model) 406 | 407 | def forward(self, input): 408 | if self.identity_map_reordering: 409 | out = self.layer_norm(input) 410 | out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) 411 | out = input + self.dropout(torch.relu(out)) 412 | else: 413 | out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) 414 | out = self.dropout(out) 415 | out = self.layer_norm(input + out) 416 | return out 417 | 418 | 419 | if __name__ == '__main__': 420 | pass 421 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import h5py 3 | import argparse 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import scipy.io as sio 9 | import torchvision.models.resnet as models 10 | from torchvision import datasets, transforms 11 | from torch.utils.data import Dataset, DataLoader 12 | from PIL import Image 13 | 14 | 15 | class CustomedDataset(Dataset): 16 | def __init__(self, dataset, img_dir, file_paths, transform=None): 17 | self.dataset = dataset 18 | self.matcontent = sio.loadmat(file_paths) 19 | self.image_files = np.squeeze(self.matcontent['image_files']) 20 | self.img_dir = img_dir 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return len(self.image_files) 25 | 26 | def __getitem__(self, idx): 27 | image_file = self.image_files[idx][0] 28 | if self.dataset == 'CUB': 29 | split_idx = 6 30 | elif self.dataset == 'SUN': 31 | split_idx = 7 32 | elif self.dataset == 'AWA2': 33 | split_idx = 5 34 | image_file = os.path.join(self.img_dir, 35 | '/'.join(image_file.split('/')[split_idx:])) 36 | image = Image.open(image_file) 37 | if image.mode != 'RGB': 38 | image = image.convert('RGB') 39 | if self.transform: 40 | image = self.transform(image) 41 | return image 42 | 43 | 44 | def extract_features(config): 45 | 46 | img_dir = f'data/{config.dataset}' 47 | file_paths = f'data/xlsa17/data/{config.dataset}/res101.mat' 48 | save_path = f'data/{config.dataset}/feature_map_ResNet_101_{config.dataset}.hdf5' 49 | attribute_path = f'w2v/{config.dataset}_attribute.pkl' 50 | 51 | # region feature extractor 52 | resnet101 = models.resnet101(pretrained=True).to(config.device) 53 | resnet101 = nn.Sequential(*list(resnet101.children())[:-2]).eval() 54 | 55 | data_transforms = transforms.Compose([ 56 | transforms.Resize(448), 57 | transforms.CenterCrop(448), 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 60 | 61 | Dataset = CustomedDataset(config.dataset, img_dir, file_paths, data_transforms) 62 | dataset_loader = torch.utils.data.DataLoader(Dataset, 63 | batch_size=config.batch_size, 64 | shuffle=False, 65 | num_workers=config.nun_workers) 66 | 67 | with torch.no_grad(): 68 | all_features = [] 69 | for _, imgs in enumerate(dataset_loader): 70 | imgs = imgs.to(config.device) 71 | features = resnet101(imgs) 72 | all_features.append(features.cpu().numpy()) 73 | all_features = np.concatenate(all_features, axis=0) 74 | 75 | # get remaining metadata 76 | matcontent = Dataset.matcontent 77 | labels = matcontent['labels'].astype(int).squeeze() - 1 78 | 79 | split_path = os.path.join(f'data/xlsa17/data/{config.dataset}/att_splits.mat') 80 | matcontent = sio.loadmat(split_path) 81 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 82 | # train_loc = matcontent['train_loc'].squeeze() - 1 83 | # val_unseen_loc = matcontent['val_loc'].squeeze() - 1 84 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 85 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 86 | att = matcontent['att'].T 87 | original_att = matcontent['original_att'].T 88 | 89 | # construct attribute w2v 90 | with open(attribute_path,'rb') as f: 91 | w2v_att = pickle.load(f) 92 | if config.dataset == 'CUB': 93 | assert w2v_att.shape == (312,300) 94 | elif config.dataset == 'SUN': 95 | assert w2v_att.shape == (102,300) 96 | elif config.dataset == 'AWA2': 97 | assert w2v_att.shape == (85,300) 98 | 99 | compression = 'gzip' if config.compression else None 100 | f = h5py.File(save_path, 'w') 101 | f.create_dataset('feature_map', data=all_features,compression=compression) 102 | f.create_dataset('labels', data=labels,compression=compression) 103 | f.create_dataset('trainval_loc', data=trainval_loc,compression=compression) 104 | # f.create_dataset('train_loc', data=train_loc,compression=compression) 105 | # f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression=compression) 106 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression=compression) 107 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression=compression) 108 | f.create_dataset('att', data=att,compression=compression) 109 | f.create_dataset('original_att', data=original_att,compression=compression) 110 | f.create_dataset('w2v_att', data=w2v_att,compression=compression) 111 | f.close() 112 | 113 | 114 | if __name__=='__main__': 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--dataset', '-d', type=str, default='AWA2') 117 | parser.add_argument('--compression', '-c', action='store_true', default=False) 118 | parser.add_argument('--batch_size', '-b', type=int, default=200) 119 | parser.add_argument('--device', '-g', type=str, default='cuda:0') 120 | parser.add_argument('--nun_workers', '-n', type=int, default='16') 121 | config = parser.parse_args() 122 | extract_features(config) 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.6.2 2 | h5py==3.2.1 3 | numpy==1.20.2 4 | pandas==1.2.3 5 | torchvision==0.9.0 6 | torch==1.8.0 7 | Pillow==9.1.0 8 | scikit_learn==1.0.2 9 | -------------------------------------------------------------------------------- /train_awa2.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import numpy as np 6 | from model import TransZero 7 | from dataset import AWA2DataLoader 8 | from helper_func import eval_zs_gzsl 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZero', config='wandb_config/awa2_gzsl.yaml') 12 | config = wandb.config 13 | print('Config file from wandb:', config) 14 | 15 | # load dataset 16 | dataloader = AWA2DataLoader('.', config.device) 17 | 18 | # set random seed 19 | seed = config.random_seed 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | init_w2v_att = dataloader.w2v_att 25 | att = dataloader.att 26 | att[att<0] = 0 27 | normalize_att = dataloader.normalize_att 28 | 29 | # TransZero model 30 | model = TransZero(config, att, init_w2v_att, 31 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 32 | optimizer = optim.SGD(model.parameters(), lr=0.0001, weight_decay=0.0001, momentum=0.) 33 | 34 | # main loop 35 | niters = dataloader.ntrain * config.epochs//config.batch_size 36 | report_interval = niters//config.epochs 37 | best_performance = [0, 0, 0, 0] 38 | best_performance_zsl = 0 39 | best_performance_zsl = 0 40 | for i in range(0, niters): 41 | model.train() 42 | optimizer.zero_grad() 43 | 44 | batch_label, batch_feature, batch_att = dataloader.next_batch(config.batch_size) 45 | out_package = model(batch_feature) 46 | 47 | in_package = out_package 48 | in_package['batch_label'] = batch_label 49 | 50 | out_package = model.compute_loss(in_package) 51 | loss, loss_CE, loss_cal, loss_reg = out_package['loss'], out_package[ 52 | 'loss_CE'], out_package['loss_cal'], out_package['loss_reg'] 53 | 54 | loss.backward() 55 | optimizer.step() 56 | 57 | # report result 58 | if i % report_interval == 0: 59 | print('-'*30) 60 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 61 | dataloader, model, config.device, bias_seen=0, bias_unseen=0) 62 | 63 | if H > best_performance[2]: 64 | best_performance = [acc_novel, acc_seen, H, acc_zs] 65 | if acc_zs > best_performance_zsl: 66 | best_performance_zsl = acc_zs 67 | 68 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 69 | 'loss_reg=%.3f | acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | ' 70 | 'acc_zs=%.3f' % ( 71 | i, int(i//report_interval), 72 | loss.item(), loss_CE.item(), loss_cal.item(), 73 | loss_reg.item(), 74 | best_performance[0], best_performance[1], 75 | best_performance[2], best_performance_zsl)) 76 | 77 | wandb.log({ 78 | 'iter': i, 79 | 'loss': loss.item(), 80 | 'loss_CE': loss_CE.item(), 81 | 'loss_cal': loss_cal.item(), 82 | 'loss_reg': loss_reg.item(), 83 | 'acc_unseen': acc_novel, 84 | 'acc_seen': acc_seen, 85 | 'H': H, 86 | 'acc_zs': acc_zs, 87 | 'best_acc_unseen': best_performance[0], 88 | 'best_acc_seen': best_performance[1], 89 | 'best_H': best_performance[2], 90 | 'best_acc_zs': best_performance_zsl 91 | }) 92 | 93 | -------------------------------------------------------------------------------- /train_cub.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from model import TransZero 5 | from dataset import CUBDataLoader 6 | from helper_func import eval_zs_gzsl 7 | import numpy as np 8 | import wandb 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZero', config='wandb_config/cub_gzsl.yaml') 12 | config = wandb.config 13 | print('Config file from wandb:', config) 14 | 15 | # load dataset 16 | dataloader = CUBDataLoader('.', config.device, is_balance=False) 17 | 18 | # set random seed 19 | seed = config.random_seed 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | # TransZero model 25 | model = TransZero(config, dataloader.att, dataloader.w2v_att, 26 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 27 | optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001) 28 | 29 | # main loop 30 | niters = dataloader.ntrain * config.epochs//config.batch_size 31 | report_interval = niters//config.epochs 32 | best_performance = [0, 0, 0, 0] 33 | best_performance_zsl = 0 34 | for i in range(0, niters): 35 | model.train() 36 | optimizer.zero_grad() 37 | 38 | batch_label, batch_feature, batch_att = dataloader.next_batch( 39 | config.batch_size) 40 | out_package = model(batch_feature) 41 | 42 | in_package = out_package 43 | in_package['batch_label'] = batch_label 44 | 45 | out_package = model.compute_loss(in_package) 46 | loss, loss_CE, loss_cal, loss_reg = out_package['loss'], out_package[ 47 | 'loss_CE'], out_package['loss_cal'], out_package['loss_reg'] 48 | 49 | loss.backward() 50 | optimizer.step() 51 | 52 | # report result 53 | if i % report_interval == 0: 54 | print('-'*30) 55 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 56 | dataloader, model, config.device, batch_size=config.batch_size) 57 | 58 | if H > best_performance[2]: 59 | best_performance = [acc_novel, acc_seen, H, acc_zs] 60 | if acc_zs > best_performance_zsl: 61 | best_performance_zsl = acc_zs 62 | 63 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 64 | 'loss_reg=%.3f | acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | ' 65 | 'acc_zs=%.3f' % ( 66 | i, int(i//report_interval), 67 | loss.item(), loss_CE.item(), loss_cal.item(), 68 | loss_reg.item(), 69 | best_performance[0], best_performance[1], 70 | best_performance[2], best_performance_zsl)) 71 | 72 | wandb.log({ 73 | 'iter': i, 74 | 'loss': loss.item(), 75 | 'loss_CE': loss_CE.item(), 76 | 'loss_cal': loss_cal.item(), 77 | 'loss_reg': loss_reg.item(), 78 | 'acc_unseen': acc_novel, 79 | 'acc_seen': acc_seen, 80 | 'H': H, 81 | 'acc_zs': acc_zs, 82 | 'best_acc_unseen': best_performance[0], 83 | 'best_acc_seen': best_performance[1], 84 | 'best_H': best_performance[2], 85 | 'best_acc_zs': best_performance_zsl 86 | }) 87 | -------------------------------------------------------------------------------- /train_sun.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import numpy as np 6 | from model import TransZero 7 | from dataset import SUNDataLoader 8 | from helper_func import eval_zs_gzsl 9 | 10 | # init wandb from config file 11 | wandb.init(project='TransZero', config='wandb_config/sun_gzsl.yaml') 12 | config = wandb.config 13 | print('Config file from wandb:', config) 14 | 15 | # load dataset 16 | dataloader = SUNDataLoader('.', config.device) 17 | 18 | # set random seed 19 | seed = config.random_seed 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | # TransZero model 25 | model = TransZero(config, dataloader.att, dataloader.w2v_att, 26 | dataloader.seenclasses, dataloader.unseenclasses).to(config.device) 27 | optimizer = optim.SGD(model.parameters(), lr=0.0001, weight_decay=0.0001, momentum=0.9) 28 | 29 | # main loop 30 | niters = dataloader.ntrain * config.epochs//config.batch_size 31 | report_interval = niters//config.epochs 32 | best_performance = [0, 0, 0, 0] 33 | best_performance_zsl = 0 34 | for i in range(0, niters): 35 | model.train() 36 | optimizer.zero_grad() 37 | 38 | batch_label, batch_feature, batch_att = dataloader.next_batch(config.batch_size) 39 | out_package = model(batch_feature) 40 | 41 | in_package = out_package 42 | in_package['batch_label'] = batch_label 43 | 44 | out_package=model.compute_loss(in_package) 45 | loss, loss_CE, loss_cal, loss_reg = out_package['loss'], out_package[ 46 | 'loss_CE'], out_package['loss_cal'], out_package['loss_reg'] 47 | 48 | loss.backward() 49 | optimizer.step() 50 | 51 | # report result 52 | if i % report_interval==0: 53 | print('-'*30) 54 | acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl( 55 | dataloader, model, config.device, bias_seen=0, bias_unseen=0) 56 | 57 | if H > best_performance[2]: 58 | best_performance = [acc_novel, acc_seen, H, acc_zs] 59 | if acc_zs > best_performance_zsl: 60 | best_performance_zsl = acc_zs 61 | 62 | print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, ' 63 | 'loss_reg=%.3f | acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | ' 64 | 'acc_zs=%.3f' % ( 65 | i, int(i//report_interval), 66 | loss.item(), loss_CE.item(), loss_cal.item(), 67 | loss_reg.item(), 68 | best_performance[0], best_performance[1], 69 | best_performance[2], best_performance_zsl)) 70 | 71 | wandb.log({ 72 | 'iter': i, 73 | 'loss': loss.item(), 74 | 'loss_CE': loss_CE.item(), 75 | 'loss_cal': loss_cal.item(), 76 | 'loss_reg': loss_reg.item(), 77 | 'acc_unseen': acc_novel, 78 | 'acc_seen': acc_seen, 79 | 'H': H, 80 | 'acc_zs': acc_zs, 81 | 'best_acc_unseen': best_performance[0], 82 | 'best_acc_seen': best_performance[1], 83 | 'best_H': best_performance[2], 84 | 'best_acc_zs': best_performance_zsl 85 | }) 86 | -------------------------------------------------------------------------------- /w2v/AWA2_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero/7c10cdbe76d36ec209b59127a6b1bb02fbd93852/w2v/AWA2_attribute.pkl -------------------------------------------------------------------------------- /w2v/CUB_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero/7c10cdbe76d36ec209b59127a6b1bb02fbd93852/w2v/CUB_attribute.pkl -------------------------------------------------------------------------------- /w2v/SUN_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shiming-chen/TransZero/7c10cdbe76d36ec209b59127a6b1bb02fbd93852/w2v/SUN_attribute.pkl -------------------------------------------------------------------------------- /wandb_config/awa2_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: AWA2 3 | num_class: 4 | value: 50 5 | num_attribute: 6 | value: 85 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 9093 25 | lambda_: 26 | value: 1.2 27 | lambda_reg: 28 | value: 0.00005 29 | normalize_V: 30 | value: True 31 | tf_SAtt: 32 | value: True 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 512 41 | tf_aux_embed: 42 | value: True 43 | tf_dim_feedforward: 44 | value: 512 45 | tf_dropout: 46 | value: 0.3 -------------------------------------------------------------------------------- /wandb_config/awa2_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: AWA2 3 | num_class: 4 | value: 50 5 | num_attribute: 6 | value: 85 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 17 25 | lambda_: 26 | value: 2 27 | lambda_reg: 28 | value: 0.0005 29 | normalize_V: 30 | value: True 31 | tf_SAtt: 32 | value: True 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 300 41 | tf_aux_embed: 42 | value: True 43 | tf_dim_feedforward: 44 | value: 512 45 | tf_dropout: 46 | value: 0.5 -------------------------------------------------------------------------------- /wandb_config/cub_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: CUB 3 | num_class: 4 | value: 200 5 | num_attribute: 6 | value: 312 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 5 25 | lambda_: 26 | value: 0.3 27 | lambda_reg: 28 | value: 0.005 29 | normalize_V: 30 | value: False 31 | tf_SAtt: 32 | value: True 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 300 41 | tf_aux_embed: 42 | value: True 43 | tf_dim_feedforward: 44 | value: 512 45 | tf_dropout: 46 | value: 0.4 -------------------------------------------------------------------------------- /wandb_config/cub_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: CUB 3 | num_class: 4 | value: 200 5 | num_attribute: 6 | value: 312 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 200 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 5 25 | lambda_: 26 | value: 0.3 27 | lambda_reg: 28 | value: 0.005 29 | normalize_V: 30 | value: False 31 | tf_SAtt: 32 | value: True 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 300 41 | tf_aux_embed: 42 | value: True 43 | tf_dim_feedforward: 44 | value: 512 45 | tf_dropout: 46 | value: 0.4 -------------------------------------------------------------------------------- /wandb_config/sun_czsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: SUN 3 | num_class: 4 | value: 717 5 | num_attribute: 6 | value: 102 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 400 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 6 25 | lambda_: 26 | value: 0.25 27 | lambda_reg: 28 | value: 0.005 29 | normalize_V: 30 | value: False 31 | tf_SAtt: 32 | value: False 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 128 41 | tf_aux_embed: 42 | value: False 43 | tf_dim_feedforward: 44 | value: 2048 45 | tf_dropout: 46 | value: 0.3 -------------------------------------------------------------------------------- /wandb_config/sun_gzsl.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | value: SUN 3 | num_class: 4 | value: 717 5 | num_attribute: 6 | value: 102 7 | img_size: 8 | value: 448 9 | resnet_region: 10 | value: 196 11 | dim_f: 12 | value: 2048 13 | dim_v: 14 | value: 300 15 | device: 16 | value: cuda:0 17 | batch_size: 18 | value: 50 19 | epochs: 20 | value: 400 21 | use_unzip: 22 | value: True 23 | random_seed: 24 | value: 6 25 | lambda_: 26 | value: 0.25 27 | lambda_reg: 28 | value: 0.001 29 | normalize_V: 30 | value: False 31 | tf_SAtt: 32 | value: False 33 | tf_ec_layer: 34 | value: 1 35 | tf_dc_layer: 36 | value: 1 37 | tf_heads: 38 | value: 1 39 | tf_common_dim: 40 | value: 128 41 | tf_aux_embed: 42 | value: False 43 | tf_dim_feedforward: 44 | value: 2048 45 | tf_dropout: 46 | value: 0.3 --------------------------------------------------------------------------------