├── .gitignore ├── README.md ├── configs.py ├── main.py ├── resnet.py ├── test.py ├── tmp └── ms-coco-fs288 │ ├── main.py │ ├── ms_coco.py │ ├── ratio.mat │ ├── resnet.py │ └── test.py ├── utils.py └── wider.py /.gitignore: -------------------------------------------------------------------------------- 1 | flip 2 | ignored 3 | __pycache__ 4 | preds 5 | _pycache_/* 6 | ignored/* 7 | preds/* 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Attention Consistency 2 | This repository is for the following paper: 3 | ``` 4 | @InProceedings{Guo_2019_CVPR, 5 | author = {Guo, Hao and Zheng, Kang and Fan, Xiaochuan and Yu, Hongkai and Wang, Song}, 6 | title = {Visual Attention Consistency Under Image Transforms for Multi-Label Image Classification}, 7 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 8 | month = {June}, 9 | year = {2019} 10 | } 11 | ``` 12 | 13 | 14 | ## Note: 15 | ***This is just a preliminary version for early access. I will complete it as soon as I can.*** 16 | 17 | ### WIDER Attribute Dataset 18 | To run this code, you need to specify the `WIDER_DATA_DIR` (WIDER dataset "Image"), `WIDER_ANNO_DIR` (WIDER dataset annotations) in "configs.py"and the argument of `model_dir` (path to save checkpoints). Then, run the command (with PyTorch installed): 19 | `python main.py`. 20 | 21 | Note that you may need to use the specific PyTorch version: 0.3.1. 22 | 23 | ### PA-100K Dataset 24 | (To be integrated) 25 | 26 | ### MS-COCO 27 | (To be integrated) 28 | 29 | *For those who want to test the proposed method on MS-COCO dataset, the source codes which I used for experiments are temporarily uploaded in the `./tmp/` folder. I will do code cleanup later.* 30 | 31 | **Select the checkpoint that produces the best mAP**. 32 | 33 | You can also evaluate the predictions with code at: https://github.com/zhufengx/SRN_multilabel, which would produce a slightly better (~0.1\%) performance. 34 | 35 | ## Datasets 36 | 37 | 1. WIDER Attribute Dataset: http://mmlab.ie.cuhk.edu.hk/projects/WIDERAttribute.html 38 | 2. PA-100K: (will be supported later) 39 | 3. MS-COCO: (will be supported later) 40 | 41 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | 2 | # Dataset configurations 3 | 4 | datasets = ["wider", "pa-100k", "ms-coco"] 5 | 6 | 7 | WIDER_DATA_DIR = "/HGUO/WIDER_DATASET/Image" 8 | WIDER_ANNO_DIR = "/HGUO/WIDER_DATASET/wider_attribute_annotation" 9 | 10 | PA100K_DATA_dir = "/data/hguo/Datasets/PA-100K/release_data" 11 | PA100K_ANNO_FILE = "/data/hguo/Datasets/PA-100K/annotation/annotation.mat" 12 | 13 | MSCOCO_DATA_DIR = "/data/hguo/Datasets/MS-COCO" 14 | 15 | # pre-calculated weights to balance positive and negative samples of each label 16 | # as defined in Li et al. ACPR'15 17 | # WIDER dataset 18 | wider_pos_ratio = [0.5669, 0.2244, 0.0502, 0.2260, 0.2191, 0.4647, 0.0699, 0.1542, \ 19 | 0.0816, 0.3621, 0.1005, 0.0330, 0.2682, 0.0543] 20 | 21 | # PA-100K dataset 22 | pa100k_pos_ratio = [0.460444, 0.013456, 0.924378, 0.062167, 0.352667, 0.294622, \ 23 | 0.352711, 0.043544, 0.179978, 0.185000, 0.192733, 0.160100, 0.009522, \ 24 | 0.583400, 0.416600, 0.049478, 0.151044, 0.107756, 0.041911, 0.004722, \ 25 | 0.016889, 0.032411, 0.711711, 0.173444, 0.114844, 0.006000] 26 | 27 | # MS-COCO dataset 28 | 29 | def get_configs(dataset): 30 | opts = {} 31 | if not dataset in datasets: 32 | raise Exception("Not supported dataset!") 33 | else: 34 | if dataset == "wider": 35 | opts["dataset"] = "WIDER" 36 | opts["num_labels"] = 14 37 | opts["data_dir"] = WIDER_DATA_DIR 38 | opts["anno_dir"] = WIDER_ANNO_DIR 39 | opts["pos_ratio"] = wider_pos_ratio 40 | else: 41 | if dataset == "pa-100k": 42 | # will be added later 43 | pass 44 | else: 45 | pass 46 | return opts -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import torchvision.models as models 10 | 11 | from test import test 12 | from configs import get_configs 13 | from utils import get_dataset, adjust_learning_rate, SigmoidCrossEntropyLoss, \ 14 | generate_flip_grid 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | 19 | import sys 20 | import argparse 21 | import math 22 | import time 23 | import os 24 | 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser(description = 'CNN Attention Consistency') 28 | parser.add_argument("--dataset", default="wider", type=str, 29 | help="select a dataset to train models") 30 | parser.add_argument("--arch", default="resnet50", type=str, 31 | help="ResNet architecture") 32 | 33 | parser.add_argument('--train_batch_size', default = 16, type = int, 34 | help = 'default training batch size') 35 | parser.add_argument('--train_workers', default = 4, type = int, 36 | help = '# of workers used to load training samples') 37 | parser.add_argument('--test_batch_size', default = 8, type = int, 38 | help = 'default test batch size') 39 | parser.add_argument('--test_workers', default = 4, type = int, 40 | help = '# of workers used to load testing samples') 41 | 42 | parser.add_argument('--learning_rate', default = 0.001, type = float, 43 | help = 'base learning rate') 44 | parser.add_argument('--momentum', default = 0.9, type = float, 45 | help = "set the momentum") 46 | parser.add_argument('--weight_decay', default = 0.0005, type = float, 47 | help = 'set the weight_decay') 48 | parser.add_argument('--stepsize', default = 3, type = int, 49 | help = 'lr decay each # of epoches') 50 | parser.add_argument('--decay', default=0.5, type=float, 51 | help = 'update learning rate by a factor') 52 | 53 | parser.add_argument('--model_dir', 54 | default = '/Storage/models/tmp', 55 | type = str, 56 | help = 'path to save checkpoints') 57 | parser.add_argument('--model_prefix', 58 | default = 'model', 59 | type = str, 60 | help = 'model file name starts with') 61 | 62 | # optimizer 63 | parser.add_argument('--optimizer', 64 | default = 'SGD', 65 | type = str, 66 | help = 'Select an optimizer: TBD') 67 | 68 | # general parameters 69 | parser.add_argument('--epoch_max', default = 12, type = int, 70 | help = 'max # of epcoh') 71 | parser.add_argument('--display', default = 200, type = int, 72 | help = 'display') 73 | parser.add_argument('--snapshot', default = 1, type = int, 74 | help = 'snapshot') 75 | parser.add_argument('--start_epoch', default = 0, type = int, 76 | help = 'resume training from specified epoch') 77 | parser.add_argument('--resume', default = '', type = str, 78 | help = 'resume training from specified model state') 79 | 80 | parser.add_argument('--test', default = True, type = bool, 81 | help = 'conduct testing after each checkpoint being saved') 82 | 83 | return parser 84 | 85 | 86 | def main(): 87 | parser = get_parser() 88 | print(parser) 89 | args = parser.parse_args() 90 | print(args) 91 | 92 | # load data 93 | opts = get_configs(args.dataset) 94 | print(opts) 95 | pos_ratio = torch.FloatTensor(opts["pos_ratio"]) 96 | w_p = (1 - pos_ratio).exp().cuda() 97 | w_n = pos_ratio.exp().cuda() 98 | 99 | trainset, testset = get_dataset(opts) 100 | 101 | train_loader = torch.utils.data.DataLoader(trainset, 102 | batch_size = args.train_batch_size, 103 | shuffle = True, 104 | num_workers = args.train_workers) 105 | test_loader = torch.utils.data.DataLoader(testset, 106 | batch_size = args.test_batch_size, 107 | shuffle = False, 108 | num_workers = args.test_workers) 109 | 110 | 111 | # path to save models 112 | if not os.path.isdir(args.model_dir): 113 | print("Make directory: " + args.model_dir) 114 | os.makedirs(args.model_dir) 115 | 116 | # prefix of saved checkpoint 117 | model_prefix = args.model_dir + '/' + args.model_prefix 118 | 119 | 120 | # define the model: use ResNet50 as an example 121 | if args.arch == "resnet50": 122 | from resnet import resnet50 123 | model = resnet50(pretrained=True, num_labels=opts["num_labels"]) 124 | model_prefix = model_prefix + "_resnet50" 125 | elif args.arch == "resnet101": 126 | from resnet import resnet101 127 | model = resnet101(pretrained=True, num_labels=opts["num_labels"]) 128 | model_prefix = model_prefix + "_resnet101" 129 | else: 130 | raise NotImplementedError("To be implemented!") 131 | 132 | if args.start_epoch != 0: 133 | resume_model = torch.load(args.resume) 134 | resume_dict = resume_model.state_dict() 135 | model_dict = model.state_dict() 136 | resume_dict = {k:v for k,v in resume_dict.items() if k in model_dict} 137 | model_dict.update(resume_dict) 138 | model.load_state_dict(model_dict) 139 | 140 | # print(model) 141 | model.cuda() 142 | 143 | if args.optimizer == 'Adam': 144 | optimizer = optim.Adam( 145 | model.parameters(), 146 | lr = args.learning_rate 147 | ) 148 | elif args.optimizer == 'SGD': 149 | optimizer = optim.SGD( 150 | model.parameters(), 151 | lr = args.learning_rate, 152 | momentum = args.momentum, 153 | weight_decay = args.weight_decay 154 | ) 155 | else: 156 | raise NotImplementedError("Not supported yet!") 157 | 158 | # training the network 159 | model.train() 160 | 161 | # attention map size 162 | w1 = 7 163 | h1 = 7 164 | grid_l = generate_flip_grid(w1, h1) 165 | 166 | w2 = 6 167 | h2 = 6 168 | grid_s = generate_flip_grid(w2, h2) 169 | 170 | # least common multiple 171 | lcm = w1 * w2 172 | 173 | 174 | criterion = SigmoidCrossEntropyLoss 175 | criterion_mse = nn.MSELoss(size_average = True) 176 | for epoch in range(args.start_epoch, args.epoch_max): 177 | epoch_start = time.clock() 178 | if not args.stepsize == 0: 179 | adjust_learning_rate(optimizer, epoch, args) 180 | for step, batch_data in enumerate(train_loader): 181 | batch_images_lo = batch_data[0] 182 | batch_images_lf = batch_data[1] 183 | batch_images_so = batch_data[2] 184 | batch_images_sf = batch_data[3] 185 | batch_labels = batch_data[4] 186 | 187 | batch_labels[batch_labels == -1] = 0 188 | 189 | batch_images_l = torch.cat((batch_images_lo, batch_images_lf)) 190 | batch_images_s = torch.cat((batch_images_so, batch_images_sf)) 191 | batch_labels = torch.cat((batch_labels, batch_labels, batch_labels, batch_labels)) 192 | 193 | batch_images_l = batch_images_l.cuda() 194 | batch_images_s = batch_images_s.cuda() 195 | batch_labels = batch_labels.cuda() 196 | 197 | inputs_l = Variable(batch_images_l) 198 | inputs_s = Variable(batch_images_s) 199 | labels = Variable(batch_labels) 200 | 201 | output_l, hm_l = model(inputs_l) 202 | output_s, hm_s = model(inputs_s) 203 | 204 | output = torch.cat((output_l, output_s)) 205 | loss = criterion(output, labels, w_p, w_n) 206 | 207 | # flip 208 | num = hm_l.size(0) // 2 209 | 210 | hm1, hm2 = hm_l.split(num) 211 | flip_grid_large = grid_l.expand(num, -1, -1, -1) 212 | flip_grid_large = Variable(flip_grid_large, requires_grad = False) 213 | flip_grid_large = flip_grid_large.permute(0, 2, 3, 1) 214 | hm2_flip = F.grid_sample(hm2, flip_grid_large, mode = 'bilinear', 215 | padding_mode = 'border') 216 | flip_loss_l = F.mse_loss(hm1, hm2_flip) 217 | 218 | hm1_small, hm2_small = hm_s.split(num) 219 | flip_grid_small = grid_s.expand(num, -1, -1, -1) 220 | flip_grid_small = Variable(flip_grid_small, requires_grad = False) 221 | flip_grid_small = flip_grid_small.permute(0, 2, 3, 1) 222 | hm2_small_flip = F.grid_sample(hm2_small, flip_grid_small, mode = 'bilinear', 223 | padding_mode = 'border') 224 | flip_loss_s = F.mse_loss(hm1_small, hm2_small_flip) 225 | 226 | # scale loss 227 | num = hm_l.size(0) 228 | hm_l = F.upsample(hm_l, lcm) 229 | hm_s = F.upsample(hm_s, lcm) 230 | scale_loss = F.mse_loss(hm_l, hm_s) 231 | 232 | losses = loss + flip_loss_l + flip_loss_s + scale_loss 233 | 234 | optimizer.zero_grad() 235 | losses.backward() 236 | optimizer.step() 237 | 238 | if (step) % args.display == 0: 239 | print( 240 | 'epoch: {},\ttrain step: {}\tLoss: {:.6f}'.format(epoch+1, 241 | step, losses.data[0]) 242 | ) 243 | print( 244 | '\tcls loss: {:.4f};\tflip_loss_l: {:.4f}' 245 | '\tflip_loss_s: {:.4f};\tscale_loss: {:.4f}'.format( 246 | loss.data[0], 247 | flip_loss_l.data[0], 248 | flip_loss_s.data[0], 249 | scale_loss.data[0] 250 | ) 251 | ) 252 | 253 | epoch_end = time.clock() 254 | elapsed = epoch_end - epoch_start 255 | print("Epoch time: ", elapsed) 256 | 257 | # test 258 | if (epoch+1) % args.snapshot == 0: 259 | 260 | model_file = model_prefix + '_epoch{}.pth' 261 | print("Saving model to " + model_file.format(epoch+1)) 262 | torch.save(model, model_file.format(epoch+1)) 263 | 264 | if args.test: 265 | model.eval() 266 | test_start = time.clock() 267 | test(model, test_loader, epoch+1) 268 | test_time = (time.clock() - test_start) 269 | print("test time: ", test_time) 270 | model.train() 271 | 272 | final_model =model_prefix + '_final.pth' 273 | print("Saving model to " + final_model) 274 | torch.save(model, final_model) 275 | model.eval() 276 | test(model, test_loader, epoch+1) 277 | 278 | 279 | 280 | if __name__ == '__main__': 281 | main() 282 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | 101 | def __init__(self, block, layers, num_labels=14): 102 | self.inplanes = 64 103 | self.num_labels = num_labels 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AdaptiveAvgPool2d(1) 115 | 116 | self.fc_all = nn.Linear(512 * block.expansion, self.num_labels) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | # modify the forward function 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | feat = x 155 | N, C, H, W = feat.shape 156 | # global 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | y = self.fc_all(x) 160 | 161 | # local 162 | # get the FC parameters 163 | params = list(self.parameters()) 164 | fc_weights = params[-2].data 165 | fc_weights = fc_weights.view(1, self.num_labels, C, 1, 1) 166 | fc_weights = Variable(fc_weights, requires_grad = False) 167 | 168 | # attention 169 | feat = feat.unsqueeze(1) # N * 1 * C * H * W 170 | hm = feat * fc_weights 171 | hm = hm.sum(2) # N * self.num_labels * H * W 172 | 173 | heatmap = hm 174 | 175 | return y, heatmap 176 | 177 | 178 | def resnet18(pretrained=False, **kwargs): 179 | """Constructs a ResNet-18 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 185 | if pretrained: 186 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 187 | model_dict = model.state_dict() 188 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 189 | model_dict.update(pretrained_dict) 190 | model.load_state_dict(model_dict) 191 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 192 | return model 193 | 194 | 195 | def resnet34(pretrained=False, **kwargs): 196 | """Constructs a ResNet-34 model. 197 | 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | """ 201 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 202 | if pretrained: 203 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 204 | return model 205 | 206 | 207 | def resnet50(pretrained=False, **kwargs): 208 | """Constructs a ResNet-50 model. 209 | 210 | Args: 211 | pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | """ 213 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 214 | if pretrained: 215 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 216 | model_dict = model.state_dict() 217 | # for k,v in pretrained_dict.items(): 218 | # if k in model_dict: 219 | # print(k) 220 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 221 | model_dict.update(pretrained_dict) 222 | model.load_state_dict(model_dict) 223 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 224 | return model 225 | 226 | 227 | def resnet101(pretrained=False, **kwargs): 228 | """Constructs a ResNet-101 model. 229 | 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | """ 233 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 234 | if pretrained: 235 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 236 | model_dict = model.state_dict() 237 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 238 | model_dict.update(pretrained_dict) 239 | model.load_state_dict(model_dict) 240 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 241 | return model 242 | 243 | 244 | def resnet152(pretrained=False, **kwargs): 245 | """Constructs a ResNet-152 model. 246 | 247 | Args: 248 | pretrained (bool): If True, returns a model pre-trained on ImageNet 249 | """ 250 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 251 | if pretrained: 252 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 253 | return model 254 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchvision.models as models 5 | 6 | from torch.autograd import Variable 7 | 8 | import numpy as np 9 | import sys 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | 14 | import sklearn.metrics as metrics 15 | 16 | from sklearn.metrics import average_precision_score 17 | from wider import get_subsets, imshow 18 | 19 | num_attr = 14 20 | 21 | def calc_average_precision(y_true, y_score): 22 | aps = np.zeros(num_attr) 23 | for i in range(num_attr): 24 | true = y_true[i] 25 | score = y_score[i] 26 | 27 | non_index = np.where(true == 0) 28 | score = np.delete(score, non_index) 29 | true = np.delete(true, non_index) 30 | 31 | true[true == -1.] = 0 32 | 33 | ap = average_precision_score(true, score) 34 | aps[i] = ap 35 | 36 | return aps 37 | 38 | def calc_acc_pr_f1(y_true, y_pred): 39 | precision = np.zeros(num_attr) 40 | recall = np.zeros(num_attr) 41 | accuracy = np.zeros(num_attr) 42 | f1 = np.zeros(num_attr) 43 | for i in range(num_attr): 44 | true = y_true[i] 45 | pred = y_pred[i] 46 | 47 | true[true == -1.] = 0 48 | 49 | precision[i] = metrics.precision_score(true, pred) 50 | recall[i] = metrics.recall_score(true, pred) 51 | accuracy[i] = metrics.accuracy_score(true, pred) 52 | f1[i] = metrics.f1_score(true, pred) 53 | 54 | return precision, recall, accuracy, f1 55 | 56 | def calc_mean_acc(y_true, y_pred): 57 | macc = np.zeros(num_attr) 58 | for i in range(num_attr): 59 | true = y_true[i] # -1, 0, 1 60 | pred = y_pred[i] # 0, 1 61 | 62 | true[true == -1.] = 0 63 | 64 | temp = true + pred 65 | tp = (temp[temp == 2]).size 66 | tn = (temp[temp == 0]).size 67 | p = (true[true == 1]).size 68 | n = (true[true == 0]).size 69 | 70 | macc[i] = .5 * tp / (p) + .5 * tn / (n) 71 | 72 | return macc 73 | 74 | def calc_acc_pr_f1_overall(y_true, y_pred): 75 | 76 | true = y_true 77 | pred = y_pred 78 | 79 | true[true == -1.] = 0 80 | 81 | precision = metrics.precision_score(true, pred) 82 | recall = metrics.recall_score(true, pred) 83 | accuracy = metrics.accuracy_score(true, pred) 84 | f1 = metrics.f1_score(true, pred) 85 | 86 | return precision, recall, accuracy, f1 87 | 88 | def calc_mean_acc_overall(y_true, y_pred): 89 | 90 | true = y_true # 0, 1 91 | pred = y_pred # 0, 1 92 | 93 | true[true == -1.] = 0 94 | 95 | temp = true + pred 96 | tp = (temp[temp == 2]).size 97 | tn = (temp[temp == 0]).size 98 | p = (true[true == 1]).size 99 | n = (true[true == 0]).size 100 | macc = .5 * tp / (p) + .5 * tn / (n) 101 | 102 | return macc 103 | 104 | def eval_example(y_true, y_pred): 105 | # example-based metrics 106 | N = y_true.shape[1] 107 | 108 | acc = 0. 109 | prec = 0. 110 | rec = 0. 111 | f1 = 0. 112 | 113 | for i in range(N): 114 | true_exam = y_true[:,i] # column: labels for an example 115 | pred_exam = y_pred[:,i] 116 | 117 | temp = true_exam + pred_exam 118 | 119 | yi = true_exam.sum() # number of attributes for i 120 | fi = pred_exam.sum() # number of predicted attributes for i 121 | ui = (temp > 0).sum() # temp == 1 or 2 means the union of attributes in yi and fi 122 | ii = (temp == 2).sum() # temp == 2 means the intersection 123 | 124 | if ui != 0: 125 | acc += 1.0 * ii / ui 126 | if fi != 0: 127 | prec += 1.0 * ii / fi 128 | if yi != 0: 129 | rec += 1.0 * ii / yi 130 | 131 | acc /= N 132 | prec /= N 133 | rec /= N 134 | f1 = 2.0 * prec * rec / (prec + rec) 135 | return acc, prec, rec, f1 136 | 137 | 138 | 139 | def test(model, test_loader, epoch): 140 | print("testing ... ") 141 | 142 | probs = torch.FloatTensor() 143 | gtruth = torch.FloatTensor() 144 | probs = probs.cuda() 145 | gtruth = gtruth.cuda() 146 | for i, sample in enumerate(test_loader): 147 | images = sample[0] # test just large 148 | labels = sample[4] 149 | labels = labels.type(torch.FloatTensor) 150 | 151 | images = images.cuda() 152 | labels = labels.cuda() 153 | 154 | test_input = Variable(images) 155 | y, _ = model(test_input) 156 | 157 | probs = torch.cat((probs, y.data.transpose(1, 0)), 1) 158 | gtruth = torch.cat((gtruth, labels.transpose(1, 0)), 1) 159 | 160 | print('prediction finished ....') 161 | 162 | preds = np.zeros((probs.size(0), probs.size(1))) 163 | temp = probs.cpu().numpy() 164 | preds[temp > 0.] = 1 165 | 166 | import scipy.io 167 | import os 168 | if not os.path.isdir('./preds'): 169 | os.mkdir('./preds') 170 | scipy.io.savemat('./preds/prediction_e{}.mat'.format(epoch), dict(gt = gtruth.cpu().numpy(), \ 171 | prob = probs.cpu().numpy(), pred = preds)) 172 | 173 | aps = calc_average_precision(gtruth.cpu().numpy(), probs.cpu().numpy()) 174 | print('>>>>>>>>>>>>>>>>>>>>>>>> Average for Each Attribute >>>>>>>>>>>>>>>>>>>>>>>>>>>') 175 | print("APs") 176 | print(aps) 177 | precision, recall, accuracy, f1 = calc_acc_pr_f1(gtruth.cpu().numpy(), preds) 178 | print('precision scores') 179 | print(precision) 180 | print('recall scores') 181 | print(recall) 182 | print('f1 scores') 183 | print(f1) 184 | print('') 185 | 186 | print("AP: {}".format(aps.mean())) 187 | print('F1-C: {}'.format(f1.mean())) 188 | print('P-C: {}'.format(precision.mean())) 189 | print('R-C: {}'.format(recall.mean())) 190 | print('') 191 | 192 | 193 | print('>>>>>>>>>>>>>>>>>>>>>>>> Overall Sample-Label Pairs >>>>>>>>>>>>>>>>>>>>>>>>>>>') 194 | precision, recall, accuracy, f1 = calc_acc_pr_f1_overall(gtruth.cpu().numpy().flatten(), 195 | preds.flatten()) 196 | 197 | print('F1_O: {}'.format(f1)) 198 | print('P_O: {}'.format(precision)) 199 | print('R_O: {}'.format(recall)) 200 | print('\n') 201 | 202 | 203 | macc = calc_mean_acc(gtruth.cpu().numpy(), preds) 204 | print('mA scores') 205 | print(macc) 206 | print('mean mA') 207 | print(macc.mean()) 208 | 209 | print('\n') 210 | 211 | if __name__ == '__main__': 212 | anno_dir = '/path/to/wider_attribute_annotation' 213 | data_dir = '/path/to/Image' 214 | trainset, testset = get_subsets(anno_dir, data_dir) 215 | test_loader = torch.utils.data.DataLoader(testset, 216 | batch_size = 16, 217 | shuffle = False, 218 | num_workers = 4) 219 | 220 | # modify to test multiple checkpoints continuously 221 | for i in range(11, 12): 222 | model_file = '/path/to/model_resnet50_{}.pth'.format(i) 223 | model = torch.load(model_file) 224 | print(model_file) 225 | model.eval() 226 | start_time = time.clock() 227 | test(model, test_loader, i) 228 | end_time = time.clock() 229 | print('Time: ', end_time - start_time) 230 | print('\n') 231 | -------------------------------------------------------------------------------- /tmp/ms-coco-fs288/main.py: -------------------------------------------------------------------------------- 1 | # this script should be the entrance for following coding 2 | # together with wider.py test.py, and specific resnet.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | import torchvision.models as models 12 | 13 | from torch.autograd import Variable 14 | from ms_coco import getSubsets 15 | from test import test 16 | 17 | import matplotlib.pyplot as plt 18 | import numpy as np 19 | 20 | import sys 21 | import argparse 22 | import math 23 | import time 24 | import os 25 | 26 | 27 | parser = argparse.ArgumentParser(description = 'Attribute') 28 | 29 | parser.add_argument('--data_dir', 30 | default = '/work/hguo/datasets/MS-COCO', 31 | type = str, 32 | help = 'path to "Image" folder of WIDER dataset') 33 | # parser.add_argument('--anno_file', 34 | # default = '/data/hguo/Datasets/PA-100K/annotation/annotation.mat', 35 | # type = str, 36 | # help = 'annotation file') 37 | 38 | parser.add_argument('--train_batch_size', default = 24, type = int, 39 | help = 'default training batch size') 40 | parser.add_argument('--train_workers', default = 12, type = int, 41 | help = '# of workers used to load training samples') 42 | parser.add_argument('--test_batch_size', default = 8, type = int, 43 | help = 'default test batch size') 44 | parser.add_argument('--test_workers', default = 12, type = int, 45 | help = '# of workers used to load testing samples') 46 | 47 | parser.add_argument('--learning_rate', default = 0.001, type = float, 48 | help = 'base learning rate') 49 | parser.add_argument('--momentum', default = 0.9, type = float, 50 | help = "set the momentum") 51 | parser.add_argument('--weight_decay', default = 0.0005, type = float, 52 | help = 'set the weight_decay') 53 | parser.add_argument('--stepsize', default = 6, type = int, 54 | help = 'lr decay each # of epoches') 55 | 56 | parser.add_argument('--model_dir', 57 | default = '/work/hguo/models/ac_mscoco/mscoco_fs288', 58 | type = str, 59 | help = 'path to save trained models') 60 | parser.add_argument('--model_prefix', 61 | default = 'model', 62 | type = str, 63 | help = 'model file name starts with') 64 | 65 | # optimizer 66 | parser.add_argument('--optimizer', 67 | default = 'SGD', 68 | type = str, 69 | help = 'Select an optimizer: TBD') 70 | 71 | 72 | # general parameters 73 | parser.add_argument('--epoch_max', default = 14, type = int, 74 | help = 'max # of epcoh') 75 | parser.add_argument('--display', default = 200, type = int, 76 | help = 'display') 77 | parser.add_argument('--snapshot', default = 1, type = int, 78 | help = 'snapshot') 79 | parser.add_argument('--resume', default = 0, type = int, 80 | help = 'resume training from specified epoch') 81 | 82 | 83 | import scipy.io 84 | ratio_file = '../ratio.mat' 85 | temp = scipy.io.loadmat(ratio_file) 86 | ratio = temp['ratio'] 87 | ratio = torch.from_numpy(ratio).squeeze().type(torch.FloatTensor) 88 | w_p = (1-ratio).exp().cuda() 89 | w_n = ratio.exp().cuda() 90 | 91 | 92 | def adjust_learning_rate(optimizer, epoch, stepsize): 93 | """Sets the learning rate to the initial LR decayed by 2 every 30 epochs""" 94 | lr = args.learning_rate * (0.1 ** (epoch // stepsize)) 95 | # print('6, 12, ...') 96 | # if epoch < 6: 97 | # lr = 0.001 98 | # else: 99 | # if epoch < 12: 100 | # lr = 0.0001 101 | # else: 102 | # lr = 0.00001 103 | print("Current learning rate is: {:.6f}".format(lr)) 104 | for param_group in optimizer.param_groups: 105 | param_group['lr'] = lr 106 | 107 | def imshow(inp, title=None): 108 | """Imshow for Tensor.""" 109 | inp = inp.numpy().transpose((1, 2, 0)) 110 | mean = np.array([0.485, 0.456, 0.406]) 111 | std = np.array([0.229, 0.224, 0.225]) 112 | inp = std * inp + mean 113 | plt.imshow(inp) 114 | plt.show() 115 | if title is not None: 116 | plt.title(title) 117 | 118 | def multi_label_classification_loss(x, y): 119 | 120 | loss = 0.0 121 | if not x.size() == y.size(): 122 | print("x and y must have the same size") 123 | else: 124 | 125 | for i in range(y.size(0)): 126 | temp = -(y[i]*y[i] * ( 1 / (1+(-x[i]*y[i]).exp()) ).log()) 127 | loss += temp.sum() 128 | 129 | loss = loss / y.size(0) 130 | return loss 131 | 132 | 133 | def SigmoidCrossEntropyLoss(x, y): 134 | loss = 0.0 135 | if not x.size() == y.size(): 136 | print("x and y must have the same size") 137 | else: 138 | N = y.size(0) 139 | L = y.size(1) 140 | for i in range(N): 141 | w = torch.zeros(L).cuda() 142 | # print(y[i].data) 143 | # print(w_p) 144 | # print(w_n) 145 | w[y[i].data == 1] = w_p[y[i].data == 1] 146 | w[y[i].data == 0] = w_n[y[i].data == 0] 147 | # sys.exit() 148 | 149 | w = Variable(w, requires_grad = False) 150 | temp = - w * ( y[i] * (1 / (1 + (-x[i]).exp())).log() + \ 151 | (1 - y[i]) * ( (-x[i]).exp() / (1 + (-x[i]).exp()) ).log() ) 152 | loss += temp.sum() 153 | 154 | loss = loss / N 155 | return loss 156 | 157 | def generate_flip_grid(w, h): 158 | 159 | x_ = torch.arange(w).view(1, -1).expand(h, -1) 160 | y_ = torch.arange(h).view(-1, 1).expand(-1, w) 161 | grid = torch.stack([x_, y_], dim=0).float().cuda() 162 | grid = grid.unsqueeze(0).expand(1, -1, -1, -1) 163 | grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1 164 | grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1 165 | 166 | grid[:, 0, :, :] = -grid[:, 0, :, :] 167 | return grid 168 | 169 | 170 | def main(): 171 | global args 172 | args = parser.parse_args() 173 | print(args) 174 | 175 | num_cls = 80 176 | 177 | # load data 178 | data_dir = args.data_dir 179 | 180 | trainset, valset = getSubsets(data_dir) 181 | train_loader = torch.utils.data.DataLoader(trainset, 182 | batch_size = args.train_batch_size, 183 | shuffle = True, 184 | num_workers = args.train_workers) 185 | test_loader = torch.utils.data.DataLoader(valset, 186 | batch_size = args.test_batch_size, 187 | shuffle = False, 188 | num_workers = args.test_workers) 189 | 190 | 191 | # path to save models 192 | if not os.path.isdir(args.model_dir): 193 | print("Make directory: " + args.model_dir) 194 | os.makedirs(args.model_dir) 195 | 196 | model_prefix = args.model_dir + '/' + args.model_prefix 197 | 198 | 199 | # define the model 200 | from resnet import resnet101 201 | model = resnet101(pretrained = True) 202 | # resume_model_file = '/data/userdata/hguo/models/AttrFlow/mscoco_flow_scale_wsce_288_b16/model_resnet101_4.pth' 203 | # resume_model = torch.load(resume_model_file) 204 | # resume_dict = resume_model.state_dict() 205 | # model_dict = model.state_dict() 206 | # resume_dict = {k:v for k,v in resume_dict.items() if k in model_dict} 207 | # model_dict.update(resume_dict) 208 | # model.load_state_dict(model_dict) 209 | resume_epoch = 0 210 | if torch.cuda.device_count() > 1: 211 | model = nn.DataParallel(model) 212 | 213 | model.cuda() 214 | 215 | 216 | if args.optimizer == 'Adam': 217 | optimizer = optim.Adam( 218 | model.parameters(), 219 | lr = args.learning_rate) 220 | elif args.optimizer == 'SGD': 221 | optimizer = optim.SGD( 222 | model.parameters(), 223 | lr = args.learning_rate, 224 | momentum = args.momentum, 225 | weight_decay = args.weight_decay) 226 | else: 227 | pass 228 | 229 | # training the network 230 | model.train() 231 | 232 | w = 9 233 | h = 9 234 | grid_l = generate_flip_grid(w, h) 235 | 236 | w = 8 237 | h = 8 238 | grid_s = generate_flip_grid(w, h) 239 | 240 | # criterion = multi_label_classification_loss 241 | # criterion = nn.MultiLabelSoftMarginLoss() 242 | criterion = SigmoidCrossEntropyLoss 243 | 244 | for epoch in range(resume_epoch, args.epoch_max): 245 | epoch_start = time.clock() 246 | if not args.stepsize == 0: 247 | adjust_learning_rate(optimizer, epoch, args.stepsize) 248 | for step, batch_data in enumerate(train_loader): 249 | batch_images_lo = batch_data[0] 250 | batch_images_lf = batch_data[1] 251 | batch_images_so = batch_data[2] 252 | batch_images_sf = batch_data[3] 253 | batch_labels = batch_data[4] 254 | 255 | if batch_labels.size(0) != args.train_batch_size: 256 | continue 257 | 258 | batch_images_l = torch.cat((batch_images_lo, batch_images_lf)) 259 | batch_images_s = torch.cat((batch_images_so, batch_images_sf)) 260 | batch_labels = torch.cat((batch_labels, batch_labels, batch_labels, batch_labels)) 261 | 262 | batch_images_l = batch_images_l.cuda() 263 | batch_images_s = batch_images_s.cuda() 264 | batch_labels = batch_labels.cuda() 265 | 266 | inputs_l = Variable(batch_images_l) 267 | inputs_s = Variable(batch_images_s) 268 | labels = Variable(batch_labels) 269 | 270 | output_l, hm_l = model(inputs_l) 271 | output_s, hm_s = model(inputs_s) 272 | 273 | output = torch.cat((output_l, output_s)) 274 | loss = criterion(output, labels) 275 | 276 | 277 | # flip 278 | num = hm_l.size(0) // 2 279 | 280 | hm1, hm2 = hm_l.split(num) 281 | flip_grid_large = grid_l.expand(num, -1, -1, -1) 282 | flip_grid_large = Variable(flip_grid_large, requires_grad = False) 283 | flip_grid_large = flip_grid_large.permute(0, 2, 3, 1) 284 | hm2_flip = F.grid_sample(hm2, flip_grid_large, mode = 'bilinear', 285 | padding_mode = 'border') 286 | flip_loss_l = F.mse_loss(hm1, hm2_flip) 287 | 288 | hm1_small, hm2_small = hm_s.split(num) 289 | flip_grid_small = grid_s.expand(num, -1, -1, -1) 290 | flip_grid_small = Variable(flip_grid_small, requires_grad = False) 291 | flip_grid_small = flip_grid_small.permute(0, 2, 3, 1) 292 | hm2_small_flip = F.grid_sample(hm2_small, flip_grid_small, mode = 'bilinear', 293 | padding_mode = 'border') 294 | flip_loss_s = F.mse_loss(hm1_small, hm2_small_flip) 295 | 296 | # scale loss 297 | num = hm_l.size(0) 298 | hm_l = F.upsample(hm_l, 72) 299 | hm_s = F.upsample(hm_s, 72) 300 | scale_loss = F.mse_loss(hm_l, hm_s) 301 | 302 | losses = loss + 0.6 * flip_loss_l + 0.6 * flip_loss_s + 0.8 * scale_loss 303 | 304 | optimizer.zero_grad() 305 | losses.backward() 306 | optimizer.step() 307 | 308 | if (step) % args.display == 0: 309 | print('epoch: {},\ttrain step: {}\tLoss: {:.6f}'.format(epoch+1, 310 | step, losses.data[0])) 311 | print('cls loss: {:.5f}'.format(loss.data[0])) 312 | print('flip_loss_l: {:.5f}'.format(flip_loss_l.data[0])) 313 | print('flip_loss_s: {:.5f}'.format(flip_loss_s.data[0])) 314 | print('scale_loss: {:.5f}'.format(scale_loss.data[0])) 315 | 316 | epoch_end = time.clock() 317 | elapsed = epoch_end - epoch_start 318 | print("Epoch time: ", elapsed) 319 | 320 | 321 | # test 322 | if (epoch+1) % args.snapshot == 0: 323 | 324 | model_file = model_prefix + '_resnet101_{}.pth' 325 | print("Saving model to " + model_file.format(epoch+1)) 326 | torch.save(model, model_file.format(epoch+1)) 327 | 328 | model.eval() 329 | test_start = time.clock() 330 | test(model, test_loader, epoch+1) 331 | test_time = (time.clock() - test_start) 332 | print("test time: ", test_time) 333 | model.train() 334 | 335 | final_model =model_prefix + '_resnet101_final.pth' 336 | print("Saving model to " + final_model) 337 | torch.save(model, final_model) 338 | model.eval() 339 | # test(model, test_loader) 340 | 341 | 342 | 343 | if __name__ == '__main__': 344 | main() 345 | -------------------------------------------------------------------------------- /tmp/ms-coco-fs288/ms_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import scipy.io 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | 11 | from os.path import join 12 | from PIL import Image 13 | 14 | from pycocotools.coco import COCO 15 | import pylab 16 | import random 17 | 18 | def default_loader(path): 19 | return Image.open(path).convert('RGB') 20 | 21 | def imshow(inp, title=None): 22 | """Imshow for Tensor.""" 23 | inp = inp.numpy().transpose((1, 2, 0)) 24 | mean = np.array([0.485, 0.456, 0.406]) 25 | std = np.array([0.229, 0.224, 0.225]) 26 | inp = std * inp + mean 27 | plt.imshow(inp) 28 | plt.show() 29 | if title is not None: 30 | plt.title(title) 31 | 32 | 33 | size_set = [256, 224, 192, 168, 128] 34 | 35 | class MSCOCO(data.Dataset): 36 | 37 | def __init__(self, dataDir, dataType, phase): 38 | # self.annoDir = join(dataDir, 'annotations') 39 | self.imgDir = join(dataDir, dataType) 40 | self.phase = phase 41 | annoFile = '{}/annotations/instances_{}.json'.format(dataDir, dataType) 42 | coco = COCO(annoFile) 43 | 44 | imgFiles = [] 45 | imgLabels = [] 46 | 47 | # get all image ids 48 | imgIds = coco.getImgIds() 49 | n_img = len(imgIds) 50 | catIds = coco.getCatIds() 51 | 52 | for i in range(n_img): 53 | img_id = imgIds[i] 54 | # load img info 55 | img_info = coco.loadImgs(img_id)[0] 56 | img_name = img_info['file_name'] 57 | img_file = join(self.imgDir, img_name) 58 | imgFiles.append(img_file) 59 | 60 | # get labels 61 | labels = {} 62 | for c in catIds: 63 | labels[c] = 0 64 | 65 | # [v for k,v in labels.items()] 66 | annIds = coco.getAnnIds(imgIds = img_id) 67 | for j in range(len(annIds)): 68 | ann = coco.loadAnns(annIds[j])[0] 69 | cat_id = ann['category_id'] 70 | labels[cat_id] = 1 71 | # catCount[cat_id] += 1 72 | imgLabels.append([v for k, v in labels.items()]) 73 | 74 | # for k, v in catCount.items(): 75 | # print('{}: {}'.format(k, v)) 76 | self.imgFiles = imgFiles 77 | self.imgLabels = imgLabels 78 | 79 | self.transform = transforms.Compose([ 80 | # transforms.RandomHorizontalFlip(), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean = [0.485, 0.456, 0.406], 83 | std = [0.229, 0.224, 0.225]), 84 | ]) 85 | 86 | def __getitem__(self, idx): 87 | imgFile = self.imgFiles[idx] 88 | imgLabel = self.imgLabels[idx] 89 | 90 | img = default_loader(imgFile) 91 | 92 | img1 = img.resize((288, 288)) 93 | img2 = img.resize((256, 256)) 94 | 95 | # imshow(image) 96 | # imshow(image2) 97 | image_lo = self.transform(img1) 98 | # large flip 99 | image_lf = self.transform(img1.transpose(Image.FLIP_LEFT_RIGHT)) 100 | 101 | # small orig 102 | image_so = self.transform(img2) 103 | # small flip 104 | image_sf = self.transform(img2.transpose(Image.FLIP_LEFT_RIGHT)) 105 | 106 | labels = torch.FloatTensor(imgLabel) 107 | # print sample['file_name'] 108 | # print labels 109 | # imshow(image) 110 | # sys.exit() 111 | return image_lo, image_lf, image_so, image_sf, labels 112 | 113 | 114 | def __len__(self): 115 | return len(self.imgFiles) 116 | 117 | 118 | def getSubsets(dataDir): 119 | trainset = MSCOCO(dataDir, 'train2014', 'train') 120 | valset = MSCOCO(dataDir, 'val2014', 'test') 121 | return trainset, valset 122 | 123 | 124 | if __name__ == '__main__': 125 | dataDir = '/data/hguo/Datasets/MS-COCO' 126 | dataType = 'val2014' 127 | mscoco = MSCOCO(dataDir, dataType, 'test') 128 | mscoco[0] 129 | mscoco[1] 130 | -------------------------------------------------------------------------------- /tmp/ms-coco-fs288/ratio.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hguosc/visual_attention_consistency/7c5b0dcfc444d7e6cecadd793b799b4a7ef9da7d/tmp/ms-coco-fs288/ratio.mat -------------------------------------------------------------------------------- /tmp/ms-coco-fs288/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | } 20 | 21 | num_cls = 80 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AdaptiveAvgPool2d(1) 115 | 116 | self.fc_all = nn.Linear(512 * block.expansion, num_cls) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, nn.BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | # modify the forward function 145 | x = self.conv1(x) 146 | x = self.bn1(x) 147 | x = self.relu(x) 148 | x = self.maxpool(x) 149 | 150 | x = self.layer1(x) 151 | x = self.layer2(x) 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | feat = x 155 | N, C, H, W = feat.shape 156 | 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | y = self.fc_all(x) 160 | 161 | params = list(self.parameters()) 162 | fc_weights = params[-2].data 163 | fc_weights = fc_weights.view(1, num_cls, C, 1, 1) 164 | fc_weights = Variable(fc_weights, requires_grad = False) 165 | feat = feat.unsqueeze(1) # N * 1 * C * H * W 166 | hm = feat * fc_weights 167 | hm = hm.sum(2) # N * num_cls * H * W 168 | heatmap = hm 169 | 170 | return y, heatmap 171 | 172 | 173 | def resnet18(pretrained=False, **kwargs): 174 | """Constructs a ResNet-18 model. 175 | 176 | Args: 177 | pretrained (bool): If True, returns a model pre-trained on ImageNet 178 | """ 179 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 180 | if pretrained: 181 | pretrained_dict = model_zoo.load_url(model_urls['resnet18']) 182 | model_dict = model.state_dict() 183 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 184 | model_dict.update(pretrained_dict) 185 | model.load_state_dict(model_dict) 186 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 187 | return model 188 | 189 | 190 | def resnet34(pretrained=False, **kwargs): 191 | """Constructs a ResNet-34 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 199 | return model 200 | 201 | 202 | def resnet50(pretrained=False, **kwargs): 203 | """Constructs a ResNet-50 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 209 | if pretrained: 210 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 211 | model_dict = model.state_dict() 212 | # for k,v in pretrained_dict.items(): 213 | # if k in model_dict: 214 | # print(k) 215 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 216 | model_dict.update(pretrained_dict) 217 | model.load_state_dict(model_dict) 218 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 219 | return model 220 | 221 | 222 | def resnet101(pretrained=False, **kwargs): 223 | """Constructs a ResNet-101 model. 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | """ 228 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 229 | if pretrained: 230 | pretrained_dict = model_zoo.load_url(model_urls['resnet101']) 231 | model_dict = model.state_dict() 232 | # for k,v in pretrained_dict.items(): 233 | # if k in model_dict: 234 | # print(k) 235 | pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict} 236 | model_dict.update(pretrained_dict) 237 | model.load_state_dict(model_dict) 238 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 239 | return model 240 | 241 | 242 | def resnet152(pretrained=False, **kwargs): 243 | """Constructs a ResNet-152 model. 244 | 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | """ 248 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 249 | if pretrained: 250 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 251 | return model 252 | -------------------------------------------------------------------------------- /tmp/ms-coco-fs288/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torchvision.models as models 5 | 6 | from torch.autograd import Variable 7 | 8 | import numpy as np 9 | import sys 10 | import math 11 | import time 12 | import matplotlib.pyplot as plt 13 | 14 | import sklearn.metrics as metrics 15 | import scipy.io 16 | 17 | from sklearn.metrics import average_precision_score 18 | from ms_coco import getSubsets 19 | 20 | num_attr = 80 21 | 22 | def imshow(inp, title=None): 23 | """Imshow for Tensor.""" 24 | inp = inp.numpy().transpose((1, 2, 0)) 25 | mean = np.array([0.485, 0.456, 0.406]) 26 | std = np.array([0.229, 0.224, 0.225]) 27 | inp = std * inp + mean 28 | plt.imshow(inp) 29 | plt.show() 30 | if title is not None: 31 | plt.title(title) 32 | 33 | def calc_average_precision(y_true, y_score): 34 | aps = np.zeros(num_attr) 35 | for i in range(num_attr): 36 | true = y_true[i] 37 | score = y_score[i] 38 | 39 | ap = average_precision_score(true, score) 40 | aps[i] = ap 41 | 42 | return aps 43 | 44 | def calc_acc_pr_f1(y_true, y_pred): 45 | precision = np.zeros(num_attr) 46 | recall = np.zeros(num_attr) 47 | accuracy = np.zeros(num_attr) 48 | f1 = np.zeros(num_attr) 49 | for i in range(num_attr): 50 | true = y_true[i] 51 | pred = y_pred[i] 52 | 53 | precision[i] = metrics.precision_score(true, pred) 54 | recall[i] = metrics.recall_score(true, pred) 55 | accuracy[i] = metrics.accuracy_score(true, pred) 56 | f1[i] = metrics.f1_score(true, pred) 57 | 58 | return precision, recall, accuracy, f1 59 | 60 | def calc_mean_acc(y_true, y_pred): 61 | macc = np.zeros(num_attr) 62 | for i in range(num_attr): 63 | true = y_true[i] # 0, 1 64 | pred = y_pred[i] # 0, 1 65 | 66 | temp = true + pred 67 | tp = (temp[temp == 2]).size 68 | tn = (temp[temp == 0]).size 69 | p = (true[true == 1]).size 70 | n = (true[true == 0]).size 71 | 72 | macc[i] = .5 * tp / (p) + .5 * tn / (n) 73 | 74 | return macc 75 | 76 | def calc_acc_pr_f1_overall(y_true, y_pred): 77 | 78 | true = y_true 79 | pred = y_pred 80 | 81 | precision = metrics.precision_score(true, pred) 82 | recall = metrics.recall_score(true, pred) 83 | accuracy = metrics.accuracy_score(true, pred) 84 | f1 = metrics.f1_score(true, pred) 85 | 86 | return precision, recall, accuracy, f1 87 | 88 | def calc_mean_acc_overall(y_true, y_pred): 89 | 90 | true = y_true # 0, 1 91 | pred = y_pred # 0, 1 92 | 93 | temp = true + pred 94 | tp = (temp[temp == 2]).size 95 | tn = (temp[temp == 0]).size 96 | p = (true[true == 1]).size 97 | n = (true[true == 0]).size 98 | macc = .5 * tp / (p) + .5 * tn / (n) 99 | 100 | return macc 101 | 102 | def eval_example(y_true, y_pred): 103 | N = y_true.shape[1] 104 | 105 | acc = 0. 106 | prec = 0. 107 | rec = 0. 108 | f1 = 0. 109 | 110 | for i in range(N): 111 | true_exam = y_true[:,i] # column: labels for an example 112 | pred_exam = y_pred[:,i] 113 | 114 | temp = true_exam + pred_exam 115 | 116 | yi = true_exam.sum() # number of attributes for i 117 | fi = pred_exam.sum() # number of predicted attributes for i 118 | ui = (temp > 0).sum() # temp == 1 or 2 means the union of attributes in yi and fi 119 | ii = (temp == 2).sum() # temp == 2 means the intersection 120 | 121 | acc += 1.0 * ii / ui 122 | prec += 1.0 * ii / fi 123 | rec += 1.0 * ii / yi 124 | 125 | acc /= N 126 | prec /= N 127 | rec /= N 128 | f1 = 2.0 * prec * rec / (prec + rec) 129 | return acc, prec, rec, f1 130 | 131 | 132 | 133 | def test(model, test_loader, epoch): 134 | print("testing ... ") 135 | 136 | probs = torch.FloatTensor() 137 | gtruth = torch.FloatTensor() 138 | probs = probs.cuda() 139 | gtruth = gtruth.cuda() 140 | for i, sample in enumerate(test_loader): 141 | images = sample[0] 142 | labels = sample[4] 143 | labels = labels.type(torch.FloatTensor) 144 | 145 | images = images.cuda() 146 | labels = labels.cuda() 147 | 148 | test_input = Variable(images) 149 | y, _ = model(test_input) 150 | 151 | probs = torch.cat((probs, y.data.transpose(1, 0)), 1) 152 | gtruth = torch.cat((gtruth, labels.transpose(1, 0)), 1) 153 | 154 | print('predicting finished ....') 155 | 156 | preds = np.zeros((probs.size(0), probs.size(1))) 157 | temp = probs.cpu().numpy() 158 | preds[temp > 0.] = 1 159 | 160 | scipy.io.savemat('./preds/prediction_e{}.mat'.format(epoch), dict(gt = gtruth.cpu().numpy(), \ 161 | prob = probs.cpu().numpy(), pred = preds)) 162 | 163 | aps = calc_average_precision(gtruth.cpu().numpy(), probs.cpu().numpy()) 164 | print('>>>>>>>>>>>>>>>>>>>>>>>> Average for Each Attribute >>>>>>>>>>>>>>>>>>>>>>>>>>>') 165 | print("APs") 166 | print(aps) 167 | precision, recall, accuracy, f1 = calc_acc_pr_f1(gtruth.cpu().numpy(), preds) 168 | print('precision scores') 169 | print(precision) 170 | print('recall scores') 171 | print(recall) 172 | print('f1 scores') 173 | print(f1) 174 | print('') 175 | 176 | 177 | 178 | macc = calc_mean_acc(gtruth.cpu().numpy(), preds) 179 | print('mA scores') 180 | print(macc) 181 | 182 | print("\nmean AP: {}".format(aps.mean())) 183 | print('F1-C: {}'.format(f1.mean())) 184 | print('P-C: {}'.format(precision.mean())) 185 | print('R-C: {}'.format(recall.mean())) 186 | print('') 187 | 188 | print('>>>>>>>>>>>>>>>>>>>>>>>> Overall Sample-Label Pairs >>>>>>>>>>>>>>>>>>>>>>>>>>>') 189 | precision, recall, accuracy, f1 = calc_acc_pr_f1_overall(gtruth.cpu().numpy().flatten(), 190 | preds.flatten()) 191 | # macc = calc_mean_acc_overall(gtruth.cpu().numpy().flatten(), preds.flatten()) 192 | # print('mA: {}'.format(macc) ) 193 | print('F1_O: {}'.format(f1)) 194 | print('P_O: {}'.format(precision)) 195 | print('R_O: {}'.format(recall)) 196 | 197 | 198 | print('mean mA') 199 | print(macc.mean()) 200 | 201 | 202 | if __name__ == '__main__': 203 | dataDir = '/data/hguo/Datasets/MS-COCO' 204 | dataType = 'val2014' 205 | trainset, valset = getSubsets(dataDir) 206 | test_loader = torch.utils.data.DataLoader(valset, 207 | batch_size = 8, 208 | shuffle = True, 209 | num_workers = 4) 210 | 211 | model_file = '/data/hguo/models/AttrFlow/mscoco_flow_resnet101_256/model_resnet101_1.pth' 212 | print(model_file) 213 | model = torch.load(model_file) 214 | model.eval() 215 | import time 216 | start_time = time.clock() 217 | test(model, test_loader, 1) 218 | end_time = time.clock() 219 | print('Test time: ', end_time - start_time) 220 | print('\n\n') 221 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from wider import get_subsets 3 | from torch.autograd import Variable 4 | 5 | 6 | def get_dataset(opts): 7 | if opts["dataset"] == "WIDER": 8 | data_dir = opts["data_dir"] 9 | anno_dir = opts["anno_dir"] 10 | trainset, testset = get_subsets(anno_dir, data_dir) 11 | else: 12 | # will be added later 13 | pass 14 | 15 | return trainset, testset 16 | 17 | 18 | def adjust_learning_rate(optimizer, epoch, args): 19 | """Sets the learning rate to the initial LR decayed every 30 epochs""" 20 | lr = args.learning_rate * (args.decay ** (epoch // args.stepsize)) 21 | print("Current learning rate is: {:.5f}".format(lr)) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] = lr 24 | 25 | 26 | def SigmoidCrossEntropyLoss(x, y, w_p, w_n): 27 | # weighted sigmoid cross entropy loss defined in Li et al. ACPR'15 28 | loss = 0.0 29 | if not x.size() == y.size(): 30 | print("x and y must have the same size") 31 | else: 32 | N = y.size(0) 33 | L = y.size(1) 34 | for i in range(N): 35 | w = torch.zeros(L).cuda() 36 | w[y[i].data == 1] = w_p[y[i].data == 1] 37 | w[y[i].data == 0] = w_n[y[i].data == 0] 38 | 39 | w = Variable(w, requires_grad = False) 40 | temp = - w * ( y[i] * (1 / (1 + (-x[i]).exp())).log() + \ 41 | (1 - y[i]) * ( (-x[i]).exp() / (1 + (-x[i]).exp()) ).log() ) 42 | loss += temp.sum() 43 | 44 | loss = loss / N 45 | return loss 46 | 47 | 48 | def generate_flip_grid(w, h): 49 | # used to flip attention maps 50 | x_ = torch.arange(w).view(1, -1).expand(h, -1) 51 | y_ = torch.arange(h).view(-1, 1).expand(-1, w) 52 | grid = torch.stack([x_, y_], dim=0).float().cuda() 53 | grid = grid.unsqueeze(0).expand(1, -1, -1, -1) 54 | grid[:, 0, :, :] = 2 * grid[:, 0, :, :] / (w - 1) - 1 55 | grid[:, 1, :, :] = 2 * grid[:, 1, :, :] / (h - 1) - 1 56 | 57 | grid[:, 0, :, :] = -grid[:, 0, :, :] 58 | return grid -------------------------------------------------------------------------------- /wider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | 11 | from os.path import exists, join, basename 12 | from PIL import Image 13 | 14 | crop_size = 224 15 | scale_size = 224 16 | 17 | def default_loader(path): 18 | return Image.open(path).convert('RGB') 19 | 20 | def imshow(inp, title=None): 21 | """Imshow for Tensor.""" 22 | inp = inp.numpy().transpose((1, 2, 0)) 23 | mean = np.array([0.485, 0.456, 0.406]) 24 | std = np.array([0.229, 0.224, 0.225]) 25 | inp = std * inp + mean 26 | plt.imshow(inp) 27 | plt.show() 28 | if title is not None: 29 | plt.title(title) 30 | 31 | class WiderAttr(data.Dataset): 32 | def __init__(self, subset, anno_dir, data_dir): 33 | self.subset = subset 34 | self.data_dir = data_dir 35 | if self.subset == 'train': 36 | anno_file = join(anno_dir, 'wider_attribute_trainval.json') 37 | else: 38 | anno_file = join(anno_dir, 'wider_attribute_test.json') 39 | 40 | self.transform = transforms.Compose([ 41 | # transforms.CenterCrop((crop_size, crop_size)), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean = [0.485, 0.456, 0.406], 44 | std = [0.229, 0.224, 0.225]), 45 | ]) 46 | 47 | with open(anno_file) as AF: 48 | anno = json.load(AF) 49 | 50 | self.images = anno['images'] 51 | self.attributes = anno['attribute_id_map'] 52 | self.scenes = anno['scene_id_map'] 53 | 54 | # create a dict to store separate bboxes 55 | samples = {} 56 | num_img = len(self.images) 57 | s_id = 0 58 | for i in range(num_img): 59 | file_name = self.images[i]['file_name']#.encode('utf-8') 60 | scene_id = self.images[i]['scene_id'] 61 | targets = self.images[i]['targets'] 62 | num_tar = len(targets) 63 | 64 | for j in range(num_tar): 65 | attribute = targets[j]['attribute'] 66 | bbox = targets[j]['bbox'] 67 | samples[s_id] = {} 68 | samples[s_id]['file_name'] = file_name 69 | samples[s_id]['scene_id'] = scene_id 70 | samples[s_id]['labels'] = attribute 71 | samples[s_id]['bbox'] = bbox 72 | s_id += 1 73 | 74 | # img_file = join(data_dir, file_name) 75 | # img = default_loader(img_file) 76 | # wd, ht = img.size 77 | # if bbox[0] > wd or bbox[1] > ht: 78 | # print file_name 79 | # print bbox 80 | 81 | self.samples = samples 82 | 83 | 84 | def __getitem__(self, idx): 85 | # sampe: self.samples[idx] 86 | sample = self.samples[idx] 87 | img_file = join(self.data_dir, sample['file_name']) 88 | labels = sample['labels'] 89 | bbox = sample['bbox'] 90 | scene_id = sample['scene_id'] 91 | 92 | # load image 93 | img = default_loader(img_file) 94 | wd, ht = img.size 95 | 96 | # crop bounding box 97 | # bbox: x, y, w, h -- need to be x1, y1, x2, y2 98 | # extend 99 | x = bbox[0] 100 | y = bbox[1] 101 | w = bbox[2] 102 | h = bbox[3] 103 | 104 | bbox[2] = x+w 105 | bbox[3] = y+h 106 | 107 | # there are some samples not annotated well 108 | if x > wd or y > ht: 109 | bbox = [0, 0, wd, ht] 110 | 111 | img_crop = img.crop(tuple(bbox)) 112 | t1,t2 = img_crop.size 113 | if t1 == 0. or t2 == 0.: 114 | # find if there still images not work 115 | print(sample) 116 | 117 | img1 = img_crop.resize((224, 224)) 118 | img2 = img_crop.resize((192, 192)) 119 | 120 | # large orig 121 | image_lo = self.transform(img1) 122 | # large flip 123 | image_lf = self.transform(img1.transpose(Image.FLIP_LEFT_RIGHT)) 124 | 125 | # small orig 126 | image_so = self.transform(img2) 127 | # small flip 128 | image_sf = self.transform(img2.transpose(Image.FLIP_LEFT_RIGHT)) 129 | 130 | labels = torch.FloatTensor(labels) 131 | # print sample['file_name'] 132 | # print labels 133 | # imshow(image) 134 | # import pdb; pdb.set_trace() 135 | return image_lo, image_lf, image_so, image_sf, labels 136 | 137 | def __len__(self): 138 | return len(self.samples) 139 | 140 | 141 | def get_subsets(anno_dir, data_dir): 142 | trainset = WiderAttr('train', anno_dir, data_dir) 143 | testset = WiderAttr('test', anno_dir, data_dir) 144 | return trainset, testset 145 | 146 | if __name__ == '__main__': 147 | subset = 'test' 148 | anno_dir = '/path/to/wider_attribute_annotation' 149 | data_dir = '/path/to/Image' 150 | wa = WiderAttr(subset, anno_dir, data_dir) 151 | trainset, testset = get_subsets(anno_dir, data_dir) 152 | wa[1] 153 | 154 | train_loader = torch.utils.data.DataLoader( 155 | trainset, 156 | batch_size = 32, 157 | shuffle = True, 158 | num_workers = 8, 159 | ) 160 | test_loader = torch.utils.data.DataLoader( 161 | testset, 162 | batch_size = 1, 163 | shuffle = False, 164 | num_workers = 2, 165 | ) 166 | --------------------------------------------------------------------------------