├── .vscode └── settings.json ├── models ├── __init__.py └── add_gcn.py ├── data ├── test_dataset.py ├── coco.py ├── __init__.py └── voc.py ├── README.md ├── main.py ├── util.py └── trainer.py /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/usr/bin/python" 3 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from .add_gcn import ADD_GCN 3 | 4 | model_dict = {'ADD_GCN': ADD_GCN} 5 | 6 | def get_model(num_classes, args): 7 | res101 = torchvision.models.resnet101(pretrained=True) 8 | model = model_dict[args.model_name](res101, num_classes) 9 | return model -------------------------------------------------------------------------------- /data/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | import torchvision.transforms as transforms 6 | from coco import COCO2014 7 | from voc import VOC2007, VOC2012 8 | 9 | data = sys.argv[1] 10 | 11 | 12 | def collate_fn(batch): 13 | ret_batch = dict() 14 | for k in batch[0].keys(): 15 | if k == 'image' or k == 'target': 16 | ret_batch[k] = torch.cat([b[k].unsqueeze(0) for b in batch]) 17 | else: 18 | ret_batch[k] = [b[k] for b in batch] 19 | return ret_batch 20 | 21 | 22 | transform = transforms.Compose([ 23 | transforms.RandomResizedCrop(448, scale=(0.1, 1.5), ratio=(1.0, 1.0)), 24 | transforms.RandomHorizontalFlip(), 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 27 | ]) 28 | 29 | # COCO2014 30 | # train_dataset = COCO2014(data, phase='train', transform=transform) 31 | # val_dataset = COCO2014(data, phase='val', transform=transform) 32 | 33 | # VOC2007 34 | # train_dataset = VOC2007(data, phase='trainval', transform=transform) 35 | # val_dataset = VOC2007(data, phase='test', transform=transform) 36 | 37 | # VOC2012 38 | train_dataset = VOC2012(data, phase='trainval', transform=transform) 39 | val_dataset = VOC2012(data, phase='test', transform=transform) 40 | 41 | train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, 42 | pin_memory=True, collate_fn=collate_fn, drop_last=True) 43 | val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4, 44 | pin_memory=True, collate_fn=collate_fn) 45 | 46 | for data in train_loader: 47 | pdb.set_trace() 48 | 49 | 50 | 51 | pdb.set_trace() 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADD-GCN: Attention-Driven Dynamic Graph Convolutional Network for Multi-Label Image Recognition 2 | 3 | This project hosts the code for implementing the ADD-GCN algorithm for multi-label image recognition, as presented in our paper: 4 | 5 | Attention-Driven Dynamic Graph Convolutional Network for Multi-Label Image Recognition; 6 | Jin Ye, Junjun He, Xiaojiang Peng, Wenhao Wu, Yu Qiao; 7 | In: European Conference on Computer Vision (ECCV), 2020. 8 | arXiv preprint arXiv:2012.02994 9 | 10 | The full paper is available at: [https://arxiv.org/abs/2012.02994](https://arxiv.org/abs/2012.02994). 11 | 12 | ## Installation 13 | #### This project is implemented with Pytorch and has been tested on version Pytorch 1.0/1.1/1.2. 14 | 15 | ## A quick demo 16 | After you have installed Pytorch, you can follow the below steps to run a quick demo. 17 | 18 | ### Inference for COCO2014 19 | 20 | python main.py --data COCO2014 --data_root_dir {YOUR-ROOT-DATA-DIR} --model_name ADD_GCN --resume {THE-TEST-MODEL} -e -i 448 21 | 22 | Please note that: 23 | 1) You should put the COCO2014 folder in {YOUR-ROOT-DATA-DIR}. 24 | 25 | 2) You should put the test model in {THE-TEST-MODEL} folder. 26 | 27 | 3) You can get the same ADD-GCN results with [this model](https://pan.baidu.com/s/17Y1knACAo5U6XbV75GUI8w). The password is ``4ebj``. 28 | 29 | Model | Test size | mAP 30 | --- |:---:|:---: 31 | ResNet-101 | 448×448 | 79.7 32 | DecoupleNet | 448×448 | 82.2 33 | ML-GCN | 448×448 | 83.0 34 | ADD-GCN | 448×448 | 84.2 35 | ResNet-101 | 576×576 | 80.0 36 | SSGRL | 576×576 | 84.2 37 | ML-GCN | 576×576 | 84.3 38 | ADD-GCN | 576×576 | 85.2 39 | 40 | 41 | ## Citations 42 | Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follows. 43 | ``` 44 | @inproceedings{ye2020add, 45 | title = {Attention-Driven Dynamic Graph Convolutional Network for Multi-Label Image Recognition}, 46 | author = {Jin Ye, Junjun He, Xiaojiang Peng, Wenhao Wu, Yu Qiao}, 47 | booktitle = {European Conference on Computer Vision (ECCV)}, 48 | year = {2020} 49 | } 50 | ``` 51 | 52 | 53 | ## License 54 | 55 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import argparse 3 | from models import get_model 4 | from data import make_data_loader 5 | import warnings 6 | from trainer import Trainer 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import random 10 | 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch Training for Multi-label Image Classification') 13 | 14 | ''' Fixed in general ''' 15 | parser.add_argument('--data_root_dir', default='./datasets/', type=str, help='save path') 16 | parser.add_argument('--image-size', '-i', default=448, type=int) 17 | parser.add_argument('--epochs', default=50, type=int) 18 | parser.add_argument('--epoch_step', default=[30, 40], type=int, nargs='+', help='number of epochs to change learning rate') 19 | # parser.add_argument('--device_ids', default=[0], type=int, nargs='+', help='number of epochs to change learning rate') 20 | parser.add_argument('-b', '--batch-size', default=16, type=int) 21 | parser.add_argument('-j', '--num_workers', default=4, type=int, metavar='INT', help='number of data loading workers (default: 4)') 22 | parser.add_argument('--display_interval', default=200, type=int, metavar='M', help='display_interval') 23 | parser.add_argument('--lr', '--learning-rate', default=0.05, type=float) 24 | parser.add_argument('--lrp', '--learning-rate-pretrained', default=0.1, type=float, metavar='LRP', help='learning rate for pre-trained layers') 25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 26 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 27 | parser.add_argument('--max_clip_grad_norm', default=10.0, type=float, metavar='M', help='max_clip_grad_norm') 28 | parser.add_argument('--seed', default=1, type=int, help='seed for initializing training. ') 29 | 30 | ''' Train setting ''' 31 | parser.add_argument('--data', metavar='NAME', help='dataset name (e.g. COCO2014') 32 | parser.add_argument('--model_name', type=str, default='ADD_GCN') 33 | parser.add_argument('--save_dir', default='./checkpoint/COCO2014/', type=str, help='save path') 34 | 35 | ''' Val or Tese setting ''' 36 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 37 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 38 | 39 | 40 | def main(args): 41 | 42 | if args.seed is not None: 43 | print ('* absolute seed: {}'.format(args.seed)) 44 | random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed(args.seed) 47 | cudnn.deterministic = True 48 | warnings.warn('You have chosen to seed training. ' 49 | 'This will turn on the CUDNN deterministic setting, ' 50 | 'which can slow down your training considerably! ' 51 | 'You may see unexpected behavior when restarting ' 52 | 'from checkpoints.') 53 | 54 | is_train = True if not args.evaluate else False 55 | train_loader, val_loader, num_classes = make_data_loader(args, is_train=is_train) 56 | 57 | model = get_model(num_classes, args) 58 | 59 | criterion = torch.nn.MultiLabelSoftMarginLoss() 60 | 61 | trainer = Trainer(model, criterion, train_loader, val_loader, args) 62 | 63 | if is_train: 64 | trainer.train() 65 | else: 66 | trainer.validate() 67 | 68 | if __name__ == "__main__": 69 | args = parser.parse_args() 70 | main(args) 71 | -------------------------------------------------------------------------------- /models/add_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DynamicGraphConvolution(nn.Module): 6 | def __init__(self, in_features, out_features, num_nodes): 7 | super(DynamicGraphConvolution, self).__init__() 8 | 9 | self.static_adj = nn.Sequential( 10 | nn.Conv1d(num_nodes, num_nodes, 1, bias=False), 11 | nn.LeakyReLU(0.2)) 12 | self.static_weight = nn.Sequential( 13 | nn.Conv1d(in_features, out_features, 1), 14 | nn.LeakyReLU(0.2)) 15 | 16 | self.gap = nn.AdaptiveAvgPool1d(1) 17 | self.conv_global = nn.Conv1d(in_features, in_features, 1) 18 | self.bn_global = nn.BatchNorm1d(in_features) 19 | self.relu = nn.LeakyReLU(0.2) 20 | 21 | self.conv_create_co_mat = nn.Conv1d(in_features*2, num_nodes, 1) 22 | self.dynamic_weight = nn.Conv1d(in_features, out_features, 1) 23 | 24 | def forward_static_gcn(self, x): 25 | x = self.static_adj(x.transpose(1, 2)) 26 | x = self.static_weight(x.transpose(1, 2)) 27 | return x 28 | 29 | def forward_construct_dynamic_graph(self, x): 30 | ### Model global representations ### 31 | x_glb = self.gap(x) 32 | x_glb = self.conv_global(x_glb) 33 | x_glb = self.bn_global(x_glb) 34 | x_glb = self.relu(x_glb) 35 | x_glb = x_glb.expand(x_glb.size(0), x_glb.size(1), x.size(2)) 36 | 37 | ### Construct the dynamic correlation matrix ### 38 | x = torch.cat((x_glb, x), dim=1) 39 | dynamic_adj = self.conv_create_co_mat(x) 40 | dynamic_adj = torch.sigmoid(dynamic_adj) 41 | return dynamic_adj 42 | 43 | def forward_dynamic_gcn(self, x, dynamic_adj): 44 | x = torch.matmul(x, dynamic_adj) 45 | x = self.relu(x) 46 | x = self.dynamic_weight(x) 47 | x = self.relu(x) 48 | return x 49 | 50 | def forward(self, x): 51 | """ D-GCN module 52 | 53 | Shape: 54 | - Input: (B, C_in, N) # C_in: 1024, N: num_classes 55 | - Output: (B, C_out, N) # C_out: 1024, N: num_classes 56 | """ 57 | out_static = self.forward_static_gcn(x) 58 | x = x + out_static # residual 59 | dynamic_adj = self.forward_construct_dynamic_graph(x) 60 | x = self.forward_dynamic_gcn(x, dynamic_adj) 61 | return x 62 | 63 | 64 | class ADD_GCN(nn.Module): 65 | def __init__(self, model, num_classes): 66 | super(ADD_GCN, self).__init__() 67 | self.features = nn.Sequential( 68 | model.conv1, 69 | model.bn1, 70 | model.relu, 71 | model.maxpool, 72 | model.layer1, 73 | model.layer2, 74 | model.layer3, 75 | model.layer4, 76 | ) 77 | self.num_classes = num_classes 78 | 79 | self.fc = nn.Conv2d(model.fc.in_features, num_classes, (1,1), bias=False) 80 | 81 | self.conv_transform = nn.Conv2d(2048, 1024, (1,1)) 82 | self.relu = nn.LeakyReLU(0.2) 83 | 84 | self.gcn = DynamicGraphConvolution(1024, 1024, num_classes) 85 | 86 | self.mask_mat = nn.Parameter(torch.eye(self.num_classes).float()) 87 | self.last_linear = nn.Conv1d(1024, self.num_classes, 1) 88 | 89 | # image normalization 90 | self.image_normalization_mean = [0.485, 0.456, 0.406] 91 | self.image_normalization_std = [0.229, 0.224, 0.225] 92 | 93 | def forward_feature(self, x): 94 | x = self.features(x) 95 | return x 96 | 97 | def forward_classification_sm(self, x): 98 | """ Get another confident scores {s_m}. 99 | 100 | Shape: 101 | - Input: (B, C_in, H, W) # C_in: 2048 102 | - Output: (B, C_out) # C_out: num_classes 103 | """ 104 | x = self.fc(x) 105 | x = x.view(x.size(0), x.size(1), -1) 106 | x = x.topk(1, dim=-1)[0].mean(dim=-1) 107 | return x 108 | 109 | def forward_sam(self, x): 110 | """ SAM module 111 | 112 | Shape: 113 | - Input: (B, C_in, H, W) # C_in: 2048 114 | - Output: (B, C_out, N) # C_out: 1024, N: num_classes 115 | """ 116 | mask = self.fc(x) 117 | mask = mask.view(mask.size(0), mask.size(1), -1) 118 | mask = torch.sigmoid(mask) 119 | mask = mask.transpose(1, 2) 120 | 121 | x = self.conv_transform(x) 122 | x = x.view(x.size(0), x.size(1), -1) 123 | x = torch.matmul(x, mask) 124 | return x 125 | 126 | def forward_dgcn(self, x): 127 | x = self.gcn(x) 128 | return x 129 | 130 | def forward(self, x): 131 | x = self.forward_feature(x) 132 | 133 | out1 = self.forward_classification_sm(x) 134 | 135 | v = self.forward_sam(x) # B*1024*num_classes 136 | z = self.forward_dgcn(v) 137 | z = v + z 138 | 139 | out2 = self.last_linear(z) # B*1*num_classes 140 | mask_mat = self.mask_mat.detach() 141 | out2 = (out2 * mask_mat).sum(-1) 142 | return out1, out2 143 | 144 | def get_config_optim(self, lr, lrp): 145 | small_lr_layers = list(map(id, self.features.parameters())) 146 | large_lr_layers = filter(lambda p:id(p) not in small_lr_layers, self.parameters()) 147 | return [ 148 | {'params': self.features.parameters(), 'lr': lr * lrp}, 149 | {'params': large_lr_layers, 'lr': lr}, 150 | ] 151 | 152 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import subprocess 4 | from PIL import Image 5 | # import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | import pickle 9 | 10 | urls = {'train_img':'http://images.cocodataset.org/zips/train2014.zip', 11 | 'val_img' : 'http://images.cocodataset.org/zips/val2014.zip', 12 | 'annotations':'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'} 13 | 14 | def download_coco2014(root, phase): 15 | work_dir = os.getcwd() 16 | tmpdir = os.path.join(root, 'tmp/') 17 | if not os.path.exists(root): 18 | os.makedirs(root) 19 | if not os.path.exists(tmpdir): 20 | os.makedirs(tmpdir) 21 | if phase == 'train': 22 | filename = 'train2014.zip' 23 | elif phase == 'val': 24 | filename = 'val2014.zip' 25 | cached_file = os.path.join(tmpdir, filename) 26 | if not os.path.exists(cached_file): 27 | print('Downloading: "{}" to {}\n'.format(urls[phase + '_img'], cached_file)) 28 | os.chdir(tmpdir) 29 | subprocess.call('wget ' + urls[phase + '_img'], shell=True) 30 | os.chdir(root) 31 | # extract file 32 | img_data = os.path.join(root, filename.split('.')[0]) 33 | if not os.path.exists(img_data): 34 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 35 | command = 'unzip {} -d {}'.format(cached_file,root) 36 | os.system(command) 37 | print('[dataset] Done!') 38 | 39 | # train/val images/annotations 40 | cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip') 41 | if not os.path.exists(cached_file): 42 | print('Downloading: "{}" to {}\n'.format(urls['annotations'], cached_file)) 43 | os.chdir(tmpdir) 44 | # subprocess.Popen('wget ' + urls['annotations'], shell=True) 45 | subprocess.call('wget ' + urls['annotations'], shell=True) 46 | os.chdir(root) 47 | annotations_data = os.path.join(root, 'annotations') 48 | if not os.path.exists(annotations_data): 49 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 50 | command = 'unzip {} -d {}'.format(cached_file, root) 51 | os.system(command) 52 | print('[annotation] Done!') 53 | 54 | annotations_data = os.path.join(root, 'annotations') 55 | anno = os.path.join(root, '{}_anno.json'.format(phase)) 56 | img_id = {} 57 | annotations_id = {} 58 | if not os.path.exists(anno): 59 | annotations_file = json.load(open(os.path.join(annotations_data, 'instances_{}2014.json'.format(phase)))) 60 | annotations = annotations_file['annotations'] 61 | category = annotations_file['categories'] 62 | category_id = {} 63 | for cat in category: 64 | category_id[cat['id']] = cat['name'] 65 | cat2idx = categoty_to_idx(sorted(category_id.values())) 66 | images = annotations_file['images'] 67 | for annotation in annotations: 68 | if annotation['image_id'] not in annotations_id: 69 | annotations_id[annotation['image_id']] = set() 70 | annotations_id[annotation['image_id']].add(cat2idx[category_id[annotation['category_id']]]) 71 | for img in images: 72 | if img['id'] not in annotations_id: 73 | continue 74 | if img['id'] not in img_id: 75 | img_id[img['id']] = {} 76 | img_id[img['id']]['file_name'] = img['file_name'] 77 | img_id[img['id']]['labels'] = list(annotations_id[img['id']]) 78 | anno_list = [] 79 | for k, v in img_id.items(): 80 | anno_list.append(v) 81 | json.dump(anno_list, open(anno, 'w')) 82 | if not os.path.exists(os.path.join(root, 'category.json')): 83 | json.dump(cat2idx, open(os.path.join(root, 'category.json'), 'w')) 84 | del img_id 85 | del anno_list 86 | del images 87 | del annotations_id 88 | del annotations 89 | del category 90 | del category_id 91 | print('[json] Done!') 92 | os.chdir(work_dir) 93 | 94 | def categoty_to_idx(category): 95 | cat2idx = {} 96 | for cat in category: 97 | cat2idx[cat] = len(cat2idx) 98 | return cat2idx 99 | 100 | 101 | class COCO2014(Dataset): 102 | def __init__(self, root, transform=None, phase='train'): 103 | self.root = os.path.abspath(root) 104 | self.phase = phase 105 | self.img_list = [] 106 | self.transform = transform 107 | download_coco2014(self.root, phase) 108 | self.get_anno() 109 | self.num_classes = len(self.cat2idx) 110 | print('[dataset] COCO2014 classification phase={} number of classes={} number of images={}'.format(phase, self.num_classes, len(self.img_list))) 111 | 112 | def get_anno(self): 113 | list_path = os.path.join(self.root, '{}_anno.json'.format(self.phase)) 114 | self.img_list = json.load(open(list_path, 'r')) 115 | self.cat2idx = json.load(open(os.path.join(self.root, 'category.json'), 'r')) 116 | 117 | def __len__(self): 118 | return len(self.img_list) 119 | 120 | def __getitem__(self, index): 121 | item = self.img_list[index] 122 | filename = item['file_name'] 123 | labels = sorted(item['labels']) 124 | img = Image.open(os.path.join(self.root, '{}2014'.format(self.phase), filename)).convert('RGB') 125 | if self.transform is not None: 126 | img = self.transform(img) 127 | # target = np.zeros(self.num_classes, np.float32) - 1 128 | target = torch.zeros(self.num_classes, dtype=torch.float32) - 1 129 | target[labels] = 1 130 | data = {'image':img, 'name': filename, 'target': target} 131 | return data 132 | # return image, target 133 | # return (img, filename), target 134 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | 9 | from .coco import COCO2014 10 | from .voc import VOC2007, VOC2012 11 | 12 | data_dict = {'COCO2014': COCO2014, 13 | 'VOC2007': VOC2007, 14 | 'VOC2012': VOC2012} 15 | 16 | def collate_fn(batch): 17 | ret_batch = dict() 18 | for k in batch[0].keys(): 19 | if k == 'image' or k == 'target': 20 | ret_batch[k] = torch.cat([b[k].unsqueeze(0) for b in batch]) 21 | else: 22 | ret_batch[k] = [b[k] for b in batch] 23 | return ret_batch 24 | 25 | class MultiScaleCrop(object): 26 | 27 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 28 | self.scales = scales if scales is not None else [1, 875, .75, .66] 29 | self.max_distort = max_distort 30 | self.fix_crop = fix_crop 31 | self.more_fix_crop = more_fix_crop 32 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 33 | self.interpolation = Image.BILINEAR 34 | 35 | def __call__(self, img): 36 | im_size = img.size 37 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 38 | crop_img_group = img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) 39 | ret_img_group = crop_img_group.resize((self.input_size[0], self.input_size[1]), self.interpolation) 40 | return ret_img_group 41 | 42 | def _sample_crop_size(self, im_size): 43 | image_w, image_h = im_size[0], im_size[1] 44 | 45 | # find a crop size 46 | base_size = min(image_w, image_h) 47 | crop_sizes = [int(base_size * x) for x in self.scales] 48 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 49 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 50 | 51 | pairs = [] 52 | for i, h in enumerate(crop_h): 53 | for j, w in enumerate(crop_w): 54 | if abs(i - j) <= self.max_distort: 55 | pairs.append((w, h)) 56 | 57 | crop_pair = random.choice(pairs) 58 | if not self.fix_crop: 59 | w_offset = random.randint(0, image_w - crop_pair[0]) 60 | h_offset = random.randint(0, image_h - crop_pair[1]) 61 | else: 62 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 63 | 64 | return crop_pair[0], crop_pair[1], w_offset, h_offset 65 | 66 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 67 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 68 | return random.choice(offsets) 69 | 70 | @staticmethod 71 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 72 | w_step = (image_w - crop_w) // 4 73 | h_step = (image_h - crop_h) // 4 74 | 75 | ret = list() 76 | ret.append((0, 0)) # upper left 77 | ret.append((4 * w_step, 0)) # upper right 78 | ret.append((0, 4 * h_step)) # lower left 79 | ret.append((4 * w_step, 4 * h_step)) # lower right 80 | ret.append((2 * w_step, 2 * h_step)) # center 81 | 82 | if more_fix_crop: 83 | ret.append((0, 2 * h_step)) # center left 84 | ret.append((4 * w_step, 2 * h_step)) # center right 85 | ret.append((2 * w_step, 4 * h_step)) # lower center 86 | ret.append((2 * w_step, 0 * h_step)) # upper center 87 | 88 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 89 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 90 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 91 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 92 | 93 | return ret 94 | 95 | def __str__(self): 96 | return self.__class__.__name__ 97 | 98 | 99 | def get_transform(args, is_train=True): 100 | if is_train: 101 | transform = transforms.Compose([ 102 | # transforms.RandomResizedCrop(args.image_size, scale=(0.1, 1.5), ratio=(1.0, 1.0)), 103 | # transforms.RandomResizedCrop(args.image_size, scale=(0.1, 2.0), ratio=(1.0, 1.0)), 104 | transforms.Resize((args.image_size+64, args.image_size+64)), 105 | MultiScaleCrop(args.image_size, scales=(1.0, 0.875, 0.75, 0.66, 0.5), max_distort=2), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 109 | ]) 110 | else: 111 | transform = transforms.Compose([ 112 | transforms.Resize((args.image_size,args.image_size)), 113 | transforms.ToTensor(), 114 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 115 | ]) 116 | return transform 117 | 118 | def make_data_loader(args, is_train=True): 119 | root_dir = os.path.join(args.data_root_dir, args.data) 120 | 121 | # Build val_loader 122 | transform = get_transform(args, is_train=False) 123 | if args.data == 'COCO2014': 124 | val_dataset = COCO2014(root_dir, phase='val', transform=transform) 125 | elif args.data in ('VOC2007', 'VOC2012'): 126 | val_dataset = data_dict[args.data](root_dir, phase='test', transform=transform) 127 | else: 128 | raise NotImplementedError('Value error: No matched dataset!') 129 | 130 | num_classes = val_dataset[0]['target'].size(-1) 131 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 132 | num_workers=args.num_workers, pin_memory=True, 133 | collate_fn=collate_fn, drop_last=False) 134 | 135 | if not is_train: 136 | return None, val_loader, num_classes 137 | 138 | # Build train_loader 139 | transform = get_transform(args, is_train=True) 140 | if args.data == 'COCO2014': 141 | train_dataset = COCO2014(root_dir, phase='train', transform=transform) 142 | elif args.data in ('VOC2007', 'VOC2012'): 143 | train_dataset = data_dict[args.data](root_dir, phase='trainval', transform=transform) 144 | else: 145 | raise NotImplementedError('Value error: No matched dataset!') 146 | 147 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 148 | num_workers=args.num_workers, pin_memory=True, 149 | collate_fn=collate_fn, drop_last=True) 150 | 151 | 152 | return train_loader, val_loader, num_classes 153 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import math 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import random 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def average(self): 28 | return self.avg 29 | 30 | def value(self): 31 | return self.val 32 | 33 | def __str__(self): 34 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 35 | return fmtstr.format(**self.__dict__) 36 | 37 | class AveragePrecisionMeter(object): 38 | """ 39 | The APMeter measures the average precision per class. 40 | The APMeter is designed to operate on `NxK` Tensors `output` and 41 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 42 | contains model output scores for `N` examples and `K` classes that ought to 43 | be higher when the model is more convinced that the example should be 44 | positively labeled, and smaller when the model believes the example should 45 | be negatively labeled (for instance, the output of a sigmoid function); (2) 46 | the `target` contains only values 0 (for negative examples) and 1 47 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 48 | each sample. 49 | """ 50 | 51 | def __init__(self, difficult_examples=True): 52 | super(AveragePrecisionMeter, self).__init__() 53 | self.reset() 54 | self.difficult_examples = difficult_examples 55 | 56 | def reset(self): 57 | """Resets the meter with empty member variables""" 58 | self.scores = torch.FloatTensor(torch.FloatStorage()) 59 | self.targets = torch.LongTensor(torch.LongStorage()) 60 | self.filenames = [] 61 | 62 | def add(self, output, target, filename): 63 | """ 64 | Args: 65 | output (Tensor): NxK tensor that for each of the N examples 66 | indicates the probability of the example belonging to each of 67 | the K classes, according to the model. The probabilities should 68 | sum to one over all classes 69 | target (Tensor): binary NxK tensort that encodes which of the K 70 | classes are associated with the N-th input 71 | (eg: a row [0, 1, 0, 1] indicates that the example is 72 | associated with classes 2 and 4) 73 | weight (optional, Tensor): Nx1 tensor representing the weight for 74 | each example (each weight > 0) 75 | """ 76 | if not torch.is_tensor(output): 77 | output = torch.from_numpy(output) 78 | if not torch.is_tensor(target): 79 | target = torch.from_numpy(target) 80 | 81 | if output.dim() == 1: 82 | output = output.view(-1, 1) 83 | else: 84 | assert output.dim() == 2, \ 85 | 'wrong output size (should be 1D or 2D with one column \ 86 | per class)' 87 | if target.dim() == 1: 88 | target = target.view(-1, 1) 89 | else: 90 | assert target.dim() == 2, \ 91 | 'wrong target size (should be 1D or 2D with one column \ 92 | per class)' 93 | if self.scores.numel() > 0: 94 | assert target.size(1) == self.targets.size(1), \ 95 | 'dimensions for output should match previously added examples.' 96 | 97 | # make sure storage is of sufficient size 98 | if self.scores.storage().size() < self.scores.numel() + output.numel(): 99 | new_size = math.ceil(self.scores.storage().size() * 1.5) 100 | self.scores.storage().resize_(int(new_size + output.numel())) 101 | self.targets.storage().resize_(int(new_size + output.numel())) 102 | 103 | # store scores and targets 104 | offset = self.scores.size(0) if self.scores.dim() > 0 else 0 105 | self.scores.resize_(offset + output.size(0), output.size(1)) 106 | self.targets.resize_(offset + target.size(0), target.size(1)) 107 | self.scores.narrow(0, offset, output.size(0)).copy_(output) 108 | self.targets.narrow(0, offset, target.size(0)).copy_(target) 109 | 110 | self.filenames += filename # record filenames 111 | 112 | def value(self): 113 | """Returns the model's average precision for each class 114 | Return: 115 | ap (FloatTensor): 1xK tensor, with avg precision for each class k 116 | """ 117 | 118 | if self.scores.numel() == 0: 119 | return 0 120 | ap = torch.zeros(self.scores.size(1)) 121 | rg = torch.arange(1, self.scores.size(0)).float() 122 | # compute average precision for each class 123 | for k in range(self.scores.size(1)): 124 | # sort scores 125 | scores = self.scores[:, k] 126 | targets = self.targets[:, k] 127 | # compute average precision 128 | ap[k] = AveragePrecisionMeter.average_precision(scores, targets, self.difficult_examples) 129 | return ap 130 | 131 | @staticmethod 132 | def average_precision(output, target, difficult_examples=True): 133 | 134 | # sort examples 135 | sorted, indices = torch.sort(output, dim=0, descending=True) 136 | 137 | # Computes prec@i 138 | pos_count = 0. 139 | total_count = 0. 140 | precision_at_i = 0. 141 | for i in indices: 142 | label = target[i] 143 | if difficult_examples and label == 0: 144 | continue 145 | if label == 1: 146 | pos_count += 1 147 | total_count += 1 148 | if label == 1: 149 | precision_at_i += pos_count / total_count 150 | precision_at_i /= pos_count 151 | return precision_at_i 152 | 153 | def overall(self): 154 | if self.scores.numel() == 0: 155 | return 0 156 | scores = self.scores.cpu().numpy() 157 | targets = self.targets.clone().cpu().numpy() 158 | targets[targets == -1] = 0 159 | return self.evaluation(scores, targets) 160 | 161 | def overall_topk(self, k): 162 | targets = self.targets.clone().cpu().numpy() 163 | targets[targets == -1] = 0 164 | n, c = self.scores.size() 165 | scores = np.zeros((n, c)) - 1 166 | index = self.scores.topk(k, 1, True, True)[1].cpu().numpy() 167 | tmp = self.scores.cpu().numpy() 168 | for i in range(n): 169 | for ind in index[i]: 170 | scores[i, ind] = 1 if tmp[i, ind] >= 0 else -1 171 | return self.evaluation(scores, targets) 172 | 173 | def evaluation(self, scores_, targets_): 174 | n, n_class = scores_.shape 175 | Nc, Np, Ng = np.zeros(n_class), np.zeros(n_class), np.zeros(n_class) 176 | for k in range(n_class): 177 | scores = scores_[:, k] 178 | targets = targets_[:, k] 179 | targets[targets == -1] = 0 180 | Ng[k] = np.sum(targets == 1) 181 | Np[k] = np.sum(scores >= 0) 182 | Nc[k] = np.sum(targets * (scores >= 0)) 183 | Np[Np == 0] = 1 184 | OP = np.sum(Nc) / np.sum(Np) 185 | OR = np.sum(Nc) / np.sum(Ng) 186 | OF1 = (2 * OP * OR) / (OP + OR) 187 | 188 | CP = np.sum(Nc / Np) / n_class 189 | CR = np.sum(Nc / Ng) / n_class 190 | CF1 = (2 * CP * CR) / (CP + CR) 191 | return OP, OR, OF1, CP, CR, CF1 192 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import shutil 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | # import torchnet as tnt 8 | # import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | # from torch.optim import lr_scheduler 11 | from util import AverageMeter, AveragePrecisionMeter 12 | from datetime import datetime 13 | # from pprint import pprint 14 | from tqdm import tqdm 15 | 16 | 17 | class Trainer(object): 18 | def __init__(self, model, criterion, train_loader, val_loader, args): 19 | self.model = model 20 | self.criterion = criterion 21 | self.train_loader = train_loader 22 | self.val_loader = val_loader 23 | self.args = args 24 | # pprint (self.args) 25 | print('--------Args Items----------') 26 | for k, v in vars(self.args).items(): 27 | print('{}: {}'.format(k, v)) 28 | print('--------Args Items----------\n') 29 | 30 | def initialize_optimizer_and_scheduler(self): 31 | self.optimizer = torch.optim.SGD(self.model.get_config_optim(self.args.lr, self.args.lrp), 32 | lr=self.args.lr, 33 | momentum=self.args.momentum, 34 | weight_decay=self.args.weight_decay) 35 | # self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer, self.args.epoch_step, gamma=0.1) 36 | 37 | def initialize_meters(self): 38 | self.meters = {} 39 | # meters 40 | self.meters['loss'] = AverageMeter('loss') 41 | self.meters['ap_meter'] = AveragePrecisionMeter() 42 | # time measure 43 | self.meters['batch_time'] = AverageMeter('batch_time') 44 | self.meters['data_time'] = AverageMeter('data_time') 45 | 46 | def initialization(self, is_train=False): 47 | """ initialize self.model and self.criterion here """ 48 | 49 | if is_train: 50 | self.start_epoch = 0 51 | self.epoch = 0 52 | self.end_epoch = self.args.epochs 53 | self.best_score = 0. 54 | self.lr_now = self.args.lr 55 | 56 | # initialize some settings 57 | self.initialize_optimizer_and_scheduler() 58 | 59 | self.initialize_meters() 60 | 61 | # load checkpoint if args.resume is a valid checkpint file. 62 | if os.path.isfile(self.args.resume) and self.args.resume.endswith('pth'): 63 | self.load_checkpoint() 64 | 65 | if torch.cuda.is_available(): 66 | cudnn.benchmark = True 67 | self.model = torch.nn.DataParallel(self.model).cuda() 68 | self.criterion = self.criterion.cuda() 69 | # self.train_loader.pin_memory = True 70 | # self.val_loader.pin_memory = True 71 | 72 | def reset_meters(self): 73 | for k, v in self.meters.items(): 74 | self.meters[k].reset() 75 | 76 | def on_start_epoch(self): 77 | self.reset_meters() 78 | 79 | def on_end_epoch(self, is_train=False): 80 | 81 | if is_train: 82 | # maybe you can do something like 'print the training results' here. 83 | return 84 | else: 85 | # map = self.meters['ap_meter'].value().mean() 86 | ap = self.meters['ap_meter'].value() 87 | print (ap) 88 | map = ap.mean() 89 | loss = self.meters['loss'].average() 90 | data_time = self.meters['data_time'].average() 91 | batch_time = self.meters['batch_time'].average() 92 | 93 | OP, OR, OF1, CP, CR, CF1 = self.meters['ap_meter'].overall() 94 | OP_k, OR_k, OF1_k, CP_k, CR_k, CF1_k = self.meters['ap_meter'].overall_topk(3) 95 | 96 | print('* Test\nLoss: {loss:.4f}\t mAP: {map:.4f}\t' 97 | 'Data_time: {data_time:.4f}\t Batch_time: {batch_time:.4f}'.format( 98 | loss=loss, map=map, data_time=data_time, batch_time=batch_time)) 99 | print('OP: {OP:.3f}\t OR: {OR:.3f}\t OF1: {OF1:.3f}\t' 100 | 'CP: {CP:.3f}\t CR: {CR:.3f}\t CF1: {CF1:.3f}'.format( 101 | OP=OP, OR=OR, OF1=OF1, CP=CP, CR=CR, CF1=CF1)) 102 | print('OP_3: {OP:.3f}\t OR_3: {OR:.3f}\t OF1_3: {OF1:.3f}\t' 103 | 'CP_3: {CP:.3f}\t CR_3: {CR:.3f}\t CF1_3: {CF1:.3f}'.format( 104 | OP=OP_k, OR=OR_k, OF1=OF1_k, CP=CP_k, CR=CR_k, CF1=CF1_k)) 105 | 106 | return map 107 | 108 | def on_forward(self, inputs, targets, is_train): 109 | inputs = Variable(inputs).float() 110 | targets = Variable(targets).float() 111 | 112 | if not is_train: 113 | with torch.no_grad(): 114 | outputs1, outputs2 = self.model(inputs) 115 | else: 116 | outputs1, outputs2 = self.model(inputs) 117 | outputs = (outputs1 + outputs2) / 2 118 | 119 | loss = self.criterion(outputs, targets) 120 | self.meters['loss'].update(loss.item(), inputs.size(0)) 121 | 122 | if is_train: 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.args.max_clip_grad_norm) 126 | self.optimizer.step() 127 | 128 | return outputs 129 | 130 | def adjust_learning_rate(self): 131 | """ Sets learning rate if it is needed """ 132 | lr_list = [] 133 | decay = 0.1 if sum(self.epoch == np.array(self.args.epoch_step)) > 0 else 1.0 134 | for param_group in self.optimizer.param_groups: 135 | param_group['lr'] = param_group['lr'] * decay 136 | lr_list.append(param_group['lr']) 137 | 138 | return np.unique(lr_list) 139 | 140 | def train(self): 141 | self.initialization(is_train=True) 142 | 143 | for epoch in range(self.start_epoch, self.end_epoch): 144 | self.lr_now = self.adjust_learning_rate() 145 | print ('Lr: {}'.format(self.lr_now)) 146 | 147 | self.epoch = epoch 148 | # train for one epoch 149 | self.run_iteration(self.train_loader, is_train=True) 150 | 151 | # evaluate on validation set 152 | score = self.run_iteration(self.val_loader, is_train=False) 153 | 154 | # record best score, save checkpoint and result 155 | is_best = score > self.best_score 156 | self.best_score = max(score, self.best_score) 157 | checkpoint = { 158 | 'epoch': epoch + 1, 159 | 'model_name': self.args.model_name, 160 | 'state_dict': self.model.module.state_dict() if torch.cuda.is_available() else self.model.state_dict(), 161 | 'best_score': self.best_score 162 | } 163 | model_dir = self.args.save_dir 164 | # assert os.path.exists(model_dir) == True 165 | self.save_checkpoint(checkpoint, model_dir, is_best) 166 | self.save_result(model_dir, is_best) 167 | 168 | print(' * best mAP={best:.4f}'.format(best=self.best_score)) 169 | 170 | return self.best_score 171 | 172 | def run_iteration(self, data_loader, is_train=True): 173 | self.on_start_epoch() 174 | 175 | if not is_train: 176 | data_loader = tqdm(data_loader, desc='Validate') 177 | self.model.eval() 178 | else: 179 | self.model.train() 180 | 181 | st_time = time.time() 182 | for i, data in enumerate(data_loader): 183 | 184 | # measure data loading time 185 | data_time = time.time() - st_time 186 | self.meters['data_time'].update(data_time) 187 | 188 | # inputs, targets, targets_gt, filenames = self.on_start_batch(data) 189 | inputs = data['image'] 190 | targets = data['target'] 191 | 192 | # for voc 193 | labels = targets.clone() 194 | targets[targets==0] = 1 195 | targets[targets==-1] = 0 196 | 197 | if torch.cuda.is_available(): 198 | inputs = inputs.cuda() 199 | targets = targets.cuda() 200 | 201 | outputs = self.on_forward(inputs, targets, is_train=is_train) 202 | 203 | # measure elapsed time 204 | batch_time = time.time() - st_time 205 | self.meters['batch_time'].update(batch_time) 206 | 207 | self.meters['ap_meter'].add(outputs.data, labels.data, data['name']) 208 | st_time = time.time() 209 | 210 | if is_train and i % self.args.display_interval == 0: 211 | print ('{}, {} Epoch, {} Iter, Loss: {:.4f}, Data time: {:.4f}, Batch time: {:.4f}'.format( 212 | datetime.now().strftime('%Y-%m-%d %H:%M:%S'), self.epoch+1, i, 213 | self.meters['loss'].value(), self.meters['data_time'].value(), 214 | self.meters['batch_time'].value())) 215 | 216 | return self.on_end_epoch(is_train=is_train) 217 | 218 | def validate(self): 219 | self.initialization(is_train=False) 220 | 221 | map = self.run_iteration(self.val_loader, is_train=False) 222 | 223 | model_dir = os.path.dirname(self.args.resume) 224 | assert os.path.exists(model_dir) == True 225 | self.save_result(model_dir, is_best=False) 226 | 227 | return map 228 | 229 | def load_checkpoint(self): 230 | print("* Loading checkpoint '{}'".format(self.args.resume)) 231 | checkpoint = torch.load(self.args.resume) 232 | self.start_epoch = checkpoint['epoch'] 233 | self.best_score = checkpoint['best_score'] 234 | model_dict = self.model.state_dict() 235 | for k, v in checkpoint['state_dict'].items(): 236 | if k in model_dict and v.shape == model_dict[k].shape: 237 | model_dict[k] = v 238 | else: 239 | print ('\tMismatched layers: {}'.format(k)) 240 | self.model.load_state_dict(model_dict) 241 | 242 | def save_checkpoint(self, checkpoint, model_dir, is_best=False): 243 | if not os.path.exists(model_dir): 244 | os.makedirs(model_dir) 245 | 246 | # filename = 'Epoch-{}.pth'.format(self.epoch) 247 | filename = 'checkpoint.pth' 248 | res_path = os.path.join(model_dir, filename) 249 | print('Save checkpoint to {}'.format(res_path)) 250 | torch.save(checkpoint, res_path) 251 | if is_best: 252 | filename_best = 'checkpoint_best.pth' 253 | res_path_best = os.path.join(model_dir, filename_best) 254 | shutil.copyfile(res_path, res_path_best) 255 | 256 | def save_result(self, model_dir, is_best=False): 257 | if not os.path.exists(model_dir): 258 | os.makedirs(model_dir) 259 | 260 | # filename = 'results.csv' if not is_best else 'best_results.csv' 261 | filename = 'results.csv' 262 | res_path = os.path.join(model_dir, filename) 263 | print('Save results to {}'.format(res_path)) 264 | with open(res_path, 'w') as fid: 265 | for i in range(self.meters['ap_meter'].scores.shape[0]): 266 | fid.write('{},{},{}\n'.format(self.meters['ap_meter'].filenames[i], 267 | ','.join(map(str,self.meters['ap_meter'].scores[i].numpy())), 268 | ','.join(map(str,self.meters['ap_meter'].targets[i].numpy())))) 269 | 270 | if is_best: 271 | filename_best = 'output_best.csv' 272 | res_path_best = os.path.join(model_dir, filename_best) 273 | shutil.copyfile(res_path, res_path_best) 274 | -------------------------------------------------------------------------------- /data/voc.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import tarfile 4 | from urllib.parse import urlparse 5 | from urllib.request import urlretrieve 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 13 | 'bottle', 'bus', 'car', 'cat', 'chair', 14 | 'cow', 'diningtable', 'dog', 'horse', 15 | 'motorbike', 'person', 'pottedplant', 16 | 'sheep', 'sofa', 'train', 'tvmonitor'] 17 | 18 | urls2007 = { 19 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 20 | 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 21 | 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 22 | 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', 23 | } 24 | 25 | urls2012 = { 26 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 27 | # 'trainval_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_06-Nov-2012.tar', 28 | 'trainval_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 29 | # 'test_images_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtest_06-Nov-2012.tar', 30 | 'test_images_2012': 'http://pjreddie.com/media/files/VOC2012test.tar', 31 | # 'test_anno_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtestnoimgs_06-Nov-2012.tar', 32 | } 33 | 34 | 35 | def download_url(url, destination=None, progress_bar=True): 36 | """Download a URL to a local file. 37 | 38 | Parameters 39 | ---------- 40 | url : str 41 | The URL to download. 42 | destination : str, None 43 | The destination of the file. If None is given the file is saved to a temporary directory. 44 | progress_bar : bool 45 | Whether to show a command-line progress bar while downloading. 46 | 47 | Returns 48 | ------- 49 | filename : str 50 | The location of the downloaded file. 51 | 52 | Notes 53 | ----- 54 | Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm 55 | """ 56 | 57 | def my_hook(t): 58 | last_b = [0] 59 | 60 | def inner(b=1, bsize=1, tsize=None): 61 | if tsize is not None: 62 | t.total = tsize 63 | if b > 0: 64 | t.update((b - last_b[0]) * bsize) 65 | last_b[0] = b 66 | 67 | return inner 68 | 69 | if progress_bar: 70 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: 71 | filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t)) 72 | else: 73 | filename, _ = urlretrieve(url, filename=destination) 74 | 75 | 76 | def read_image_label(file): 77 | print('[dataset] read ' + file) 78 | data = dict() 79 | with open(file, 'r') as f: 80 | for line in f: 81 | tmp = line.split(' ') 82 | name = tmp[0] 83 | label = int(tmp[-1]) 84 | data[name] = label 85 | return data 86 | 87 | 88 | def read_object_labels(root, dataset, phase): 89 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 90 | labeled_data = dict() 91 | num_classes = len(object_categories) 92 | 93 | for i in range(num_classes): 94 | file = os.path.join(path_labels, object_categories[i] + '_' + phase + '.txt') 95 | data = read_image_label(file) 96 | 97 | if i == 0: 98 | for (name, label) in data.items(): 99 | labels = np.zeros(num_classes) 100 | labels[i] = label 101 | labeled_data[name] = labels 102 | else: 103 | for (name, label) in data.items(): 104 | labeled_data[name][i] = label 105 | 106 | return labeled_data 107 | 108 | 109 | def write_object_labels_csv(file, labeled_data): 110 | # write a csv file 111 | print('[dataset] write file %s' % file) 112 | with open(file, 'w') as csvfile: 113 | fieldnames = ['name'] 114 | fieldnames.extend(object_categories) 115 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 116 | 117 | writer.writeheader() 118 | for (name, labels) in labeled_data.items(): 119 | example = {'name': name} 120 | for i in range(20): 121 | example[fieldnames[i + 1]] = int(labels[i]) 122 | writer.writerow(example) 123 | 124 | csvfile.close() 125 | 126 | 127 | def read_object_labels_csv(file, header=True): 128 | images = [] 129 | num_categories = 0 130 | print('[dataset] read', file) 131 | with open(file, 'r') as f: 132 | reader = csv.reader(f) 133 | rownum = 0 134 | for row in reader: 135 | if header and rownum == 0: 136 | header = row 137 | else: 138 | if num_categories == 0: 139 | num_categories = len(row) - 1 140 | name = row[0] 141 | labels = torch.from_numpy((np.asarray(row[1:num_categories + 1])).astype(np.float32)) 142 | item = (name, labels) 143 | images.append(item) 144 | rownum += 1 145 | return images 146 | 147 | 148 | # def find_images_classification(root, dataset, phase): 149 | # path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 150 | # images = [] 151 | # file = os.path.join(path_labels, phase + '.txt') 152 | # with open(file, 'r') as f: 153 | # for line in f: 154 | # images.append(line) 155 | # return images 156 | 157 | 158 | def download_voc2007(root): 159 | path_devkit = os.path.join(root, 'VOCdevkit') 160 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 161 | tmpdir = os.path.join(root, 'tmp') 162 | 163 | # create directory 164 | if not os.path.exists(root): 165 | os.makedirs(root) 166 | 167 | if not os.path.exists(path_devkit): 168 | 169 | if not os.path.exists(tmpdir): 170 | os.makedirs(tmpdir) 171 | 172 | parts = urlparse(urls2007['devkit']) 173 | filename = os.path.basename(parts.path) 174 | cached_file = os.path.join(tmpdir, filename) 175 | 176 | if not os.path.exists(cached_file): 177 | print('Downloading: "{}" to {}\n'.format(urls2007['devkit'], cached_file)) 178 | download_url(urls2007['devkit'], cached_file) 179 | 180 | # extract file 181 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 182 | cwd = os.getcwd() 183 | tar = tarfile.open(cached_file, "r") 184 | os.chdir(root) 185 | tar.extractall() 186 | tar.close() 187 | os.chdir(cwd) 188 | print('[dataset] Done!') 189 | 190 | # train/val images/annotations 191 | if not os.path.exists(path_images): 192 | 193 | # download train/val images/annotations 194 | parts = urlparse(urls2007['trainval_2007']) 195 | filename = os.path.basename(parts.path) 196 | cached_file = os.path.join(tmpdir, filename) 197 | 198 | if not os.path.exists(cached_file): 199 | print('Downloading: "{}" to {}\n'.format(urls2007['trainval_2007'], cached_file)) 200 | download_url(urls2007['trainval_2007'], cached_file) 201 | 202 | # extract file 203 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 204 | cwd = os.getcwd() 205 | tar = tarfile.open(cached_file, "r") 206 | os.chdir(root) 207 | tar.extractall() 208 | tar.close() 209 | os.chdir(cwd) 210 | print('[dataset] Done!') 211 | 212 | # test images 213 | test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') 214 | if not os.path.exists(test_image): 215 | 216 | # download test images 217 | parts = urlparse(urls2007['test_images_2007']) 218 | filename = os.path.basename(parts.path) 219 | cached_file = os.path.join(tmpdir, filename) 220 | 221 | if not os.path.exists(cached_file): 222 | print('Downloading: "{}" to {}\n'.format(urls2007['test_images_2007'], cached_file)) 223 | download_url(urls2007['test_images_2007'], cached_file) 224 | 225 | # extract file 226 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 227 | cwd = os.getcwd() 228 | tar = tarfile.open(cached_file, "r") 229 | os.chdir(root) 230 | tar.extractall() 231 | tar.close() 232 | os.chdir(cwd) 233 | print('[dataset] Done!') 234 | 235 | # test annotations 236 | test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') 237 | if not os.path.exists(test_anno): 238 | 239 | # download test annotations 240 | parts = urlparse(urls2007['test_anno_2007']) 241 | filename = os.path.basename(parts.path) 242 | cached_file = os.path.join(tmpdir, filename) 243 | 244 | if not os.path.exists(cached_file): 245 | print('Downloading: "{}" to {}\n'.format(urls2007['test_anno_2007'], cached_file)) 246 | download_url(urls2007['test_anno_2007'], cached_file) 247 | 248 | # extract file 249 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 250 | cwd = os.getcwd() 251 | tar = tarfile.open(cached_file, "r") 252 | os.chdir(root) 253 | tar.extractall() 254 | tar.close() 255 | os.chdir(cwd) 256 | print('[dataset] Done!') 257 | 258 | 259 | def download_voc2012(root): 260 | path_devkit = os.path.join(root, 'VOCdevkit') 261 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2012', 'JPEGImages') 262 | tmpdir = os.path.join(root, 'tmp') 263 | 264 | # create directory 265 | if not os.path.exists(root): 266 | os.makedirs(root) 267 | 268 | if not os.path.exists(path_devkit): 269 | 270 | if not os.path.exists(tmpdir): 271 | os.makedirs(tmpdir) 272 | 273 | parts = urlparse(urls2012['devkit']) 274 | filename = os.path.basename(parts.path) 275 | cached_file = os.path.join(tmpdir, filename) 276 | 277 | if not os.path.exists(cached_file): 278 | print('Downloading: "{}" to {}\n'.format(urls2012['devkit'], cached_file)) 279 | download_url(urls2012['devkit'], cached_file) 280 | 281 | # extract file 282 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 283 | cwd = os.getcwd() 284 | tar = tarfile.open(cached_file, "r") 285 | os.chdir(root) 286 | tar.extractall() 287 | tar.close() 288 | os.chdir(cwd) 289 | print('[dataset] Done!') 290 | 291 | # train/val images/annotations 292 | if not os.path.exists(path_images): 293 | 294 | # download train/val images/annotations 295 | parts = urlparse(urls2012['trainval_2012']) 296 | filename = os.path.basename(parts.path) 297 | cached_file = os.path.join(tmpdir, filename) 298 | 299 | if not os.path.exists(cached_file): 300 | print('Downloading: "{}" to {}\n'.format(urls2012['trainval_2012'], cached_file)) 301 | download_url(urls2012['trainval_2012'], cached_file) 302 | 303 | # extract file 304 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 305 | cwd = os.getcwd() 306 | tar = tarfile.open(cached_file, "r") 307 | os.chdir(root) 308 | tar.extractall() 309 | tar.close() 310 | os.chdir(cwd) 311 | print('[dataset] Done!') 312 | 313 | # test images 314 | test_image = os.path.join(path_devkit, 'VOC2012/JPEGImages/2012_000001.jpg') 315 | if not os.path.exists(test_image): 316 | 317 | # download test images 318 | parts = urlparse(urls2012['test_images_2012']) 319 | filename = os.path.basename(parts.path) 320 | cached_file = os.path.join(tmpdir, filename) 321 | 322 | if not os.path.exists(cached_file): 323 | print('Downloading: "{}" to {}\n'.format(urls2012['test_images_2012'], cached_file)) 324 | download_url(urls2012['test_images_2012'], cached_file) 325 | 326 | # extract file 327 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 328 | cwd = os.getcwd() 329 | tar = tarfile.open(cached_file, "r") 330 | os.chdir(root) 331 | tar.extractall() 332 | tar.close() 333 | os.chdir(cwd) 334 | print('[dataset] Done!') 335 | 336 | 337 | class VOC2007(Dataset): 338 | def __init__(self, root, phase, transform=None): 339 | self.root = os.path.abspath(root) 340 | self.path_devkit = os.path.join(self.root, 'VOCdevkit') 341 | self.path_images = os.path.join(self.root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 342 | self.phase = phase 343 | self.transform = transform 344 | download_voc2007(self.root) 345 | 346 | # define path of csv file 347 | path_csv = os.path.join(self.root, 'files', 'VOC2007') 348 | # define filename of csv file 349 | file_csv = os.path.join(path_csv, 'classification_' + phase + '.csv') 350 | 351 | # create the csv file if necessary 352 | if not os.path.exists(file_csv): 353 | if not os.path.exists(path_csv): # create dir if necessary 354 | os.makedirs(path_csv) 355 | # generate csv file 356 | labeled_data = read_object_labels(self.root, 'VOC2007', self.phase) 357 | # write csv file 358 | write_object_labels_csv(file_csv, labeled_data) 359 | 360 | self.classes = object_categories 361 | self.images = read_object_labels_csv(file_csv) 362 | print('[dataset] VOC 2007 classification phase={} number of classes={} number of images={}'.format(phase, len(self.classes), len(self.images))) 363 | 364 | def __getitem__(self, index): 365 | filename, target = self.images[index] 366 | img = Image.open(os.path.join(self.path_images, filename + '.jpg')).convert('RGB') 367 | if self.transform is not None: 368 | img = self.transform(img) 369 | 370 | data = {'image':img, 'name': filename, 'target': target} 371 | return data 372 | # image = {'image': img, 'name': filename} 373 | # return image, target 374 | # return (img, filename), target 375 | 376 | def __len__(self): 377 | return len(self.images) 378 | 379 | def get_number_classes(self): 380 | return len(self.classes) 381 | 382 | 383 | class VOC2012(Dataset): 384 | def __init__(self, root, phase, transform=None): 385 | self.root = os.path.abspath(root) 386 | self.path_devkit = os.path.join(self.root, 'VOCdevkit') 387 | self.path_images = os.path.join(self.root, 'VOCdevkit', 'VOC2012', 'JPEGImages') 388 | self.phase = phase 389 | self.transform = transform 390 | download_voc2012(self.root) 391 | 392 | # define path of csv file 393 | path_csv = os.path.join(self.root, 'files', 'VOC2012') 394 | # define filename of csv file 395 | file_csv = os.path.join(path_csv, 'classification_' + phase + '.csv') 396 | 397 | # create the csv file if necessary 398 | if not os.path.exists(file_csv): 399 | if not os.path.exists(path_csv): # create dir if necessary 400 | os.makedirs(path_csv) 401 | # generate csv file 402 | labeled_data = read_object_labels(self.root, 'VOC2012', self.phase) 403 | # write csv file 404 | write_object_labels_csv(file_csv, labeled_data) 405 | 406 | self.classes = object_categories 407 | self.images = read_object_labels_csv(file_csv) 408 | print('[dataset] VOC 2012 classification phase={} number of classes={} number of images={}'.format(phase, len(self.classes), len(self.images))) 409 | 410 | def __getitem__(self, index): 411 | filename, target = self.images[index] 412 | img = Image.open(os.path.join(self.path_images, filename + '.jpg')).convert('RGB') 413 | if self.transform is not None: 414 | img = self.transform(img) 415 | 416 | data = {'image':img, 'name': filename, 'target': target} 417 | return data 418 | # image = {'image': img, 'name': filename} 419 | # return image, target 420 | # return (img, filename), target 421 | 422 | def __len__(self): 423 | return len(self.images) 424 | 425 | def get_number_classes(self): 426 | return len(self.classes) 427 | --------------------------------------------------------------------------------