├── LICENSE ├── README.md ├── __pycache__ ├── active_learning.cpython-36.pyc ├── csd.cpython-36.pyc ├── loaders.cpython-36.pyc ├── pseudo_labels.cpython-36.pyc ├── ssd.cpython-36.pyc └── subset_sequential_sampler.cpython-36.pyc ├── active_learning.py ├── csd.py ├── data ├── COCO.txt ├── ONLY_VOC_IN_COCO.txt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── coco.cpython-36.pyc │ ├── coco.cpython-37.pyc │ ├── coco.cpython-38.pyc │ ├── coco_new.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── config.cpython-38.pyc │ ├── voc0712.cpython-36.pyc │ ├── voc0712.cpython-37.pyc │ ├── voc0712.cpython-38.pyc │ ├── voc0712_backup.cpython-36.pyc │ ├── voc0712_consistency.cpython-36.pyc │ ├── voc07_consistency.cpython-36.pyc │ ├── voc07_consistency.cpython-38.pyc │ ├── voc07_consistency_init.cpython-36.pyc │ ├── voc07_consistency_init.cpython-38.pyc │ └── voc07_voc12coco.cpython-36.pyc ├── coco.py ├── coco │ └── coco_labels.txt ├── config.py ├── example.jpg ├── only_voc.txt ├── scripts │ ├── COCO2014.sh │ ├── VOC2007.sh │ └── VOC2012.sh └── voc0712.py ├── eval.py ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── box_utils.cpython-36.pyc │ ├── box_utils.cpython-37.pyc │ └── box_utils.cpython-38.pyc ├── box_utils.py ├── functions │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── detection.cpython-36.pyc │ │ ├── detection.cpython-37.pyc │ │ ├── detection.cpython-38.pyc │ │ ├── prior_box.cpython-36.pyc │ │ ├── prior_box.cpython-37.pyc │ │ └── prior_box.cpython-38.pyc │ ├── detection.py │ └── prior_box.py └── modules │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── l2norm.cpython-36.pyc │ ├── l2norm.cpython-37.pyc │ ├── l2norm.cpython-38.pyc │ ├── multibox_loss.cpython-36.pyc │ ├── multibox_loss.cpython-37.pyc │ ├── multibox_loss.cpython-38.pyc │ ├── multibox_loss_gmm.cpython-36.pyc │ ├── multibox_loss_new.cpython-36.pyc │ ├── multibox_loss_new.cpython-37.pyc │ └── multibox_loss_new.cpython-38.pyc │ ├── l2norm.py │ └── multibox_loss.py ├── loaders.py ├── pseudo_labels.py ├── ssd.py ├── subset_sequential_sampler.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── augmentations.cpython-36.pyc ├── augmentations.cpython-37.pyc └── augmentations.cpython-38.pyc └── augmentations.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | 3 | Nvidia Source Code License-NC 4 | 5 | 1. Definitions 6 | 7 | “Licensor” means any person or entity that distributes its Work. 8 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 10 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 11 | 12 | 2. License Grant 13 | 14 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 15 | 16 | 3. Limitations 17 | 18 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 19 | 20 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 21 | 22 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 23 | 24 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 25 | 26 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 27 | 28 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 29 | 30 | 4. Disclaimer of Warranty. 31 | 32 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 33 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 34 | 35 | 5. Limitation of Liability. 36 | 37 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Not All Labels Are Equal:Rationalizing The Labeling Costs for Training Object Detection 2 | 3 | This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for: 4 | 5 | [Not All Labels Are Equal:Rationalizing The Labeling Costs for Training Object Detection](https://openaccess.thecvf.com/content/CVPR2022/papers/Elezi_Not_All_Labels_Are_Equal_Rationalizing_the_Labeling_Costs_for_CVPR_2022_paper.pdf)
6 | Ismail Elezi, Zhiding Yu, Anima Anandkumar, Laura Leal-Taixe, and Jose M. Alvarez.
7 | CVPR 2022. 8 | 9 | 10 | [Code](https://github.com/NVlabs/AL-SSL) 11 | 12 | 13 | ## Installation & Preparation 14 | We experimented with the SSD in the PyTorch framework. To use our model, complete the installation & preparation on the [SSD pytorch homepage](https://github.com/amdegroot/ssd.pytorch) 15 | 16 | #### prerequisites 17 | - Python 3.6 18 | - Pytorch 1.1.0 19 | 20 | ## Training 21 | ```Shell 22 | python train.py 23 | ``` 24 | 25 | ## Evaluation 26 | ```Shell 27 | python eval.py 28 | ``` 29 | 30 | 31 | 32 | ## License 33 | Copyright © 2022-2023, NVIDIA Corporation and Affiliates. All rights reserved. 34 | 35 | This work is made available under the Nvidia Source Code License-NC. Click [here](https://github.com/NVlabs/AL-SSL/blob/main/LICENSE) to view a copy of this license. 36 | 37 | The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original. 38 | 39 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). 40 | 41 | 42 | ## Citation 43 | 44 | If you find this code useful, please consider citing the following paper: 45 | 46 | ```` 47 | @inproceedings{DBLP:conf/cvpr/EleziYALA22, 48 | author = {Ismail Elezi and 49 | Zhiding Yu and 50 | Anima Anandkumar and 51 | Laura Leal{-}Taix{\'{e}} and 52 | Jose M. Alvarez}, 53 | title = {Not All Labels Are Equal: Rationalizing The Labeling Costs for Training 54 | Object Detection}, 55 | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition, 56 | {CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022}, 57 | pages = {14472--14481}, 58 | publisher = {{IEEE}}, 59 | year = {2022}, 60 | url = {https://doi.org/10.1109/CVPR52688.2022.01409}, 61 | doi = {10.1109/CVPR52688.2022.01409}, 62 | timestamp = {Wed, 05 Oct 2022 16:31:19 +0200}, 63 | biburl = {https://dblp.org/rec/conf/cvpr/EleziYALA22.bib}, 64 | bibsource = {dblp computer science bibliography, https://dblp.org} 65 | } 66 | ```` 67 | 68 | We use the semi-consistency method CSD as developed by Jeong et al. Consider citing the following paper 69 | 70 | ```` 71 | @inproceedings{DBLP:conf/nips/JeongLKK19, 72 | author = {Jisoo Jeong and 73 | Seungeui Lee and 74 | Jeesoo Kim and 75 | Nojun Kwak}, 76 | editor = {Hanna M. Wallach and 77 | Hugo Larochelle and 78 | Alina Beygelzimer and 79 | Florence d'Alch{\'{e}}{-}Buc and 80 | Emily B. Fox and 81 | Roman Garnett}, 82 | title = {Consistency-based Semi-supervised Learning for Object detection}, 83 | booktitle = {Advances in Neural Information Processing Systems 32: Annual Conference 84 | on Neural Information Processing Systems 2019, NeurIPS 2019, December 85 | 8-14, 2019, Vancouver, BC, Canada}, 86 | pages = {10758--10767}, 87 | year = {2019}, 88 | url = {https://proceedings.neurips.cc/paper/2019/hash/d0f4dae80c3d0277922f8371d5827292-Abstract.html}, 89 | timestamp = {Mon, 16 May 2022 15:41:51 +0200}, 90 | biburl = {https://dblp.org/rec/conf/nips/JeongLKK19.bib}, 91 | bibsource = {dblp computer science bibliography, https://dblp.org} 92 | } 93 | ```` 94 | -------------------------------------------------------------------------------- /__pycache__/active_learning.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/active_learning.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/csd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/csd.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/loaders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/loaders.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/pseudo_labels.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/pseudo_labels.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/ssd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/ssd.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/subset_sequential_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/subset_sequential_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /active_learning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from collections import defaultdict 12 | from copy import deepcopy 13 | 14 | from layers.box_utils import decode, nms 15 | 16 | 17 | def class_consistency_loss_al(conf, conf_flip): 18 | conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda() 19 | 20 | conf_class = conf.clone() 21 | conf_class_flip = conf_flip.clone() 22 | 23 | consistency_conf_loss_a = conf_consistency_criterion(conf_class.log(), 24 | conf_class_flip.detach()).sum(-1) 25 | consistency_conf_loss_b = conf_consistency_criterion(conf_class_flip.log(), 26 | conf_class.detach()).sum(-1) 27 | return (consistency_conf_loss_a + consistency_conf_loss_b) / 2 28 | 29 | 30 | def active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net, num_classes, 31 | criterion_select='consistency_class', loader=None): 32 | criterion_UC = np.zeros(len(batch_iterator)) 33 | batch_iterator = iter(loader) 34 | thresh = args.thresh 35 | 36 | for j in range(len(batch_iterator)): # 3000 37 | print(j) 38 | images, lab, _ = next(batch_iterator) 39 | images = images.cuda() 40 | 41 | out, conf, conf_flip, loc, loc_flip, _ = net(images) 42 | loc, _, priors = out 43 | 44 | num = loc.size(0) # batch size 45 | num_priors = priors.size(0) 46 | output = torch.zeros(num, num_classes, 200, 6) 47 | conf_preds = conf.view(num, num_priors, num_classes).transpose(2, 1) 48 | conf_preds_flip = conf_flip.view(num, num_priors, num_classes).transpose(2, 1) 49 | variance = [0.1, 0.2] 50 | 51 | # Decode predictions into bboxes. 52 | for i in range(num): 53 | decoded_boxes = decode(loc[i], priors, variance) 54 | decoded_boxes_flip = decode(loc_flip[i], priors, variance) 55 | # For each class, perform nms 56 | conf_scores = conf_preds[i].clone() 57 | conf_scores_flip = conf_preds_flip[i].clone() 58 | 59 | if criterion_select == 'consistency_class' or criterion_select == 'consistency': 60 | H = class_consistency_loss_al(conf, conf_flip).squeeze() 61 | else: 62 | H = conf_scores * ( 63 | torch.log(conf_scores) / torch.log(torch.tensor(num_classes).type(torch.FloatTensor))) 64 | H = H.sum(dim=0) * (-1.0) 65 | 66 | for cl in range(1, num_classes): 67 | c_mask = conf_scores[cl].gt(0.01) # confidence threshold 68 | scores = conf_scores[cl][c_mask] 69 | Entropy = H[c_mask] 70 | if scores.size(0) == 0: 71 | continue 72 | 73 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 74 | boxes = decoded_boxes[l_mask].view(-1, 4) 75 | 76 | ids, count = nms(boxes.detach(), scores.detach(), 0.5, 200) 77 | output[i, cl, :count] = torch.cat( 78 | (scores[ids[:count]].unsqueeze(1), boxes[ids[:count]], Entropy[ids[:count]].unsqueeze(1)), 1) 79 | 80 | count_num = 0 81 | UC_max = 0 82 | for p in range(output.size(1)): # [1, 21, 200, 9] 83 | q = 0 84 | while output[0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration? 85 | count_num += 1 86 | score = output[0, p, q, 0] 87 | entropy = output[0, p, q, 5:6] 88 | UC_max_temp = entropy.item() 89 | if (UC_max < UC_max_temp): 90 | UC_max = UC_max_temp 91 | q += 1 92 | 93 | if count_num == 0: 94 | criterion_UC[j] = 0 95 | else: 96 | criterion_UC[j] = UC_max 97 | 98 | if args.criterion_select == 'combined': 99 | return criterion_UC 100 | 101 | sorted_indices = np.argsort(criterion_UC)[::-1] 102 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]]) 103 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]]) 104 | 105 | # assert that sizes of lists are correct and that there are no elements that are in both lists 106 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images 107 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0 108 | 109 | # save the labeled set 110 | return labeled_set, unlabeled_set 111 | 112 | 113 | def active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net, num_classes, criterion_select, loader=None): 114 | criterion_UC = np.zeros(len(batch_iterator)) 115 | batch_iterator = iter(loader) 116 | thresh = args.thresh 117 | 118 | for j in range(len(batch_iterator)): # 3000 119 | print(j) 120 | images, lab, _ = next(batch_iterator) 121 | images = images.cuda() 122 | 123 | out, _, _, _, _, _ = net(images) 124 | loc, conf, priors = out 125 | conf = torch.softmax(conf.detach(), dim=2) 126 | 127 | num = loc.size(0) # batch size 128 | num_priors = priors.size(0) 129 | output = torch.zeros(num, num_classes, 200, 6) 130 | conf_preds = conf.view(num, num_priors, num_classes).transpose(2, 1) 131 | variance = [0.1, 0.2] 132 | # Decode predictions into bboxes. 133 | for i in range(num): 134 | decoded_boxes = decode(loc[i], priors, variance) 135 | # For each class, perform nms 136 | conf_scores = conf_preds[i].clone() 137 | 138 | H = conf_scores * (torch.log(conf_scores) / torch.log(torch.tensor(num_classes).type(torch.FloatTensor))) 139 | H = H.sum(dim=0) * (-1.0) 140 | 141 | for cl in range(1, num_classes): 142 | c_mask = conf_scores[cl].gt(0.01) # confidence threshold 143 | scores = conf_scores[cl][c_mask] 144 | Entropy = H[c_mask] # jwchoi 145 | if scores.size(0) == 0: 146 | continue 147 | 148 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 149 | boxes = decoded_boxes[l_mask].view(-1, 4) 150 | 151 | ids, count = nms(boxes.detach(), scores.detach(), 0.5, 200) 152 | output[i, cl, :count] = torch.cat( 153 | (scores[ids[:count]].unsqueeze(1), boxes[ids[:count]], Entropy[ids[:count]].unsqueeze(1)), 1) 154 | 155 | if criterion_select == 'random': 156 | val = np.random.normal(0, 1, 1) 157 | criterion_UC[j] = val[0] 158 | 159 | 160 | elif criterion_select == 'Max_aver': 161 | count_num = 0 162 | # UC_sum = 0 163 | UC_max = 0 164 | for p in range(output.size(1)): # [1, 21, 200, 9] 165 | q = 0 166 | while output[ 167 | 0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration? 168 | count_num += 1 169 | score = output[0, p, q, 0] 170 | entropy = output[0, p, q, 5:6] 171 | UC_max += entropy.item() 172 | q += 1 173 | 174 | if count_num == 0: 175 | criterion_UC[j] = 0 176 | else: 177 | criterion_UC[j] = UC_max / count_num 178 | 179 | else: 180 | count_num = 0 181 | # UC_sum = 0 182 | UC_max = 0 183 | for p in range(output.size(1)): # [1, 21, 200, 9] 184 | q = 0 185 | while output[ 186 | 0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration? 187 | count_num += 1 188 | score = output[0, p, q, 0] 189 | entropy = output[0, p, q, 5:6] 190 | UC_max_temp = entropy.item() 191 | if (UC_max < UC_max_temp): 192 | UC_max = UC_max_temp 193 | q += 1 194 | 195 | if count_num == 0: # UC_sum == 0: 196 | # if the net cannot detect anything, give it a very high inconsistency score, that image is hard 197 | criterion_UC[j] = 0 198 | else: 199 | criterion_UC[j] = UC_max 200 | 201 | if args.criterion_select == 'combined': 202 | return criterion_UC 203 | sorted_indices = np.argsort(criterion_UC)[::-1] 204 | 205 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]]) 206 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]]) 207 | 208 | # assert that sizes of lists are correct and that there are no elements that are in both lists 209 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images 210 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0 211 | 212 | # save the labeled set 213 | return labeled_set, unlabeled_set 214 | 215 | 216 | def combined_score(args, batch_iterator, labeled_set, unlabeled_set, net, unsupervised_data_loader): 217 | entropy_score = active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net, 218 | args.cfg['num_classes'], 'entropy', 219 | loader=unsupervised_data_loader) 220 | consistency_score = active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net, 221 | args.cfg['num_classes'], 'consistency_class', 222 | loader=unsupervised_data_loader) 223 | 224 | len_arr = len(entropy_score) 225 | ind = np.argpartition(entropy_score, -args.filter_entropy_num)[:len_arr - args.filter_entropy_num] 226 | 227 | consistency_score[ind] = 0. 228 | 229 | sorted_indices = np.argsort(consistency_score)[::-1] 230 | 231 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]]) 232 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]]) 233 | 234 | # assert that sizes of lists are correct and that there are no elements that are in both lists 235 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images 236 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0 237 | 238 | return labeled_set, unlabeled_set 239 | -------------------------------------------------------------------------------- /csd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc300, voc512, coco 7 | import os 8 | import warnings 9 | import math 10 | import numpy as np 11 | import cv2 12 | 13 | 14 | class SSD_CON(nn.Module): 15 | """Single Shot Multibox Architecture 16 | The network is composed of a base VGG network followed by the 17 | added multibox conv layers. Each multibox layer branches into 18 | 1) conv2d for class conf scores 19 | 2) conv2d for localization predictions 20 | 3) associated priorbox layer to produce default bounding 21 | boxes specific to the layer's feature map size. 22 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 23 | 24 | Args: 25 | phase: (string) Can be "test" or "train" 26 | size: input image size 27 | base: VGG16 layers for input, size of either 300 or 500 28 | extras: extra layers that feed to multibox loc and conf layers 29 | head: "multibox head" consists of loc and conf conv layers 30 | """ 31 | 32 | def __init__(self, phase, size, base, extras, head, num_classes): 33 | super(SSD_CON, self).__init__() 34 | self.phase = phase 35 | self.num_classes = num_classes 36 | if(size==300): 37 | self.cfg = (coco, voc300)[num_classes == 21] 38 | else: 39 | self.cfg = (coco, voc512)[num_classes == 21] 40 | self.priorbox = PriorBox(self.cfg) 41 | self.priors = Variable(self.priorbox.forward(), volatile=True) 42 | self.size = size 43 | 44 | # SSD network 45 | self.vgg = nn.ModuleList(base) 46 | # Layer learns to scale the l2 normalized features from conv4_3 47 | self.L2Norm = L2Norm(512, 20) 48 | self.extras = nn.ModuleList(extras) 49 | 50 | self.loc = nn.ModuleList(head[0]) 51 | self.conf = nn.ModuleList(head[1]) 52 | 53 | self.softmax = nn.Softmax(dim=-1) 54 | 55 | if phase == 'test': 56 | # self.softmax = nn.Softmax(dim=-1) 57 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 58 | 59 | def forward(self, x): 60 | """Applies network layers and ops on input image(s) x. 61 | 62 | Args: 63 | x: input image or batch of images. Shape: [batch,3,300,300]. 64 | 65 | Return: 66 | Depending on phase: 67 | test: 68 | Variable(tensor) of output class label predictions, 69 | confidence score, and corresponding location predictions for 70 | each object detected. Shape: [batch,topk,7] 71 | 72 | train: 73 | list of concat outputs from: 74 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 75 | 2: localization layers, Shape: [batch,num_priors*4] 76 | 3: priorbox layers, Shape: [2,num_priors*4] 77 | """ 78 | 79 | 80 | x_flip = x.clone() 81 | x_flip = flip(x_flip,3) 82 | 83 | sources = list() 84 | loc = list() 85 | conf = list() 86 | 87 | # apply vgg up to conv4_3 relu 88 | for k in range(23): 89 | x = self.vgg[k](x) 90 | 91 | s = self.L2Norm(x) 92 | sources.append(s) 93 | 94 | # apply vgg up to fc7 95 | for k in range(23, len(self.vgg)): 96 | x = self.vgg[k](x) 97 | 98 | # just a name so we can later point to return this 99 | sources.append(x) 100 | feats = x 101 | 102 | # apply extra layers and cache source layer outputs 103 | for k, v in enumerate(self.extras): 104 | x = F.relu(v(x), inplace=True) 105 | if k % 2 == 1: 106 | sources.append(x) 107 | 108 | # apply multibox head to source layers 109 | for (x, l, c) in zip(sources, self.loc, self.conf): 110 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 111 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 112 | 113 | 114 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 115 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 116 | # zero_mask = torch.cat([o.view(o.size(0), -1) for o in zero_mask], 1) 117 | 118 | if self.phase == "test": 119 | output = self.detect( 120 | loc.view(loc.size(0), -1, 4), # loc preds 121 | self.softmax(conf.view(conf.size(0), -1, 122 | self.num_classes)), # conf preds 123 | self.priors.type(type(x.data)) # default boxes 124 | ) 125 | else: 126 | output = ( 127 | loc.view(loc.size(0), -1, 4), 128 | conf.view(conf.size(0), -1, self.num_classes), 129 | self.priors 130 | ) 131 | 132 | loc = loc.view(loc.size(0), -1, 4) 133 | conf = self.softmax(conf.view(conf.size(0), -1, self.num_classes)) 134 | # basic 135 | 136 | sources_flip = list() 137 | loc_flip = list() 138 | conf_flip = list() 139 | 140 | # apply vgg up to conv4_3 relu 141 | for k in range(23): 142 | x_flip = self.vgg[k](x_flip) 143 | 144 | s_flip = self.L2Norm(x_flip) 145 | sources_flip.append(s_flip) 146 | 147 | # apply vgg up to fc7 148 | for k in range(23, len(self.vgg)): 149 | x_flip = self.vgg[k](x_flip) 150 | sources_flip.append(x_flip) 151 | 152 | # apply extra layers and cache source layer outputs 153 | for k, v in enumerate(self.extras): 154 | x_flip = F.relu(v(x_flip), inplace=True) 155 | if k % 2 == 1: 156 | sources_flip.append(x_flip) 157 | 158 | # apply multibox head to source layers 159 | for (x_flip, l, c) in zip(sources_flip, self.loc, self.conf): 160 | append_loc = l(x_flip).permute(0, 2, 3, 1).contiguous() 161 | append_conf = c(x_flip).permute(0, 2, 3, 1).contiguous() 162 | append_loc = flip(append_loc,2) 163 | append_conf = flip(append_conf,2) 164 | loc_flip.append(append_loc) 165 | conf_flip.append(append_conf) 166 | 167 | loc_flip = torch.cat([o.view(o.size(0), -1) for o in loc_flip], 1) 168 | conf_flip = torch.cat([o.view(o.size(0), -1) for o in conf_flip], 1) 169 | 170 | loc_flip = loc_flip.view(loc_flip.size(0), -1, 4) 171 | 172 | conf_flip = self.softmax(conf_flip.view(conf.size(0), -1, self.num_classes)) 173 | 174 | 175 | if self.phase == "test": 176 | return output 177 | else: 178 | return output, conf, conf_flip, loc, loc_flip, feats 179 | 180 | def load_weights(self, base_file): 181 | other, ext = os.path.splitext(base_file) 182 | if ext == '.pkl' or '.pth': 183 | print('Loading weights into state dict...') 184 | self.load_state_dict(torch.load(base_file, 185 | map_location=lambda storage, loc: storage)) 186 | print('Finished!') 187 | else: 188 | print('Sorry only .pth and .pkl files supported.') 189 | 190 | 191 | # This function is derived from torchvision VGG make_layers() 192 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 193 | def vgg(cfg, i, batch_norm=False): 194 | layers = [] 195 | in_channels = i 196 | for v in cfg: 197 | if v == 'M': 198 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 199 | elif v == 'C': 200 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 201 | else: 202 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 203 | if batch_norm: 204 | print('batch_norm') 205 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 206 | else: 207 | layers += [conv2d, nn.ReLU(inplace=True)] 208 | in_channels = v 209 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 210 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 211 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 212 | layers += [pool5, conv6, 213 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 214 | return layers 215 | 216 | 217 | def add_extras(cfg, i, batch_norm=False): 218 | # Extra layers added to VGG for feature scaling 219 | layers = [] 220 | in_channels = i 221 | flag = False 222 | for k, v in enumerate(cfg): 223 | if in_channels != 'S': 224 | if v == 'S': 225 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 226 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 227 | elif v=='K': 228 | layers += [nn.Conv2d(in_channels, 256, 229 | kernel_size=4, stride=1, padding=1)] 230 | else: 231 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 232 | flag = not flag 233 | in_channels = v 234 | return layers 235 | 236 | 237 | 238 | def multibox(vgg, extra_layers, cfg, num_classes): 239 | loc_layers = [] 240 | conf_layers = [] 241 | vgg_source = [21, -2] 242 | for k, v in enumerate(vgg_source): 243 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 244 | cfg[k] * 4, kernel_size=3, padding=1)] 245 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 246 | cfg[k] * num_classes, kernel_size=3, padding=1)] 247 | for k, v in enumerate(extra_layers[1::2], 2): 248 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 249 | * 4, kernel_size=3, padding=1)] 250 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 251 | * num_classes, kernel_size=3, padding=1)] 252 | return vgg, extra_layers, (loc_layers, conf_layers) 253 | 254 | 255 | base = { 256 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 257 | 512, 512, 512], 258 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 259 | 512, 512, 512], 260 | } 261 | extras = { 262 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 263 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'], 264 | } 265 | mbox = { 266 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 267 | '512': [4, 6, 6, 6, 6, 4, 4], 268 | } 269 | 270 | def flip(x, dim): 271 | dim = x.dim() + dim if dim < 0 else dim 272 | return x[tuple(slice(None, None) if i != dim 273 | else torch.arange(x.size(i)-1, -1, -1).long() 274 | for i in range(x.dim()))] 275 | 276 | class GaussianNoise(nn.Module): 277 | def __init__(self, batch_size, input_size=(3, 300, 300), mean=0, std=0.15): 278 | super(GaussianNoise, self).__init__() 279 | self.shape = (batch_size, ) + input_size 280 | self.noise = Variable(torch.zeros(self.shape).cuda()) 281 | self.mean = mean 282 | self.std = std 283 | 284 | def forward(self, x): 285 | self.noise.data.normal_(self.mean, std=self.std) 286 | if x.size(0) == self.noise.size(0): 287 | return x + self.noise 288 | else: 289 | #print('---- Noise Size ') 290 | return x + self.noise[:x.size(0)] 291 | 292 | 293 | def build_ssd_con(phase, size=300, num_classes=21): 294 | if phase != "test" and phase != "train": 295 | print("ERROR: Phase: " + phase + " not recognized") 296 | return 297 | if size != 300: 298 | print("ERROR: You specified size " + repr(size) + ". However, " + 299 | "currently only SSD300 (size=300) is supported!") 300 | return 301 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 302 | add_extras(extras[str(size)], 1024), 303 | mbox[str(size)], num_classes) 304 | return SSD_CON(phase, size, base_, extras_, head_, num_classes) 305 | 306 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | 9 | from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT 10 | from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map 11 | 12 | from .config import * 13 | import torch 14 | import cv2 15 | import numpy as np 16 | 17 | def detection_collate(batch): 18 | """Custom collate fn for dealing with batches of images that have a different 19 | number of associated object annotations (bounding boxes). 20 | 21 | Arguments: 22 | batch: (tuple) A tuple of tensor images and lists of annotations 23 | 24 | Return: 25 | A tuple containing: 26 | 1) (tensor) batch of images stacked on their 0 dim 27 | 2) (list of tensors) annotations for a given image are stacked on 28 | 0 dim 29 | """ 30 | ### changed when semi-supervised 31 | targets = [] 32 | imgs = [] 33 | semis = [] 34 | for sample in batch: 35 | imgs.append(sample[0]) 36 | targets.append(torch.FloatTensor(sample[1])) 37 | if(len(sample)==3): 38 | semis.append(torch.FloatTensor(sample[2])) 39 | if(len(sample)==2): 40 | return torch.stack(imgs, 0), targets 41 | else: 42 | return torch.stack(imgs, 0), targets, semis 43 | # return torch.stack(imgs, 0), targets 44 | 45 | 46 | def base_transform(image, size, mean): 47 | x = cv2.resize(image, (size, size)).astype(np.float32) 48 | x -= mean 49 | x = x.astype(np.float32) 50 | return x 51 | 52 | 53 | class BaseTransform: 54 | def __init__(self, size, mean): 55 | self.size = size 56 | self.mean = np.array(mean, dtype=np.float32) 57 | 58 | def __call__(self, image, boxes=None, labels=None): 59 | return base_transform(image, self.size, self.mean), boxes, labels 60 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/coco_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco_new.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712_backup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712_backup.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc0712_consistency.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712_consistency.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency_init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency_init.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_consistency_init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency_init.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/voc07_voc12coco.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_voc12coco.cpython-36.pyc -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | from .config import HOME 2 | import cv2 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import sys 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | 11 | 12 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 13 | # 14 | # This work is made available under the Nvidia Source Code License-NC. 15 | # To view a copy of this license, visit 16 | # https://github.com/NVlabs/AL-SSL/ 17 | 18 | 19 | COCO_ROOT = '/usr/wiss/elezi/data/coco' 20 | IMAGES = 'images' 21 | ANNOTATIONS = 'annotations' 22 | COCO_API = 'PythonAPI' 23 | INSTANCES_SET = 'instances_{}.json' 24 | COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 25 | 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', 26 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 27 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 28 | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 29 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 30 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 31 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 32 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 33 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 34 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 35 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 36 | 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', 37 | 'refrigerator', 'book', 'clock', 'vase', 'scissors', 38 | 'teddy bear', 'hair drier', 'toothbrush') 39 | 40 | VOC_CLASSES = ( # always index 0 41 | 'aeroplane', 'bicycle', 'bird', 'boat', 42 | 'bottle', 'bus', 'car', 'cat', 'chair', 43 | 'cow', 'diningtable', 'dog', 'horse', 44 | 'motorbike', 'person', 'pottedplant', 45 | 'sheep', 'sofa', 'train', 'tvmonitor') 46 | 47 | 48 | def get_label_map(label_file): 49 | label_map = {} 50 | labels = open(label_file, 'r') 51 | for line in labels: 52 | ids = line.split(',') 53 | label_map[int(ids[0])] = int(ids[1]) 54 | return label_map 55 | 56 | 57 | class COCOAnnotationTransform(object): 58 | """Transforms a COCO annotation into a Tensor of bbox coords and label index 59 | Initilized with a dictionary lookup of classnames to indexes 60 | """ 61 | def __init__(self): 62 | self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt')) 63 | 64 | def __call__(self, target, width, height): 65 | """ 66 | Args: 67 | target (dict): COCO target json annotation as a python dict 68 | height (int): height 69 | width (int): width 70 | Returns: 71 | a list containing lists of bounding boxes [bbox coords, class idx] 72 | """ 73 | scale = np.array([width, height, width, height]) 74 | res = [] 75 | for obj in target: 76 | if 'bbox' in obj: 77 | bbox = obj['bbox'] 78 | bbox[2] += bbox[0] 79 | bbox[3] += bbox[1] 80 | label_idx = self.label_map[obj['category_id']] - 1 81 | final_box = list(np.array(bbox)/scale) 82 | final_box.append(label_idx) 83 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx] 84 | else: 85 | print("no bbox problem!") 86 | 87 | return res # [[xmin, ymin, xmax, ymax, label_idx], ... ] 88 | 89 | 90 | class COCODetection(data.Dataset): 91 | """`MS Coco Detection `_ Dataset. 92 | Args: 93 | root (string): Root directory where images are downloaded to. 94 | set_name (string): Name of the specific set of COCO images. 95 | transform (callable, optional): A function/transform that augments the 96 | raw images` 97 | target_transform (callable, optional): A function/transform that takes 98 | in the target (bbox) and transforms it. 99 | """ 100 | 101 | def __init__(self, root, supervised_indices=None, image_set='train2014', transform=None, 102 | target_transform=COCOAnnotationTransform(), dataset_name='MS COCO', pseudo_labels={}): 103 | sys.path.append(osp.join(root, COCO_API)) 104 | self.supervised_indices = supervised_indices 105 | from pycocotools.coco import COCO 106 | self.root = osp.join(root, IMAGES, image_set) 107 | self.coco = COCO(osp.join(root, ANNOTATIONS, 108 | INSTANCES_SET.format(image_set))) 109 | self.ids = list(self.coco.imgToAnns.keys()) 110 | self.supervised_indices = supervised_indices 111 | self.pseudo_labels = pseudo_labels 112 | self.pseudo_labels_indices = self.get_pseudo_label_indices() 113 | self.transform = transform 114 | self.target_transform = target_transform 115 | self.name = dataset_name 116 | self.class_to_ind = dict(zip(COCO_CLASSES, range(len(COCO_CLASSES)))) 117 | self.class_to_ind_voc = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 118 | self.dict_coco_to_voc = {0: 0, 1: 1, 2: 6, 3: 13, 4: 0, 5: 5, 6: 18, 8: 3, 15: 2, 16: 7, 17: 11, 18: 12, 19: 16, 119 | 20: 9, 40: 4, 57: 8, 58: 17, 59: 15, 61: 10, 63: 19} 120 | 121 | def contain_voc(self): 122 | return self.coco_contain_voc 123 | 124 | def __getitem__(self, index): 125 | """ 126 | Args: 127 | index (int): Index 128 | Returns: 129 | tuple: Tuple (image, target). 130 | target is the object returned by ``coco.loadAnns``. 131 | """ 132 | im, gt, h, w, semi = self.pull_item(index) 133 | return im, gt, semi 134 | 135 | def __len__(self): 136 | return len(self.ids) 137 | 138 | def pull_item(self, index): 139 | """ 140 | Args: 141 | index (int): Index 142 | Returns: 143 | tuple: Tuple (image, target, height, width). 144 | target is the object returned by ``coco.loadAnns``. 145 | """ 146 | img_id = self.ids[index] 147 | target = self.coco.imgToAnns[img_id] 148 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 149 | 150 | target = self.coco.loadAnns(ann_ids) 151 | path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name']) 152 | split_path = path.split("/") 153 | split_path = os.path.join(split_path[-2], split_path[-1]) 154 | # if split_path in self.coco_for_voc_valid: 155 | # print("Inside") 156 | # self.coco_contain_voc.append(index) 157 | assert osp.exists(path), 'Image path does not exist: {}'.format(path) 158 | # img = cv2.imread(osp.join(self.root, path)) 159 | img = cv2.imread(path) 160 | 161 | height, width, _ = img.shape 162 | 163 | if self.target_transform is not None: 164 | if index not in self.pseudo_labels_indices: 165 | target = self.target_transform(target, width, height) 166 | else: 167 | target = self.target_transform_pseudo_label(self.pseudo_labels, index) 168 | 169 | if self.transform is not None: 170 | target = np.array(target) 171 | img, boxes, labels = self.transform(img, target[:, :4], 172 | target[:, 4]) 173 | # to rgb 174 | img = img[:, :, (2, 1, 0)] 175 | 176 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 177 | 178 | if self.supervised_indices != None: 179 | if index in self.supervised_indices: 180 | semi = np.array([1]) 181 | elif index in self.pseudo_labels_indices: 182 | semi = np.array([2]) 183 | else: 184 | semi = np.array([0]) 185 | target = np.zeros([1,5]) 186 | else: 187 | # it does not matter 188 | semi = np.array([0]) 189 | 190 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi 191 | 192 | def pull_image(self, index): 193 | '''Returns the original image object at index in PIL form 194 | 195 | Note: not using self.__getitem__(), as any transformations passed in 196 | could mess up this functionality. 197 | 198 | Argument: 199 | index (int): index of img to show 200 | Return: 201 | cv2 img 202 | ''' 203 | img_id = self.ids[index] 204 | path = self.coco.loadImgs(img_id)[0]['file_name'] 205 | return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR) 206 | 207 | def pull_anno(self, index): 208 | '''Returns the original annotation of image at index 209 | 210 | Note: not using self.__getitem__(), as any transformations passed in 211 | could mess up this functionality. 212 | 213 | Argument: 214 | index (int): index of img to get annotation of 215 | Return: 216 | list: [img_id, [(label, bbox coords),...]] 217 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 218 | ''' 219 | img_id = self.ids[index] 220 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 221 | return self.coco.loadAnns(ann_ids) 222 | 223 | def pull_pseudo_anno(self, index): 224 | '''Returns the original annotation of image at index 225 | 226 | Note: not using self.__getitem__(), as any transformations passed in 227 | could mess up this functionality. 228 | 229 | Argument: 230 | index (int): index of img to get annotation of 231 | Return: 232 | list: [img_id, [(label, bbox coords),...]] 233 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 234 | ''' 235 | return self.pseudo_labels[index] 236 | 237 | def get_pseudo_label_indices(self): 238 | pseudo_label_indices = [] 239 | for key in self.pseudo_labels: 240 | pseudo_label_indices.append(key) 241 | return pseudo_label_indices 242 | 243 | def target_transform_pseudo_label(self, pseudo_labels, index): 244 | # height, width = pseudo_labels[index][0][3], pseudo_labels[index][0][4] 245 | height, width = pseudo_labels[index][0][4], pseudo_labels[index][0][3] 246 | # height, width = 300, 300 247 | res = [] 248 | for i in range(len(pseudo_labels[index])): 249 | pts = [pseudo_labels[index][i][5], pseudo_labels[index][i][6], pseudo_labels[index][i][7], pseudo_labels[index][i][8]] 250 | name = pseudo_labels[index][i][2] 251 | pts[0] /= height 252 | pts[2] /= height 253 | pts[1] /= width 254 | pts[3] /= width 255 | for i in range(len(pts)): 256 | if pts[i] < 0: 257 | pts[i] = 0 258 | if pts[0] > height: 259 | pts[0] = height 260 | if pts[2] > height: 261 | pts[2] = height 262 | label_idx = float(self.class_to_ind[name]) 263 | # label_idx = float(self.class_to_ind_voc[name]) 264 | # transform to voc standard 265 | # label_idx = self.class_to_ind[name] 266 | # label_idx = float(self.dict_coco_to_voc[label_idx]) 267 | pts.append(label_idx) 268 | res.append(pts) 269 | return res 270 | 271 | def __repr__(self): 272 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 273 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 274 | fmt_str += ' Root Location: {}\n'.format(self.root) 275 | tmp = ' Transforms (if any): ' 276 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 277 | tmp = ' Target Transforms (if any): ' 278 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 279 | return fmt_str 280 | -------------------------------------------------------------------------------- /data/coco/coco_labels.txt: -------------------------------------------------------------------------------- 1 | 1,1,person 2 | 2,2,bicycle 3 | 3,3,car 4 | 4,4,motorcycle 5 | 5,5,airplane 6 | 6,6,bus 7 | 7,7,train 8 | 8,8,truck 9 | 9,9,boat 10 | 10,10,traffic light 11 | 11,11,fire hydrant 12 | 13,12,stop sign 13 | 14,13,parking meter 14 | 15,14,bench 15 | 16,15,bird 16 | 17,16,cat 17 | 18,17,dog 18 | 19,18,horse 19 | 20,19,sheep 20 | 21,20,cow 21 | 22,21,elephant 22 | 23,22,bear 23 | 24,23,zebra 24 | 25,24,giraffe 25 | 27,25,backpack 26 | 28,26,umbrella 27 | 31,27,handbag 28 | 32,28,tie 29 | 33,29,suitcase 30 | 34,30,frisbee 31 | 35,31,skis 32 | 36,32,snowboard 33 | 37,33,sports ball 34 | 38,34,kite 35 | 39,35,baseball bat 36 | 40,36,baseball glove 37 | 41,37,skateboard 38 | 42,38,surfboard 39 | 43,39,tennis racket 40 | 44,40,bottle 41 | 46,41,wine glass 42 | 47,42,cup 43 | 48,43,fork 44 | 49,44,knife 45 | 50,45,spoon 46 | 51,46,bowl 47 | 52,47,banana 48 | 53,48,apple 49 | 54,49,sandwich 50 | 55,50,orange 51 | 56,51,broccoli 52 | 57,52,carrot 53 | 58,53,hot dog 54 | 59,54,pizza 55 | 60,55,donut 56 | 61,56,cake 57 | 62,57,chair 58 | 63,58,couch 59 | 64,59,potted plant 60 | 65,60,bed 61 | 67,61,dining table 62 | 70,62,toilet 63 | 72,63,tv 64 | 73,64,laptop 65 | 74,65,mouse 66 | 75,66,remote 67 | 76,67,keyboard 68 | 77,68,cell phone 69 | 78,69,microwave 70 | 79,70,oven 71 | 80,71,toaster 72 | 81,72,sink 73 | 82,73,refrigerator 74 | 84,74,book 75 | 85,75,clock 76 | 86,76,vase 77 | 87,77,scissors 78 | 88,78,teddy bear 79 | 89,79,hair drier 80 | 90,80,toothbrush 81 | -------------------------------------------------------------------------------- /data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | import os.path 3 | 4 | # gets home dir cross platform 5 | HOME = os.path.expanduser("~") 6 | 7 | # for making bounding boxes pretty 8 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 9 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 10 | 11 | MEANS = (104, 117, 123) 12 | 13 | 14 | voc300 = { 15 | 'num_classes': 21, 16 | 'lr_steps': (80000, 100000, 120000), 17 | 'max_iter': 120000, 18 | 'feature_maps': [38, 19, 10, 5, 3, 1], 19 | 'min_dim': 300, 20 | 'steps': [8, 16, 32, 64, 100, 300], 21 | 'min_sizes': [30, 60, 111, 162, 213, 264], 22 | 'max_sizes': [60, 111, 162, 213, 264, 315], 23 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 24 | 'variance': [0.1, 0.2], 25 | 'clip': True, 26 | 'name': 'VOC', 27 | } 28 | voc512 = { 29 | 'num_classes': 21, 30 | 'lr_steps': (80000, 100000, 120000), 31 | 'max_iter': 120000, 32 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1], 33 | 'min_dim': 512, 34 | 'steps': [8, 16, 32, 64, 128, 256, 512], 35 | 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8], 36 | 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6], 37 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]], 38 | 'variance': [0.1, 0.2], 39 | 'clip': True, 40 | 'name': 'VOC', 41 | } 42 | 43 | 44 | coco = { 45 | 'num_classes': 81, 46 | 'lr_steps': (80000, 100000, 120000), 47 | 'max_iter': 120000, 48 | 'feature_maps': [38, 19, 10, 5, 3, 1], 49 | 'min_dim': 300, 50 | 'steps': [8, 16, 32, 64, 100, 300], 51 | 'min_sizes': [21, 45, 99, 153, 207, 261], 52 | 'max_sizes': [45, 99, 153, 207, 261, 315], 53 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 54 | 'variance': [0.1, 0.2], 55 | 'clip': True, 56 | 'name': 'COCO', 57 | } 58 | 59 | -------------------------------------------------------------------------------- /data/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/example.jpg -------------------------------------------------------------------------------- /data/scripts/COCO2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | start=`date +%s` 4 | 5 | # handle optional download dir 6 | if [ -z "$1" ] 7 | then 8 | # navigate to ~/data 9 | echo "navigating to ~/data/ ..." 10 | mkdir -p ~/data 11 | cd ~/data/ 12 | mkdir -p ./coco 13 | cd ./coco 14 | mkdir -p ./images 15 | mkdir -p ./annotations 16 | else 17 | # check if specified dir is valid 18 | if [ ! -d $1 ]; then 19 | echo $1 " is not a valid directory" 20 | exit 0 21 | fi 22 | echo "navigating to " $1 " ..." 23 | cd $1 24 | fi 25 | 26 | if [ ! -d images ] 27 | then 28 | mkdir -p ./images 29 | fi 30 | 31 | # Download the image data. 32 | cd ./images 33 | echo "Downloading MSCOCO train images ..." 34 | curl -LO http://images.cocodataset.org/zips/train2014.zip 35 | echo "Downloading MSCOCO val images ..." 36 | curl -LO http://images.cocodataset.org/zips/val2014.zip 37 | 38 | cd ../ 39 | if [ ! -d annotations] 40 | then 41 | mkdir -p ./annotations 42 | fi 43 | 44 | # Download the annotation data. 45 | cd ./annotations 46 | echo "Downloading MSCOCO train/val annotations ..." 47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip 48 | echo "Finished downloading. Now extracting ..." 49 | 50 | # Unzip data 51 | echo "Extracting train images ..." 52 | unzip ../images/train2014.zip -d ../images 53 | echo "Extracting val images ..." 54 | unzip ../images/val2014.zip -d ../images 55 | echo "Extracting annotations ..." 56 | unzip ./annotations_trainval2014.zip 57 | 58 | echo "Removing zip files ..." 59 | rm ../images/train2014.zip 60 | rm ../images/val2014.zip 61 | rm ./annotations_trainval2014.zip 62 | 63 | echo "Creating trainval35k dataset..." 64 | 65 | # Download annotations json 66 | echo "Downloading trainval35k annotations from S3" 67 | curl -LO https://s3.amazonaws.com/amdegroot-datasets/instances_trainval35k.json.zip 68 | 69 | # combine train and val 70 | echo "Combining train and val images" 71 | mkdir ../images/trainval35k 72 | cd ../images/train2014 73 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + # dir too large for cp 74 | cd ../val2014 75 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + 76 | 77 | 78 | end=`date +%s` 79 | runtime=$((end-start)) 80 | 81 | echo "Completed in " $runtime " seconds" 82 | -------------------------------------------------------------------------------- /data/scripts/VOC2007.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2007 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 26 | echo "Downloading VOC2007 test data ..." 27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 28 | echo "Done downloading." 29 | 30 | # Extract data 31 | echo "Extracting trainval ..." 32 | tar -xvf VOCtrainval_06-Nov-2007.tar 33 | echo "Extracting test ..." 34 | tar -xvf VOCtest_06-Nov-2007.tar 35 | echo "removing tars ..." 36 | rm VOCtrainval_06-Nov-2007.tar 37 | rm VOCtest_06-Nov-2007.tar 38 | 39 | end=`date +%s` 40 | runtime=$((end-start)) 41 | 42 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/scripts/VOC2012.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Ellis Brown 3 | 4 | start=`date +%s` 5 | 6 | # handle optional download dir 7 | if [ -z "$1" ] 8 | then 9 | # navigate to ~/data 10 | echo "navigating to ~/data/ ..." 11 | mkdir -p ~/data 12 | cd ~/data/ 13 | else 14 | # check if is valid directory 15 | if [ ! -d $1 ]; then 16 | echo $1 "is not a valid directory" 17 | exit 0 18 | fi 19 | echo "navigating to" $1 "..." 20 | cd $1 21 | fi 22 | 23 | echo "Downloading VOC2012 trainval ..." 24 | # Download the data. 25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 26 | echo "Done downloading." 27 | 28 | 29 | # Extract data 30 | echo "Extracting trainval ..." 31 | tar -xvf VOCtrainval_11-May-2012.tar 32 | echo "removing tar ..." 33 | rm VOCtrainval_11-May-2012.tar 34 | 35 | end=`date +%s` 36 | runtime=$((end-start)) 37 | 38 | echo "Completed in" $runtime "seconds" -------------------------------------------------------------------------------- /data/voc0712.py: -------------------------------------------------------------------------------- 1 | """VOC Dataset Classes 2 | 3 | Original author: Francisco Massa 4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py 5 | 6 | Updated by: Ellis Brown, Max deGroot 7 | """ 8 | 9 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License-NC. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/AL-SSL/ 14 | 15 | 16 | from .config import HOME 17 | import os.path as osp 18 | import sys 19 | import torch 20 | import torch.utils.data as data 21 | import cv2 22 | import numpy as np 23 | import matplotlib.pyplot as plt 24 | from utils.augmentations import jaccard_numpy 25 | 26 | if sys.version_info[0] == 2: 27 | import xml.etree.cElementTree as ET 28 | else: 29 | import xml.etree.ElementTree as ET 30 | 31 | import matplotlib.pyplot as plt 32 | 33 | VOC_CLASSES = ( # always index 0 34 | 'aeroplane', 'bicycle', 'bird', 'boat', 35 | 'bottle', 'bus', 'car', 'cat', 'chair', 36 | 'cow', 'diningtable', 'dog', 'horse', 37 | 'motorbike', 'person', 'pottedplant', 38 | 'sheep', 'sofa', 'train', 'tvmonitor') 39 | 40 | # note: if you used our download scripts, this should be right 41 | # VOC_ROOT = osp.join(HOME, "tmp/VOC0712/") 42 | VOC_ROOT = '/usr/wiss/elezi/data/VOC0712' 43 | 44 | class VOCAnnotationTransform(object): 45 | """Transforms a VOC annotation into a Tensor of bbox coords and label index 46 | Initilized with a dictionary lookup of classnames to indexes 47 | 48 | Arguments: 49 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes 50 | (default: alphabetic indexing of VOC's 20 classes) 51 | keep_difficult (bool, optional): keep difficult instances or not 52 | (default: False) 53 | height (int): height 54 | width (int): width 55 | """ 56 | 57 | def __init__(self, class_to_ind=None, keep_difficult=False): 58 | self.class_to_ind = class_to_ind or dict( 59 | zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 60 | self.keep_difficult = keep_difficult 61 | 62 | def __call__(self, target, width, height): 63 | """ 64 | Arguments: 65 | target (annotation) : the target annotation to be made usable 66 | will be an ET.Element 67 | Returns: 68 | a list containing lists of bounding boxes [bbox coords, class name] 69 | """ 70 | res = [] 71 | for obj in target.iter('object'): 72 | difficult = int(obj.find('difficult').text) == 1 73 | if not self.keep_difficult and difficult: 74 | continue 75 | name = obj.find('name').text.lower().strip() 76 | bbox = obj.find('bndbox') 77 | 78 | pts = ['xmin', 'ymin', 'xmax', 'ymax'] 79 | bndbox = [] 80 | for i, pt in enumerate(pts): 81 | cur_pt = int(bbox.find(pt).text) - 1 82 | # scale height or width 83 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height 84 | bndbox.append(cur_pt) 85 | label_idx = self.class_to_ind[name] 86 | bndbox.append(label_idx) 87 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] 88 | # img_id = target.find('filename').text[:-4] 89 | 90 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] 91 | 92 | 93 | class VOCDetection(data.Dataset): 94 | """VOC Detection Dataset Object 95 | 96 | input is image, target is annotation 97 | 98 | Arguments: 99 | root (string): filepath to VOCdevkit folder. 100 | image_set (string): imageset to use (eg. 'train', 'val', 'test') 101 | transform (callable, optional): transformation to perform on the 102 | input image 103 | target_transform (callable, optional): transformation to perform on the 104 | target `annotation` 105 | (eg: take in caption string, return tensor of word indices) 106 | dataset_name (string, optional): which dataset to load 107 | (default: 'VOC2007') 108 | """ 109 | def __init__(self, root, supervised_indices=None, 110 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 111 | transform=None, target_transform=VOCAnnotationTransform(), 112 | dataset_name='VOC0712', pseudo_labels={}, bounding_box_dict={}): 113 | self.root = root 114 | self.image_set = image_sets 115 | self.transform = transform 116 | self.target_transform = target_transform 117 | self.name = dataset_name 118 | self._annopath = osp.join('%s', 'Annotations', '%s.xml') 119 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') 120 | self.ids = list() 121 | self.supervised_indices = supervised_indices 122 | self.pseudo_labels = pseudo_labels 123 | self.pseudo_labels_indices = self.get_pseudo_label_indices() 124 | self.bounding_box_dict = bounding_box_dict 125 | self.bounding_box_indices = self.get_bounding_box_indices() 126 | for (year, name) in image_sets: 127 | rootpath = osp.join(self.root, 'VOC' + year) 128 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): 129 | self.ids.append((rootpath, line.strip())) 130 | self.class_to_ind = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES)))) 131 | 132 | def __getitem__(self, index): 133 | im, gt, h, w, semi = self.pull_item(index) 134 | return im, gt, semi 135 | 136 | def __len__(self): 137 | return len(self.ids) 138 | 139 | def pull_item(self, index): 140 | img_id = self.ids[index] 141 | colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist() 142 | 143 | target = ET.parse(self._annopath % img_id).getroot() 144 | img = cv2.imread(self._imgpath % img_id) 145 | height, width, channels = img.shape 146 | 147 | if self.target_transform is not None: 148 | if index in self.bounding_box_indices: 149 | target_real = self.target_transform(target, width, height) 150 | target = self.target_transform_bounding_box(self.bounding_box_dict, index, target_real, width, height) 151 | 152 | else: 153 | target = self.target_transform(target, width, height) 154 | 155 | if self.transform is not None: 156 | target = np.array(target) 157 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) 158 | 159 | # to rgb 160 | img = img[:, :, (2, 1, 0)] 161 | # img = img.transpose(2, 0, 1) 162 | target = np.hstack((boxes, np.expand_dims(labels, axis=1))) 163 | 164 | if self.supervised_indices != None: 165 | if index in self.supervised_indices: 166 | semi = np.array([1]) 167 | elif index in self.pseudo_labels_indices or index in self.bounding_box_indices: 168 | semi = np.array([2]) 169 | # elif index in self.bounding_box_indices: 170 | # semi = np.array([2]) 171 | # print(semi) 172 | else: 173 | semi = np.array([0]) 174 | target = np.zeros([1,5]) 175 | else: 176 | # it does not matter 177 | semi = np.array([0]) 178 | 179 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi 180 | 181 | def pull_image(self, index): 182 | '''Returns the original image object at index in PIL form 183 | 184 | Note: not using self.__getitem__(), as any transformations passed in 185 | could mess up this functionality. 186 | 187 | Argument: 188 | index (int): index of img to show 189 | Return: 190 | PIL img 191 | ''' 192 | img_id = self.ids[index] 193 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR) 194 | 195 | def pull_anno(self, index): 196 | '''Returns the original annotation of image at index 197 | 198 | Note: not using self.__getitem__(), as any transformations passed in 199 | could mess up this functionality. 200 | 201 | Argument: 202 | index (int): index of img to get annotation of 203 | Return: 204 | list: [img_id, [(label, bbox coords),...]] 205 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 206 | ''' 207 | img_id = self.ids[index] 208 | anno = ET.parse(self._annopath % img_id).getroot() 209 | gt = self.target_transform(anno, 1, 1) 210 | return img_id[1], gt 211 | 212 | def pull_pseudo_anno(self, index): 213 | '''Returns the original annotation of image at index 214 | 215 | Note: not using self.__getitem__(), as any transformations passed in 216 | could mess up this functionality. 217 | 218 | Argument: 219 | index (int): index of img to get annotation of 220 | Return: 221 | list: [img_id, [(label, bbox coords),...]] 222 | eg: ('001718', [('dog', (96, 13, 438, 332))]) 223 | ''' 224 | return self.pseudo_labels[index] 225 | 226 | def pull_tensor(self, index): 227 | '''Returns the original image at an index in tensor form 228 | 229 | Note: not using self.__getitem__(), as any transformations passed in 230 | could mess up this functionality. 231 | 232 | Argument: 233 | index (int): index of img to show 234 | Return: 235 | tensorized version of img, squeezed 236 | ''' 237 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0) 238 | 239 | def get_pseudo_label_indices(self): 240 | pseudo_label_indices = [] 241 | if self.pseudo_labels is not None: 242 | for key in self.pseudo_labels: 243 | pseudo_label_indices.append(key) 244 | return pseudo_label_indices 245 | else: 246 | return None 247 | 248 | def get_bounding_box_indices(self): 249 | bounding_box_indices = [] 250 | for key in self.bounding_box_dict: 251 | bounding_box_indices.append(key) 252 | return bounding_box_indices 253 | 254 | def target_transform_pseudo_label(self, pseudo_labels, index): 255 | # height, width = pseudo_labels[index][0][3], pseudo_labels[index][0][4] 256 | height, width = pseudo_labels[index][0][4], pseudo_labels[index][0][3] 257 | # height, width = 300, 300 258 | res = [] 259 | for i in range(len(pseudo_labels[index])): 260 | pts = [pseudo_labels[index][i][5], pseudo_labels[index][i][6], pseudo_labels[index][i][7], pseudo_labels[index][i][8]] 261 | name = pseudo_labels[index][i][2] 262 | pts[0] /= height 263 | pts[2] /= height 264 | pts[1] /= width 265 | pts[3] /= width 266 | for i in range(len(pts)): 267 | if pts[i] < 0: 268 | pts[i] = 0 269 | if pts[0] > height: 270 | pts[0] = height 271 | if pts[2] > height: 272 | pts[2] = height 273 | label_idx = float(self.class_to_ind[name]) 274 | pts.append(label_idx) 275 | res.append(pts) 276 | return res 277 | 278 | def target_transform_bounding_box(self, bounding_box_dict, index, target_real, width, height): 279 | predictions = np.array(bounding_box_dict[index]) 280 | predictions = predictions[:, :-1] 281 | target_real_numpy = np.array(target_real) 282 | target_real_numpy = target_real_numpy[:, :-1] 283 | iou_all = np.zeros((predictions.shape[0], len(target_real))) 284 | for i in range(len(target_real)): 285 | iou_all[:, i] = jaccard_numpy(predictions, target_real_numpy[i]) 286 | max_val = 2 287 | print() 288 | print(iou_all) 289 | print() 290 | 291 | targets_intersected = [] 292 | 293 | # otherwise get the value 294 | while max_val > 0: 295 | max_val = np.max(iou_all) 296 | if max_val > 0: 297 | argmax_val = np.where(iou_all == np.amax(iou_all)) 298 | iou_all[argmax_val[0], :] = np.zeros((1, iou_all.shape[1])) 299 | iou_all[:, argmax_val[1]] = np.zeros((iou_all.shape[0], 1)) 300 | # get the GT from target_real 301 | targets_intersected.append(target_real[int(argmax_val[1])]) 302 | 303 | else: 304 | # get a random one from target real 305 | targets_intersected.append(target_real[0]) 306 | 307 | return targets_intersected 308 | 309 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """Adapted from: 2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch 3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn 4 | Licensed under The MIT License [see LICENSE for details] 5 | """ 6 | 7 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 8 | # 9 | # This work is made available under the Nvidia Source Code License-NC. 10 | # To view a copy of this license, visit 11 | # https://github.com/NVlabs/AL-SSL/ 12 | 13 | from __future__ import print_function 14 | import torch 15 | import torch.nn as nn 16 | import torch.backends.cudnn as cudnn 17 | from torch.autograd import Variable 18 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform 19 | from utils.augmentations import SSDAugmentation 20 | from data import VOC_CLASSES as labelmap 21 | import torch.utils.data as data 22 | 23 | from ssd import build_ssd 24 | 25 | import sys 26 | import os 27 | import time 28 | import argparse 29 | import numpy as np 30 | import pickle 31 | import cv2 32 | from collections import OrderedDict 33 | 34 | if sys.version_info[0] == 2: 35 | import xml.etree.cElementTree as ET 36 | else: 37 | import xml.etree.ElementTree as ET 38 | 39 | 40 | def str2bool(v): 41 | return v.lower() in ("yes", "true", "t", "1") 42 | 43 | 44 | parser = argparse.ArgumentParser( 45 | description='Single Shot MultiBox Detector Evaluation') 46 | parser.add_argument('--trained_model', default='weights_voc_code/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth', 47 | type=str, help='Trained state_dict file path to open') 48 | parser.add_argument('--save_folder', default='eval/', type=str, 49 | help='File path to save results') 50 | parser.add_argument('--confidence_threshold', default=0.01, type=float, 51 | help='Detection confidence threshold') 52 | parser.add_argument('--top_k', default=5, type=int, 53 | help='Further restrict the number of predictions to parse') 54 | parser.add_argument('--cuda', default=True, type=str2bool, 55 | help='Use cuda to train model') 56 | parser.add_argument('--voc_root', default='../../data/VOC0712/', 57 | help='Location of VOC root directory') 58 | parser.add_argument('--cleanup', default=True, type=str2bool, 59 | help='Cleanup and remove results files following eval') 60 | 61 | args = parser.parse_args() 62 | 63 | if not os.path.exists(args.save_folder): 64 | os.mkdir(args.save_folder) 65 | 66 | if torch.cuda.is_available(): 67 | if args.cuda: 68 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 69 | if not args.cuda: 70 | print("WARNING: It looks like you have a CUDA device, but aren't using \ 71 | CUDA. Run with --cuda for optimal eval speed.") 72 | torch.set_default_tensor_type('torch.FloatTensor') 73 | else: 74 | torch.set_default_tensor_type('torch.FloatTensor') 75 | 76 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml') 77 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg') 78 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets', 79 | 'Main', '{:s}.txt') 80 | YEAR = '2007' 81 | devkit_path = args.voc_root + 'VOC' + YEAR 82 | dataset_mean = (104, 117, 123) 83 | set_type = 'test' 84 | 85 | 86 | class Timer(object): 87 | """A simple timer.""" 88 | def __init__(self): 89 | self.total_time = 0. 90 | self.calls = 0 91 | self.start_time = 0. 92 | self.diff = 0. 93 | self.average_time = 0. 94 | 95 | def tic(self): 96 | # using time.time instead of time.clock because time time.clock 97 | # does not normalize for multithreading 98 | self.start_time = time.time() 99 | 100 | def toc(self, average=True): 101 | self.diff = time.time() - self.start_time 102 | self.total_time += self.diff 103 | self.calls += 1 104 | self.average_time = self.total_time / self.calls 105 | if average: 106 | return self.average_time 107 | else: 108 | return self.diff 109 | 110 | 111 | def parse_rec(filename): 112 | """ Parse a PASCAL VOC xml file """ 113 | tree = ET.parse(filename) 114 | objects = [] 115 | for obj in tree.findall('object'): 116 | obj_struct = {} 117 | obj_struct['name'] = obj.find('name').text 118 | obj_struct['pose'] = obj.find('pose').text 119 | obj_struct['truncated'] = int(obj.find('truncated').text) 120 | obj_struct['difficult'] = int(obj.find('difficult').text) 121 | bbox = obj.find('bndbox') 122 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1, 123 | int(bbox.find('ymin').text) - 1, 124 | int(bbox.find('xmax').text) - 1, 125 | int(bbox.find('ymax').text) - 1] 126 | objects.append(obj_struct) 127 | 128 | return objects 129 | 130 | 131 | def get_output_dir(name, phase): 132 | """Return the directory where experimental artifacts are placed. 133 | If the directory does not exist, it is created. 134 | A canonical path is built using the name from an imdb and a network 135 | (if not None). 136 | """ 137 | filedir = os.path.join(name, phase) 138 | if not os.path.exists(filedir): 139 | os.makedirs(filedir) 140 | return filedir 141 | 142 | 143 | def get_voc_results_file_template(image_set, cls): 144 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt 145 | filename = 'det_' + image_set + '_%s.txt' % (cls) 146 | filedir = os.path.join(devkit_path, 'results') 147 | if not os.path.exists(filedir): 148 | os.makedirs(filedir) 149 | path = os.path.join(filedir, filename) 150 | return path 151 | 152 | 153 | def write_voc_results_file(all_boxes, dataset): 154 | for cls_ind, cls in enumerate(labelmap): 155 | print('Writing {:s} VOC results file'.format(cls)) 156 | filename = get_voc_results_file_template(set_type, cls) 157 | with open(filename, 'wt') as f: 158 | for im_ind, index in enumerate(dataset.ids): 159 | dets = all_boxes[cls_ind+1][im_ind] 160 | if dets == []: 161 | continue 162 | # the VOCdevkit expects 1-based indices 163 | for k in range(dets.shape[0]): 164 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'. 165 | format(index[1], dets[k, -1], 166 | dets[k, 0] + 1, dets[k, 1] + 1, 167 | dets[k, 2] + 1, dets[k, 3] + 1)) 168 | 169 | 170 | def do_python_eval(output_dir='output', use_07=True): 171 | cachedir = os.path.join(devkit_path, 'annotations_cache') 172 | aps = [] 173 | # The PASCAL VOC metric changed in 2010 174 | use_07_metric = use_07 175 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No')) 176 | if not os.path.isdir(output_dir): 177 | os.mkdir(output_dir) 178 | for i, cls in enumerate(labelmap): 179 | filename = get_voc_results_file_template(set_type, cls) 180 | rec, prec, ap = voc_eval( 181 | filename, annopath, imgsetpath.format(set_type), cls, cachedir, 182 | ovthresh=0.5, use_07_metric=use_07_metric) 183 | aps += [ap] 184 | print('AP for {} = {:.4f}'.format(cls, ap)) 185 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f: 186 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f) 187 | print('Mean AP = {:.4f}'.format(np.mean(aps))) 188 | print('~~~~~~~~') 189 | print('Results:') 190 | for ap in aps: 191 | print('{:.3f}'.format(ap)) 192 | print('{:.3f}'.format(np.mean(aps))) 193 | print('~~~~~~~~') 194 | print('') 195 | print('--------------------------------------------------------------') 196 | print('Results computed with the **unofficial** Python eval code.') 197 | print('Results should be very close to the official MATLAB eval code.') 198 | print('--------------------------------------------------------------') 199 | return aps 200 | 201 | 202 | def voc_ap(rec, prec, use_07_metric=True): 203 | """ ap = voc_ap(rec, prec, [use_07_metric]) 204 | Compute VOC AP given precision and recall. 205 | If use_07_metric is true, uses the 206 | VOC 07 11 point method (default:True). 207 | """ 208 | if use_07_metric: 209 | # 11 point metric 210 | ap = 0. 211 | for t in np.arange(0., 1.1, 0.1): 212 | if np.sum(rec >= t) == 0: 213 | p = 0 214 | else: 215 | p = np.max(prec[rec >= t]) 216 | ap = ap + p / 11. 217 | else: 218 | # correct AP calculation 219 | # first append sentinel values at the end 220 | mrec = np.concatenate(([0.], rec, [1.])) 221 | mpre = np.concatenate(([0.], prec, [0.])) 222 | 223 | # compute the precision envelope 224 | for i in range(mpre.size - 1, 0, -1): 225 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 226 | 227 | # to calculate area under PR curve, look for points 228 | # where X axis (recall) changes value 229 | i = np.where(mrec[1:] != mrec[:-1])[0] 230 | 231 | # and sum (\Delta recall) * prec 232 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 233 | return ap 234 | 235 | 236 | def voc_eval(detpath, 237 | annopath, 238 | imagesetfile, 239 | classname, 240 | cachedir, 241 | ovthresh=0.5, 242 | use_07_metric=True): 243 | """rec, prec, ap = voc_eval(detpath, 244 | annopath, 245 | imagesetfile, 246 | classname, 247 | [ovthresh], 248 | [use_07_metric]) 249 | Top level function that does the PASCAL VOC evaluation. 250 | detpath: Path to detections 251 | detpath.format(classname) should produce the detection results file. 252 | annopath: Path to annotations 253 | annopath.format(imagename) should be the xml annotations file. 254 | imagesetfile: Text file containing the list of images, one image per line. 255 | classname: Category name (duh) 256 | cachedir: Directory for caching the annotations 257 | [ovthresh]: Overlap threshold (default = 0.5) 258 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 259 | (default True) 260 | """ 261 | # assumes detections are in detpath.format(classname) 262 | # assumes annotations are in annopath.format(imagename) 263 | # assumes imagesetfile is a text file with each line an image name 264 | # cachedir caches the annotations in a pickle file 265 | # first load gt 266 | if not os.path.isdir(cachedir): 267 | os.mkdir(cachedir) 268 | cachefile = os.path.join(cachedir, 'annots.pkl') 269 | # read list of images 270 | with open(imagesetfile, 'r') as f: 271 | lines = f.readlines() 272 | imagenames = [x.strip() for x in lines] 273 | if not os.path.isfile(cachefile): 274 | # load annots 275 | recs = {} 276 | for i, imagename in enumerate(imagenames): 277 | recs[imagename] = parse_rec(annopath % (imagename)) 278 | if i % 100 == 0: 279 | print('Reading annotation for {:d}/{:d}'.format( 280 | i + 1, len(imagenames))) 281 | # save 282 | print('Saving cached annotations to {:s}'.format(cachefile)) 283 | with open(cachefile, 'wb') as f: 284 | pickle.dump(recs, f) 285 | else: 286 | # load 287 | with open(cachefile, 'rb') as f: 288 | recs = pickle.load(f) 289 | 290 | # extract gt objects for this class 291 | class_recs = {} 292 | npos = 0 293 | for imagename in imagenames: 294 | R = [obj for obj in recs[imagename] if obj['name'] == classname] 295 | bbox = np.array([x['bbox'] for x in R]) 296 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool) 297 | det = [False] * len(R) 298 | npos = npos + sum(~difficult) 299 | class_recs[imagename] = {'bbox': bbox, 300 | 'difficult': difficult, 301 | 'det': det} 302 | 303 | # read dets 304 | detfile = detpath.format(classname) 305 | with open(detfile, 'r') as f: 306 | lines = f.readlines() 307 | if any(lines) == 1: 308 | 309 | splitlines = [x.strip().split(' ') for x in lines] 310 | image_ids = [x[0] for x in splitlines] 311 | confidence = np.array([float(x[1]) for x in splitlines]) 312 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]) 313 | 314 | # sort by confidence 315 | sorted_ind = np.argsort(-confidence) 316 | sorted_scores = np.sort(-confidence) 317 | BB = BB[sorted_ind, :] 318 | image_ids = [image_ids[x] for x in sorted_ind] 319 | 320 | # go down dets and mark TPs and FPs 321 | nd = len(image_ids) 322 | tp = np.zeros(nd) 323 | fp = np.zeros(nd) 324 | for d in range(nd): 325 | R = class_recs[image_ids[d]] 326 | bb = BB[d, :].astype(float) 327 | ovmax = -np.inf 328 | BBGT = R['bbox'].astype(float) 329 | if BBGT.size > 0: 330 | # compute overlaps 331 | # intersection 332 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 333 | iymin = np.maximum(BBGT[:, 1], bb[1]) 334 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 335 | iymax = np.minimum(BBGT[:, 3], bb[3]) 336 | iw = np.maximum(ixmax - ixmin, 0.) 337 | ih = np.maximum(iymax - iymin, 0.) 338 | inters = iw * ih 339 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) + 340 | (BBGT[:, 2] - BBGT[:, 0]) * 341 | (BBGT[:, 3] - BBGT[:, 1]) - inters) 342 | overlaps = inters / uni 343 | ovmax = np.max(overlaps) 344 | jmax = np.argmax(overlaps) 345 | 346 | if ovmax > ovthresh: 347 | if not R['difficult'][jmax]: 348 | if not R['det'][jmax]: 349 | tp[d] = 1. 350 | R['det'][jmax] = 1 351 | else: 352 | fp[d] = 1. 353 | else: 354 | fp[d] = 1. 355 | 356 | # compute precision recall 357 | fp = np.cumsum(fp) 358 | tp = np.cumsum(tp) 359 | rec = tp / float(npos) 360 | # avoid divide by zero in case the first detection matches a difficult 361 | # ground truth 362 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 363 | ap = voc_ap(rec, prec, use_07_metric) 364 | else: 365 | rec = -1. 366 | prec = -1. 367 | ap = -1. 368 | 369 | return rec, prec, ap 370 | 371 | 372 | def test_net(save_folder, net, cuda, dataset, transform, top_k, 373 | im_size=300, thresh=0.05): 374 | num_images = len(dataset) 375 | 376 | all_boxes = [[[] for _ in range(num_images)] 377 | for _ in range(len(labelmap)+1)] 378 | 379 | # timers 380 | _t = {'im_detect': Timer(), 'misc': Timer()} 381 | output_dir = get_output_dir('ssd300_120000', set_type) 382 | det_file = os.path.join(output_dir, 'detections.pkl') 383 | 384 | for i in range(num_images): 385 | im, gt, h, w, _ = dataset.pull_item(i) 386 | 387 | x = Variable(im.unsqueeze(0)) 388 | if args.cuda: 389 | x = x.cuda() 390 | _t['im_detect'].tic() 391 | detections = net(x).data 392 | detect_time = _t['im_detect'].toc(average=False) 393 | 394 | # skip j = 0, because it's the background class 395 | for j in range(1, detections.size(1)): 396 | dets = detections[0, j, :] 397 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t() 398 | dets = torch.masked_select(dets, mask).view(-1, 5) 399 | if dets.dim() == 0: 400 | continue 401 | boxes = dets[:, 1:] 402 | boxes[:, 0] *= w 403 | boxes[:, 2] *= w 404 | boxes[:, 1] *= h 405 | boxes[:, 3] *= h 406 | scores = dets[:, 0].cpu().numpy() 407 | cls_dets = np.hstack((boxes.cpu().numpy(), 408 | scores[:, np.newaxis])).astype(np.float32, 409 | copy=False) 410 | all_boxes[j][i] = cls_dets 411 | 412 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, 413 | num_images, detect_time)) 414 | 415 | with open(det_file, 'wb') as f: 416 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 417 | 418 | print('Evaluating detections') 419 | res = evaluate_detections(all_boxes, output_dir, dataset) 420 | return res 421 | 422 | 423 | def evaluate_detections(box_list, output_dir, dataset): 424 | write_voc_results_file(box_list, dataset) 425 | res = do_python_eval(output_dir) 426 | return res 427 | 428 | 429 | if __name__ == '__main__': 430 | """ 431 | student_model_dict = {} 432 | for key, value in model_student.state_dict().items(): 433 | student_model_dict[key] = value 434 | 435 | new_teacher_dict = OrderedDict() 436 | for key, value in model_teacher.state_dict().items(): 437 | if key in student_model_dict.keys(): 438 | new_teacher_dict[key] = ( 439 | student_model_dict[key] * (1 - keep_rate) + value * keep_rate)""" 440 | 441 | # load net 442 | num_classes = len(labelmap) + 1 # +1 for background 443 | net = build_ssd('test', 300, num_classes) # initialize SSD 444 | net = nn.DataParallel(net) 445 | fi_write = open("results.txt", "a") 446 | for key in net.state_dict(): 447 | print(key) 448 | 449 | list_of_folders = ['weights'] # folder where the saved networks are 450 | for folder in list_of_folders: 451 | list_nets = os.listdir(folder) 452 | for nnn in sorted(list_nets): 453 | 454 | net.load_state_dict(torch.load(os.path.join(folder, nnn))) 455 | net.eval() 456 | print('Finished loading model!') 457 | # load data 458 | dataset = VOCDetection(args.voc_root, image_sets=[('2007', 'test')], 459 | transform=BaseTransform(300, dataset_mean), 460 | target_transform=VOCAnnotationTransform()) 461 | if args.cuda: 462 | net = net.cuda() 463 | cudnn.benchmark = True 464 | # evaluation 465 | m_ap = test_net(args.save_folder, net, args.cuda, dataset, 466 | BaseTransform(300, dataset_mean), args.top_k, 300, 467 | thresh=args.confidence_threshold) 468 | print(m_ap) 469 | fi_write.write(folder + '_____' + nnn + ": " +str(np.mean(m_ap))) 470 | fi_write.write('\n') 471 | fi_write.close() 472 | 473 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/box_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/box_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-37.pyc -------------------------------------------------------------------------------- /layers/__pycache__/box_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-38.pyc -------------------------------------------------------------------------------- /layers/box_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def point_form(boxes): 6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 7 | representation for comparison to point form ground truth data. 8 | Args: 9 | boxes: (tensor) center-size default boxes from priorbox layers. 10 | Return: 11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 12 | """ 13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 15 | 16 | 17 | def center_size(boxes): 18 | """ Convert prior_boxes to (cx, cy, w, h) 19 | representation for comparison to center-size form ground truth data. 20 | Args: 21 | boxes: (tensor) point_form boxes 22 | Return: 23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 24 | """ 25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 27 | 28 | 29 | def intersect(box_a, box_b): 30 | """ We resize both tensors to [A,B,2] without new malloc: 31 | [A,2] -> [A,1,2] -> [A,B,2] 32 | [B,2] -> [1,B,2] -> [A,B,2] 33 | Then we compute the area of intersect between box_a and box_b. 34 | Args: 35 | box_a: (tensor) bounding boxes, Shape: [A,4]. 36 | box_b: (tensor) bounding boxes, Shape: [B,4]. 37 | Return: 38 | (tensor) intersection area, Shape: [A,B]. 39 | """ 40 | A = box_a.size(0) 41 | B = box_b.size(0) 42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 46 | inter = torch.clamp((max_xy - min_xy), min=0) 47 | return inter[:, :, 0] * inter[:, :, 1] 48 | 49 | 50 | def jaccard(box_a, box_b): 51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 52 | is simply the intersection over union of two boxes. Here we operate on 53 | ground truth boxes and default boxes. 54 | E.g.: 55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 56 | Args: 57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 59 | Return: 60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 61 | """ 62 | inter = intersect(box_a, box_b) 63 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 65 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 67 | union = area_a + area_b - inter 68 | return inter / union # [A,B] 69 | 70 | 71 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): 72 | """Match each prior box with the ground truth box of the highest jaccard 73 | overlap, encode the bounding boxes, then return the matched indices 74 | corresponding to both confidence and location preds. 75 | Args: 76 | threshold: (float) The overlap threshold used when mathing boxes. 77 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 78 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 79 | variances: (tensor) Variances corresponding to each prior coord, 80 | Shape: [num_priors, 4]. 81 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 82 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets. 83 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. 84 | idx: (int) current batch index 85 | Return: 86 | The matched indices corresponding to 1)location and 2)confidence preds. 87 | """ 88 | # jaccard index 89 | overlaps = jaccard( 90 | truths, 91 | point_form(priors) 92 | ) 93 | # (Bipartite Matching) 94 | # [1,num_objects] best prior for each ground truth 95 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 96 | # [1,num_priors] best ground truth for each prior 97 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 98 | best_truth_idx.squeeze_(0) 99 | best_truth_overlap.squeeze_(0) 100 | best_prior_idx.squeeze_(1) 101 | best_prior_overlap.squeeze_(1) 102 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 103 | # TODO refactor: index best_prior_idx with long tensor 104 | # ensure every gt matches with its prior of max overlap 105 | for j in range(best_prior_idx.size(0)): 106 | best_truth_idx[best_prior_idx[j]] = j 107 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 108 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors] 109 | conf[best_truth_overlap < threshold] = 0 # label as background 110 | loc = encode(matches, priors, variances) 111 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn 112 | conf_t[idx] = conf # [num_priors] top class label for each prior 113 | 114 | 115 | def encode(matched, priors, variances): 116 | """Encode the variances from the priorbox layers into the ground truth boxes 117 | we have matched (based on jaccard overlap) with the prior boxes. 118 | Args: 119 | matched: (tensor) Coords of ground truth for each prior in point-form 120 | Shape: [num_priors, 4]. 121 | priors: (tensor) Prior boxes in center-offset form 122 | Shape: [num_priors,4]. 123 | variances: (list[float]) Variances of priorboxes 124 | Return: 125 | encoded boxes (tensor), Shape: [num_priors, 4] 126 | """ 127 | 128 | # dist b/t match center and prior's center 129 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 130 | # encode variance 131 | g_cxcy /= (variances[0] * priors[:, 2:]) 132 | # match wh / prior wh 133 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 134 | g_wh = torch.log(g_wh) / variances[1] 135 | # return target for smooth_l1_loss 136 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 137 | 138 | 139 | # Adapted from https://github.com/Hakuyume/chainer-ssd 140 | def decode(loc, priors, variances): 141 | """Decode locations from predictions using priors to undo 142 | the encoding we did for offset regression at train time. 143 | Args: 144 | loc (tensor): location predictions for loc layers, 145 | Shape: [num_priors,4] 146 | priors (tensor): Prior boxes in center-offset form. 147 | Shape: [num_priors,4]. 148 | variances: (list[float]) Variances of priorboxes 149 | Return: 150 | decoded bounding box predictions 151 | """ 152 | 153 | boxes = torch.cat(( 154 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 155 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 156 | boxes[:, :2] -= boxes[:, 2:] / 2 157 | boxes[:, 2:] += boxes[:, :2] 158 | return boxes 159 | 160 | 161 | def log_sum_exp(x): 162 | """Utility function for computing log_sum_exp while determining 163 | This will be used to determine unaveraged confidence loss across 164 | all examples in a batch. 165 | Args: 166 | x (Variable(tensor)): conf_preds from conf layers 167 | """ 168 | x_max = x.data.max() 169 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 170 | 171 | 172 | # Original author: Francisco Massa: 173 | # https://github.com/fmassa/object-detection.torch 174 | # Ported to PyTorch by Max deGroot (02/01/2017) 175 | def nms(boxes, scores, overlap=0.5, top_k=200): 176 | """Apply non-maximum suppression at test time to avoid detecting too many 177 | overlapping bounding boxes for a given object. 178 | Args: 179 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 180 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 181 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 182 | top_k: (int) The Maximum number of box preds to consider. 183 | Return: 184 | The indices of the kept boxes with respect to num_priors. 185 | """ 186 | 187 | keep = scores.new(scores.size(0)).zero_().long() 188 | if boxes.numel() == 0: 189 | return keep 190 | x1 = boxes[:, 0] 191 | y1 = boxes[:, 1] 192 | x2 = boxes[:, 2] 193 | y2 = boxes[:, 3] 194 | area = torch.mul(x2 - x1, y2 - y1) 195 | v, idx = scores.sort(0) # sort in ascending order 196 | # I = I[v >= 0.01] 197 | idx = idx[-top_k:] # indices of the top-k largest vals 198 | xx1 = boxes.new() 199 | yy1 = boxes.new() 200 | xx2 = boxes.new() 201 | yy2 = boxes.new() 202 | w = boxes.new() 203 | h = boxes.new() 204 | 205 | # keep = torch.Tensor() 206 | count = 0 207 | while idx.numel() > 0: 208 | i = idx[-1] # index of current largest val 209 | # keep.append(i) 210 | keep[count] = i 211 | count += 1 212 | if idx.size(0) == 1: 213 | break 214 | idx = idx[:-1] # remove kept element from view 215 | # load bboxes of next highest vals 216 | torch.index_select(x1, 0, idx, out=xx1) 217 | torch.index_select(y1, 0, idx, out=yy1) 218 | torch.index_select(x2, 0, idx, out=xx2) 219 | torch.index_select(y2, 0, idx, out=yy2) 220 | # store element-wise max with next highest score 221 | xx1 = torch.clamp(xx1, min=x1[i]) 222 | yy1 = torch.clamp(yy1, min=y1[i]) 223 | xx2 = torch.clamp(xx2, max=x2[i]) 224 | yy2 = torch.clamp(yy2, max=y2[i]) 225 | w.resize_as_(xx2) 226 | h.resize_as_(yy2) 227 | w = xx2 - xx1 228 | h = yy2 - yy1 229 | # check sizes of xx1 and xx2.. after each iteration 230 | w = torch.clamp(w, min=0.0) 231 | h = torch.clamp(h, min=0.0) 232 | inter = w*h 233 | # IoU = i / (area(a) + area(b) - i) 234 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 235 | union = (rem_areas - inter) + area[i] 236 | IoU = inter/union # store result in iou 237 | # keep only elements with an IoU <= overlap 238 | idx = idx[IoU.le(overlap)] 239 | return keep, count 240 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .prior_box import PriorBox 3 | 4 | 5 | __all__ = ['Detect', 'PriorBox'] 6 | -------------------------------------------------------------------------------- /layers/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/detection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/detection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-37.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/detection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-38.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/prior_box.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-36.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/prior_box.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-37.pyc -------------------------------------------------------------------------------- /layers/functions/__pycache__/prior_box.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-38.pyc -------------------------------------------------------------------------------- /layers/functions/detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from ..box_utils import decode, nms 4 | from data import voc300 as cfg 5 | from data import voc512 as cfg 6 | 7 | 8 | class Detect(Function): 9 | """At test time, Detect is the final layer of SSD. Decode location preds, 10 | apply non-maximum suppression to location predictions based on conf 11 | scores and threshold to a top_k number of output predictions for both 12 | confidence score and locations. 13 | """ 14 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): 15 | self.num_classes = num_classes 16 | self.background_label = bkg_label 17 | self.top_k = top_k 18 | # Parameters used in nms. 19 | self.nms_thresh = nms_thresh 20 | if nms_thresh <= 0: 21 | raise ValueError('nms_threshold must be non negative.') 22 | self.conf_thresh = conf_thresh 23 | self.variance = [0.1, 0.2] #cfg['variance'] 24 | 25 | def forward(self, loc_data, conf_data, prior_data): 26 | """ 27 | Args: 28 | loc_data: (tensor) Loc preds from loc layers 29 | Shape: [batch,num_priors*4] 30 | conf_data: (tensor) Shape: Conf preds from conf layers 31 | Shape: [batch*num_priors,num_classes] 32 | prior_data: (tensor) Prior boxes and variances from priorbox layers 33 | Shape: [1,num_priors,4] 34 | """ 35 | num = loc_data.size(0) # batch size 36 | num_priors = prior_data.size(0) 37 | output = torch.zeros(num, self.num_classes, self.top_k, 5) 38 | conf_preds = conf_data.view(num, num_priors, 39 | self.num_classes).transpose(2, 1) 40 | 41 | # Decode predictions into bboxes. 42 | for i in range(num): 43 | decoded_boxes = decode(loc_data[i], prior_data, self.variance) 44 | # For each class, perform nms 45 | conf_scores = conf_preds[i].clone() 46 | 47 | for cl in range(1, self.num_classes): 48 | c_mask = conf_scores[cl].gt(self.conf_thresh) 49 | scores = conf_scores[cl][c_mask] 50 | if scores.size(0) == 0: 51 | continue 52 | # if scores.dim() == 0: 53 | # continue 54 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 55 | boxes = decoded_boxes[l_mask].view(-1, 4) 56 | # idx of highest scoring and non-overlapping boxes per class 57 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) 58 | output[i, cl, :count] = \ 59 | torch.cat((scores[ids[:count]].unsqueeze(1), 60 | boxes[ids[:count]]), 1) 61 | flt = output.contiguous().view(num, -1, 5) 62 | _, idx = flt[:, :, 0].sort(1, descending=True) 63 | _, rank = idx.sort(1) 64 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) 65 | return output 66 | -------------------------------------------------------------------------------- /layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from math import sqrt as sqrt 3 | from itertools import product as product 4 | import torch 5 | 6 | 7 | class PriorBox(object): 8 | """Compute priorbox coordinates in center-offset form for each source 9 | feature map. 10 | """ 11 | def __init__(self, cfg): 12 | super(PriorBox, self).__init__() 13 | self.image_size = cfg['min_dim'] 14 | # number of priors for feature map location (either 4 or 6) 15 | self.num_priors = len(cfg['aspect_ratios']) 16 | self.variance = cfg['variance'] or [0.1] 17 | self.feature_maps = cfg['feature_maps'] 18 | self.min_sizes = cfg['min_sizes'] 19 | self.max_sizes = cfg['max_sizes'] 20 | self.steps = cfg['steps'] 21 | self.aspect_ratios = cfg['aspect_ratios'] 22 | self.clip = cfg['clip'] 23 | self.version = cfg['name'] 24 | for v in self.variance: 25 | if v <= 0: 26 | raise ValueError('Variances must be greater than 0') 27 | 28 | def forward(self): 29 | mean = [] 30 | for k, f in enumerate(self.feature_maps): 31 | for i, j in product(range(f), repeat=2): 32 | f_k = self.image_size / self.steps[k] 33 | # unit center x,y 34 | cx = (j + 0.5) / f_k 35 | cy = (i + 0.5) / f_k 36 | 37 | # aspect_ratio: 1 38 | # rel size: min_size 39 | s_k = self.min_sizes[k]/self.image_size 40 | mean += [cx, cy, s_k, s_k] 41 | 42 | # aspect_ratio: 1 43 | # rel size: sqrt(s_k * s_(k+1)) 44 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) 45 | mean += [cx, cy, s_k_prime, s_k_prime] 46 | 47 | # rest of aspect ratios 48 | for ar in self.aspect_ratios[k]: 49 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] 50 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] 51 | # back to torch land 52 | output = torch.Tensor(mean).view(-1, 4) 53 | if self.clip: 54 | output.clamp_(max=1, min=0) 55 | return output 56 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .l2norm import L2Norm 2 | from .multibox_loss import MultiBoxLoss 3 | 4 | __all__ = ['L2Norm', 'MultiBoxLoss'] 5 | -------------------------------------------------------------------------------- /layers/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/l2norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/l2norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-37.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/l2norm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-38.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-37.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-38.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss_gmm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_gmm.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-36.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss_new.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-37.pyc -------------------------------------------------------------------------------- /layers/modules/__pycache__/multibox_loss_new.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-38.pyc -------------------------------------------------------------------------------- /layers/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | 7 | class L2Norm(nn.Module): 8 | def __init__(self,n_channels, scale): 9 | super(L2Norm,self).__init__() 10 | self.n_channels = n_channels 11 | self.gamma = scale or None 12 | self.eps = 1e-10 13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels)) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | init.constant(self.weight,self.gamma) 18 | 19 | def forward(self, x): 20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps 21 | #x /= norm 22 | x = torch.div(x,norm) 23 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x 24 | return out 25 | -------------------------------------------------------------------------------- /layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | # -*- coding: utf-8 -*- 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | from data import voc300 as cfg 14 | from ..box_utils import match, log_sum_exp 15 | import numpy as np 16 | 17 | 18 | class MultiBoxLoss(nn.Module): 19 | """SSD Weighted Loss Function 20 | Compute Targets: 21 | 1) Produce Confidence Target Indices by matching ground truth boxes 22 | with (default) 'priorboxes' that have jaccard index > threshold parameter 23 | (default threshold: 0.5). 24 | 2) Produce localization target by 'encoding' variance into offsets of ground 25 | truth boxes and their matched 'priorboxes'. 26 | 3) Hard negative mining to filter the excessive number of negative examples 27 | that comes with using a large number of default bounding boxes. 28 | (default negative:positive ratio 3:1) 29 | Objective Loss: 30 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 31 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 32 | weighted by α which is set to 1 by cross val. 33 | Args: 34 | c: class confidences, 35 | l: predicted boxes, 36 | g: ground truth boxes 37 | N: number of matched default boxes 38 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 39 | """ 40 | 41 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, 42 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, 43 | use_gpu=True): 44 | super(MultiBoxLoss, self).__init__() 45 | self.use_gpu = use_gpu 46 | self.num_classes = num_classes 47 | self.threshold = overlap_thresh 48 | self.background_label = bkg_label 49 | self.encode_target = encode_target 50 | self.use_prior_for_matching = prior_for_matching 51 | self.do_neg_mining = neg_mining 52 | self.negpos_ratio = neg_pos 53 | self.neg_overlap = neg_overlap 54 | self.variance = cfg['variance'] 55 | 56 | def forward(self, predictions, targets, semis=[]): 57 | """Multibox Loss 58 | Args: 59 | predictions (tuple): A tuple containing loc preds, conf preds, 60 | and prior boxes from SSD net. 61 | conf shape: torch.size(batch_size,num_priors,num_classes) 62 | loc shape: torch.size(batch_size,num_priors,4) 63 | priors shape: torch.size(num_priors,4) 64 | 65 | targets (tensor): Ground truth boxes and labels for a batch, 66 | shape: [batch_size,num_objs,5] (last idx is the label). 67 | """ 68 | loc_data, conf_data, priors = predictions 69 | num = loc_data.size(0) 70 | priors = priors[:loc_data.size(1), :] 71 | num_priors = (priors.size(0)) 72 | num_classes = self.num_classes 73 | 74 | # match priors (default boxes) and ground truth boxes 75 | loc_t = torch.Tensor(num, num_priors, 4) 76 | conf_t = torch.LongTensor(num, num_priors) 77 | for idx in range(num): 78 | truths = targets[idx][:, :-1].data 79 | labels = targets[idx][:, -1].data 80 | defaults = priors.data 81 | match(self.threshold, truths, defaults, self.variance, labels, 82 | loc_t, conf_t, idx) 83 | if self.use_gpu: 84 | loc_t = loc_t.cuda() 85 | conf_t = conf_t.cuda() 86 | # wrap targets 87 | loc_t = Variable(loc_t, requires_grad=False) 88 | conf_t = Variable(conf_t, requires_grad=False) 89 | 90 | pos = conf_t > 0 91 | num_pos = pos.sum(dim=1, keepdim=True) 92 | 93 | # Localization Loss (Smooth L1) 94 | # Shape: [batch,num_priors,4] 95 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 96 | loc_p = loc_data[pos_idx].view(-1, 4) 97 | loc_t = loc_t[pos_idx].view(-1, 4) 98 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) 99 | 100 | # Compute max conf across batch for hard negative mining 101 | batch_conf = conf_data.view(-1, self.num_classes) 102 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 103 | 104 | # Hard Negative Mining 105 | loss_c = loss_c.view(pos.size()[0], pos.size()[1]) 106 | loss_c[pos] = 0 # filter out pos boxes for now 107 | loss_c = loss_c.view(num, -1) 108 | _, loss_idx = loss_c.sort(1, descending=True) 109 | _, idx_rank = loss_idx.sort(1) 110 | num_pos = pos.long().sum(1, keepdim=True) 111 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 112 | neg = idx_rank < num_neg.expand_as(idx_rank) 113 | 114 | # Confidence Loss Including Positive and Negative Examples 115 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 116 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 117 | 118 | real_labels, pseudo_labels = [False] * len(semis), [False] * len(semis) 119 | for i, el in enumerate(semis): 120 | if el == torch.FloatTensor([1.]): 121 | real_labels[i] = True 122 | elif el == torch.FloatTensor([2.]): 123 | pseudo_labels[i] = True 124 | 125 | pos_idx_labels = pos_idx[real_labels] 126 | neg_idx = neg_idx[real_labels] 127 | conf_data_real = conf_data[real_labels] 128 | conf_t_real = conf_t[real_labels] 129 | pos_real = pos[real_labels] 130 | neg_real = neg[real_labels] 131 | pos_idx_pseudo = pos_idx[pseudo_labels] 132 | conf_data_pseudo = conf_data[pseudo_labels] 133 | conf_t_pseudo = conf_t[pseudo_labels] 134 | pos_pseudo = pos[pseudo_labels] 135 | 136 | conf_p_labels = conf_data_real[(pos_idx_labels+neg_idx).gt(0)].view(-1, self.num_classes) 137 | conf_p_pseudo = conf_data_pseudo[(pos_idx_pseudo).gt(0)].view(-1, self.num_classes) 138 | targets_weighted_labels = conf_t_real[(pos_real + neg_real).gt(0)] 139 | targets_weighted_pseudo = conf_t_pseudo[(pos_pseudo).gt(0)] 140 | 141 | loss_c_real = F.cross_entropy(conf_p_labels, targets_weighted_labels, size_average=False) 142 | if targets_weighted_pseudo.shape[0] > 0: 143 | loss_c_pseudo = F.cross_entropy(conf_p_pseudo, targets_weighted_pseudo, size_average=False) 144 | loss_c = loss_c_real + loss_c_pseudo 145 | else: 146 | loss_c = loss_c_real 147 | 148 | N = num_pos.data.sum() 149 | loss_l /= N 150 | loss_c /= N 151 | return loss_l, loss_c 152 | -------------------------------------------------------------------------------- /loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | import random 9 | import torch.utils.data as data 10 | 11 | from data import * 12 | from pseudo_labels import predict_pseudo_labels 13 | from subset_sequential_sampler import SubsetSequentialSampler, BalancedSampler 14 | from utils.augmentations import SSDAugmentation 15 | 16 | random.seed(314) 17 | 18 | 19 | def create_loaders(args): 20 | indices_1 = list(range(args.num_total_images)) 21 | random.shuffle(indices_1) 22 | labeled_set = indices_1[:args.num_initial_labeled_set] 23 | indices = list(range(args.num_total_images)) 24 | unlabeled_set = set(indices) - set(labeled_set) 25 | labeled_set = list(labeled_set) 26 | unlabeled_set = list(unlabeled_set) 27 | random.shuffle(indices) 28 | 29 | 30 | print(len(indices)) 31 | 32 | if args.dataset_name == 'voc': 33 | supervised_dataset = VOCDetection(root=args.dataset_root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 34 | supervised_indices=labeled_set, 35 | transform=SSDAugmentation(args.cfg['min_dim'], MEANS)) 36 | 37 | unsupervised_dataset = VOCDetection(args.dataset_root, supervised_indices=None, 38 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 39 | transform=BaseTransform(300, MEANS), 40 | target_transform=VOCAnnotationTransform()) 41 | 42 | else: 43 | supervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=labeled_set, image_set='train2014', 44 | transform=SSDAugmentation(args.cfg['min_dim']), 45 | target_transform=COCOAnnotationTransform(), 46 | dataset_name='MS COCO') 47 | 48 | unsupervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=None, image_set='val2014', 49 | transform=SSDAugmentation(args.cfg['min_dim']), 50 | target_transform=COCOAnnotationTransform(), 51 | dataset_name='MS COCO') 52 | 53 | 54 | supervised_data_loader = data.DataLoader(supervised_dataset, batch_size=args.batch_size, 55 | num_workers=args.num_workers, 56 | sampler=BalancedSampler(indices, labeled_set, unlabeled_set, ratio=1), 57 | collate_fn=detection_collate, 58 | pin_memory=True) 59 | 60 | unsupervised_data_loader = data.DataLoader(unsupervised_dataset, batch_size=1, 61 | sampler=SubsetSequentialSampler(unlabeled_set), 62 | num_workers=args.num_workers, 63 | collate_fn=detection_collate, 64 | pin_memory=True) 65 | 66 | return supervised_dataset, supervised_data_loader, unsupervised_dataset, unsupervised_data_loader, indices, labeled_set, unlabeled_set 67 | 68 | 69 | def change_loaders(args, supervised_dataset, unsupervised_dataset, labeled_set, 70 | unlabeled_set, indices, net, pseudo=True): 71 | print("Labeled set size: " + str(len(labeled_set))) 72 | unsupervised_data_loader = data.DataLoader(unsupervised_dataset, batch_size=1, 73 | sampler=SubsetSequentialSampler(unlabeled_set), 74 | num_workers=args.num_workers, 75 | collate_fn=detection_collate, 76 | pin_memory=True) 77 | 78 | if pseudo: 79 | if args.dataset_name == 'voc': 80 | voc = True 81 | num_classes = 21 82 | else: 83 | voc = False 84 | num_classes = 81 85 | pseudo_labels = predict_pseudo_labels(unlabeled_set=unlabeled_set, net_name=net, 86 | threshold=args.pseudo_threshold, root=args.dataset_root, 87 | voc=voc, num_classes=num_classes) 88 | else: 89 | pseudo_labels = {} 90 | 91 | if args.dataset_name == 'voc': 92 | supervised_dataset = VOCDetection(root=args.dataset_root, 93 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 94 | supervised_indices=labeled_set, 95 | transform=SSDAugmentation(args.cfg['min_dim'], MEANS), 96 | pseudo_labels=pseudo_labels) 97 | else: 98 | supervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=labeled_set, image_set='train2014', 99 | transform=SSDAugmentation(args.cfg['min_dim']), 100 | target_transform=COCOAnnotationTransform(), 101 | dataset_name='MS COCO', 102 | pseudo_labels=pseudo_labels) 103 | 104 | print("Changing the loaders") 105 | 106 | supervised_data_loader = data.DataLoader(supervised_dataset, batch_size=args.batch_size, 107 | num_workers=args.num_workers, 108 | sampler=BalancedSampler(indices, labeled_set, unlabeled_set, ratio=1), 109 | collate_fn=detection_collate, 110 | pin_memory=True) 111 | 112 | return supervised_data_loader, unsupervised_data_loader 113 | -------------------------------------------------------------------------------- /pseudo_labels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | import os 9 | import sys 10 | module_path = os.path.abspath(os.path.join('..')) 11 | if module_path not in sys.path: 12 | sys.path.append(module_path) 13 | import random 14 | import pickle 15 | from torch.autograd import Variable 16 | from ssd import build_ssd 17 | import torch.nn as nn 18 | from data import * 19 | from data import VOC_CLASSES as labels1 20 | from data import COCO_CLASSES as labels_2 21 | from collections import defaultdict 22 | import random 23 | random.seed(314) 24 | torch.manual_seed(314) 25 | 26 | 27 | def predict_pseudo_labels(unlabeled_set, net_name, threshold=0.5, root='../tmp/VOC0712/', voc=1, num_classes=21): 28 | labels = labels1 if voc else labels_2 29 | if voc: 30 | testset = VOCDetection(root=root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')], 31 | supervised_indices=None, 32 | transform=None) 33 | else: 34 | testset = COCODetection(root=root, image_set='train2014', 35 | supervised_indices=None, 36 | transform=None) 37 | 38 | print("Doing PL") 39 | net = build_ssd('test', 300, num_classes) 40 | net = nn.DataParallel(net) 41 | net.load_state_dict(torch.load(net_name)) 42 | boxes = get_pseudo_labels(testset, net, labels, unlabeled_set=unlabeled_set, threshold=threshold, voc=voc) 43 | return boxes 44 | 45 | 46 | def get_pseudo_labels(testset, net, labels, unlabeled_set=None, threshold=0.99, voc=1): 47 | 48 | boxes = defaultdict(list) 49 | for ii, img_id in enumerate(unlabeled_set): 50 | print(ii) 51 | image = testset.pull_image(img_id) 52 | x = cv2.resize(image, (300, 300)).astype(np.float32) 53 | x -= (104.0, 117.0, 123.0) 54 | 55 | x = x.astype(np.float32) 56 | x = x[:, :, ::-1].copy() 57 | x = torch.from_numpy(x).permute(2, 0, 1) 58 | 59 | xx = Variable(x.unsqueeze(0)) # wrap tensor in Variable 60 | if torch.cuda.is_available(): 61 | xx = xx.cuda() 62 | y = net(xx) 63 | 64 | detections = y.data 65 | # scale each detection back up to the image 66 | scale = torch.Tensor(image.shape[1::-1]).repeat(2) 67 | 68 | for i in range(detections.size(1)): 69 | j = 0 70 | while detections[0, i, j, 0] >= threshold: 71 | score = detections[0, i, j, 0] 72 | if voc == 1: 73 | label_name = labels[i - 1] 74 | else: 75 | label_name = labels[i - 1] 76 | pt = (detections[0, i, j, 1:5] * scale).cpu().numpy() 77 | j += 1 78 | # sore as [prediction_confidence, label_id, label_name, height, width, bbox_coordiates in range [0, inf] 79 | boxes[img_id].append([score.cpu().detach().item(), (i-1), label_name, image.shape[0], image.shape[1], int(pt[0]), int(pt[1]), int(pt[2]), int(pt[3])]) 80 | 81 | return boxes 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from layers import * 6 | from data import voc300, voc512, coco 7 | import os 8 | 9 | 10 | class SSD(nn.Module): 11 | """Single Shot Multibox Architecture 12 | The network is composed of a base VGG network followed by the 13 | added multibox conv layers. Each multibox layer branches into 14 | 1) conv2d for class conf scores 15 | 2) conv2d for localization predictions 16 | 3) associated priorbox layer to produce default bounding 17 | boxes specific to the layer's feature map size. 18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 19 | 20 | Args: 21 | phase: (string) Can be "test" or "train" 22 | size: input image size 23 | base: VGG16 layers for input, size of either 300 or 500 24 | extras: extra layers that feed to multibox loc and conf layers 25 | head: "multibox head" consists of loc and conf conv layers 26 | """ 27 | 28 | def __init__(self, phase, size, base, extras, head, num_classes): 29 | super(SSD, self).__init__() 30 | self.phase = phase 31 | self.num_classes = num_classes 32 | if(size==300): 33 | self.cfg = (coco, voc300)[num_classes == 21] 34 | else: 35 | self.cfg = (coco, voc512)[num_classes == 21] 36 | self.priorbox = PriorBox(self.cfg) 37 | self.priors = Variable(self.priorbox.forward(), volatile=True) 38 | self.size = size 39 | 40 | # SSD network 41 | self.vgg = nn.ModuleList(base) 42 | # Layer learns to scale the l2 normalized features from conv4_3 43 | self.L2Norm = L2Norm(512, 20) 44 | self.extras = nn.ModuleList(extras) 45 | 46 | self.loc = nn.ModuleList(head[0]) 47 | self.conf = nn.ModuleList(head[1]) 48 | 49 | if phase == 'test': 50 | self.softmax = nn.Softmax(dim=-1) 51 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 52 | 53 | def forward(self, x): 54 | """Applies network layers and ops on input image(s) x. 55 | 56 | Args: 57 | x: input image or batch of images. Shape: [batch,3,300,300]. 58 | 59 | Return: 60 | Depending on phase: 61 | test: 62 | Variable(tensor) of output class label predictions, 63 | confidence score, and corresponding location predictions for 64 | each object detected. Shape: [batch,topk,7] 65 | 66 | train: 67 | list of concat outputs from: 68 | 1: confidence layers, Shape: [batch*num_priors,num_classes] 69 | 2: localization layers, Shape: [batch,num_priors*4] 70 | 3: priorbox layers, Shape: [2,num_priors*4] 71 | """ 72 | sources = list() 73 | loc = list() 74 | conf = list() 75 | 76 | # apply vgg up to conv4_3 relu 77 | for k in range(23): 78 | x = self.vgg[k](x) 79 | 80 | s = self.L2Norm(x) 81 | sources.append(s) 82 | 83 | # apply vgg up to fc7 84 | for k in range(23, len(self.vgg)): 85 | x = self.vgg[k](x) 86 | sources.append(x) 87 | 88 | # apply extra layers and cache source layer outputs 89 | for k, v in enumerate(self.extras): 90 | x = F.relu(v(x), inplace=True) 91 | if k % 2 == 1: 92 | sources.append(x) 93 | 94 | # apply multibox head to source layers 95 | for (x, l, c) in zip(sources, self.loc, self.conf): 96 | loc.append(l(x).permute(0, 2, 3, 1).contiguous()) 97 | conf.append(c(x).permute(0, 2, 3, 1).contiguous()) 98 | 99 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) 100 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) 101 | if self.phase == "test": 102 | output = self.detect( 103 | loc.view(loc.size(0), -1, 4), 104 | self.softmax(conf.view(conf.size(0), -1, 105 | self.num_classes)), 106 | self.priors.type(type(x.data)) 107 | ) 108 | else: 109 | output = ( 110 | loc.view(loc.size(0), -1, 4), 111 | conf.view(conf.size(0), -1, self.num_classes), 112 | self.priors 113 | ) 114 | return output 115 | 116 | def load_weights(self, base_file): 117 | other, ext = os.path.splitext(base_file) 118 | if ext == '.pkl' or '.pth': 119 | print('Loading weights into state dict...') 120 | self.load_state_dict(torch.load(base_file, 121 | map_location=lambda storage, loc: storage)) 122 | print('Finished!') 123 | else: 124 | print('Sorry only .pth and .pkl files supported.') 125 | 126 | 127 | # This function is derived from torchvision VGG make_layers() 128 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 129 | def vgg(cfg, i, batch_norm=False): 130 | layers = [] 131 | in_channels = i 132 | for v in cfg: 133 | if v == 'M': 134 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 135 | elif v == 'C': 136 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 137 | else: 138 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 139 | if batch_norm: 140 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 141 | else: 142 | layers += [conv2d, nn.ReLU(inplace=True)] 143 | in_channels = v 144 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 145 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 146 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 147 | layers += [pool5, conv6, 148 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 149 | return layers 150 | 151 | 152 | def add_extras(cfg, i, batch_norm=False): 153 | # Extra layers added to VGG for feature scaling 154 | layers = [] 155 | in_channels = i 156 | flag = False 157 | for k, v in enumerate(cfg): 158 | if in_channels != 'S': 159 | if v == 'S': 160 | layers += [nn.Conv2d(in_channels, cfg[k + 1], 161 | kernel_size=(1, 3)[flag], stride=2, padding=1)] 162 | elif v=='K': 163 | layers += [nn.Conv2d(in_channels, 256, 164 | kernel_size=4, stride=1, padding=1)] 165 | else: 166 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] 167 | flag = not flag 168 | in_channels = v 169 | return layers 170 | 171 | 172 | 173 | def multibox(vgg, extra_layers, cfg, num_classes): 174 | loc_layers = [] 175 | conf_layers = [] 176 | vgg_source = [21, -2] 177 | for k, v in enumerate(vgg_source): 178 | loc_layers += [nn.Conv2d(vgg[v].out_channels, 179 | cfg[k] * 4, kernel_size=3, padding=1)] 180 | conf_layers += [nn.Conv2d(vgg[v].out_channels, 181 | cfg[k] * num_classes, kernel_size=3, padding=1)] 182 | for k, v in enumerate(extra_layers[1::2], 2): 183 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k] 184 | * 4, kernel_size=3, padding=1)] 185 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k] 186 | * num_classes, kernel_size=3, padding=1)] 187 | return vgg, extra_layers, (loc_layers, conf_layers) 188 | 189 | 190 | base = { 191 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 192 | 512, 512, 512], 193 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 194 | 512, 512, 512], 195 | } 196 | extras = { 197 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], 198 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'], 199 | } 200 | mbox = { 201 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location 202 | '512': [4, 6, 6, 6, 6, 4, 4], 203 | } 204 | 205 | 206 | 207 | def build_ssd(phase, size=300, num_classes=21): 208 | if phase != "test" and phase != "train": 209 | print("ERROR: Phase: " + phase + " not recognized") 210 | return 211 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3), 212 | add_extras(extras[str(size)], 1024), 213 | mbox[str(size)], num_classes) 214 | return SSD(phase, size, base_, extras_, head_, num_classes) 215 | -------------------------------------------------------------------------------- /subset_sequential_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | from torch.utils.data.sampler import Sampler 9 | import random 10 | import sys 11 | from copy import deepcopy 12 | 13 | 14 | class SubsetSequentialSampler(Sampler): 15 | """Samples elements randomly from a given list of indices, without replacement. 16 | Arguments: 17 | indices (list): a list of indices 18 | """ 19 | 20 | def __init__(self, indices): 21 | self.indices = indices 22 | 23 | def __iter__(self): 24 | return (self.indices[i] for i in range(len(self.indices))) 25 | 26 | def __len__(self): 27 | return len(self.indices) 28 | 29 | 30 | class BalancedSubsetSequentialSampler(Sampler): 31 | """Samples in a balanced way images ensuring that exactly half of the images have labels. 32 | Arguments: 33 | indices (list): a list of indices 34 | """ 35 | 36 | 37 | def __init__(self, indices, supervised, unsupervised): 38 | 39 | self.indices = indices 40 | random.shuffle(supervised) 41 | random.shuffle(unsupervised) 42 | self.len_supervised = len(supervised) 43 | self.len_unsupervised = len(unsupervised) 44 | if self.len_supervised > self.len_unsupervised: 45 | ratio = self.len_supervised // self.len_unsupervised 46 | module = self.len_supervised % self.len_unsupervised 47 | self.supervised = supervised 48 | self.unsupervised = ratio * unsupervised + unsupervised[:module] 49 | else: 50 | ratio = self.len_unsupervised // self.len_supervised 51 | module = self.len_unsupervised % self.len_supervised 52 | self.unsupervised = unsupervised 53 | self.supervised = ratio * supervised + supervised[:module] 54 | print(len(self.supervised), len(self.unsupervised)) 55 | 56 | def _add_lists_alternatively(self, lst1, lst2): 57 | random.shuffle(lst1) 58 | random.shuffle(lst2) 59 | return [sub[item] for item in range(len(lst2)) 60 | for sub in [lst1, lst2]] 61 | 62 | def __iter__(self): 63 | all_indices = self._add_lists_alternatively(self.supervised, self.unsupervised) 64 | return (all_indices[i] for i in range(len(all_indices))) 65 | 66 | def __len__(self): 67 | return len(self.len_supervised) + len(self.unsupervised) 68 | 69 | 70 | class BalancedSampler(Sampler): 71 | """Samples elements randomly from a given list of indices, without replacement. 72 | Arguments: 73 | indices (list): a list of indices 74 | """ 75 | def __init__(self, indices, supervised, unsupervised, ratio=1): 76 | 77 | self.indices = indices 78 | self.supervised = supervised 79 | self.unsupervised = unsupervised 80 | self.new_list = [] 81 | self.ratio = ratio 82 | 83 | def _create_balanced_batch(self): 84 | copy_supervised = deepcopy(self.supervised) 85 | copy_unsupervised = deepcopy(self.unsupervised) 86 | random.shuffle(copy_supervised) 87 | random.shuffle(copy_unsupervised) 88 | self.new_list = [] 89 | while copy_supervised: 90 | self.new_list.append(copy_supervised.pop()) 91 | if len(copy_unsupervised) < self.ratio: 92 | copy_unsupervised = deepcopy(self.unsupervised) 93 | random.shuffle(copy_unsupervised) 94 | for i in range(self.ratio): 95 | self.new_list.append(copy_unsupervised.pop()) 96 | return self.new_list 97 | 98 | def __iter__(self): 99 | all_indices = self._create_balanced_batch() 100 | return (all_indices[i] for i in range(len(all_indices))) 101 | 102 | def __len__(self): 103 | return len(self.new_list) 104 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | from __future__ import print_function 9 | import sys 10 | import os 11 | import argparse 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | import torchvision.transforms as transforms 16 | from torch.autograd import Variable 17 | from data import VOC_ROOT, VOC_CLASSES as labelmap 18 | from PIL import Image 19 | from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES 20 | import torch.utils.data as data 21 | from ssd import build_ssd 22 | 23 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection') 24 | parser.add_argument('--trained_model', default='weights_voc_code/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth', 25 | type=str, help='Trained state_dict file path to open') 26 | # parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth', 27 | # type=str, help='Trained state_dict file path to open') 28 | parser.add_argument('--save_folder', default='eval/', type=str, 29 | help='Dir to save results') 30 | parser.add_argument('--visual_threshold', default=0.6, type=float, 31 | help='Final confidence threshold') 32 | parser.add_argument('--cuda', default=True, type=bool, 33 | help='Use cuda to train model') 34 | parser.add_argument('--voc_root', default='/usr/wiss/elezi/data/VOC0712', help='Location of VOC root directory') 35 | parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks") 36 | args = parser.parse_args() 37 | 38 | if args.cuda and torch.cuda.is_available(): 39 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 40 | else: 41 | torch.set_default_tensor_type('torch.FloatTensor') 42 | 43 | if not os.path.exists(args.save_folder): 44 | os.mkdir(args.save_folder) 45 | 46 | 47 | def test_net(save_folder, net, cuda, testset, transform, thresh): 48 | # dump predictions and assoc. ground truth to text file for now 49 | filename = save_folder+'test1.txt' 50 | num_images = len(testset) 51 | for i in range(num_images): 52 | print('Testing image {:d}/{:d}....'.format(i+1, num_images)) 53 | img = testset.pull_image(i) 54 | img_id, annotation = testset.pull_anno(i) 55 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1) 56 | x = Variable(x.unsqueeze(0)) 57 | 58 | with open(filename, mode='a') as f: 59 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n') 60 | for box in annotation: 61 | f.write('label: '+' || '.join(str(b) for b in box)+'\n') 62 | if cuda: 63 | x = x.cuda() 64 | 65 | y = net(x) # forward pass 66 | detections = y.data 67 | # scale each detection back up to the image 68 | scale = torch.Tensor([img.shape[1], img.shape[0], 69 | img.shape[1], img.shape[0]]) 70 | pred_num = 0 71 | for i in range(detections.size(1)): 72 | j = 0 73 | while detections[0, i, j, 0] >= 0.6: 74 | if pred_num == 0: 75 | with open(filename, mode='a') as f: 76 | f.write('PREDICTIONS: '+'\n') 77 | score = detections[0, i, j, 0] 78 | label_name = labelmap[i-1] 79 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy() 80 | coords = (pt[0], pt[1], pt[2], pt[3]) 81 | pred_num += 1 82 | with open(filename, mode='a') as f: 83 | f.write(str(pred_num)+' label: '+label_name+' score: ' + 84 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n') 85 | j += 1 86 | 87 | 88 | def test_voc(): 89 | # load net 90 | num_classes = len(VOC_CLASSES) + 1 # +1 background 91 | net = build_ssd('test', 300, num_classes) # initialize SSD 92 | net = nn.DataParallel(net) 93 | net.load_state_dict(torch.load(args.trained_model)) 94 | net.eval() 95 | print('Finished loading model!') 96 | # load data 97 | testset = VOCDetection(args.voc_root, image_sets=[('2007', 'test')], 98 | transform=None, target_transform=VOCAnnotationTransform()) 99 | if args.cuda: 100 | net = net.cuda() 101 | cudnn.benchmark = True 102 | # evaluation 103 | test_net(args.save_folder, net, args.cuda, testset, 104 | BaseTransform(300, (104, 117, 123)), 105 | thresh=args.visual_threshold) 106 | 107 | if __name__ == '__main__': 108 | test_voc() 109 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://github.com/NVlabs/AL-SSL/ 6 | 7 | 8 | import warnings 9 | import argparse 10 | import math 11 | import numpy as np 12 | import random 13 | import os 14 | import time 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import torch.backends.cudnn as cudnn 19 | import torch.nn.init as init 20 | 21 | from active_learning import combined_score, active_learning_inconsistency, active_learning_entropy 22 | from csd import build_ssd_con 23 | from data import * 24 | from layers.modules import MultiBoxLoss 25 | from loaders import create_loaders, change_loaders 26 | 27 | random.seed(314) 28 | torch.manual_seed(314) 29 | 30 | warnings.filterwarnings("ignore") 31 | 32 | 33 | def str2bool(v): 34 | return v.lower() in ("yes", "true", "t", "1") 35 | 36 | 37 | class get_al_hyperparams(): 38 | def __init__(self, dataset_name='voc'): 39 | self.dataset_name = dataset_name 40 | self.dataset_path = {'voc': '/usr/wiss/elezi/data/VOC0712', 41 | 'coco': '/usr/wiss/elezi/data/coco'} 42 | 43 | self.num_ims = {'voc': 16551, 'coco': 82081} 44 | self.num_init = {'voc': 2011, 'coco': 5000} 45 | self.pseudo_threshold = {'voc': 0.99, 'coco': 0.75} 46 | self.config = {'voc': voc300, 'coco': coco} 47 | self.batch_size = {'voc': 16, 'coco': 32} 48 | 49 | def get_dataset_path(self): 50 | return self.dataset_path[self.dataset_name] 51 | 52 | def get_num_ims(self): 53 | return self.num_ims[self.dataset_name] 54 | 55 | def get_num_init(self): 56 | return self.num_init[self.dataset_name] 57 | 58 | def get_pseudo_threshold(self): 59 | return self.pseudo_threshold[self.dataset_name] 60 | 61 | def get_config(self): 62 | return self.config[self.dataset_name] 63 | 64 | def get_dataset_name(self): 65 | return self.dataset_name 66 | 67 | def get_batch_size(self): 68 | return self.batch_size[self.dataset_name] 69 | 70 | 71 | 72 | parser = argparse.ArgumentParser( 73 | description='Single Shot MultiBox Detector Training With Pytorch') 74 | train_set = parser.add_mutually_exclusive_group() 75 | al_hyperparams = get_al_hyperparams('voc') # 'voc' for voc, 'coco' for coco 76 | parser.add_argument('--dataset_name', default=al_hyperparams.get_dataset_name(), type=str, 77 | help='Dataset name') 78 | parser.add_argument('--cfg', default=al_hyperparams.get_config(), type=dict, 79 | help='configurer for the specific dataset') 80 | parser.add_argument('--dataset', default='VOC300', choices=['VOC300', 'VOC512'], 81 | type=str, help='VOC300 or VOC512') 82 | parser.add_argument('--dataset_root', default=al_hyperparams.get_dataset_path(), type=str, 83 | help='Dataset root directory path') 84 | parser.add_argument('--num_total_images', default=al_hyperparams.get_num_ims(), type=int, 85 | help='Number of images in the dataset') 86 | parser.add_argument('--num_initial_labeled_set', default=al_hyperparams.get_num_init(), type=int, 87 | help='Number of initially labeled images') 88 | parser.add_argument('--acquisition_budget', default=1000, type=int, 89 | help='Active labeling cycle budget') 90 | parser.add_argument('--num_cycles', default=5, type=int, 91 | help='Number of active learning cycles') 92 | parser.add_argument('--criterion_select', default='combined', 93 | choices=['random', 'entropy', 'consistency', 'combined'], 94 | help='Active learning acquisition score') 95 | parser.add_argument('--filter_entropy_num', default=3000, type=int, 96 | help='How many samples to pre-filer with entropy') 97 | parser.add_argument('--id', default=2, type=int, 98 | help='the id of the experiment') 99 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth', 100 | help='Pretrained base model') 101 | parser.add_argument('--batch_size', default=al_hyperparams.get_batch_size(), type=int, 102 | help='Batch size for training') 103 | parser.add_argument('--num_workers', default=8, type=int, 104 | help='Number of workers used in dataloading') 105 | parser.add_argument('--cuda', default=True, type=str2bool, 106 | help='Use CUDA to train model') 107 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 108 | help='initial learning rate') 109 | parser.add_argument('--momentum', default=0.9, type=float, 110 | help='Momentum value for optim') 111 | parser.add_argument('--weight_decay', default=5e-4, type=float, 112 | help='Weight decay for SGD') 113 | parser.add_argument('--gamma', default=0.1, type=float, 114 | help='Gamma update for SGD') 115 | parser.add_argument('--save_folder', default='../al_ssl/weights/', 116 | help='Directory for saving checkpoint models') 117 | parser.add_argument('--net_name', default=None, type=str, 118 | help='the net checkpoint we need to load') 119 | parser.add_argument('--is_apex', default=0, type=int, 120 | help='if 1 use apex to do mixed precision training') 121 | parser.add_argument('--is_cluster', default=1, type=int, 122 | help='if 1 use GPU cluster, otherwise do the computations on the local PC') 123 | parser.add_argument('--do_PL', default=1, type=int, 124 | help='if 1 use pseudo-labels, otherwise do not use them') 125 | parser.add_argument('--pseudo_threshold', default=al_hyperparams.get_pseudo_threshold(), type=float, 126 | help='pseudo label confidence threshold for voc dataset') 127 | parser.add_argument('--thresh', default=0.5, type=float, 128 | help='we define an object if the probability of one class is above thresh') 129 | parser.add_argument('--do_AL', default=1, type=int, help='if 0 skip AL') 130 | args = parser.parse_args() 131 | 132 | if torch.cuda.is_available(): 133 | if args.cuda: 134 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 135 | if not args.cuda: 136 | print("WARNING: It looks like you have a CUDA device, but aren't " + 137 | "using CUDA.\nRun with --cuda for optimal training speed.") 138 | torch.set_default_tensor_type('torch.FloatTensor') 139 | cudnn.benchmark = True 140 | else: 141 | torch.set_default_tensor_type('torch.FloatTensor') 142 | 143 | if not os.path.exists(args.save_folder): 144 | os.mkdir(args.save_folder) 145 | 146 | 147 | def load_net_optimizer_multi(cfg): 148 | net = build_ssd_con('train', cfg['min_dim'], cfg['num_classes']) 149 | vgg_weights = torch.load(args.save_folder + args.basenet) 150 | print('Loading the backbone pretrained in Imagenet...') 151 | net.vgg.load_state_dict(vgg_weights) 152 | net.extras.apply(weights_init) 153 | net.loc.apply(weights_init) 154 | net.conf.apply(weights_init) 155 | if args.is_cluster: 156 | net = nn.DataParallel(net) 157 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, 158 | weight_decay=args.weight_decay) 159 | return net, optimizer 160 | 161 | 162 | def compute_consistency_loss(conf, loc, conf_flip, loc_flip, conf_consistency_criterion): 163 | conf_class = conf[:, :, 1:].clone() 164 | background_score = conf[:, :, 0].clone() 165 | each_val, each_index = torch.max(conf_class, dim=2) 166 | mask_val = each_val > background_score 167 | mask_val = mask_val.data 168 | 169 | mask_conf_index = mask_val.unsqueeze(2).expand_as(conf) 170 | mask_loc_index = mask_val.unsqueeze(2).expand_as(loc) 171 | 172 | conf_mask_sample = conf.clone() 173 | loc_mask_sample = loc.clone() 174 | conf_sampled = conf_mask_sample[mask_conf_index].view(-1, args.cfg['num_classes']) 175 | loc_sampled = loc_mask_sample[mask_loc_index].view(-1, 4) 176 | 177 | conf_mask_sample_flip = conf_flip.clone() 178 | loc_mask_sample_flip = loc_flip.clone() 179 | conf_sampled_flip = conf_mask_sample_flip[mask_conf_index].view(-1, args.cfg['num_classes']) 180 | loc_sampled_flip = loc_mask_sample_flip[mask_loc_index].view(-1, 4) 181 | 182 | if (mask_val.sum() > 0): 183 | # Compute Jenson-Shannon divergence (symmetric KL actually) 184 | conf_sampled_flip = conf_sampled_flip + 1e-7 185 | conf_sampled = conf_sampled + 1e-7 186 | consistency_conf_loss_a = conf_consistency_criterion(conf_sampled.log(), 187 | conf_sampled_flip.detach()).sum(-1).mean() 188 | consistency_conf_loss_b = conf_consistency_criterion(conf_sampled_flip.log(), 189 | conf_sampled.detach()).sum(-1).mean() 190 | consistency_conf_loss = consistency_conf_loss_a + consistency_conf_loss_b 191 | 192 | # Compute location consistency loss 193 | consistency_loc_loss_x = torch.mean(torch.pow(loc_sampled[:, 0] + loc_sampled_flip[:, 0], exponent=2)) 194 | consistency_loc_loss_y = torch.mean(torch.pow(loc_sampled[:, 1] - loc_sampled_flip[:, 1], exponent=2)) 195 | consistency_loc_loss_w = torch.mean(torch.pow(loc_sampled[:, 2] - loc_sampled_flip[:, 2], exponent=2)) 196 | consistency_loc_loss_h = torch.mean(torch.pow(loc_sampled[:, 3] - loc_sampled_flip[:, 3], exponent=2)) 197 | 198 | consistency_loc_loss = torch.div( 199 | consistency_loc_loss_x + consistency_loc_loss_y + consistency_loc_loss_w + consistency_loc_loss_h, 200 | 4) 201 | 202 | else: 203 | consistency_conf_loss = torch.cuda.FloatTensor([0]) 204 | consistency_loc_loss = torch.cuda.FloatTensor([0]) 205 | 206 | consistency_loss = torch.div(consistency_conf_loss, 2) + consistency_loc_loss 207 | return consistency_loss 208 | 209 | 210 | def rampweight(iteration): 211 | ramp_up_end = 32000 212 | ramp_down_start = 100000 213 | 214 | if (iteration < ramp_up_end): 215 | ramp_weight = math.exp(-5 * math.pow((1 - iteration / ramp_up_end), 2)) 216 | elif (iteration > ramp_down_start): 217 | ramp_weight = math.exp(-12.5 * math.pow((1 - (120000 - iteration) / 20000), 2)) 218 | else: 219 | ramp_weight = 1 220 | 221 | if (iteration == 0): 222 | ramp_weight = 0 223 | 224 | return ramp_weight 225 | 226 | 227 | def train(dataset, data_loader, cfg, labeled_set, supervised_dataset, indices): 228 | # net, optimizer = load_net_optimizer_multi(cfg) 229 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, 230 | False, args.cuda) 231 | conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda() 232 | 233 | # loss counters 234 | print('Loading the dataset...') 235 | print('Training SSD on:', dataset.name) 236 | print('Using the specified args:') 237 | 238 | step_index = 0 239 | 240 | # create batch iterator 241 | batch_iterator = iter(data_loader) 242 | 243 | finish_flag = True 244 | 245 | while finish_flag: 246 | net, optimizer = load_net_optimizer_multi(cfg) 247 | net.train() 248 | for iteration in range(cfg['max_iter']): 249 | print(iteration) 250 | 251 | if iteration in cfg['lr_steps']: 252 | step_index += 1 253 | adjust_learning_rate(optimizer, args.gamma, step_index) 254 | try: 255 | images, targets, semis = next(batch_iterator) 256 | except StopIteration: 257 | batch_iterator = iter(data_loader) 258 | images, targets, semis = next(batch_iterator) 259 | 260 | images = images.cuda() 261 | targets = [ann.cuda() for ann in targets] 262 | 263 | # forward 264 | t0 = time.time() 265 | out, conf, conf_flip, loc, loc_flip, _ = net(images) 266 | sup_image_binary_index = np.zeros([len(semis), 1]) 267 | 268 | semis_index, new_semis = [], [] 269 | for iii, super_image in enumerate(range(len(semis))): 270 | new_semis.append(int(semis[super_image])) 271 | if (int(semis[super_image]) > 0): 272 | sup_image_binary_index[super_image] = 1 273 | semis_index.append(super_image) 274 | else: 275 | sup_image_binary_index[super_image] = 0 276 | 277 | if (int(semis[len(semis) - 1 - super_image]) == 0): 278 | del targets[len(semis) - 1 - super_image] 279 | 280 | sup_image_index = np.where(sup_image_binary_index == 1)[0] 281 | loc_data, conf_data, priors = out 282 | 283 | if (len(sup_image_index) != 0): 284 | loc_data = loc_data[sup_image_index, :, :] 285 | conf_data = conf_data[sup_image_index, :, :] 286 | output = (loc_data, conf_data, priors) 287 | 288 | consistency_loss = compute_consistency_loss(conf, loc, conf_flip, loc_flip, conf_consistency_criterion) 289 | ramp_weight = rampweight(iteration) 290 | consistency_loss = torch.mul(consistency_loss, ramp_weight) 291 | 292 | if (len(sup_image_index) == 0): 293 | loss_l, loss_c = torch.cuda.FloatTensor([0]), torch.cuda.FloatTensor([0]) 294 | else: 295 | loss_l, loss_c = criterion(output, targets, np.array(new_semis)[semis_index]) 296 | loss = loss_l + loss_c + consistency_loss 297 | print(loss) 298 | 299 | if (loss.data > 0): 300 | optimizer.zero_grad() 301 | loss.backward() 302 | optimizer.step() 303 | else: 304 | print("Loss is 0") 305 | 306 | if (float(loss) > 100 or torch.isnan(loss)): 307 | # if the net diverges, go back to point 0 and train from scratch 308 | break 309 | t1 = time.time() 310 | 311 | if iteration % 10 == 0: 312 | print('timer: %.4f sec.' % (t1 - t0)) 313 | print('iter ' + repr( 314 | iteration) + ': loss: %.4f , loss_c: %.4f , loss_l: %.4f , loss_con: %.4f, lr : %.4f, super_len : %d\n' % ( 315 | loss.data, loss_c.data, loss_l.data, consistency_loss.data, 316 | float(optimizer.param_groups[0]['lr']), 317 | len(sup_image_index))) 318 | 319 | if iteration != 0 and (iteration + 1) % 120000 == 0: 320 | print('Saving state, iter:', iteration) 321 | net_name = 'weights/' + repr(iteration + 1) + args.criterion_select + '_id_' + str(args.id) + \ 322 | '_pl_threshold_' + str(args.pseudo_threshold) + '_labeled_set_' + str(len(labeled_set)) + '_.pth' 323 | torch.save(net.state_dict(), net_name) 324 | 325 | if iteration >= 119000: 326 | finish_flag = False 327 | return net, net_name 328 | 329 | 330 | def adjust_learning_rate(optimizer, gamma, step): 331 | """Sets the learning rate to the initial LR decayed by 10 at every 332 | specified step 333 | # Adapted from PyTorch Imagenet example: 334 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 335 | """ 336 | lr = args.lr * (gamma ** (step)) 337 | for param_group in optimizer.param_groups: 338 | param_group['lr'] = lr 339 | 340 | 341 | def xavier(param): 342 | init.xavier_uniform(param) 343 | 344 | 345 | def weights_init(m): 346 | if isinstance(m, nn.Conv2d): 347 | xavier(m.weight.data) 348 | m.bias.data.zero_() 349 | 350 | 351 | def main(): 352 | print(os.path.abspath(os.getcwd())) 353 | if args.cuda: cudnn.benchmark = True 354 | supervised_dataset, supervised_data_loader, unsupervised_dataset, unsupervised_data_loader, indices, labeled_set, unlabeled_set = create_loaders(args) 355 | net, net_name = train(supervised_dataset, supervised_data_loader, args.cfg, labeled_set, supervised_dataset, indices) 356 | 357 | net, _ = load_net_optimizer_multi(args.cfg) 358 | if not args.is_cluster: 359 | net = nn.DataParallel(net) 360 | 361 | # net_name = os.path.join('/usr/wiss/elezi/PycharmProjects/al_ssl/weights_good/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth') 362 | # net.load_state_dict(torch.load(net_name)) 363 | 364 | # do active learning cycles 365 | for i in range(args.num_cycles): 366 | net.eval() 367 | 368 | if args.do_AL: 369 | if args.criterion_select in ['Max_aver', 'entropy', 'random']: 370 | batch_iterator = iter(unsupervised_data_loader) 371 | labeled_set, unlabeled_set = active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net, 372 | args.cfg['num_classes'], 373 | args.criterion_select, 374 | loader=unsupervised_data_loader) 375 | elif args.criterion_select == 'consistency': 376 | batch_iterator = iter(unsupervised_data_loader) 377 | labeled_set, unlabeled_set = active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net, 378 | args.cfg['num_classes'], 379 | args.criterion_select, 380 | loader=unsupervised_data_loader) 381 | elif args.criterion_select == 'combined': 382 | print("Combined") 383 | batch_iterator = iter(unsupervised_data_loader) 384 | labeled_set, unlabeled_set = combined_score(args, batch_iterator, labeled_set, unlabeled_set, net, 385 | unsupervised_data_loader) 386 | 387 | supervised_data_loader, unsupervised_data_loader = change_loaders(args, supervised_dataset, 388 | unsupervised_dataset, labeled_set, unlabeled_set, indices, net_name, pseudo=args.do_PL) 389 | net, net_name = train(supervised_dataset, supervised_data_loader, args.cfg, labeled_set, supervised_dataset, 390 | indices) 391 | 392 | 393 | if __name__ == '__main__': 394 | main() 395 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import SSDAugmentation, jaccard_numpy -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-38.pyc -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import cv2 4 | import numpy as np 5 | import types 6 | from numpy import random 7 | 8 | 9 | def intersect(box_a, box_b): 10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 11 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 13 | return inter[:, 0] * inter[:, 1] 14 | 15 | 16 | def jaccard_numpy(box_a, box_b): 17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 18 | is simply the intersection over union of two boxes. 19 | E.g.: 20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 21 | Args: 22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 23 | box_b: Single bounding box, Shape: [4] 24 | Return: 25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 26 | """ 27 | inter = intersect(box_a, box_b) 28 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 29 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 30 | area_b = ((box_b[2]-box_b[0]) * 31 | (box_b[3]-box_b[1])) # [A,B] 32 | union = area_a + area_b - inter 33 | return inter / union # [A,B] 34 | 35 | 36 | class Compose(object): 37 | """Composes several augmentations together. 38 | Args: 39 | transforms (List[Transform]): list of transforms to compose. 40 | Example: 41 | >>> augmentations.Compose([ 42 | >>> transforms.CenterCrop(10), 43 | >>> transforms.ToTensor(), 44 | >>> ]) 45 | """ 46 | 47 | def __init__(self, transforms): 48 | self.transforms = transforms 49 | 50 | def __call__(self, img, boxes=None, labels=None): 51 | for t in self.transforms: 52 | img, boxes, labels = t(img, boxes, labels) 53 | return img, boxes, labels 54 | 55 | 56 | class Lambda(object): 57 | """Applies a lambda as a transform.""" 58 | 59 | def __init__(self, lambd): 60 | assert isinstance(lambd, types.LambdaType) 61 | self.lambd = lambd 62 | 63 | def __call__(self, img, boxes=None, labels=None): 64 | return self.lambd(img, boxes, labels) 65 | 66 | 67 | class ConvertFromInts(object): 68 | def __call__(self, image, boxes=None, labels=None): 69 | return image.astype(np.float32), boxes, labels 70 | 71 | 72 | class SubtractMeans(object): 73 | def __init__(self, mean): 74 | self.mean = np.array(mean, dtype=np.float32) 75 | 76 | def __call__(self, image, boxes=None, labels=None): 77 | image = image.astype(np.float32) 78 | image -= self.mean 79 | return image.astype(np.float32), boxes, labels 80 | 81 | 82 | class ToAbsoluteCoords(object): 83 | def __call__(self, image, boxes=None, labels=None): 84 | height, width, channels = image.shape 85 | boxes[:, 0] *= width 86 | boxes[:, 2] *= width 87 | boxes[:, 1] *= height 88 | boxes[:, 3] *= height 89 | 90 | return image, boxes, labels 91 | 92 | 93 | class ToPercentCoords(object): 94 | def __call__(self, image, boxes=None, labels=None): 95 | height, width, channels = image.shape 96 | boxes[:, 0] /= width 97 | boxes[:, 2] /= width 98 | boxes[:, 1] /= height 99 | boxes[:, 3] /= height 100 | 101 | return image, boxes, labels 102 | 103 | 104 | class Resize(object): 105 | def __init__(self, size=300): 106 | self.size = size 107 | 108 | def __call__(self, image, boxes=None, labels=None): 109 | image = cv2.resize(image, (self.size, 110 | self.size)) 111 | return image, boxes, labels 112 | 113 | 114 | class RandomSaturation(object): 115 | def __init__(self, lower=0.5, upper=1.5): 116 | self.lower = lower 117 | self.upper = upper 118 | assert self.upper >= self.lower, "contrast upper must be >= lower." 119 | assert self.lower >= 0, "contrast lower must be non-negative." 120 | 121 | def __call__(self, image, boxes=None, labels=None): 122 | if random.randint(2): 123 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 124 | 125 | return image, boxes, labels 126 | 127 | 128 | class RandomHue(object): 129 | def __init__(self, delta=18.0): 130 | assert delta >= 0.0 and delta <= 360.0 131 | self.delta = delta 132 | 133 | def __call__(self, image, boxes=None, labels=None): 134 | if random.randint(2): 135 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 136 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 137 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 138 | return image, boxes, labels 139 | 140 | 141 | class RandomLightingNoise(object): 142 | def __init__(self): 143 | self.perms = ((0, 1, 2), (0, 2, 1), 144 | (1, 0, 2), (1, 2, 0), 145 | (2, 0, 1), (2, 1, 0)) 146 | 147 | def __call__(self, image, boxes=None, labels=None): 148 | if random.randint(2): 149 | swap = self.perms[random.randint(len(self.perms))] 150 | shuffle = SwapChannels(swap) # shuffle channels 151 | image = shuffle(image) 152 | return image, boxes, labels 153 | 154 | 155 | class ConvertColor(object): 156 | def __init__(self, current='BGR', transform='HSV'): 157 | self.transform = transform 158 | self.current = current 159 | 160 | def __call__(self, image, boxes=None, labels=None): 161 | if self.current == 'BGR' and self.transform == 'HSV': 162 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 163 | elif self.current == 'HSV' and self.transform == 'BGR': 164 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 165 | else: 166 | raise NotImplementedError 167 | return image, boxes, labels 168 | 169 | 170 | class RandomContrast(object): 171 | def __init__(self, lower=0.5, upper=1.5): 172 | self.lower = lower 173 | self.upper = upper 174 | assert self.upper >= self.lower, "contrast upper must be >= lower." 175 | assert self.lower >= 0, "contrast lower must be non-negative." 176 | 177 | # expects float image 178 | def __call__(self, image, boxes=None, labels=None): 179 | if random.randint(2): 180 | alpha = random.uniform(self.lower, self.upper) 181 | image *= alpha 182 | return image, boxes, labels 183 | 184 | 185 | class RandomBrightness(object): 186 | def __init__(self, delta=32): 187 | assert delta >= 0.0 188 | assert delta <= 255.0 189 | self.delta = delta 190 | 191 | def __call__(self, image, boxes=None, labels=None): 192 | if random.randint(2): 193 | delta = random.uniform(-self.delta, self.delta) 194 | image += delta 195 | return image, boxes, labels 196 | 197 | 198 | class ToCV2Image(object): 199 | def __call__(self, tensor, boxes=None, labels=None): 200 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels 201 | 202 | 203 | class ToTensor(object): 204 | def __call__(self, cvimage, boxes=None, labels=None): 205 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels 206 | 207 | 208 | class RandomSampleCrop(object): 209 | """Crop 210 | Arguments: 211 | img (Image): the image being input during training 212 | boxes (Tensor): the original bounding boxes in pt form 213 | labels (Tensor): the class labels for each bbox 214 | mode (float tuple): the min and max jaccard overlaps 215 | Return: 216 | (img, boxes, classes) 217 | img (Image): the cropped image 218 | boxes (Tensor): the adjusted bounding boxes in pt form 219 | labels (Tensor): the class labels for each bbox 220 | """ 221 | def __init__(self): 222 | self.sample_options = ( 223 | # using entire original input image 224 | None, 225 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 226 | (0.1, None), 227 | (0.3, None), 228 | (0.7, None), 229 | (0.9, None), 230 | # randomly sample a patch 231 | (None, None), 232 | ) 233 | 234 | def __call__(self, image, boxes=None, labels=None): 235 | height, width, _ = image.shape 236 | while True: 237 | # randomly choose a mode 238 | mode = random.choice(self.sample_options) 239 | if mode is None: 240 | return image, boxes, labels 241 | 242 | min_iou, max_iou = mode 243 | if min_iou is None: 244 | min_iou = float('-inf') 245 | if max_iou is None: 246 | max_iou = float('inf') 247 | 248 | # max trails (50) 249 | for _ in range(50): 250 | current_image = image 251 | 252 | w = random.uniform(0.3 * width, width) 253 | h = random.uniform(0.3 * height, height) 254 | 255 | # aspect ratio constraint b/t .5 & 2 256 | if h / w < 0.5 or h / w > 2: 257 | continue 258 | 259 | left = random.uniform(width - w) 260 | top = random.uniform(height - h) 261 | 262 | # convert to integer rect x1,y1,x2,y2 263 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 264 | 265 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes 266 | overlap = jaccard_numpy(boxes, rect) 267 | 268 | # is min and max overlap constraint satisfied? if not try again 269 | if overlap.min() < min_iou and max_iou < overlap.max(): 270 | continue 271 | 272 | # cut the crop from the image 273 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], 274 | :] 275 | 276 | # keep overlap with gt box IF center in sampled patch 277 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 278 | 279 | # mask in all gt boxes that above and to the left of centers 280 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 281 | 282 | # mask in all gt boxes that under and to the right of centers 283 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 284 | 285 | # mask in that both m1 and m2 are true 286 | mask = m1 * m2 287 | 288 | # have any valid boxes? try again if not 289 | if not mask.any(): 290 | continue 291 | 292 | # take only matching gt boxes 293 | current_boxes = boxes[mask, :].copy() 294 | 295 | # take only matching gt labels 296 | current_labels = labels[mask] 297 | 298 | # should we use the box left and top corner or the crop's 299 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], 300 | rect[:2]) 301 | # adjust to crop (by substracting crop's left,top) 302 | current_boxes[:, :2] -= rect[:2] 303 | 304 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], 305 | rect[2:]) 306 | # adjust to crop (by substracting crop's left,top) 307 | current_boxes[:, 2:] -= rect[:2] 308 | 309 | return current_image, current_boxes, current_labels 310 | 311 | 312 | class Expand(object): 313 | def __init__(self, mean): 314 | self.mean = mean 315 | 316 | def __call__(self, image, boxes, labels): 317 | if random.randint(2): 318 | return image, boxes, labels 319 | 320 | height, width, depth = image.shape 321 | ratio = random.uniform(1, 4) 322 | left = random.uniform(0, width*ratio - width) 323 | top = random.uniform(0, height*ratio - height) 324 | 325 | expand_image = np.zeros( 326 | (int(height*ratio), int(width*ratio), depth), 327 | dtype=image.dtype) 328 | expand_image[:, :, :] = self.mean 329 | expand_image[int(top):int(top + height), 330 | int(left):int(left + width)] = image 331 | image = expand_image 332 | 333 | boxes = boxes.copy() 334 | boxes[:, :2] += (int(left), int(top)) 335 | boxes[:, 2:] += (int(left), int(top)) 336 | 337 | return image, boxes, labels 338 | 339 | 340 | class RandomMirror(object): 341 | def __call__(self, image, boxes, classes): 342 | _, width, _ = image.shape 343 | if random.randint(2): 344 | image = image[:, ::-1] 345 | boxes = boxes.copy() 346 | boxes[:, 0::2] = width - boxes[:, 2::-2] 347 | return image, boxes, classes 348 | 349 | 350 | class SwapChannels(object): 351 | """Transforms a tensorized image by swapping the channels in the order 352 | specified in the swap tuple. 353 | Args: 354 | swaps (int triple): final order of channels 355 | eg: (2, 1, 0) 356 | """ 357 | 358 | def __init__(self, swaps): 359 | self.swaps = swaps 360 | 361 | def __call__(self, image): 362 | """ 363 | Args: 364 | image (Tensor): image tensor to be transformed 365 | Return: 366 | a tensor with channels swapped according to swap 367 | """ 368 | # if torch.is_tensor(image): 369 | # image = image.data.cpu().numpy() 370 | # else: 371 | # image = np.array(image) 372 | image = image[:, :, self.swaps] 373 | return image 374 | 375 | 376 | class PhotometricDistort(object): 377 | def __init__(self): 378 | self.pd = [ 379 | RandomContrast(), 380 | ConvertColor(transform='HSV'), 381 | RandomSaturation(), 382 | RandomHue(), 383 | ConvertColor(current='HSV', transform='BGR'), 384 | RandomContrast() 385 | ] 386 | self.rand_brightness = RandomBrightness() 387 | self.rand_light_noise = RandomLightingNoise() 388 | 389 | def __call__(self, image, boxes, labels): 390 | im = image.copy() 391 | im, boxes, labels = self.rand_brightness(im, boxes, labels) 392 | if random.randint(2): 393 | distort = Compose(self.pd[:-1]) 394 | else: 395 | distort = Compose(self.pd[1:]) 396 | im, boxes, labels = distort(im, boxes, labels) 397 | return self.rand_light_noise(im, boxes, labels) 398 | 399 | 400 | class SSDAugmentation(object): 401 | def __init__(self, size=300, mean=(104, 117, 123)): 402 | self.mean = mean 403 | self.size = size 404 | self.augment = Compose([ 405 | ConvertFromInts(), 406 | ToAbsoluteCoords(), 407 | PhotometricDistort(), 408 | Expand(self.mean), 409 | RandomSampleCrop(), 410 | RandomMirror(), 411 | ToPercentCoords(), 412 | Resize(self.size), 413 | SubtractMeans(self.mean) 414 | ]) 415 | 416 | # self.augment = Compose([ 417 | # ToAbsoluteCoords(), 418 | # # Expand(self.mean), 419 | # # RandomSampleCrop(), 420 | # ToPercentCoords(), 421 | # Resize(self.size), 422 | # ]) 423 | 424 | def __call__(self, img, boxes, labels): 425 | return self.augment(img, boxes, labels) 426 | --------------------------------------------------------------------------------