├── AwA1_RN.py ├── AwA2_RN.py ├── CUB_RN.py └── README.md /AwA1_RN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader,TensorDataset 7 | import numpy as np 8 | import scipy.io as sio 9 | import math 10 | import argparse 11 | import random 12 | import os 13 | from sklearn.metrics import accuracy_score 14 | import pdb 15 | 16 | # '/home/lz/Workspace/ZSL/data/Animals_with_Attributes2', 17 | 18 | parser = argparse.ArgumentParser(description="Zero Shot Learning") 19 | parser.add_argument("-b","--batch_size",type = int, default = 32) 20 | parser.add_argument("-e","--episode",type = int, default= 500000) 21 | parser.add_argument("-t","--test_episode", type = int, default = 1000) 22 | parser.add_argument("-l","--learning_rate", type = float, default = 1e-5) 23 | parser.add_argument("-g","--gpu",type=int, default=0) 24 | args = parser.parse_args() 25 | 26 | 27 | # Hyper Parameters 28 | 29 | BATCH_SIZE = args.batch_size 30 | EPISODE = args.episode 31 | TEST_EPISODE = args.test_episode 32 | LEARNING_RATE = args.learning_rate 33 | GPU = args.gpu 34 | 35 | class AttributeNetwork(nn.Module): 36 | """docstring for RelationNetwork""" 37 | def __init__(self,input_size,hidden_size,output_size): 38 | super(AttributeNetwork, self).__init__() 39 | self.fc1 = nn.Linear(input_size,hidden_size) 40 | self.fc2 = nn.Linear(hidden_size,output_size) 41 | 42 | def forward(self,x): 43 | 44 | x = F.relu(self.fc1(x)) 45 | x = F.relu(self.fc2(x)) 46 | 47 | return x 48 | 49 | class RelationNetwork(nn.Module): 50 | """docstring for RelationNetwork""" 51 | def __init__(self,input_size,hidden_size,): 52 | super(RelationNetwork, self).__init__() 53 | self.fc1 = nn.Linear(input_size,hidden_size) 54 | self.fc2 = nn.Linear(hidden_size,1) 55 | 56 | def forward(self,x): 57 | 58 | x = F.relu(self.fc1(x)) 59 | x = F.sigmoid(self.fc2(x)) 60 | return x 61 | 62 | 63 | 64 | def main(): 65 | # step 1: init dataset 66 | print("init dataset") 67 | 68 | dataroot = './data' 69 | dataset = 'AwA1_data' 70 | image_embedding = 'res101' 71 | class_embedding = 'original_att' 72 | 73 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat") 74 | feature = matcontent['features'].T 75 | label = matcontent['labels'].astype(int).squeeze() - 1 76 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding + "_splits.mat") 77 | # numpy array index starts from 0, matlab starts from 1 78 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 79 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 80 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 81 | 82 | attribute = matcontent['att'].T 83 | 84 | x = feature[trainval_loc] # train_features 85 | train_label = label[trainval_loc].astype(int) # train_label 86 | att = attribute[train_label] # train attributes 87 | 88 | x_test = feature[test_unseen_loc] # test_feature 89 | test_label = label[test_unseen_loc].astype(int) # test_label 90 | x_test_seen = feature[test_seen_loc] #test_seen_feature 91 | test_label_seen = label[test_seen_loc].astype(int) # test_seen_label 92 | test_id = np.unique(test_label) # test_id 93 | att_pro = attribute[test_id] # test_attribute 94 | 95 | 96 | # train set 97 | train_features=torch.from_numpy(x) 98 | print(train_features.shape) 99 | 100 | train_label=torch.from_numpy(train_label).unsqueeze(1) 101 | print(train_label.shape) 102 | 103 | # attributes 104 | all_attributes=np.array(attribute) 105 | print(all_attributes.shape) 106 | 107 | attributes = torch.from_numpy(attribute) 108 | # test set 109 | 110 | test_features=torch.from_numpy(x_test) 111 | print(test_features.shape) 112 | 113 | test_label=torch.from_numpy(test_label).unsqueeze(1) 114 | print(test_label.shape) 115 | 116 | testclasses_id = np.array(test_id) 117 | print(testclasses_id.shape) 118 | 119 | test_attributes = torch.from_numpy(att_pro).float() 120 | print(test_attributes.shape) 121 | 122 | 123 | test_seen_features = torch.from_numpy(x_test_seen) 124 | print(test_seen_features.shape) 125 | 126 | test_seen_label = torch.from_numpy(test_label_seen) 127 | 128 | 129 | 130 | train_data = TensorDataset(train_features,train_label) 131 | 132 | 133 | # init network 134 | print("init networks") 135 | attribute_network = AttributeNetwork(85,1024,2048) 136 | relation_network = RelationNetwork(4096,400) 137 | 138 | attribute_network.cuda(GPU) 139 | relation_network.cuda(GPU) 140 | 141 | attribute_network_optim = torch.optim.Adam(attribute_network.parameters(),lr=LEARNING_RATE,weight_decay=1e-5) 142 | attribute_network_scheduler = StepLR(attribute_network_optim,step_size=200000,gamma=0.5) 143 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 144 | relation_network_scheduler = StepLR(relation_network_optim,step_size=200000,gamma=0.5) 145 | 146 | 147 | print("training...") 148 | last_accuracy = 0.0 149 | 150 | for episode in range(EPISODE): 151 | attribute_network_scheduler.step(episode) 152 | relation_network_scheduler.step(episode) 153 | 154 | train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True) 155 | 156 | batch_features,batch_labels = train_loader.__iter__().next() 157 | 158 | sample_labels = [] 159 | for label in batch_labels.numpy(): 160 | if label not in sample_labels: 161 | sample_labels.append(label) 162 | # pdb.set_trace() 163 | 164 | sample_attributes = torch.Tensor([all_attributes[i] for i in sample_labels]).squeeze(1) 165 | class_num = sample_attributes.shape[0] 166 | 167 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 168 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU)) #k*312 169 | 170 | 171 | sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_SIZE,1,1) 172 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 173 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 174 | 175 | #print(sample_features_ext) 176 | #print(batch_features_ext) 177 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 178 | # pdb.set_trace() 179 | relations = relation_network(relation_pairs).view(-1,class_num) 180 | #print(relations) 181 | 182 | # re-build batch_labels according to sample_labels 183 | sample_labels = np.array(sample_labels) 184 | re_batch_labels = [] 185 | for label in batch_labels.numpy(): 186 | index = np.argwhere(sample_labels==label) 187 | re_batch_labels.append(index[0][0]) 188 | re_batch_labels = torch.LongTensor(re_batch_labels) 189 | # pdb.set_trace() 190 | 191 | 192 | # loss 193 | mse = nn.MSELoss().cuda(GPU) 194 | one_hot_labels = Variable(torch.zeros(BATCH_SIZE, class_num).scatter_(1, re_batch_labels.view(-1,1), 1)).cuda(GPU) 195 | loss = mse(relations,one_hot_labels) 196 | # pdb.set_trace() 197 | 198 | # update 199 | attribute_network.zero_grad() 200 | relation_network.zero_grad() 201 | 202 | loss.backward() 203 | 204 | attribute_network_optim.step() 205 | relation_network_optim.step() 206 | 207 | if (episode+1)%100 == 0: 208 | print("episode:",episode+1,"loss",loss.data[0]) 209 | 210 | if (episode+1)%2000 == 0: 211 | # test 212 | print("Testing...") 213 | 214 | def compute_accuracy(test_features,test_label,test_id,test_attributes): 215 | 216 | test_data = TensorDataset(test_features,test_label) 217 | test_batch = 32 218 | test_loader = DataLoader(test_data,batch_size=test_batch,shuffle=False) 219 | total_rewards = 0 220 | # fetch attributes 221 | # pdb.set_trace() 222 | 223 | sample_labels = test_id 224 | sample_attributes = test_attributes 225 | class_num = sample_attributes.shape[0] 226 | test_size = test_features.shape[0] 227 | 228 | print("class num:",class_num) 229 | predict_labels_total = [] 230 | re_batch_labels_total = [] 231 | 232 | for batch_features,batch_labels in test_loader: 233 | 234 | batch_size = batch_labels.shape[0] 235 | 236 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 237 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU).float()) 238 | 239 | sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1) 240 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 241 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 242 | 243 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 244 | relations = relation_network(relation_pairs).view(-1,class_num) 245 | 246 | # re-build batch_labels according to sample_labels 247 | 248 | re_batch_labels = [] 249 | for label in batch_labels.numpy(): 250 | index = np.argwhere(sample_labels==label) 251 | re_batch_labels.append(index[0][0]) 252 | re_batch_labels = torch.LongTensor(re_batch_labels) 253 | # pdb.set_trace() 254 | 255 | 256 | _,predict_labels = torch.max(relations.data,1) 257 | predict_labels = predict_labels.cpu().numpy() 258 | re_batch_labels = re_batch_labels.cpu().numpy() 259 | 260 | predict_labels_total = np.append(predict_labels_total, predict_labels) 261 | re_batch_labels_total = np.append(re_batch_labels_total, re_batch_labels) 262 | 263 | # compute averaged per class accuracy 264 | predict_labels_total = np.array(predict_labels_total, dtype='int') 265 | re_batch_labels_total = np.array(re_batch_labels_total, dtype='int') 266 | unique_labels = np.unique(re_batch_labels_total) 267 | acc = 0 268 | for l in unique_labels: 269 | idx = np.nonzero(re_batch_labels_total == l)[0] 270 | acc += accuracy_score(re_batch_labels_total[idx], predict_labels_total[idx]) 271 | acc = acc / unique_labels.shape[0] 272 | return acc 273 | 274 | zsl_accuracy = compute_accuracy(test_features,test_label,test_id,test_attributes) 275 | gzsl_unseen_accuracy = compute_accuracy(test_features,test_label,np.arange(50),attributes) 276 | gzsl_seen_accuracy = compute_accuracy(test_seen_features,test_seen_label,np.arange(50),attributes) 277 | 278 | H = 2 * gzsl_seen_accuracy * gzsl_unseen_accuracy / (gzsl_unseen_accuracy + gzsl_seen_accuracy) 279 | 280 | print('zsl:', zsl_accuracy) 281 | print('gzsl: seen=%.4f, unseen=%.4f, h=%.4f' % (gzsl_seen_accuracy, gzsl_unseen_accuracy, H)) 282 | 283 | 284 | if zsl_accuracy > last_accuracy: 285 | 286 | # save networks 287 | torch.save(attribute_network.state_dict(),"./models/zsl_awa1_attribute_network_v33.pkl") 288 | torch.save(relation_network.state_dict(),"./models/zsl_awa1_relation_network_v33.pkl") 289 | 290 | print("save networks for episode:",episode) 291 | 292 | last_accuracy = zsl_accuracy 293 | 294 | 295 | 296 | if __name__ == '__main__': 297 | main() -------------------------------------------------------------------------------- /AwA2_RN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader,TensorDataset 7 | import numpy as np 8 | import scipy.io as sio 9 | import math 10 | import argparse 11 | import random 12 | import os 13 | from sklearn.metrics import accuracy_score 14 | import pdb 15 | 16 | # '/home/lz/Workspace/ZSL/data/Animals_with_Attributes2', 17 | 18 | parser = argparse.ArgumentParser(description="Zero Shot Learning") 19 | parser.add_argument("-b","--batch_size",type = int, default = 32) 20 | parser.add_argument("-e","--episode",type = int, default= 500000) 21 | parser.add_argument("-t","--test_episode", type = int, default = 1000) 22 | parser.add_argument("-l","--learning_rate", type = float, default = 1e-5) 23 | parser.add_argument("-g","--gpu",type=int, default=0) 24 | args = parser.parse_args() 25 | 26 | 27 | # Hyper Parameters 28 | 29 | BATCH_SIZE = args.batch_size 30 | EPISODE = args.episode 31 | TEST_EPISODE = args.test_episode 32 | LEARNING_RATE = args.learning_rate 33 | GPU = args.gpu 34 | 35 | class AttributeNetwork(nn.Module): 36 | """docstring for RelationNetwork""" 37 | def __init__(self,input_size,hidden_size,output_size): 38 | super(AttributeNetwork, self).__init__() 39 | self.fc1 = nn.Linear(input_size,hidden_size) 40 | self.fc2 = nn.Linear(hidden_size,output_size) 41 | 42 | def forward(self,x): 43 | 44 | x = F.relu(self.fc1(x)) 45 | x = F.relu(self.fc2(x)) 46 | 47 | return x 48 | 49 | class RelationNetwork(nn.Module): 50 | """docstring for RelationNetwork""" 51 | def __init__(self,input_size,hidden_size,): 52 | super(RelationNetwork, self).__init__() 53 | self.fc1 = nn.Linear(input_size,hidden_size) 54 | self.fc2 = nn.Linear(hidden_size,1) 55 | 56 | def forward(self,x): 57 | 58 | x = F.relu(self.fc1(x)) 59 | x = F.sigmoid(self.fc2(x)) 60 | return x 61 | 62 | 63 | 64 | def main(): 65 | # step 1: init dataset 66 | print("init dataset") 67 | 68 | dataroot = './data' 69 | dataset = 'AwA2_data' 70 | image_embedding = 'res101' 71 | class_embedding = 'att' 72 | 73 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat") 74 | feature = matcontent['features'].T 75 | label = matcontent['labels'].astype(int).squeeze() - 1 76 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding + "_splits.mat") 77 | # numpy array index starts from 0, matlab starts from 1 78 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 79 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 80 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 81 | 82 | attribute = matcontent['original_att'].T 83 | 84 | x = feature[trainval_loc] # train_features 85 | train_label = label[trainval_loc].astype(int) # train_label 86 | att = attribute[train_label] # train attributes 87 | 88 | x_test = feature[test_unseen_loc] # test_feature 89 | test_label = label[test_unseen_loc].astype(int) # test_label 90 | x_test_seen = feature[test_seen_loc] #test_seen_feature 91 | test_label_seen = label[test_seen_loc].astype(int) # test_seen_label 92 | test_id = np.unique(test_label) # test_id 93 | att_pro = attribute[test_id] # test_attribute 94 | 95 | 96 | # train set 97 | train_features=torch.from_numpy(x) 98 | print(train_features.shape) 99 | 100 | train_label=torch.from_numpy(train_label).unsqueeze(1) 101 | print(train_label.shape) 102 | 103 | # attributes 104 | all_attributes=np.array(attribute) 105 | print(all_attributes.shape) 106 | 107 | attributes = torch.from_numpy(attribute) 108 | # test set 109 | 110 | test_features=torch.from_numpy(x_test) 111 | print(test_features.shape) 112 | 113 | test_label=torch.from_numpy(test_label).unsqueeze(1) 114 | print(test_label.shape) 115 | 116 | testclasses_id = np.array(test_id) 117 | print(testclasses_id.shape) 118 | 119 | test_attributes = torch.from_numpy(att_pro).float() 120 | print(test_attributes.shape) 121 | 122 | 123 | test_seen_features = torch.from_numpy(x_test_seen) 124 | print(test_seen_features.shape) 125 | 126 | test_seen_label = torch.from_numpy(test_label_seen) 127 | 128 | 129 | 130 | train_data = TensorDataset(train_features,train_label) 131 | 132 | 133 | # init network 134 | print("init networks") 135 | attribute_network = AttributeNetwork(85,1024,2048) 136 | relation_network = RelationNetwork(4096,400) 137 | 138 | attribute_network.cuda(GPU) 139 | relation_network.cuda(GPU) 140 | 141 | attribute_network_optim = torch.optim.Adam(attribute_network.parameters(),lr=LEARNING_RATE,weight_decay=1e-5) 142 | attribute_network_scheduler = StepLR(attribute_network_optim,step_size=200000,gamma=0.5) 143 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 144 | relation_network_scheduler = StepLR(relation_network_optim,step_size=200000,gamma=0.5) 145 | 146 | # if os.path.exists("./models/zsl_awa2_attribute_network_v33.pkl"): 147 | # attribute_network.load_state_dict(torch.load("./models/zsl_awa2_attribute_network_v33.pkl")) 148 | # print("load attribute network success") 149 | # if os.path.exists("./models/zsl_awa2_relation_network_v33.pkl"): 150 | # relation_network.load_state_dict(torch.load("./models/zsl_awa2_relation_network_v33.pkl")) 151 | # print("load relation network success") 152 | 153 | print("training...") 154 | last_accuracy = 0.0 155 | 156 | for episode in range(EPISODE): 157 | attribute_network_scheduler.step(episode) 158 | relation_network_scheduler.step(episode) 159 | 160 | train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True) 161 | 162 | batch_features,batch_labels = train_loader.__iter__().next() 163 | 164 | sample_labels = [] 165 | for label in batch_labels.numpy(): 166 | if label not in sample_labels: 167 | sample_labels.append(label) 168 | # pdb.set_trace() 169 | 170 | sample_attributes = torch.Tensor([all_attributes[i] for i in sample_labels]).squeeze(1) 171 | class_num = sample_attributes.shape[0] 172 | 173 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 174 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU)) #k*312 175 | 176 | 177 | sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_SIZE,1,1) 178 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 179 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 180 | 181 | #print(sample_features_ext) 182 | #print(batch_features_ext) 183 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 184 | # pdb.set_trace() 185 | relations = relation_network(relation_pairs).view(-1,class_num) 186 | #print(relations) 187 | 188 | # re-build batch_labels according to sample_labels 189 | sample_labels = np.array(sample_labels) 190 | re_batch_labels = [] 191 | for label in batch_labels.numpy(): 192 | index = np.argwhere(sample_labels==label) 193 | re_batch_labels.append(index[0][0]) 194 | re_batch_labels = torch.LongTensor(re_batch_labels) 195 | # pdb.set_trace() 196 | 197 | 198 | # loss 199 | mse = nn.MSELoss().cuda(GPU) 200 | one_hot_labels = Variable(torch.zeros(BATCH_SIZE, class_num).scatter_(1, re_batch_labels.view(-1,1), 1)).cuda(GPU) 201 | loss = mse(relations,one_hot_labels) 202 | # pdb.set_trace() 203 | 204 | # update 205 | attribute_network.zero_grad() 206 | relation_network.zero_grad() 207 | 208 | loss.backward() 209 | 210 | attribute_network_optim.step() 211 | relation_network_optim.step() 212 | 213 | if (episode+1)%100 == 0: 214 | print("episode:",episode+1,"loss",loss.data[0]) 215 | 216 | if (episode+1)%2000 == 0: 217 | # test 218 | print("Testing...") 219 | 220 | def compute_accuracy(test_features,test_label,test_id,test_attributes): 221 | 222 | test_data = TensorDataset(test_features,test_label) 223 | test_batch = 32 224 | test_loader = DataLoader(test_data,batch_size=test_batch,shuffle=False) 225 | total_rewards = 0 226 | # fetch attributes 227 | # pdb.set_trace() 228 | 229 | sample_labels = test_id 230 | sample_attributes = test_attributes 231 | class_num = sample_attributes.shape[0] 232 | test_size = test_features.shape[0] 233 | 234 | print("class num:",class_num) 235 | predict_labels_total = [] 236 | re_batch_labels_total = [] 237 | 238 | for batch_features,batch_labels in test_loader: 239 | 240 | batch_size = batch_labels.shape[0] 241 | 242 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 243 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU).float()) 244 | 245 | sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1) 246 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 247 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 248 | 249 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 250 | relations = relation_network(relation_pairs).view(-1,class_num) 251 | 252 | # re-build batch_labels according to sample_labels 253 | 254 | re_batch_labels = [] 255 | for label in batch_labels.numpy(): 256 | index = np.argwhere(sample_labels==label) 257 | re_batch_labels.append(index[0][0]) 258 | re_batch_labels = torch.LongTensor(re_batch_labels) 259 | # pdb.set_trace() 260 | 261 | 262 | _,predict_labels = torch.max(relations.data,1) 263 | predict_labels = predict_labels.cpu().numpy() 264 | re_batch_labels = re_batch_labels.cpu().numpy() 265 | 266 | predict_labels_total = np.append(predict_labels_total, predict_labels) 267 | re_batch_labels_total = np.append(re_batch_labels_total, re_batch_labels) 268 | 269 | # compute averaged per class accuracy 270 | predict_labels_total = np.array(predict_labels_total, dtype='int') 271 | re_batch_labels_total = np.array(re_batch_labels_total, dtype='int') 272 | unique_labels = np.unique(re_batch_labels_total) 273 | acc = 0 274 | for l in unique_labels: 275 | idx = np.nonzero(re_batch_labels_total == l)[0] 276 | acc += accuracy_score(re_batch_labels_total[idx], predict_labels_total[idx]) 277 | acc = acc / unique_labels.shape[0] 278 | return acc 279 | 280 | zsl_accuracy = compute_accuracy(test_features,test_label,test_id,test_attributes) 281 | gzsl_unseen_accuracy = compute_accuracy(test_features,test_label,np.arange(50),attributes) 282 | gzsl_seen_accuracy = compute_accuracy(test_seen_features,test_seen_label,np.arange(50),attributes) 283 | 284 | H = 2 * gzsl_seen_accuracy * gzsl_unseen_accuracy / (gzsl_unseen_accuracy + gzsl_seen_accuracy) 285 | 286 | print('zsl:', zsl_accuracy) 287 | print('gzsl: seen=%.4f, unseen=%.4f, h=%.4f' % (gzsl_seen_accuracy, gzsl_unseen_accuracy, H)) 288 | 289 | 290 | if zsl_accuracy > last_accuracy: 291 | 292 | # save networks 293 | torch.save(attribute_network.state_dict(),"./models/zsl_awa2_attribute_network_v33.pkl") 294 | torch.save(relation_network.state_dict(),"./models/zsl_awa2_relation_network_v33.pkl") 295 | 296 | print("save networks for episode:",episode) 297 | 298 | last_accuracy = zsl_accuracy 299 | 300 | 301 | 302 | if __name__ == '__main__': 303 | main() -------------------------------------------------------------------------------- /CUB_RN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader,TensorDataset 7 | import numpy as np 8 | import scipy.io as sio 9 | import math 10 | import argparse 11 | import random 12 | import os 13 | from sklearn.metrics import accuracy_score 14 | 15 | 16 | parser = argparse.ArgumentParser(description="Zero Shot Learning") 17 | parser.add_argument("-b","--batch_size",type = int, default = 32) 18 | parser.add_argument("-e","--episode",type = int, default= 200000) 19 | parser.add_argument("-t","--test_episode", type = int, default = 1000) 20 | parser.add_argument("-l","--learning_rate", type = float, default = 1e-5) 21 | parser.add_argument("-g","--gpu",type=int, default=2) 22 | args = parser.parse_args() 23 | 24 | 25 | # Hyper Parameters 26 | 27 | BATCH_SIZE = args.batch_size 28 | EPISODE = args.episode 29 | TEST_EPISODE = args.test_episode 30 | LEARNING_RATE = args.learning_rate 31 | GPU = args.gpu 32 | 33 | class AttributeNetwork(nn.Module): 34 | """docstring for RelationNetwork""" 35 | def __init__(self,input_size,hidden_size,output_size): 36 | super(AttributeNetwork, self).__init__() 37 | self.fc1 = nn.Linear(input_size,hidden_size) 38 | self.fc2 = nn.Linear(hidden_size,output_size) 39 | 40 | def forward(self,x): 41 | 42 | x = F.relu(self.fc1(x)) 43 | x = F.relu(self.fc2(x)) 44 | 45 | return x 46 | 47 | class RelationNetwork(nn.Module): 48 | """docstring for RelationNetwork""" 49 | def __init__(self,input_size,hidden_size,): 50 | super(RelationNetwork, self).__init__() 51 | self.fc1 = nn.Linear(input_size,hidden_size) 52 | self.fc2 = nn.Linear(hidden_size,1) 53 | 54 | def forward(self,x): 55 | 56 | x = F.relu(self.fc1(x)) 57 | x = F.sigmoid(self.fc2(x)) 58 | return x 59 | 60 | 61 | 62 | def main(): 63 | # step 1: init dataset 64 | print("init dataset") 65 | 66 | dataroot = './data' 67 | dataset = 'CUB1_data' 68 | image_embedding = 'res101' 69 | class_embedding = 'original_att_splits' 70 | 71 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + image_embedding + ".mat") 72 | feature = matcontent['features'].T 73 | label = matcontent['labels'].astype(int).squeeze() - 1 74 | matcontent = sio.loadmat(dataroot + "/" + dataset + "/" + class_embedding + ".mat") 75 | # numpy array index starts from 0, matlab starts from 1 76 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 77 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 78 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 79 | 80 | attribute = matcontent['att'].T 81 | 82 | x = feature[trainval_loc] # train_features 83 | train_label = label[trainval_loc].astype(int) # train_label 84 | att = attribute[train_label] # train attributes 85 | 86 | x_test = feature[test_unseen_loc] # test_feature 87 | test_label = label[test_unseen_loc].astype(int) # test_label 88 | x_test_seen = feature[test_seen_loc] #test_seen_feature 89 | test_label_seen = label[test_seen_loc].astype(int) # test_seen_label 90 | test_id = np.unique(test_label) # test_id 91 | att_pro = attribute[test_id] # test_attribute 92 | 93 | 94 | # train set 95 | train_features=torch.from_numpy(x) 96 | print(train_features.shape) 97 | 98 | train_label=torch.from_numpy(train_label).unsqueeze(1) 99 | print(train_label.shape) 100 | 101 | # attributes 102 | all_attributes=np.array(attribute) 103 | print(all_attributes.shape) 104 | 105 | attributes = torch.from_numpy(attribute) 106 | # test set 107 | 108 | test_features=torch.from_numpy(x_test) 109 | print(test_features.shape) 110 | 111 | test_label=torch.from_numpy(test_label).unsqueeze(1) 112 | print(test_label.shape) 113 | 114 | testclasses_id = np.array(test_id) 115 | print(testclasses_id.shape) 116 | 117 | test_attributes = torch.from_numpy(att_pro).float() 118 | print(test_attributes.shape) 119 | 120 | 121 | test_seen_features = torch.from_numpy(x_test_seen) 122 | print(test_seen_features.shape) 123 | 124 | test_seen_label = torch.from_numpy(test_label_seen) 125 | 126 | 127 | 128 | train_data = TensorDataset(train_features,train_label) 129 | 130 | 131 | # init network 132 | print("init networks") 133 | attribute_network = AttributeNetwork(312,1200,2048) 134 | relation_network = RelationNetwork(4096,1200) 135 | 136 | attribute_network.cuda(GPU) 137 | relation_network.cuda(GPU) 138 | 139 | attribute_network_optim = torch.optim.Adam(attribute_network.parameters(),lr=LEARNING_RATE,weight_decay=1e-5) 140 | attribute_network_scheduler = StepLR(attribute_network_optim,step_size=30000,gamma=0.5) 141 | relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) 142 | relation_network_scheduler = StepLR(relation_network_optim,step_size=30000,gamma=0.5) 143 | 144 | # if os.path.exists("./models/zsl_cub2_attribute_network_v35.pkl"): 145 | # attribute_network.load_state_dict(torch.load("./models/zsl_cub_attribute_network_v35.pkl")) 146 | # print("load attribute network success") 147 | # if os.path.exists("./models/zsl_cub2_relation_network_v35.pkl"): 148 | # relation_network.load_state_dict(torch.load("./models/zsl_cub_relation_network_v35.pkl")) 149 | # print("load relation network success") 150 | 151 | print("training...") 152 | last_accuracy = 0.0 153 | 154 | for episode in range(EPISODE): 155 | attribute_network_scheduler.step(episode) 156 | relation_network_scheduler.step(episode) 157 | 158 | train_loader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True) 159 | 160 | batch_features,batch_labels = train_loader.__iter__().next() 161 | 162 | sample_labels = [] 163 | for label in batch_labels.numpy(): 164 | if label not in sample_labels: 165 | sample_labels.append(label) 166 | 167 | sample_attributes = torch.Tensor([all_attributes[i] for i in sample_labels]).squeeze(1) 168 | class_num = sample_attributes.shape[0] 169 | 170 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 171 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU)) #k*312 172 | 173 | sample_features_ext = sample_features.unsqueeze(0).repeat(BATCH_SIZE,1,1) 174 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 175 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 176 | 177 | #print(sample_features_ext) 178 | #print(batch_features_ext) 179 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 180 | relations = relation_network(relation_pairs).view(-1,class_num) 181 | #print(relations) 182 | 183 | # re-build batch_labels according to sample_labels 184 | sample_labels = np.array(sample_labels) 185 | re_batch_labels = [] 186 | for label in batch_labels.numpy(): 187 | index = np.argwhere(sample_labels==label) 188 | re_batch_labels.append(index[0][0]) 189 | re_batch_labels = torch.LongTensor(re_batch_labels) 190 | 191 | # loss 192 | mse = nn.MSELoss().cuda(GPU) 193 | one_hot_labels = Variable(torch.zeros(BATCH_SIZE, class_num).scatter_(1, re_batch_labels.view(-1,1), 1)).cuda(GPU) 194 | loss = mse(relations,one_hot_labels) 195 | 196 | # update 197 | attribute_network.zero_grad() 198 | relation_network.zero_grad() 199 | 200 | loss.backward() 201 | 202 | attribute_network_optim.step() 203 | relation_network_optim.step() 204 | 205 | if (episode+1)%100 == 0: 206 | print("episode:",episode+1,"loss",loss.data[0]) 207 | 208 | if (episode+1)%2000 == 0: 209 | # test 210 | print("Testing...") 211 | 212 | def compute_accuracy(test_features,test_label,test_id,test_attributes): 213 | 214 | test_data = TensorDataset(test_features,test_label) 215 | test_batch = 32 216 | test_loader = DataLoader(test_data,batch_size=test_batch,shuffle=False) 217 | total_rewards = 0 218 | # fetch attributes 219 | sample_labels = test_id 220 | sample_attributes = test_attributes 221 | class_num = sample_attributes.shape[0] 222 | test_size = test_features.shape[0] 223 | 224 | print("class num:",class_num) 225 | predict_labels_total = [] 226 | re_batch_labels_total = [] 227 | 228 | for batch_features,batch_labels in test_loader: 229 | 230 | batch_size = batch_labels.shape[0] 231 | 232 | batch_features = Variable(batch_features).cuda(GPU).float() # 32*1024 233 | sample_features = attribute_network(Variable(sample_attributes).cuda(GPU).float()) 234 | 235 | sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1) 236 | batch_features_ext = batch_features.unsqueeze(0).repeat(class_num,1,1) 237 | batch_features_ext = torch.transpose(batch_features_ext,0,1) 238 | 239 | relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,4096) 240 | relations = relation_network(relation_pairs).view(-1,class_num) 241 | 242 | # re-build batch_labels according to sample_labels 243 | 244 | re_batch_labels = [] 245 | for label in batch_labels.numpy(): 246 | index = np.argwhere(sample_labels==label) 247 | re_batch_labels.append(index[0][0]) 248 | re_batch_labels = torch.LongTensor(re_batch_labels) 249 | 250 | _,predict_labels = torch.max(relations.data,1) 251 | predict_labels = predict_labels.cpu().numpy() 252 | re_batch_labels = re_batch_labels.cpu().numpy() 253 | 254 | predict_labels_total = np.append(predict_labels_total, predict_labels) 255 | re_batch_labels_total = np.append(re_batch_labels_total, re_batch_labels) 256 | 257 | # compute averaged per class accuracy 258 | predict_labels_total = np.array(predict_labels_total, dtype='int') 259 | re_batch_labels_total = np.array(re_batch_labels_total, dtype='int') 260 | unique_labels = np.unique(re_batch_labels_total) 261 | acc = 0 262 | for l in unique_labels: 263 | idx = np.nonzero(re_batch_labels_total == l)[0] 264 | acc += accuracy_score(re_batch_labels_total[idx], predict_labels_total[idx]) 265 | acc = acc / unique_labels.shape[0] 266 | return acc 267 | 268 | zsl_accuracy = compute_accuracy(test_features,test_label,test_id,test_attributes) 269 | gzsl_unseen_accuracy = compute_accuracy(test_features,test_label,np.arange(200),attributes) 270 | gzsl_seen_accuracy = compute_accuracy(test_seen_features,test_seen_label,np.arange(200),attributes) 271 | 272 | H = 2 * gzsl_seen_accuracy * gzsl_unseen_accuracy / (gzsl_unseen_accuracy + gzsl_seen_accuracy) 273 | 274 | print('zsl:', zsl_accuracy) 275 | print('gzsl: seen=%.4f, unseen=%.4f, h=%.4f' % (gzsl_seen_accuracy, gzsl_unseen_accuracy, H)) 276 | 277 | if zsl_accuracy > last_accuracy: 278 | 279 | 280 | # save networks 281 | torch.save(attribute_network.state_dict(),"./models/zsl_cub_attribute_network_v35.pkl") 282 | torch.save(relation_network.state_dict(),"./models/zsl_cub_relation_network_v35.pkl") 283 | 284 | print("save networks for episode:",episode) 285 | 286 | last_accuracy = zsl_accuracy 287 | 288 | 289 | 290 | 291 | if __name__ == '__main__': 292 | main() 293 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LearningToCompare_ZSL 2 | 3 | PyTorch code for CVPR 2018 paper: [Learning to Compare: Relation Network for Few-Shot Learning](https://arxiv.org/abs/1711.06025) (Zero-Shot Learning part) 4 | 5 | For Few-Shot Learning part, please visit [here](https://github.com/songrotek/LearningToCompare_FSL). 6 | 7 | # Requirements 8 | 9 | Python 2.7 10 | 11 | Pytorch 0.3 12 | 13 | # Data 14 | Download data from [here](http://www.robots.ox.ac.uk/~lz/DEM_cvpr2017/data.zip) and unzip it `unzip data.zip`. 15 | 16 | # Run 17 | ZSL and GZSL performance evaluated under GBU setting [1]: ResNet feature, GBU split, averaged per class accuracy. 18 | 19 | `AwA1_RN.py` will give you ZSL and GZSL performance on AwA1 with attribute under GBU setting [1]. 20 | 21 | `AwA2_RN.py` will give you ZSL and GZSL performance on AwA2 with attribute under GBU setting [1]. 22 | 23 | `CUB_RN.py` will give you ZSL and GZSL performance on CUB with attribute under GBU setting [1]. 24 | 25 | 26 | | Model | AwA1 T1 | u | s | H | CUB T1 | u | s | H | 27 | |------------|---------|---------|---------|---------|---------|---------|---------|---------| 28 | | DAP [2] | 44.1 | 0.0 | 88.7 | 0.0 | 40.0 | 1.7 | 67.9 | 3.3 | 29 | | CONSE [3] | 45.6 | 0.4 | 88.6 | 0.8 | 34.3 | 1.6 | **72.2** | 3.1 | 30 | | SSE [4] | 60.1 | 7.0 | 80.5 | 12.9 | 43.9 | 8.5 | 46.9 | 14.4 | 31 | | DEVISE [5] | 54.2 | 13.4 | 68.7 | 22.4 | 52.0 | 23.8 | 53.0 | 32.8 | 32 | | SJE [6] | 65.6 | 11.3 | 74.6 | 19.6 | 53.9 | 23.5 | 59.2 | 33.6 | 33 | | LATEM [7] | 55.1 | 7.3 | 71.7 | 13.3 | 49.3 | 15.2 | 57.3 | 24.0 | 34 | | ESZSL [8] | 58.2 | 6.6 | 75.6 | 12.1 | 53.9 | 12.6 | 63.8 | 21.0 | 35 | | ALE [9] | 59.9 | 16.8 | 76.1 | 27.5 | 54.9 | 23.7 | 62.8 | 34.4 | 36 | | SYNC [10] | 54.0 | 8.9 | 87.3 | 16.2 | 55.6 | 11.5 | 70.9 | 19.8 | 37 | | SAE [11] | 53.0 | 1.8 | 77.1 | 3.5 | 33.3 | 7.8 | 54.0 | 13.6 | 38 | | [DEM](https://github.com/lzrobots/DeepEmbeddingModel_ZSL) [12] | **68.4** | **32.8** | 84.7 | **47.3** | 51.7 | 19.6 | 57.9 | 29.2 | 39 | | **RN (OURS)** |68.2 | 31.4 | **91.3** | 46.7 | **55.6** | **38.1** | 61.4 | **47.0** | 40 | 41 | 42 | | Model | AwA2 T1 | u | s | H | 43 | |------------|---------|---------|---------|---------| 44 | | DAP [2] | 46.1 | 0.0 | 84.7 | 0.0 | 45 | | CONSE [3] | 44.5 | 0.5 | 90.6| 1.0 | 46 | | SSE [4] | 61.0 | 8.1 | 82.5 | 14.8 | 47 | | DEVISE [5] | 59.7 | 17.1 | 74.7 | 27.8 | 48 | | SJE [6] | 61.9 | 8.0 | 73.9 | 14.4 | 49 | | LATEM [7] | 55.8 | 11.5 | 77.3 | 20.0 | 50 | | ESZSL [8] | 58.6 | 5.9 | 77.8 | 11.0 | 51 | | ALE [9] | 62.5 | 14.0 | 81.8 | 23.9 | 52 | | SYNC [10] | 46.6 | 10.0 | 90.5 | 18.0 | 53 | | SAE [11] | 54.1 | 1.1 | 82.2 | 2.2 | 54 | | [DEM](https://github.com/lzrobots/DeepEmbeddingModel_ZSL) [12] | **67.1** | **30.5** | 86.4 | 45.1| 55 | | **RN (OURS)** |64.2 | 30.0 | **93.4** | **45.3** | 56 | 57 | 58 | 59 | 60 | 61 | 62 | ## Citing 63 | 64 | If you use this code in your research, please use the following BibTeX entry. 65 | 66 | ``` 67 | @inproceedings{sung2018learning, 68 | title={Learning to Compare: Relation Network for Few-Shot Learning}, 69 | author={Sung, Flood and Yang, Yongxin and Zhang, Li and Xiang, Tao and Torr, Philip HS and Hospedales, Timothy M}, 70 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 71 | year={2018} 72 | } 73 | ``` 74 | 75 | ## References 76 | 77 | - [1] [Zero-Shot Learning - A Comprehensive Evaluation of the Good, the Bad and the Ugly](https://arxiv.org/abs/1707.00600). 78 | Yongqin Xian, Christoph H. Lampert, Bernt Schiele, Zeynep Akata. 79 | arXiv, 2017. 80 | - [2] [Attribute-Based Classification forZero-Shot Visual Object Categorization](https://cvml.ist.ac.at/papers/lampert-pami2013.pdf). 81 | Christoph H. Lampert, Hannes Nickisch and Stefan Harmeling. 82 | PAMI, 2014. 83 | - [3] [Zero-Shot Learning by Convex Combination of Semantic Embeddings](https://arxiv.org/abs/1312.5650). 84 | Mohammad Norouzi, Tomas Mikolov, Samy Bengio, Yoram Singer, Jonathon Shlens, Andrea Frome, Greg S. Corrado, Jeffrey Dean. 85 | arXiv, 2013. 86 | - [4] [Zero-Shot Learning via Semantic Similarity Embedding](https://arxiv.org/abs/1509.04767). 87 | Ziming Zhang, Venkatesh Saligrama. 88 | ICCV, 2015. 89 | - [5] [DeViSE: A Deep Visual-Semantic Embedding Model](http://papers.nips.cc/paper/5204-devise-a-deep-visual-semantic-embedding-model.pdf). 90 | Andrea Frome*, Greg S. Corrado*, Jonathon Shlens*, Samy BengioJeffrey Dean, Marc’Aurelio Ranzato, Tomas Mikolov. 91 | NIPS, 2013. 92 | - [6] [Evaluation of Output Embeddings for Fine-Grained Image Classification](https://arxiv.org/abs/1409.8403). 93 | Zeynep Akata, Scott Reed, Daniel Walter, Honglak Lee, Bernt Schiele. 94 | CVPR, 2015. 95 | - [7] [Latent Embeddings for Zero-shot Classification](https://arxiv.org/abs/1603.08895). 96 | Yongqin Xian, Zeynep Akata, Gaurav Sharma, Quynh Nguyen, Matthias Hein, Bernt Schiele 97 | CVPR, 2016. 98 | - [8] [An embarrassingly simple approach to zero-shot learning](http://proceedings.mlr.press/v37/romera-paredes15.pdf). 99 | Bernardino Romera-Paredes, Philip H. S. Torr. 100 | ICML, 2015. 101 | - [9] [Label-Embedding for Image Classification](https://arxiv.org/abs/1503.08677). 102 | Zeynep Akata, Florent Perronnin, Zaid Harchaoui, Cordelia Schmid. 103 | PAMI, 2016. 104 | - [10] [Synthesized Classifiers for Zero-Shot Learning](https://arxiv.org/abs/1603.00550). 105 | Soravit Changpinyo, Wei-Lun Chao, Boqing Gong, Fei Sha. 106 | CVPR, 2016. 107 | - [11] [Semantic Autoencoder for Zero-Shot Learning](https://arxiv.org/abs/1704.08345). 108 | Elyor Kodirov, Tao Xiang, Shaogang Gong. 109 | CVPR, 2017. 110 | - [12] [Learning a Deep Embedding Model for Zero-Shot Learning](https://arxiv.org/abs/1611.05088). 111 | Li Zhang, Tao Xiang, Shaogang Gong. 112 | CVPR, 2017. 113 | --------------------------------------------------------------------------------