├── PSOL_inference.py ├── PSOL_training.py ├── README.md ├── generate_box_imagenet.py ├── loader ├── ddt_imagenet_dataset.py └── imagenet_loader.py ├── models └── models.py ├── poster.png ├── requirements.txt └── utils ├── IoU.py ├── augment.py ├── func.py ├── nms.py └── vis.py /PSOL_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.backends import cudnn 8 | import torch.nn as nn 9 | import torchvision 10 | from PIL import Image 11 | from utils.func import * 12 | from utils.vis import * 13 | from utils.IoU import * 14 | from models.models import choose_locmodel,choose_clsmodel 15 | from utils.augment import * 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='Parameters for PSOL evaluation') 19 | parser.add_argument('--loc-model', metavar='locarg', type=str, default='vgg16',dest='locmodel') 20 | parser.add_argument('--cls-model', metavar='clsarg', type=str, default='vgg16',dest='clsmodel') 21 | parser.add_argument('--input_size',default=256,dest='input_size') 22 | parser.add_argument('--crop_size',default=224,dest='crop_size') 23 | parser.add_argument('--ten-crop', help='tencrop', action='store_true',dest='tencrop') 24 | parser.add_argument('--gpu',help='which gpu to use',default='4',dest='gpu') 25 | parser.add_argument('data',metavar='DIR',help='path to imagenet dataset') 26 | 27 | args = parser.parse_args() 28 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 29 | os.environ['OMP_NUM_THREADS'] = "4" 30 | os.environ['MKL_NUM_THREADS'] = "4" 31 | cudnn.benchmark = True 32 | TEN_CROP = args.tencrop 33 | normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 34 | transform = transforms.Compose([ 35 | transforms.Resize((args.input_size,args.input_size)), 36 | transforms.CenterCrop(args.crop_size), 37 | transforms.ToTensor(), 38 | normalize 39 | ]) 40 | cls_transform = transforms.Compose([ 41 | transforms.Resize((args.input_size,args.input_size)), 42 | transforms.CenterCrop(args.crop_size), 43 | transforms.ToTensor(), 44 | normalize 45 | ]) 46 | ten_crop_aug = transforms.Compose([ 47 | transforms.Resize((args.input_size,args.input_size)), 48 | transforms.TenCrop(args.crop_size), 49 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), 50 | transforms.Lambda(lambda crops: torch.stack([normalize(crop) for crop in crops])), 51 | ]) 52 | locname = args.locmodel 53 | model = choose_locmodel(locname, True) 54 | 55 | print(model) 56 | model = model.to(0) 57 | model.eval() 58 | clsname = args.clsmodel 59 | cls_model = choose_clsmodel(clsname) 60 | cls_model = cls_model.to(0) 61 | cls_model.eval() 62 | 63 | root = args.data 64 | val_imagedir = os.path.join(root, 'val') 65 | 66 | anno_root = os.path.join(root,'bbox') 67 | val_annodir = os.path.join(anno_root, 'myval') 68 | 69 | 70 | classes = os.listdir(val_imagedir) 71 | classes.sort() 72 | temp_softmax = nn.Softmax() 73 | #print(classes[0]) 74 | 75 | 76 | class_to_idx = {classes[i]:i for i in range(len(classes))} 77 | 78 | result = {} 79 | 80 | accs = [] 81 | accs_top5 = [] 82 | loc_accs = [] 83 | cls_accs = [] 84 | final_cls = [] 85 | final_loc = [] 86 | final_clsloc = [] 87 | final_clsloctop5 = [] 88 | final_ind = [] 89 | for k in range(1000): 90 | cls = classes[k] 91 | 92 | total = 0 93 | IoUSet = [] 94 | IoUSetTop5 = [] 95 | LocSet = [] 96 | ClsSet = [] 97 | 98 | files = os.listdir(os.path.join(val_imagedir, cls)) 99 | files.sort() 100 | 101 | for (i, name) in enumerate(files): 102 | # raw_img = cv2.imread(os.path.join(imagedir, cls, name)) 103 | now_index = int(name.split('_')[-1].split('.')[0]) 104 | final_ind.append(now_index-1) 105 | xmlfile = os.path.join(val_annodir, cls, name.split('.')[0] + '.xml') 106 | gt_boxes = get_cls_gt_boxes(xmlfile, cls) 107 | if len(gt_boxes)==0: 108 | continue 109 | 110 | raw_img = Image.open(os.path.join(val_imagedir, cls, name)).convert('RGB') 111 | w, h = raw_img.size 112 | 113 | with torch.no_grad(): 114 | img = transform(raw_img) 115 | img = torch.unsqueeze(img, 0) 116 | img = img.to(0) 117 | reg_outputs = model(img) 118 | 119 | bbox = to_data(reg_outputs) 120 | bbox = torch.squeeze(bbox) 121 | bbox = bbox.numpy() 122 | if TEN_CROP: 123 | img = ten_crop_aug(raw_img) 124 | img = img.to(0) 125 | vgg16_out = cls_model(img) 126 | vgg16_out = temp_softmax(vgg16_out) 127 | vgg16_out = torch.mean(vgg16_out,dim=0,keepdim=True) 128 | vgg16_out = torch.topk(vgg16_out, 5, 1)[1] 129 | else: 130 | img = cls_transform(raw_img) 131 | img = torch.unsqueeze(img, 0) 132 | img = img.to(0) 133 | vgg16_out = cls_model(img) 134 | vgg16_out = torch.topk(vgg16_out, 5, 1)[1] 135 | vgg16_out = to_data(vgg16_out) 136 | vgg16_out = torch.squeeze(vgg16_out) 137 | vgg16_out = vgg16_out.numpy() 138 | out = vgg16_out 139 | ClsSet.append(out[0]==class_to_idx[cls]) 140 | 141 | #handle resize and centercrop for gt_boxes 142 | for j in range(len(gt_boxes)): 143 | temp_list = list(gt_boxes[j]) 144 | raw_img_i, gt_bbox_i = ResizedBBoxCrop((256,256))(raw_img, temp_list) 145 | raw_img_i, gt_bbox_i = CenterBBoxCrop((224))(raw_img_i, gt_bbox_i) 146 | w, h = raw_img_i.size 147 | 148 | gt_bbox_i[0] = gt_bbox_i[0] * w 149 | gt_bbox_i[2] = gt_bbox_i[2] * w 150 | gt_bbox_i[1] = gt_bbox_i[1] * h 151 | gt_bbox_i[3] = gt_bbox_i[3] * h 152 | 153 | gt_boxes[j] = gt_bbox_i 154 | 155 | w, h = raw_img_i.size 156 | 157 | bbox[0] = bbox[0] * w 158 | bbox[2] = bbox[2] * w + bbox[0] 159 | bbox[1] = bbox[1] * h 160 | bbox[3] = bbox[3] * h + bbox[1] 161 | 162 | max_iou = -1 163 | for gt_bbox in gt_boxes: 164 | iou = IoU(bbox, gt_bbox) 165 | if iou > max_iou: 166 | max_iou = iou 167 | 168 | LocSet.append(max_iou) 169 | temp_loc_iou = max_iou 170 | if out[0] != class_to_idx[cls]: 171 | max_iou = 0 172 | 173 | # print(max_iou) 174 | result[os.path.join(cls, name)] = max_iou 175 | IoUSet.append(max_iou) 176 | #cal top5 IoU 177 | max_iou = 0 178 | for i in range(5): 179 | if out[i] == class_to_idx[cls]: 180 | max_iou = temp_loc_iou 181 | IoUSetTop5.append(max_iou) 182 | #visualization code 183 | ''' 184 | opencv_image = deepcopy(np.array(raw_img_i)) 185 | opencv_image = opencv_image[:, :, ::-1].copy() 186 | for gt_bbox in gt_boxes: 187 | cv2.rectangle(opencv_image, (int(gt_bbox[0]), int(gt_bbox[1])), 188 | (int(gt_bbox[2]), int(gt_bbox[3])), (0, 255, 0), 4) 189 | cv2.rectangle(opencv_image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), 190 | (0, 255, 255), 4) 191 | cv2.imwrite(os.path.join(savepath, str(name) + '.jpg'), np.asarray(opencv_image)) 192 | ''' 193 | cls_loc_acc = np.sum(np.array(IoUSet) > 0.5) / len(IoUSet) 194 | final_clsloc.extend(IoUSet) 195 | cls_loc_acc_top5 = np.sum(np.array(IoUSetTop5) > 0.5) / len(IoUSetTop5) 196 | final_clsloctop5.extend(IoUSetTop5) 197 | loc_acc = np.sum(np.array(LocSet) > 0.5) / len(LocSet) 198 | final_loc.extend(LocSet) 199 | cls_acc = np.sum(np.array(ClsSet))/len(ClsSet) 200 | final_cls.extend(ClsSet) 201 | print('{} cls-loc acc is {}, loc acc is {}, vgg16 cls acc is {}'.format(cls, cls_loc_acc, loc_acc, cls_acc)) 202 | with open('inference_CorLoc.txt', 'a+') as corloc_f: 203 | corloc_f.write('{} {}\n'.format(cls, loc_acc)) 204 | accs.append(cls_loc_acc) 205 | accs_top5.append(cls_loc_acc_top5) 206 | loc_accs.append(loc_acc) 207 | cls_accs.append(cls_acc) 208 | if (k+1) %100==0: 209 | print(k) 210 | 211 | 212 | print(accs) 213 | print('Cls-Loc acc {}'.format(np.mean(accs))) 214 | print('Cls-Loc acc Top 5 {}'.format(np.mean(accs_top5))) 215 | 216 | print('GT Loc acc {}'.format(np.mean(loc_accs))) 217 | print('{} cls acc {}'.format(clsname, np.mean(cls_accs))) 218 | with open('Corloc_result.txt', 'w') as f: 219 | for k in sorted(result.keys()): 220 | f.write('{} {}\n'.format(k, str(result[k]))) 221 | -------------------------------------------------------------------------------- /PSOL_training.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # In[1]: 4 | 5 | import time 6 | import os 7 | import random 8 | import math 9 | 10 | import torch 11 | import torchvision 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from torch.utils.data import Dataset 17 | from torch.optim import lr_scheduler 18 | from torch.autograd import Variable 19 | from torchvision import datasets, models, transforms 20 | from PIL import Image 21 | from loader.imagenet_loader import ImageNetDataset 22 | from utils.func import * 23 | from utils.IoU import * 24 | from models.models import * 25 | import warnings 26 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 27 | import argparse 28 | 29 | # In[2]: 30 | 31 | ### Some utilities 32 | 33 | 34 | # In[3]: 35 | 36 | def compute_reg_acc(preds, targets, theta=0.5): 37 | # preds = box_transform_inv(preds.clone(), im_sizes) 38 | # preds = crop_boxes(preds, im_sizes) 39 | # targets = box_transform_inv(targets.clone(), im_sizes) 40 | IoU = compute_IoU(preds, targets) 41 | # print(preds, targets, IoU) 42 | corr = (IoU >= theta).sum() 43 | return float(corr) / float(preds.size(0)) 44 | 45 | 46 | def compute_cls_acc(preds, targets): 47 | pred = torch.max(preds, 1)[1] 48 | # print(preds, pred) 49 | num_correct = (pred == targets).sum() 50 | return float(num_correct) / float(preds.size(0)) 51 | 52 | 53 | def compute_acc(reg_preds, reg_targets, cls_preds, cls_targets, theta=0.5): 54 | IoU = compute_IoU(reg_preds, reg_targets) 55 | reg_corr = (IoU >= theta) 56 | 57 | pred = torch.max(cls_preds, 1)[1] 58 | cls_corr = (pred == cls_targets) 59 | 60 | corr = (reg_corr & cls_corr).sum() 61 | 62 | return float(corr) / float(reg_preds.size(0)) 63 | 64 | 65 | class AverageMeter(object): 66 | def __init__(self): 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0 73 | self.cnt = 0 74 | 75 | def update(self, val, n=1): 76 | self.val = val 77 | self.sum += val * n 78 | self.cnt += n 79 | self.avg = self.sum / self.cnt 80 | 81 | 82 | # ### Visualize training data 83 | 84 | # In[8]: 85 | 86 | train_transform = transforms.Compose([ 87 | transforms.ToTensor(), 88 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 89 | ]) 90 | test_transfrom = transforms.Compose([ 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 93 | ]) 94 | # ### Training 95 | 96 | # In[10]: 97 | 98 | # prepare data 99 | parser = argparse.ArgumentParser(description='Parameters for PSOL evaluation') 100 | parser.add_argument('--loc-model', metavar='locarg', type=str, default='resnet50',dest='locmodel') 101 | parser.add_argument('--input_size',default=256,dest='input_size') 102 | parser.add_argument('--crop_size',default=224,dest='crop_size') 103 | parser.add_argument('--epochs',default=6,dest='epochs') 104 | parser.add_argument('--gpu',help='which gpu to use',default='4,5,6,7',dest='gpu') 105 | parser.add_argument('--ddt_path',help='generated ddt path',default='ImageNet/Projection/VGG16-448',dest="ddt_path") 106 | parser.add_argument('--gt_path',help='validation groundtruth path',default='ImageNet_gt/',dest="gt_path") 107 | parser.add_argument('--save_path',help='model save path',default='ImageNet_checkpoint',dest='save_path') 108 | parser.add_argument('--batch_size',default=256,dest='batch_size') 109 | parser.add_argument('data',metavar='DIR',help='path to imagenet dataset') 110 | 111 | 112 | args = parser.parse_args() 113 | batch_size = args.batch_size 114 | #lr = 1e-3 * (batch_size / 64) 115 | lr = 1e-3 * (batch_size / 256) 116 | # lr = 3e-4 117 | momentum = 0.9 118 | weight_decay = 1e-4 119 | print_freq = 10 120 | root = args.data 121 | savepath = args.save_path 122 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 123 | os.environ['OMP_NUM_THREADS'] = '20' 124 | os.environ['MKL_NUM_THREADS'] = '20' 125 | 126 | MyTrainData = ImageNetDataset(root=root, ddt_path=args.ddt_path, gt_path=args.gt_path,train=True, input_size=args.input_size,crop_size = args.crop_size, 127 | transform=train_transform) 128 | MyTestData = ImageNetDataset(root=root, ddt_path=args.ddt_path, gt_path=args.gt_path, train=False, input_size=args.input_size,crop_size = args.crop_size, 129 | transform=test_transfrom) 130 | 131 | 132 | 133 | train_loader = torch.utils.data.DataLoader(dataset=MyTrainData, 134 | batch_size=batch_size, 135 | shuffle=True, num_workers=20, pin_memory=True) 136 | test_loader = torch.utils.data.DataLoader(dataset=MyTestData, batch_size=batch_size, 137 | num_workers=8, pin_memory=True) 138 | dataloaders = {'train': train_loader, 'test': test_loader} 139 | 140 | # construct model 141 | model = choose_locmodel(args.locmodel) 142 | print(model) 143 | model = torch.nn.DataParallel(model).cuda() 144 | reg_criterion = nn.MSELoss().cuda() 145 | dense1_params = list(map(id, model.module.fc.parameters())) 146 | rest_params = filter(lambda x: id(x) not in dense1_params, model.parameters()) 147 | param_list = [{'params': model.module.fc.parameters(), 'lr': 2 * lr}, 148 | {'params': rest_params,'lr': 1 * lr}] 149 | optimizer = torch.optim.SGD(param_list, lr, momentum=momentum, 150 | weight_decay=weight_decay) 151 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.epochs, gamma=0.1) 152 | torch.backends.cudnn.benchmark = True 153 | best_model_state = model.state_dict() 154 | best_epoch = -1 155 | best_acc = 0.0 156 | 157 | epoch_loss = {'train': [], 'test': []} 158 | epoch_acc = {'train': [], 'test': []} 159 | epochs = args.epochs 160 | lambda_reg = 0 161 | for epoch in range(epochs): 162 | lambda_reg = 50 163 | for phase in ('train', 'test'): 164 | reg_accs = AverageMeter() 165 | accs = AverageMeter() 166 | reg_losses = AverageMeter() 167 | batch_time = AverageMeter() 168 | data_time = AverageMeter() 169 | if phase == 'train': 170 | if epoch >0: 171 | scheduler.step() 172 | model.train() 173 | else: 174 | model.eval() 175 | 176 | end = time.time() 177 | cnt = 0 178 | for ims, labels, boxes in dataloaders[phase]: 179 | data_time.update(time.time() - end) 180 | inputs = Variable(ims.cuda()) 181 | boxes = Variable(boxes.cuda()) 182 | labels = Variable(labels.cuda()) 183 | 184 | optimizer.zero_grad() 185 | 186 | # forward 187 | if phase == 'train': 188 | if 'inception' in args.locmodel: 189 | reg_outputs1,reg_outputs2 = model(inputs) 190 | reg_loss1 = reg_criterion(reg_outputs1, boxes) 191 | reg_loss2 = reg_criterion(reg_outputs2, boxes) 192 | reg_loss = 1 * reg_loss1 + 0.3 * reg_loss2 193 | reg_outputs = reg_outputs1 194 | else: 195 | reg_outputs = model(inputs) 196 | reg_loss = reg_criterion(reg_outputs, boxes) 197 | #_,reg_loss = compute_iou(reg_outputs,boxes) 198 | else: 199 | with torch.no_grad(): 200 | reg_outputs = model(inputs) 201 | reg_loss = reg_criterion(reg_outputs, boxes) 202 | loss = lambda_reg * reg_loss 203 | reg_acc = compute_reg_acc(reg_outputs.data.cpu(), boxes.data.cpu()) 204 | 205 | nsample = inputs.size(0) 206 | reg_accs.update(reg_acc, nsample) 207 | reg_losses.update(reg_loss.item(), nsample) 208 | if phase == 'train': 209 | loss.backward() 210 | optimizer.step() 211 | batch_time.update(time.time() - end) 212 | end = time.time() 213 | if cnt % print_freq == 0: 214 | print( 215 | '[{}]\tEpoch: {}/{}\t Iter: {}/{} Time {:.3f} ({:.3f})\t Data {:.3f} ({:.3f})\tLoc Loss: {:.4f}\tLoc Acc: {:.2%}\t'.format( 216 | phase, epoch + 1, epochs, cnt, len(dataloaders[phase]), batch_time.val,batch_time.avg,data_time.val,data_time.avg,lambda_reg * reg_losses.avg, reg_accs.avg)) 217 | cnt += 1 218 | if phase == 'test' and reg_accs.avg > best_acc: 219 | best_acc = reg_accs.avg 220 | best_epoch = epoch 221 | best_model_state = model.state_dict() 222 | 223 | elapsed_time = time.time() - end 224 | print( 225 | '[{}]\tEpoch: {}/{}\tLoc Loss: {:.4f}\tLoc Acc: {:.2%}\tTime: {:.3f}'.format( 226 | phase, epoch + 1, epochs, lambda_reg * reg_losses.avg, reg_accs.avg,elapsed_time)) 227 | epoch_loss[phase].append(reg_losses.avg) 228 | epoch_acc[phase].append(reg_accs.avg) 229 | 230 | print('[Info] best test acc: {:.2%} at {}th epoch'.format(best_acc, best_epoch + 1)) 231 | if not os.path.exists(savepath): 232 | os.makedirs(savepath) 233 | torch.save(model.state_dict(), os.path.join(savepath,'checkpoint_localization_imagenet_ddt_' + args.locmodel + "_" + str(epoch) + '.pth.tar')) 234 | torch.save(best_model_state, os.path.join(savepath,'best_cls_localization_imagenet_ddt_' + args.locmodel + "_" + str(epoch) + '.pth.tar')) 235 | 236 | 237 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSOL 2 | ## We have updated the training and validation code for PSOL on ImageNet; 3 | This is the offical website for PSOL. 4 | 5 | We have uploaded the poster for PSOL. 6 | ![Poster][poster] 7 | This package is developed by Mr. Chen-Lin Zhang (http://www.lamda.nju.edu.cn/zhangcl/) and Mr. Yun-Hao Cao (http://www.lamda.nju.edu.cn/caoyh/). If you have any problem about 8 | the code, please feel free to contact Mr. Chen-Lin Zhang (zhangcl@lamda.nju.edu.cn). 9 | The package is free for academic usage. You can run it at your own risk. For other purposes, please contact Prof. Jianxin Wu (mailto:wujx2001@gmail.com). 10 | 11 | If you find our package is useful to your research, please cite our paper: 12 | 13 | Reference: 14 | 15 | [1] C.-L. Zhang, Y.-H. Cao and J. Wu. Rethinking the Route Towards Weakly Supervised Object Localization 16 | . In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2020), Seattle, WA (Virtual), USA, pp. 13460-13469. 17 | ## Requirements 18 | The code needs [PyTorch][pytorch] and OpenCV. 19 | 20 | A GPU is optional for faster speed. 21 | 22 | Other requirements are in requirements.txt. You can simply install them by `pip install -r requirements.txt`. We recommend to use Anaconda with virtual environments. 23 | 24 | ## Usage 25 | 26 | 27 | ### prepare datasets: 28 | #### Training: 29 | First you should download the ImageNet training set. Then, use our program `generate_box_imagenet.py` to generate pseudo bounding boxes by DDT methods: 30 | 31 | `python generate_box_imagenet.py PATH` 32 | 33 | We follow the organization of PyTorch official ImageNet training script. There are some options in `generate_box_imagenet.py`: 34 | 35 | ``` 36 | PATH the path to the ImageNet dataset 37 | --input_size input size for each image (448 for default) 38 | --gpu which gpus to use 39 | --output_path the output pseudo boxes folder 40 | --batch_size batch_size for executing forward pass of CNN 41 | ``` 42 | 43 | If you use default parameters, you are just using DDT-VGG16 with 448x448 as illustrated in the paper. For efficiency, we provide generated VGG-16 448x448 boxes [pseudo boxes][pseudo boxes]. You can directly download it. 44 | #### Validation: 45 | First you should download the ImageNet validation set, and corresponding annotation xmls. We need both validation set and annotation xmls in PyTorch format. Please refer to [ImageNet example][imagenet example] for details. 46 | 47 | #### Folder Structure 48 | We except the ImageNet Folder has these structures: 49 | ``` 50 | root/ 51 | |-- train/ 52 | | |-- class1 |-- image1.jpg 53 | | |-- class2 |-- image2.jpg 54 | | |-- class3 |-- image3.jpg 55 | | ... 56 | |-- val/ 57 | | |-- class1 |-- image1.jpg 58 | | |-- class2 |-- image2.jpg 59 | | |-- class3 |-- image3.jpg 60 | | ... 61 | |-- myval/ (groundtruth annotation xml file, you can change the folder name, and modify it in Line 67 in PSOL_inference.py) 62 | | |-- class1 |-- image1.xml 63 | | |-- class2 |-- image2.xml 64 | | |-- class3 |-- image3.xml 65 | ... 66 | ``` 67 | 68 | ### Training 69 | 70 | Please first refer to prepare datasets session for generating pseudo bounding boxes. Then, you can use `PSOL_training.py` to train `PSOL-Sep` models on pseudo bounding boxes. 71 | 72 | `python PSOL_training.py PATH`. There are some options in `PSOL_training.py`: 73 | 74 | ``` 75 | PATH the path to the ImageNet dataset 76 | --loc-model which localization model to use. Current options: densenet161, vgg16,vgggap, resnet50, densenet161 77 | --input_size input size for each image (256 for default) 78 | --crop_size crop size for each image (only used in validation process) 79 | --ddt_path the path of generated DDT pseudo boxes 80 | --gt_path the path for groundtruth on ImageNet val dataset 81 | --save_path where to save the checkpoint 82 | --gpu which gpus to use 83 | --batch_size batch_size for executing forward pass of CNN 84 | ``` 85 | ### Testing 86 | 87 | We have provided some [pretrained models and cached groundtruth files][modellink]. If you want to directly test it, please download it. 88 | 89 | We use same options as `PSOL_training.py` in `PSOL_inference.py`. Then you can run: 90 | 91 | `python PSOL_inference.py --loc-model {$LOC_MODEL} --cls-model {$CLS_MODEL} {--ten-crop}` 92 | 93 | Extra Options: 94 | ``` 95 | --cls-model represents the classification models we support now. Current options: densenet161, vgg16, resnet50, inceptionv3, dpn131, efficientnetb7 (need to install efficientnet-pytorch). 96 | --ten-crop use ten crop to boost the classification performance. 97 | ``` 98 | 99 | Please note that for SOTA classification models, you should change the resolutions in cls_transforms (Line248-Line249). 100 | 101 | Then you can get the final Corloc and Clsloc result. 102 | 103 | [pytorch]:https://pytorch.org/ 104 | [imagenet example]:https://github.com/pytorch/examples/tree/master/imagenet 105 | [annolink]:https://drive.google.com/open?id=1XcSJ4WIhgema_jSPI0PJX14W2l8xQkvP 106 | [poster]:poster.png 107 | [modellink]:https://drive.google.com/open?id=1uepi6B6cL2EorygHgODG77HN9V1_ZTXX 108 | [pseudo boxes]: https://drive.google.com/file/d/1bOrHlcqLhCMt5-d6FH-GryNEIn66YLWi/view?usp=sharing 109 | -------------------------------------------------------------------------------- /generate_box_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import json 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | from torch.backends import cudnn 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | import torchvision 12 | import torchvision.models as models 13 | from PIL import Image 14 | from skimage import measure 15 | # from scipy.misc import imresize 16 | from utils.func import * 17 | from utils.vis import * 18 | from utils.IoU import * 19 | import argparse 20 | from loader.ddt_imagenet_dataset import DDTImageNetDataset 21 | 22 | 23 | parser = argparse.ArgumentParser(description='Parameters for DDT generate box') 24 | parser.add_argument('--input_size',default=448,dest='input_size') 25 | parser.add_argument('data',metavar='DIR',help='path to imagenet dataset') 26 | parser.add_argument('--gpu',help='which gpu to use',default='4,5,6,7',dest='gpu') 27 | parser.add_argument('--output_path',default='ImageNet/Projection/VGG16-448',dest='output_path') 28 | parser.add_argument('--batch_size',default=256,dest='batch_size') 29 | args = parser.parse_args() 30 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 31 | os.environ['OMP_NUM_THREADS'] = "10" 32 | os.environ['MKL_NUM_THREADS'] = "10" 33 | cudnn.benchmark = True 34 | model_ft = models.vgg16(pretrained=True) 35 | model = model_ft.features 36 | #removed = list(model.children())[:-1] 37 | #model = torch.nn.Sequential(*removed) 38 | model = torch.nn.DataParallel(model).cuda() 39 | model.eval() 40 | projdir = args.output_path 41 | if not os.path.exists(projdir): 42 | os.makedirs(projdir) 43 | 44 | transform = transforms.Compose([ 45 | transforms.Resize((args.input_size,args.input_size)), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 48 | ]) 49 | batch_size = args.batch_size 50 | a = DDTImageNetDataset(root=os.path.join(args.data,'train'),batch_size=args.batch_size, transforms=transform) 51 | 52 | # print(classes[0]) 53 | 54 | for class_ind in range(1000): 55 | #if class_ind == 10: 56 | # import sys 57 | # sys.exit() 58 | now_class_dict = {} 59 | feature_list = [] 60 | ddt_bbox = {} 61 | with torch.no_grad(): 62 | for (input_img,path) in a[class_ind]: 63 | input_img = to_variable(input_img) 64 | output = model(input_img) 65 | output = to_data(output) 66 | output = torch.squeeze(output).numpy() 67 | if len(output.shape) == 3: 68 | output = np.expand_dims(output,0) 69 | output = np.transpose(output,(0,2,3,1)) 70 | n,h,w,c = output.shape 71 | for i in range(n): 72 | now_class_dict[path[i]] = output[i,:,:,:] 73 | output = np.reshape(output,(n*h*w,c)) 74 | feature_list.append(output) 75 | X = np.concatenate(feature_list,axis=0) 76 | mean_matrix = np.mean(X, 0) 77 | X = X - mean_matrix 78 | print("Before PCA") 79 | trans_matrix = sk_pca(X, 1) 80 | print("AFTER PCA") 81 | cls = a.label_class_dict[class_ind] 82 | # save json 83 | d = {'mean_matrix': mean_matrix.tolist(), 'trans_matrix': trans_matrix.tolist()} 84 | with open(os.path.join(projdir, '%s_trans.json' % cls), 'w') as f: 85 | json.dump(d, f) 86 | # load json 87 | with open(os.path.join(projdir, '%s_trans.json' % cls), 'r') as f: 88 | t = json.load(f) 89 | mean_matrix = np.array(t['mean_matrix']) 90 | trans_matrix = np.array(t['trans_matrix']) 91 | 92 | print('trans_matrix shape is {}'.format(trans_matrix.shape)) 93 | cnt = 0 94 | for k,v in now_class_dict.items(): 95 | w = 14 96 | h = 14 97 | he = 448 98 | wi = 448 99 | v = np.reshape(v,(h * w,512)) 100 | v = v - mean_matrix 101 | 102 | heatmap = np.dot(v, trans_matrix.T) 103 | heatmap = np.reshape(heatmap, (h, w)) 104 | highlight = np.zeros(heatmap.shape) 105 | highlight[heatmap > 0] = 1 106 | # max component 107 | all_labels = measure.label(highlight) 108 | highlight = np.zeros(highlight.shape) 109 | highlight[all_labels == count_max(all_labels.tolist())] = 1 110 | 111 | # visualize heatmap 112 | # show highlight in origin image 113 | highlight = np.round(highlight * 255) 114 | highlight_big = cv2.resize(highlight, (he, wi), interpolation=cv2.INTER_NEAREST) 115 | props = measure.regionprops(highlight_big.astype(int)) 116 | 117 | if len(props) == 0: 118 | #print(highlight) 119 | bbox = [0, 0, wi, he] 120 | else: 121 | temp = props[0]['bbox'] 122 | bbox = [temp[1], temp[0], temp[3], temp[2]] 123 | 124 | temp_bbox = [bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]] 125 | temp_save_box = [x / 448 for x in temp_bbox] 126 | ddt_bbox[os.path.join(cls, k)] = temp_save_box 127 | 128 | highlight_big = np.expand_dims(np.asarray(highlight_big), 2) 129 | highlight_3 = np.concatenate((np.zeros((he, wi, 1)), np.zeros((he, wi, 1))), axis=2) 130 | highlight_3 = np.concatenate((highlight_3, highlight_big), axis=2) 131 | cnt +=1 132 | if cnt < 10: 133 | savepath = 'ImageNet/Visualization/DDT/VGG16-vis-test/%s' % cls 134 | if not os.path.exists(savepath): 135 | os.makedirs(savepath) 136 | from PIL import Image 137 | raw_img = Image.open(k).convert("RGB") 138 | raw_img = raw_img.resize((448,448)) 139 | raw_img = np.asarray(raw_img) 140 | raw_img = cv2.cvtColor(raw_img,cv2.COLOR_BGR2RGB) 141 | cv2.rectangle(raw_img, (temp_bbox[0], temp_bbox[1]), 142 | (temp_bbox[2] + temp_bbox[0], temp_bbox[3] + temp_bbox[1]), (255, 0, 0), 4) 143 | save_name = k.split('/')[-1] 144 | cv2.imwrite(os.path.join(savepath, save_name), np.asarray(raw_img)) 145 | with open(os.path.join(projdir, '%s_bbox.json' % cls), 'w') as fp: 146 | json.dump(ddt_bbox, fp) -------------------------------------------------------------------------------- /loader/ddt_imagenet_dataset.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | import torch.utils.data as data 3 | from PIL import Image 4 | def pil_loader(path): 5 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 6 | with open(path, 'rb') as f: 7 | with Image.open(f) as img: 8 | return img.convert('RGB') 9 | class SubImageDataset(data.Dataset): 10 | def __init__(self, now_list, transforms=None): 11 | self.imgs = now_list 12 | self.transforms = transforms 13 | 14 | def __getitem__(self, index): 15 | """ 16 | Args: 17 | index (int): Index 18 | Returns: 19 | tuple: (image, target) where target is class_index of the target class. 20 | """ 21 | path,label = self.imgs[index] 22 | now_img = pil_loader(path) 23 | if self.transforms is not None: 24 | now_img = self.transforms(now_img) 25 | return now_img,path 26 | 27 | def __len__(self): 28 | return len(self.imgs) 29 | 30 | def __next__(self): 31 | pass 32 | class DDTImageNetDataset(data.Dataset): 33 | def __init__(self, root='/mnt/ramdisk/ImageNet/train/', transforms=None, batch_size=128,target_transform=None): 34 | self.img_dataset = ImageFolder(root) 35 | self.label_class_dict = {} 36 | for k,v in self.img_dataset.class_to_idx.items(): 37 | self.label_class_dict[v] = k 38 | 39 | from collections import defaultdict 40 | self.class_dict = defaultdict(list) 41 | for i,(location,label) in enumerate(self.img_dataset.imgs): 42 | self.class_dict[label].append((location,label)) 43 | self.all_dataset = [] 44 | for i in range(1000): 45 | self.all_dataset.append(SubImageDataset(self.class_dict[i],transforms=transforms)) 46 | self.batch_size = batch_size 47 | self.transforms = transforms 48 | def __getitem__(self, index): 49 | """ 50 | Args: 51 | index (int): Index 52 | Returns: 53 | tuple: (image, target) where target is class_index of the target class. 54 | """ 55 | now_dataset = self.all_dataset[index] 56 | now_loader = data.DataLoader(now_dataset,batch_size=self.batch_size,shuffle=False,num_workers=8) 57 | for i,(img,path) in enumerate(now_loader): 58 | yield img,path 59 | pass 60 | 61 | def __len__(self): 62 | return len(self.class_dict) 63 | 64 | def __next__(self): 65 | pass 66 | if __name__ == '__main__': 67 | import torchvision 68 | transform = torchvision.transforms.Compose([ 69 | torchvision.transforms.Resize((224,224)), 70 | torchvision.transforms.ToTensor(), 71 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 72 | ] 73 | ) 74 | a = DDTImageNetDataset(batch_size=2,transforms=transform) 75 | for i in a[0]: 76 | print(i) 77 | import sys 78 | sys.exit() 79 | -------------------------------------------------------------------------------- /loader/imagenet_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import numpy as np 8 | import json 9 | from torchvision.transforms import functional as F 10 | import warnings 11 | import random 12 | import math 13 | import copy 14 | import numbers 15 | from utils.augment import * 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | def find_classes(dir): 26 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 27 | classes.sort() 28 | class_to_idx = {classes[i]: i for i in range(len(classes))} 29 | return classes, class_to_idx 30 | 31 | def get_bbox_dict(root): 32 | print('loading from ground truth bbox') 33 | name_idx_dict = {} 34 | with open(os.path.join(root, 'images.txt')) as f: 35 | filelines = f.readlines() 36 | for fileline in filelines: 37 | fileline = fileline.strip('\n').split() 38 | idx, name = fileline[0], fileline[1] 39 | name_idx_dict[name] = idx 40 | 41 | idx_bbox_dict = {} 42 | with open(os.path.join(root, 'bounding_boxes.txt')) as f: 43 | filelines = f.readlines() 44 | for fileline in filelines: 45 | fileline = fileline.strip('\n').split() 46 | idx, bbox = fileline[0], list(map(float, fileline[1:])) 47 | idx_bbox_dict[idx] = bbox 48 | 49 | name_bbox_dict = {} 50 | for name in name_idx_dict.keys(): 51 | name_bbox_dict[name] = idx_bbox_dict[name_idx_dict[name]] 52 | 53 | return name_bbox_dict 54 | 55 | def pil_loader(path): 56 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 57 | with open(path, 'rb') as f: 58 | with Image.open(f) as img: 59 | return img.convert('RGB') 60 | 61 | def default_loader(path): 62 | from torchvision import get_image_backend 63 | #if get_image_backend() == 'accimage': 64 | # return accimage_loader(path) 65 | #else: 66 | return pil_loader(path) 67 | 68 | 69 | def load_train_bbox(label_dict,bbox_dir): 70 | #bbox_dir = 'ImageNet/Projection/VGG16-448' 71 | final_dict = {} 72 | for i in range(1000): 73 | now_name = label_dict[i] 74 | now_json_file = os.path.join(bbox_dir,now_name+"_bbox.json") 75 | with open(now_json_file, 'r') as fp: 76 | name_bbox_dict = json.load(fp) 77 | final_dict[i] = name_bbox_dict 78 | return final_dict 79 | def load_val_bbox(label_dict,all_imgs,gt_location): 80 | #gt_location ='/data/zhangcl/DDT-code/ImageNet_gt' 81 | import scipy.io as sio 82 | gt_label = sio.loadmat(os.path.join(gt_location,'cache_groundtruth.mat')) 83 | locs = [(x[0].split('/')[-1],x[0],x[1]) for x in all_imgs] 84 | locs.sort() 85 | final_bbox_dict = {} 86 | for i in range(len(locs)): 87 | #gt_label['rec'][:,1][0][0][0], if multilabel then get length, for final eval 88 | final_bbox_dict[locs[i][1]] = gt_label['rec'][:,i][0][0][0][0][1][0] 89 | return final_bbox_dict 90 | class ImageNetDataset(data.Dataset): 91 | """A generic data loader where the images are arranged in this way: :: 92 | root/dog/xxx.png 93 | root/dog/xxy.png 94 | root/dog/xxz.png 95 | root/cat/123.png 96 | root/cat/nsdf3.png 97 | root/cat/asd932_.png 98 | Args: 99 | root (string): Root directory path. 100 | transform (callable, optional): A function/transform that takes in an PIL image 101 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 102 | target_transform (callable, optional): A function/transform that takes in the 103 | target and transforms it. 104 | loader (callable, optional): A function to load an image given its path. 105 | Attributes: 106 | classes (list): List of the class names. 107 | class_to_idx (dict): Dict with items (class_name, class_index). 108 | imgs (list): List of (image path, class_index) tuples 109 | """ 110 | 111 | def __init__(self, root, ddt_path,gt_path, input_size=256, crop_size=224,train=True, transform=None, target_transform=None, loader=default_loader): 112 | from torchvision.datasets import ImageFolder 113 | self.train = train 114 | self.input_size = input_size 115 | self.crop_size = crop_size 116 | self.ddt_path = ddt_path 117 | self.gt_path = gt_path 118 | if self.train: 119 | self.img_dataset = ImageFolder(os.path.join(root,'train')) 120 | else: 121 | self.img_dataset = ImageFolder(os.path.join(root,'val')) 122 | if len(self.img_dataset) == 0: 123 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 124 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 125 | self.label_class_dict = {} 126 | self.train = train 127 | 128 | for k, v in self.img_dataset.class_to_idx.items(): 129 | self.label_class_dict[v] = k 130 | if self.train: 131 | #load train bbox 132 | self.bbox_dict = load_train_bbox(self.label_class_dict,self.ddt_path) 133 | else: 134 | #load test bbox 135 | self.bbox_dict = load_val_bbox(self.label_class_dict,self.img_dataset.imgs,self.gt_path) 136 | self.img_dataset = self.img_dataset.imgs 137 | 138 | self.transform = transform 139 | self.target_transform = target_transform 140 | self.loader = loader 141 | def __getitem__(self, index): 142 | """ 143 | Args: 144 | index (int): Index 145 | Returns: 146 | tuple: (image, target) where target is class_index of the target class. 147 | """ 148 | path, target = self.img_dataset[index] 149 | img = self.loader(path) 150 | if self.train: 151 | bbox = self.bbox_dict[target][path] 152 | else: 153 | bbox = self.bbox_dict[path] 154 | w,h = img.size 155 | 156 | bbox = np.array(bbox, dtype='float32') 157 | 158 | #convert from x, y, w, h to x1,y1,x2,y2 159 | 160 | 161 | 162 | 163 | if self.train: 164 | bbox[0] = bbox[0] 165 | bbox[2] = bbox[0] + bbox[2] 166 | bbox[1] = bbox[1] 167 | bbox[3] = bbox[1] + bbox[3] 168 | bbox[0] = math.ceil(bbox[0] * w) 169 | bbox[2] = math.ceil(bbox[2] * w) 170 | bbox[1] = math.ceil(bbox[1] * h) 171 | bbox[3] = math.ceil(bbox[3] * h) 172 | img_i, bbox_i = RandomResizedBBoxCrop((self.crop_size))(img, bbox) 173 | #img_i, bbox_i = ResizedBBoxCrop((256,256))(img, bbox) 174 | #img_i, bbox_i = RandomBBoxCrop((224))(img_i, bbox_i) 175 | #img_i, bbox_i = ResizedBBoxCrop((320,320))(img, bbox) 176 | #img_i, bbox_i = RandomBBoxCrop((299))(img_i, bbox_i) 177 | img, bbox = RandomHorizontalFlipBBox()(img_i, bbox_i) 178 | #img, bbox = img_i, bbox_i 179 | else: 180 | img_i, bbox_i = ResizedBBoxCrop((self.input_size,self.input_size))(img, bbox) 181 | img, bbox = CenterBBoxCrop((self.crop_size))(img_i, bbox_i) 182 | 183 | 184 | bbox[2] = bbox[2] - bbox[0] 185 | bbox[3] = bbox[3] - bbox[1] 186 | 187 | if self.transform is not None: 188 | img = self.transform(img) 189 | if self.target_transform is not None: 190 | target = self.target_transform(target) 191 | return img, target, bbox 192 | 193 | def __len__(self): 194 | return len(self.img_dataset) 195 | 196 | if __name__ == '__main__': 197 | a =ImageNetDataset('/mnt/ramdisk/ImageNet/val/',train=False) -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.nn as nn 3 | from utils.func import * 4 | class VGGGAP(nn.Module): 5 | def __init__(self, pretrained=True, num_classes=200): 6 | super(VGGGAP,self).__init__() 7 | self.features = torchvision.models.vgg16(pretrained=pretrained).features 8 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 9 | self.classifier = nn.Sequential((nn.Linear(512,512),nn.ReLU(),nn.Linear(512,4),nn.Sigmoid())) 10 | def forward(self, x): 11 | x = self.features(x) 12 | x = self.avgpool(x) 13 | x = x.view(x.size(0),-1) 14 | x = self.classifier(x) 15 | return x 16 | class VGG16(nn.Module): 17 | def __init__(self, pretrained=True, num_classes=200): 18 | super(VGG16,self).__init__() 19 | self.features = torchvision.models.vgg16(pretrained=pretrained).features 20 | temp_classifier = torchvision.models.vgg16(pretrained=pretrained).classifier 21 | removed = list(temp_classifier.children()) 22 | removed = removed[:-1] 23 | temp_layer = nn.Sequential(nn.Linear(4096,512),nn.ReLU(),nn.Linear(512,4),nn.Sigmoid()) 24 | removed.append(temp_layer) 25 | self.classifier = nn.Sequential(*removed) 26 | def forward(self, x): 27 | x = self.features(x) 28 | x = x.view(x.size(0),-1) 29 | x = self.classifier(x) 30 | return x 31 | 32 | def choose_locmodel(model_name,pretrained=False): 33 | if model_name == 'densenet161': 34 | model = torchvision.models.densenet161(pretrained=True) 35 | 36 | model.classifier = nn.Sequential( 37 | nn.Linear(2208, 512), 38 | nn.ReLU(), 39 | nn.Linear(512, 4), 40 | nn.Sigmoid() 41 | ) 42 | if pretrained: 43 | model = copy_parameters(model, torch.load('densenet161loc.pth.tar')) 44 | elif model_name == 'resnet50': 45 | model = torchvision.models.resnet50(pretrained=True, num_classes=1000) 46 | model.fc = nn.Sequential( 47 | nn.Linear(2048, 512), 48 | nn.ReLU(), 49 | nn.Linear(512, 4), 50 | nn.Sigmoid() 51 | ) 52 | if pretrained: 53 | model = copy_parameters(model, torch.load('resnet50loc.pth.tar')) 54 | elif model_name == 'vgggap': 55 | model = VGGGAP(pretrained=True,num_classes=1000) 56 | if pretrained: 57 | model = copy_parameters(model, torch.load('vgggaploc.pth.tar')) 58 | elif model_name == 'vgg16': 59 | model = VGG16(pretrained=True,num_classes=1000) 60 | if pretrained: 61 | model = copy_parameters(model, torch.load('vgg16loc.pth.tar')) 62 | elif model_name == 'inceptionv3': 63 | #need for rollback inceptionv3 official code 64 | pass 65 | else: 66 | raise ValueError('Do not have this model currently!') 67 | return model 68 | def choose_clsmodel(model_name): 69 | if model_name == 'vgg16': 70 | cls_model = torchvision.models.vgg16(pretrained=True) 71 | elif model_name == 'inceptionv3': 72 | cls_model = torchvision.models.inception_v3(pretrained=True, aux_logits=True, transform_input=True) 73 | elif model_name == 'resnet50': 74 | cls_model = torchvision.models.resnet50(pretrained=True) 75 | elif model_name == 'densenet161': 76 | cls_model = torchvision.models.densenet161(pretrained=True) 77 | elif model_name == 'dpn131': 78 | cls_model = torch.hub.load('rwightman/pytorch-dpn-pretrained', 'dpn131', pretrained=True,test_time_pool=True) 79 | elif model_name == 'efficientnetb7': 80 | from efficientnet_pytorch import EfficientNet 81 | cls_model = EfficientNet.from_pretrained('efficientnet-b7') 82 | return cls_model -------------------------------------------------------------------------------- /poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzzcl/PSOL/ddecfce1119a50b97ec51daffea387bebfc7991d/poster.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | numpy 3 | matplotlib 4 | sklearn -------------------------------------------------------------------------------- /utils/IoU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xml.etree.ElementTree as ET 3 | import torch 4 | def get_gt_boxes(xmlfile): 5 | '''get ground-truth bbox from VOC xml file''' 6 | tree = ET.parse(xmlfile) 7 | objs = tree.findall('object') 8 | num_objs = len(objs) 9 | gt_boxes = [] 10 | for obj in objs: 11 | bbox = obj.find('bndbox') 12 | x1 = float(bbox.find('xmin').text)-1 13 | y1 = float(bbox.find('ymin').text)-1 14 | x2 = float(bbox.find('xmax').text)-1 15 | y2 = float(bbox.find('ymax').text)-1 16 | 17 | gt_boxes.append((x1, y1, x2, y2)) 18 | return gt_boxes 19 | 20 | def get_cls_gt_boxes(xmlfile, cls): 21 | '''get ground-truth bbox from VOC xml file''' 22 | tree = ET.parse(xmlfile) 23 | objs = tree.findall('object') 24 | num_objs = len(objs) 25 | gt_boxes = [] 26 | for obj in objs: 27 | bbox = obj.find('bndbox') 28 | cls_name = obj.find('name').text 29 | #print(cls_name, cls) 30 | if cls_name != cls: 31 | continue 32 | x1 = float(bbox.find('xmin').text)-1 33 | y1 = float(bbox.find('ymin').text)-1 34 | x2 = float(bbox.find('xmax').text)-1 35 | y2 = float(bbox.find('ymax').text)-1 36 | 37 | gt_boxes.append((x1, y1, x2, y2)) 38 | if len(gt_boxes)==0: 39 | pass 40 | #print('%s bbox = 0'%cls) 41 | 42 | return gt_boxes 43 | 44 | def get_cls_and_gt_boxes(xmlfile, cls,class_to_idx): 45 | '''get ground-truth bbox from VOC xml file''' 46 | tree = ET.parse(xmlfile) 47 | objs = tree.findall('object') 48 | num_objs = len(objs) 49 | gt_boxes = [] 50 | for obj in objs: 51 | bbox = obj.find('bndbox') 52 | cls_name = obj.find('name').text 53 | #print(cls_name, cls) 54 | if cls_name != cls: 55 | continue 56 | x1 = float(bbox.find('xmin').text)-1 57 | y1 = float(bbox.find('ymin').text)-1 58 | x2 = float(bbox.find('xmax').text)-1 59 | y2 = float(bbox.find('ymax').text)-1 60 | 61 | gt_boxes.append((class_to_idx[cls_name],[x1, y1, x2-x1, y2-y1])) 62 | if len(gt_boxes)==0: 63 | pass 64 | #print('%s bbox = 0'%cls) 65 | 66 | return gt_boxes 67 | def convert_boxes(boxes): 68 | ''' convert the bbox to the format (x1, y1, x2, y2) where x1,y10 100 | if aarea + barea - inter <=0: 101 | print(a) 102 | print(b) 103 | o = inter / (aarea+barea-inter) 104 | #if w<=0 or h<=0: 105 | # o = 0 106 | return o 107 | 108 | def to_2d_tensor(inp): 109 | inp = torch.Tensor(inp) 110 | if len(inp.size()) < 2: 111 | inp = inp.unsqueeze(0) 112 | return inp 113 | 114 | 115 | def xywh_to_x1y1x2y2(boxes): 116 | boxes = to_2d_tensor(boxes) 117 | boxes[:, 2] += boxes[:, 0] - 1 118 | boxes[:, 3] += boxes[:, 1] - 1 119 | return boxes 120 | 121 | 122 | def x1y1x2y2_to_xywh(boxes): 123 | boxes = to_2d_tensor(boxes) 124 | boxes[:, 2] -= boxes[:, 0] - 1 125 | boxes[:, 3] -= boxes[:, 1] - 1 126 | return boxes 127 | 128 | def compute_IoU(pred_box, gt_box): 129 | boxes1 = to_2d_tensor(pred_box) 130 | # boxes1 = xywh_to_x1y1x2y2(boxes1) 131 | boxes1[:, 2] = torch.clamp(boxes1[:, 0] + boxes1[:, 2], 0, 1) 132 | boxes1[:, 3] = torch.clamp(boxes1[:, 1] + boxes1[:, 3], 0, 1) 133 | 134 | boxes2 = to_2d_tensor(gt_box) 135 | boxes2[:, 2] = torch.clamp(boxes2[:, 0] + boxes2[:, 2], 0, 1) 136 | boxes2[:, 3] = torch.clamp(boxes2[:, 1] + boxes2[:, 3], 0, 1) 137 | # boxes2 = xywh_to_x1y1x2y2(boxes2) 138 | 139 | intersec = boxes1.clone() 140 | intersec[:, 0] = torch.max(boxes1[:, 0], boxes2[:, 0]) 141 | intersec[:, 1] = torch.max(boxes1[:, 1], boxes2[:, 1]) 142 | intersec[:, 2] = torch.min(boxes1[:, 2], boxes2[:, 2]) 143 | intersec[:, 3] = torch.min(boxes1[:, 3], boxes2[:, 3]) 144 | 145 | def compute_area(boxes): 146 | # in (x1, y1, x2, y2) format 147 | dx = boxes[:, 2] - boxes[:, 0] 148 | dx[dx < 0] = 0 149 | dy = boxes[:, 3] - boxes[:, 1] 150 | dy[dy < 0] = 0 151 | return dx * dy 152 | 153 | a1 = compute_area(boxes1) 154 | a2 = compute_area(boxes2) 155 | ia = compute_area(intersec) 156 | assert ((a1 + a2 - ia < 0).sum() == 0) 157 | return ia / (a1 + a2 - ia) -------------------------------------------------------------------------------- /utils/augment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import copy 3 | import numbers 4 | from torchvision.transforms import functional as F 5 | from .func import * 6 | import random 7 | import warnings 8 | import math 9 | 10 | class RandomHorizontalFlipBBox(object): 11 | def __init__(self, p=0.5): 12 | self.p = p 13 | 14 | def __call__(self, img, bbox): 15 | if random.random() < self.p: 16 | flipbox = copy.deepcopy(bbox) 17 | flipbox[0] = 1-bbox[2] 18 | flipbox[2] = 1-bbox[0] 19 | return F.hflip(img), flipbox 20 | 21 | return img, bbox 22 | class RandomResizedBBoxCrop(object): 23 | """Crop the given PIL Image to random size and aspect ratio. 24 | 25 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 26 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 27 | is finally resized to given size. 28 | This is popularly used to train the Inception networks. 29 | 30 | Args: 31 | size: expected output size of each edge 32 | scale: range of size of the origin size cropped 33 | ratio: range of aspect ratio of the origin aspect ratio cropped 34 | interpolation: Default: PIL.Image.BILINEAR 35 | """ 36 | 37 | def __init__(self, size, scale=(0.2, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 38 | if isinstance(size, tuple): 39 | self.size = size 40 | else: 41 | self.size = (size, size) 42 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 43 | warnings.warn("range should be of kind (min, max)") 44 | 45 | self.interpolation = interpolation 46 | self.scale = scale 47 | self.ratio = ratio 48 | 49 | @staticmethod 50 | def get_params(img, bbox, scale, ratio): 51 | """Get parameters for ``crop`` for a random sized crop. 52 | 53 | Args: 54 | img (PIL Image): Image to be cropped. 55 | scale (tuple): range of size of the origin size cropped 56 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 57 | 58 | Returns: 59 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 60 | sized crop. 61 | """ 62 | 63 | area = img.size[0] * img.size[1] 64 | 65 | for attempt in range(30): 66 | target_area = random.uniform(*scale) * area 67 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 68 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 69 | 70 | w = int(round(math.sqrt(target_area * aspect_ratio))) 71 | h = int(round(math.sqrt(target_area / aspect_ratio))) 72 | 73 | if w <= img.size[0] and h <= img.size[1]: 74 | 75 | i = random.randint(0, img.size[1] - h) # i is y actually 76 | j = random.randint(0, img.size[0] - w) # j is x 77 | 78 | #compute intersection between crop image and bbox 79 | intersec = compute_intersec(i, j, h, w, bbox) 80 | 81 | if intersec[2]-intersec[0]>0 and intersec[3]-intersec[1]>0: 82 | intersec = normalize_intersec(i, j, h, w, intersec) 83 | return i, j, h, w, intersec 84 | 85 | # Fallback to central crop 86 | in_ratio = img.size[0] / img.size[1] 87 | if (in_ratio < min(ratio)): 88 | w = img.size[0] 89 | h = int(round(w / min(ratio))) 90 | elif (in_ratio > max(ratio)): 91 | h = img.size[1] 92 | w = int(round(h * max(ratio))) 93 | else: # whole image 94 | w = img.size[0] 95 | h = img.size[1] 96 | 97 | i = (img.size[1] - h) // 2 98 | j = (img.size[0] - w) // 2 99 | 100 | intersec = compute_intersec(i, j, h, w, bbox) 101 | intersec = normalize_intersec(i, j, h, w, intersec) 102 | return i, j, h, w, intersec 103 | 104 | def __call__(self, img, bbox): 105 | """ 106 | Args: 107 | img (PIL Image): Image to be cropped and resized. 108 | 109 | Returns: 110 | PIL Image: Randomly cropped and resized image. 111 | """ 112 | i, j, h, w, crop_bbox = self.get_params(img, bbox, self.scale, self.ratio) 113 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), crop_bbox 114 | class RandomBBoxCrop(object): 115 | """Crop the given PIL Image at a random location. 116 | 117 | Args: 118 | size (sequence or int): Desired output size of the crop. If size is an 119 | int instead of sequence like (h, w), a square crop (size, size) is 120 | made. 121 | padding (int or sequence, optional): Optional padding on each border 122 | of the image. Default is None, i.e no padding. If a sequence of length 123 | 4 is provided, it is used to pad left, top, right, bottom borders 124 | respectively. If a sequence of length 2 is provided, it is used to 125 | pad left/right, top/bottom borders, respectively. 126 | pad_if_needed (boolean): It will pad the image if smaller than the 127 | desired size to avoid raising an exception. Since cropping is done 128 | after padding, the padding seems to be done at a random offset. 129 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 130 | length 3, it is used to fill R, G, B channels respectively. 131 | This value is only used when the padding_mode is constant 132 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 133 | 134 | - constant: pads with a constant value, this value is specified with fill 135 | 136 | - edge: pads with the last value on the edge of the image 137 | 138 | - reflect: pads with reflection of image (without repeating the last value on the edge) 139 | 140 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 141 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 142 | 143 | - symmetric: pads with reflection of image (repeating the last value on the edge) 144 | 145 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 146 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 147 | 148 | """ 149 | 150 | def __init__(self, size): 151 | if isinstance(size, numbers.Number): 152 | self.size = (int(size), int(size)) 153 | else: 154 | self.size = size 155 | 156 | @staticmethod 157 | def get_params(img, bbox, output_size): 158 | """Get parameters for ``crop`` for a random crop. 159 | 160 | Args: 161 | img (PIL Image): Image to be cropped. 162 | output_size (tuple): Expected output size of the crop. 163 | 164 | Returns: 165 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 166 | """ 167 | w, h = img.size 168 | th, tw = output_size 169 | if w == tw and h == th: 170 | return 0, 0, h, w 171 | 172 | i = random.randint(0, h - th) 173 | j = random.randint(0, w - tw) 174 | intersec = compute_intersec(i, j, h, w, bbox) 175 | intersec = normalize_intersec(i, j, h, w, intersec) 176 | return i, j, th, tw, intersec 177 | 178 | def __call__(self, img, bbox): 179 | """ 180 | Args: 181 | img (PIL Image): Image to be cropped. 182 | 183 | Returns: 184 | PIL Image: Cropped image. 185 | """ 186 | 187 | i, j, h, w,crop_bbox = self.get_params(img, bbox, self.size) 188 | 189 | return F.crop(img, i, j, h, w),crop_bbox 190 | 191 | def __repr__(self): 192 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 193 | 194 | class ResizedBBoxCrop(object): 195 | 196 | def __init__(self, size, interpolation=Image.BILINEAR): 197 | self.size = size 198 | 199 | self.interpolation = interpolation 200 | 201 | @staticmethod 202 | def get_params(img, bbox, size): 203 | #resize to 256 204 | if isinstance(size, int): 205 | w, h = img.size 206 | if (w <= h and w == size) or (h <= w and h == size): 207 | img = copy.deepcopy(img) 208 | ow, oh = w, h 209 | if w < h: 210 | ow = size 211 | oh = int(size*h/w) 212 | else: 213 | oh = size 214 | ow = int(size*w/h) 215 | else: 216 | ow, oh = size[::-1] 217 | w, h = img.size 218 | 219 | 220 | intersec = copy.deepcopy(bbox) 221 | ratew = ow / w 222 | rateh = oh / h 223 | intersec[0] = bbox[0]*ratew 224 | intersec[2] = bbox[2]*ratew 225 | intersec[1] = bbox[1]*rateh 226 | intersec[3] = bbox[3]*rateh 227 | 228 | #intersec = normalize_intersec(i, j, h, w, intersec) 229 | return (oh, ow), intersec 230 | 231 | def __call__(self, img, bbox): 232 | """ 233 | Args: 234 | img (PIL Image): Image to be cropped and resized. 235 | 236 | Returns: 237 | PIL Image: Randomly cropped and resized image. 238 | """ 239 | size, crop_bbox = self.get_params(img, bbox, self.size) 240 | return F.resize(img, self.size, self.interpolation), crop_bbox 241 | 242 | 243 | class CenterBBoxCrop(object): 244 | 245 | def __init__(self, size, interpolation=Image.BILINEAR): 246 | self.size = size 247 | 248 | self.interpolation = interpolation 249 | 250 | @staticmethod 251 | def get_params(img, bbox, size): 252 | #center crop 253 | if isinstance(size, numbers.Number): 254 | output_size = (int(size), int(size)) 255 | 256 | w, h = img.size 257 | th, tw = output_size 258 | 259 | i = int(round((h - th) / 2.)) 260 | j = int(round((w - tw) / 2.)) 261 | 262 | intersec = compute_intersec(i, j, th, tw, bbox) 263 | intersec = normalize_intersec(i, j, th, tw, intersec) 264 | 265 | #intersec = normalize_intersec(i, j, h, w, intersec) 266 | return i, j, th, tw, intersec 267 | 268 | def __call__(self, img, bbox): 269 | """ 270 | Args: 271 | img (PIL Image): Image to be cropped and resized. 272 | 273 | Returns: 274 | PIL Image: Randomly cropped and resized image. 275 | """ 276 | i, j, th, tw, crop_bbox = self.get_params(img, bbox, self.size) 277 | return F.center_crop(img, self.size), crop_bbox 278 | 279 | 280 | -------------------------------------------------------------------------------- /utils/func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import copy 5 | def count_max(x): 6 | count_dict = {} 7 | for xlist in x: 8 | for item in xlist: 9 | if item==0: 10 | continue 11 | if item not in count_dict.keys(): 12 | count_dict[item] = 0 13 | count_dict[item] += 1 14 | if count_dict == {}: 15 | return -1 16 | count_dict = sorted(count_dict.items(), key=lambda d:d[1], reverse=True) 17 | return count_dict[0][0] 18 | 19 | 20 | def sk_pca(X, k): 21 | from sklearn.decomposition import PCA 22 | pca = PCA(k) 23 | pca.fit(X) 24 | vec = pca.components_ 25 | #print(vec.shape) 26 | return vec 27 | 28 | def fld(x1, x2): 29 | x1, x2 = np.mat(x1), np.mat(x2) 30 | n1 = x1.shape[0] 31 | n2 = x2.shape[0] 32 | k = x1.shape[1] 33 | 34 | m1 = np.mean(x1, axis=0) 35 | m2 = np.mean(x2, axis=0) 36 | m = np.mean(np.concatenate((x1, x2), axis=0), axis=0) 37 | print(x1.shape, m1.shape) 38 | 39 | 40 | c1 = np.cov(x1.T) 41 | s1 = c1*(n1-1) 42 | c2 = np.cov(x2.T) 43 | s2 = c2*(n2-1) 44 | Sw = s1/n1 + s2/n2 45 | print(Sw.shape) 46 | W = np.dot(np.linalg.inv(Sw), (m1-m2).T) 47 | print(W.shape) 48 | W = W / np.linalg.norm(W, 2) 49 | return np.mean(np.dot(x1, W)), np.mean(np.dot(x2, W)), W 50 | 51 | def pca(X, k): 52 | n, m = X.shape 53 | mean = np.mean(X, 0) 54 | #print(mean.shape) 55 | temp = X - mean 56 | conv = np.cov(X.T) 57 | #print(conv.shape) 58 | conv1 = np.cov(temp.T) 59 | #print(conv-conv1) 60 | 61 | w, v = np.linalg.eig(conv) 62 | #print(w.shape) 63 | #print(v.shape) 64 | index = np.argsort(-w) 65 | vec = np.matrix(v.T[index[:k]]) 66 | #print(vec.shape) 67 | 68 | recon = (temp * vec.T)*vec+mean 69 | 70 | #print(X-recon) 71 | return vec 72 | def to_variable(x): 73 | if torch.cuda.is_available(): 74 | x = x.to(0) 75 | return torch.autograd.Variable(x) 76 | 77 | def to_data(x): 78 | if torch.cuda.is_available(): 79 | x = x.cpu() 80 | return x.data 81 | 82 | def copy_parameters(model, pretrained_dict): 83 | model_dict = model.state_dict() 84 | 85 | pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict and pretrained_dict[k].size()==model_dict[k[7:]].size()} 86 | for k, v in pretrained_dict.items(): 87 | print(k) 88 | model_dict.update(pretrained_dict) 89 | model.load_state_dict(model_dict) 90 | return model 91 | 92 | 93 | def compute_intersec(i, j, h, w, bbox): 94 | ''' 95 | intersection box between croped box and GT BBox 96 | ''' 97 | intersec = copy.deepcopy(bbox) 98 | 99 | intersec[0] = max(j, bbox[0]) 100 | intersec[1] = max(i, bbox[1]) 101 | intersec[2] = min(j + w, bbox[2]) 102 | intersec[3] = min(i + h, bbox[3]) 103 | return intersec 104 | 105 | 106 | def normalize_intersec(i, j, h, w, intersec): 107 | ''' 108 | return: normalize into [0, 1] 109 | ''' 110 | 111 | intersec[0] = (intersec[0] - j) / w 112 | intersec[2] = (intersec[2] - j) / w 113 | intersec[1] = (intersec[1] - i) / h 114 | intersec[3] = (intersec[3] - i) / h 115 | return intersec 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /utils/nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def nms(boxes, scores, thresh): 4 | if len(boxes)==0: 5 | return [] 6 | boxes = np.array(boxes) 7 | #scores = np.array(scores) 8 | x1 = boxes[:, 0] 9 | y1 = boxes[:, 1] 10 | x2 = boxes[:, 2] 11 | y2 = boxes[:, 3] 12 | 13 | areas = (x2-x1+1)*(y2-y1+1) 14 | 15 | scores = np.array(scores) 16 | 17 | order = np.argsort(scores)[-100:] 18 | keep_boxes = [] 19 | while order.size > 0: 20 | i = order[-1] 21 | keep_boxes.append(boxes[i]) 22 | 23 | xx1 = np.maximum(x1[i], x1[order[:-1]]) 24 | yy1 = np.maximum(y1[i], y1[order[:-1]]) 25 | xx2 = np.minimum(x2[i], x2[order[:-1]]) 26 | yy2 = np.minimum(y2[i], y2[order[:-1]]) 27 | 28 | w = np.maximum(0.0, xx2-xx1+1) 29 | h = np.maximum(0.0, yy2-yy1+1) 30 | inter = w*h 31 | 32 | ovr = inter / (areas[i] + areas[order[:-1]] - inter) 33 | inds = np.where(ovr <= thresh) 34 | order = order[inds] 35 | 36 | return keep_boxes 37 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import cv2 4 | import numpy as np 5 | import os 6 | 7 | _GREEN = (18, 217, 15) 8 | _RED = (15, 18, 217) 9 | 10 | def vis_bbox(img, bbox, color=_GREEN, thick=1): 11 | '''Visualize a bounding box''' 12 | img = img.astype(np.uint8) 13 | (x0, y0, x1, y1) = bbox 14 | cv2.rectangle(img, (int(x0), int(y0)), (int(x1), int(y1)), color, thickness=thick) 15 | return img 16 | 17 | def vis_one_image(img, boxes, color=_GREEN): 18 | for bbox in boxes: 19 | img = vis_bbox(img, (bbox[0], bbox[1], bbox[2], bbox[3]), color) 20 | return img 21 | 22 | 23 | --------------------------------------------------------------------------------