├── LICENSE
├── README.md
├── __pycache__
├── active_learning.cpython-36.pyc
├── csd.cpython-36.pyc
├── loaders.cpython-36.pyc
├── pseudo_labels.cpython-36.pyc
├── ssd.cpython-36.pyc
└── subset_sequential_sampler.cpython-36.pyc
├── active_learning.py
├── csd.py
├── data
├── COCO.txt
├── ONLY_VOC_IN_COCO.txt
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── coco.cpython-36.pyc
│ ├── coco.cpython-37.pyc
│ ├── coco.cpython-38.pyc
│ ├── coco_new.cpython-36.pyc
│ ├── config.cpython-36.pyc
│ ├── config.cpython-37.pyc
│ ├── config.cpython-38.pyc
│ ├── voc0712.cpython-36.pyc
│ ├── voc0712.cpython-37.pyc
│ ├── voc0712.cpython-38.pyc
│ ├── voc0712_backup.cpython-36.pyc
│ ├── voc0712_consistency.cpython-36.pyc
│ ├── voc07_consistency.cpython-36.pyc
│ ├── voc07_consistency.cpython-38.pyc
│ ├── voc07_consistency_init.cpython-36.pyc
│ ├── voc07_consistency_init.cpython-38.pyc
│ └── voc07_voc12coco.cpython-36.pyc
├── coco.py
├── coco
│ └── coco_labels.txt
├── config.py
├── example.jpg
├── only_voc.txt
├── scripts
│ ├── COCO2014.sh
│ ├── VOC2007.sh
│ └── VOC2012.sh
└── voc0712.py
├── eval.py
├── layers
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── box_utils.cpython-36.pyc
│ ├── box_utils.cpython-37.pyc
│ └── box_utils.cpython-38.pyc
├── box_utils.py
├── functions
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── detection.cpython-36.pyc
│ │ ├── detection.cpython-37.pyc
│ │ ├── detection.cpython-38.pyc
│ │ ├── prior_box.cpython-36.pyc
│ │ ├── prior_box.cpython-37.pyc
│ │ └── prior_box.cpython-38.pyc
│ ├── detection.py
│ └── prior_box.py
└── modules
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── l2norm.cpython-36.pyc
│ ├── l2norm.cpython-37.pyc
│ ├── l2norm.cpython-38.pyc
│ ├── multibox_loss.cpython-36.pyc
│ ├── multibox_loss.cpython-37.pyc
│ ├── multibox_loss.cpython-38.pyc
│ ├── multibox_loss_gmm.cpython-36.pyc
│ ├── multibox_loss_new.cpython-36.pyc
│ ├── multibox_loss_new.cpython-37.pyc
│ └── multibox_loss_new.cpython-38.pyc
│ ├── l2norm.py
│ └── multibox_loss.py
├── loaders.py
├── pseudo_labels.py
├── ssd.py
├── subset_sequential_sampler.py
├── test.py
├── train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── augmentations.cpython-36.pyc
├── augmentations.cpython-37.pyc
└── augmentations.cpython-38.pyc
└── augmentations.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 |
3 | Nvidia Source Code License-NC
4 |
5 | 1. Definitions
6 |
7 | “Licensor” means any person or entity that distributes its Work.
8 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
10 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
11 |
12 | 2. License Grant
13 |
14 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
15 |
16 | 3. Limitations
17 |
18 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
19 |
20 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
21 |
22 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
23 |
24 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
25 |
26 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
27 |
28 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
29 |
30 | 4. Disclaimer of Warranty.
31 |
32 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
33 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
34 |
35 | 5. Limitation of Liability.
36 |
37 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
38 |
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Not All Labels Are Equal:Rationalizing The Labeling Costs for Training Object Detection
2 |
3 | This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for:
4 |
5 | [Not All Labels Are Equal:Rationalizing The Labeling Costs for Training Object Detection](https://openaccess.thecvf.com/content/CVPR2022/papers/Elezi_Not_All_Labels_Are_Equal_Rationalizing_the_Labeling_Costs_for_CVPR_2022_paper.pdf)
6 | Ismail Elezi, Zhiding Yu, Anima Anandkumar, Laura Leal-Taixe, and Jose M. Alvarez.
7 | CVPR 2022.
8 |
9 |
10 | [Code](https://github.com/NVlabs/AL-SSL)
11 |
12 |
13 | ## Installation & Preparation
14 | We experimented with the SSD in the PyTorch framework. To use our model, complete the installation & preparation on the [SSD pytorch homepage](https://github.com/amdegroot/ssd.pytorch)
15 |
16 | #### prerequisites
17 | - Python 3.6
18 | - Pytorch 1.1.0
19 |
20 | ## Training
21 | ```Shell
22 | python train.py
23 | ```
24 |
25 | ## Evaluation
26 | ```Shell
27 | python eval.py
28 | ```
29 |
30 |
31 |
32 | ## License
33 | Copyright © 2022-2023, NVIDIA Corporation and Affiliates. All rights reserved.
34 |
35 | This work is made available under the Nvidia Source Code License-NC. Click [here](https://github.com/NVlabs/AL-SSL/blob/main/LICENSE) to view a copy of this license.
36 |
37 | The pre-trained models are shared under CC-BY-NC-SA-4.0. If you remix, transform, or build upon the material, you must distribute your contributions under the same license as the original.
38 |
39 | For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
40 |
41 |
42 | ## Citation
43 |
44 | If you find this code useful, please consider citing the following paper:
45 |
46 | ````
47 | @inproceedings{DBLP:conf/cvpr/EleziYALA22,
48 | author = {Ismail Elezi and
49 | Zhiding Yu and
50 | Anima Anandkumar and
51 | Laura Leal{-}Taix{\'{e}} and
52 | Jose M. Alvarez},
53 | title = {Not All Labels Are Equal: Rationalizing The Labeling Costs for Training
54 | Object Detection},
55 | booktitle = {{IEEE/CVF} Conference on Computer Vision and Pattern Recognition,
56 | {CVPR} 2022, New Orleans, LA, USA, June 18-24, 2022},
57 | pages = {14472--14481},
58 | publisher = {{IEEE}},
59 | year = {2022},
60 | url = {https://doi.org/10.1109/CVPR52688.2022.01409},
61 | doi = {10.1109/CVPR52688.2022.01409},
62 | timestamp = {Wed, 05 Oct 2022 16:31:19 +0200},
63 | biburl = {https://dblp.org/rec/conf/cvpr/EleziYALA22.bib},
64 | bibsource = {dblp computer science bibliography, https://dblp.org}
65 | }
66 | ````
67 |
68 | We use the semi-consistency method CSD as developed by Jeong et al. Consider citing the following paper
69 |
70 | ````
71 | @inproceedings{DBLP:conf/nips/JeongLKK19,
72 | author = {Jisoo Jeong and
73 | Seungeui Lee and
74 | Jeesoo Kim and
75 | Nojun Kwak},
76 | editor = {Hanna M. Wallach and
77 | Hugo Larochelle and
78 | Alina Beygelzimer and
79 | Florence d'Alch{\'{e}}{-}Buc and
80 | Emily B. Fox and
81 | Roman Garnett},
82 | title = {Consistency-based Semi-supervised Learning for Object detection},
83 | booktitle = {Advances in Neural Information Processing Systems 32: Annual Conference
84 | on Neural Information Processing Systems 2019, NeurIPS 2019, December
85 | 8-14, 2019, Vancouver, BC, Canada},
86 | pages = {10758--10767},
87 | year = {2019},
88 | url = {https://proceedings.neurips.cc/paper/2019/hash/d0f4dae80c3d0277922f8371d5827292-Abstract.html},
89 | timestamp = {Mon, 16 May 2022 15:41:51 +0200},
90 | biburl = {https://dblp.org/rec/conf/nips/JeongLKK19.bib},
91 | bibsource = {dblp computer science bibliography, https://dblp.org}
92 | }
93 | ````
94 |
--------------------------------------------------------------------------------
/__pycache__/active_learning.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/active_learning.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/csd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/csd.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/loaders.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/loaders.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/pseudo_labels.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/pseudo_labels.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/ssd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/ssd.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/subset_sequential_sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/__pycache__/subset_sequential_sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/active_learning.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from collections import defaultdict
12 | from copy import deepcopy
13 |
14 | from layers.box_utils import decode, nms
15 |
16 |
17 | def class_consistency_loss_al(conf, conf_flip):
18 | conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
19 |
20 | conf_class = conf.clone()
21 | conf_class_flip = conf_flip.clone()
22 |
23 | consistency_conf_loss_a = conf_consistency_criterion(conf_class.log(),
24 | conf_class_flip.detach()).sum(-1)
25 | consistency_conf_loss_b = conf_consistency_criterion(conf_class_flip.log(),
26 | conf_class.detach()).sum(-1)
27 | return (consistency_conf_loss_a + consistency_conf_loss_b) / 2
28 |
29 |
30 | def active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net, num_classes,
31 | criterion_select='consistency_class', loader=None):
32 | criterion_UC = np.zeros(len(batch_iterator))
33 | batch_iterator = iter(loader)
34 | thresh = args.thresh
35 |
36 | for j in range(len(batch_iterator)): # 3000
37 | print(j)
38 | images, lab, _ = next(batch_iterator)
39 | images = images.cuda()
40 |
41 | out, conf, conf_flip, loc, loc_flip, _ = net(images)
42 | loc, _, priors = out
43 |
44 | num = loc.size(0) # batch size
45 | num_priors = priors.size(0)
46 | output = torch.zeros(num, num_classes, 200, 6)
47 | conf_preds = conf.view(num, num_priors, num_classes).transpose(2, 1)
48 | conf_preds_flip = conf_flip.view(num, num_priors, num_classes).transpose(2, 1)
49 | variance = [0.1, 0.2]
50 |
51 | # Decode predictions into bboxes.
52 | for i in range(num):
53 | decoded_boxes = decode(loc[i], priors, variance)
54 | decoded_boxes_flip = decode(loc_flip[i], priors, variance)
55 | # For each class, perform nms
56 | conf_scores = conf_preds[i].clone()
57 | conf_scores_flip = conf_preds_flip[i].clone()
58 |
59 | if criterion_select == 'consistency_class' or criterion_select == 'consistency':
60 | H = class_consistency_loss_al(conf, conf_flip).squeeze()
61 | else:
62 | H = conf_scores * (
63 | torch.log(conf_scores) / torch.log(torch.tensor(num_classes).type(torch.FloatTensor)))
64 | H = H.sum(dim=0) * (-1.0)
65 |
66 | for cl in range(1, num_classes):
67 | c_mask = conf_scores[cl].gt(0.01) # confidence threshold
68 | scores = conf_scores[cl][c_mask]
69 | Entropy = H[c_mask]
70 | if scores.size(0) == 0:
71 | continue
72 |
73 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
74 | boxes = decoded_boxes[l_mask].view(-1, 4)
75 |
76 | ids, count = nms(boxes.detach(), scores.detach(), 0.5, 200)
77 | output[i, cl, :count] = torch.cat(
78 | (scores[ids[:count]].unsqueeze(1), boxes[ids[:count]], Entropy[ids[:count]].unsqueeze(1)), 1)
79 |
80 | count_num = 0
81 | UC_max = 0
82 | for p in range(output.size(1)): # [1, 21, 200, 9]
83 | q = 0
84 | while output[0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration?
85 | count_num += 1
86 | score = output[0, p, q, 0]
87 | entropy = output[0, p, q, 5:6]
88 | UC_max_temp = entropy.item()
89 | if (UC_max < UC_max_temp):
90 | UC_max = UC_max_temp
91 | q += 1
92 |
93 | if count_num == 0:
94 | criterion_UC[j] = 0
95 | else:
96 | criterion_UC[j] = UC_max
97 |
98 | if args.criterion_select == 'combined':
99 | return criterion_UC
100 |
101 | sorted_indices = np.argsort(criterion_UC)[::-1]
102 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]])
103 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]])
104 |
105 | # assert that sizes of lists are correct and that there are no elements that are in both lists
106 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images
107 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0
108 |
109 | # save the labeled set
110 | return labeled_set, unlabeled_set
111 |
112 |
113 | def active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net, num_classes, criterion_select, loader=None):
114 | criterion_UC = np.zeros(len(batch_iterator))
115 | batch_iterator = iter(loader)
116 | thresh = args.thresh
117 |
118 | for j in range(len(batch_iterator)): # 3000
119 | print(j)
120 | images, lab, _ = next(batch_iterator)
121 | images = images.cuda()
122 |
123 | out, _, _, _, _, _ = net(images)
124 | loc, conf, priors = out
125 | conf = torch.softmax(conf.detach(), dim=2)
126 |
127 | num = loc.size(0) # batch size
128 | num_priors = priors.size(0)
129 | output = torch.zeros(num, num_classes, 200, 6)
130 | conf_preds = conf.view(num, num_priors, num_classes).transpose(2, 1)
131 | variance = [0.1, 0.2]
132 | # Decode predictions into bboxes.
133 | for i in range(num):
134 | decoded_boxes = decode(loc[i], priors, variance)
135 | # For each class, perform nms
136 | conf_scores = conf_preds[i].clone()
137 |
138 | H = conf_scores * (torch.log(conf_scores) / torch.log(torch.tensor(num_classes).type(torch.FloatTensor)))
139 | H = H.sum(dim=0) * (-1.0)
140 |
141 | for cl in range(1, num_classes):
142 | c_mask = conf_scores[cl].gt(0.01) # confidence threshold
143 | scores = conf_scores[cl][c_mask]
144 | Entropy = H[c_mask] # jwchoi
145 | if scores.size(0) == 0:
146 | continue
147 |
148 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
149 | boxes = decoded_boxes[l_mask].view(-1, 4)
150 |
151 | ids, count = nms(boxes.detach(), scores.detach(), 0.5, 200)
152 | output[i, cl, :count] = torch.cat(
153 | (scores[ids[:count]].unsqueeze(1), boxes[ids[:count]], Entropy[ids[:count]].unsqueeze(1)), 1)
154 |
155 | if criterion_select == 'random':
156 | val = np.random.normal(0, 1, 1)
157 | criterion_UC[j] = val[0]
158 |
159 |
160 | elif criterion_select == 'Max_aver':
161 | count_num = 0
162 | # UC_sum = 0
163 | UC_max = 0
164 | for p in range(output.size(1)): # [1, 21, 200, 9]
165 | q = 0
166 | while output[
167 | 0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration?
168 | count_num += 1
169 | score = output[0, p, q, 0]
170 | entropy = output[0, p, q, 5:6]
171 | UC_max += entropy.item()
172 | q += 1
173 |
174 | if count_num == 0:
175 | criterion_UC[j] = 0
176 | else:
177 | criterion_UC[j] = UC_max / count_num
178 |
179 | else:
180 | count_num = 0
181 | # UC_sum = 0
182 | UC_max = 0
183 | for p in range(output.size(1)): # [1, 21, 200, 9]
184 | q = 0
185 | while output[
186 | 0, p, q, 0] >= thresh: # filtering using threshold, To do: Increasing acoording to iteration?
187 | count_num += 1
188 | score = output[0, p, q, 0]
189 | entropy = output[0, p, q, 5:6]
190 | UC_max_temp = entropy.item()
191 | if (UC_max < UC_max_temp):
192 | UC_max = UC_max_temp
193 | q += 1
194 |
195 | if count_num == 0: # UC_sum == 0:
196 | # if the net cannot detect anything, give it a very high inconsistency score, that image is hard
197 | criterion_UC[j] = 0
198 | else:
199 | criterion_UC[j] = UC_max
200 |
201 | if args.criterion_select == 'combined':
202 | return criterion_UC
203 | sorted_indices = np.argsort(criterion_UC)[::-1]
204 |
205 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]])
206 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]])
207 |
208 | # assert that sizes of lists are correct and that there are no elements that are in both lists
209 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images
210 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0
211 |
212 | # save the labeled set
213 | return labeled_set, unlabeled_set
214 |
215 |
216 | def combined_score(args, batch_iterator, labeled_set, unlabeled_set, net, unsupervised_data_loader):
217 | entropy_score = active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net,
218 | args.cfg['num_classes'], 'entropy',
219 | loader=unsupervised_data_loader)
220 | consistency_score = active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net,
221 | args.cfg['num_classes'], 'consistency_class',
222 | loader=unsupervised_data_loader)
223 |
224 | len_arr = len(entropy_score)
225 | ind = np.argpartition(entropy_score, -args.filter_entropy_num)[:len_arr - args.filter_entropy_num]
226 |
227 | consistency_score[ind] = 0.
228 |
229 | sorted_indices = np.argsort(consistency_score)[::-1]
230 |
231 | labeled_set += list(np.array(unlabeled_set)[sorted_indices[:args.acquisition_budget]])
232 | unlabeled_set = list(np.array(unlabeled_set)[sorted_indices[args.acquisition_budget:]])
233 |
234 | # assert that sizes of lists are correct and that there are no elements that are in both lists
235 | assert len(list(set(labeled_set) | set(unlabeled_set))) == args.num_total_images
236 | assert len(list(set(labeled_set) & set(unlabeled_set))) == 0
237 |
238 | return labeled_set, unlabeled_set
239 |
--------------------------------------------------------------------------------
/csd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from layers import *
6 | from data import voc300, voc512, coco
7 | import os
8 | import warnings
9 | import math
10 | import numpy as np
11 | import cv2
12 |
13 |
14 | class SSD_CON(nn.Module):
15 | """Single Shot Multibox Architecture
16 | The network is composed of a base VGG network followed by the
17 | added multibox conv layers. Each multibox layer branches into
18 | 1) conv2d for class conf scores
19 | 2) conv2d for localization predictions
20 | 3) associated priorbox layer to produce default bounding
21 | boxes specific to the layer's feature map size.
22 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
23 |
24 | Args:
25 | phase: (string) Can be "test" or "train"
26 | size: input image size
27 | base: VGG16 layers for input, size of either 300 or 500
28 | extras: extra layers that feed to multibox loc and conf layers
29 | head: "multibox head" consists of loc and conf conv layers
30 | """
31 |
32 | def __init__(self, phase, size, base, extras, head, num_classes):
33 | super(SSD_CON, self).__init__()
34 | self.phase = phase
35 | self.num_classes = num_classes
36 | if(size==300):
37 | self.cfg = (coco, voc300)[num_classes == 21]
38 | else:
39 | self.cfg = (coco, voc512)[num_classes == 21]
40 | self.priorbox = PriorBox(self.cfg)
41 | self.priors = Variable(self.priorbox.forward(), volatile=True)
42 | self.size = size
43 |
44 | # SSD network
45 | self.vgg = nn.ModuleList(base)
46 | # Layer learns to scale the l2 normalized features from conv4_3
47 | self.L2Norm = L2Norm(512, 20)
48 | self.extras = nn.ModuleList(extras)
49 |
50 | self.loc = nn.ModuleList(head[0])
51 | self.conf = nn.ModuleList(head[1])
52 |
53 | self.softmax = nn.Softmax(dim=-1)
54 |
55 | if phase == 'test':
56 | # self.softmax = nn.Softmax(dim=-1)
57 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
58 |
59 | def forward(self, x):
60 | """Applies network layers and ops on input image(s) x.
61 |
62 | Args:
63 | x: input image or batch of images. Shape: [batch,3,300,300].
64 |
65 | Return:
66 | Depending on phase:
67 | test:
68 | Variable(tensor) of output class label predictions,
69 | confidence score, and corresponding location predictions for
70 | each object detected. Shape: [batch,topk,7]
71 |
72 | train:
73 | list of concat outputs from:
74 | 1: confidence layers, Shape: [batch*num_priors,num_classes]
75 | 2: localization layers, Shape: [batch,num_priors*4]
76 | 3: priorbox layers, Shape: [2,num_priors*4]
77 | """
78 |
79 |
80 | x_flip = x.clone()
81 | x_flip = flip(x_flip,3)
82 |
83 | sources = list()
84 | loc = list()
85 | conf = list()
86 |
87 | # apply vgg up to conv4_3 relu
88 | for k in range(23):
89 | x = self.vgg[k](x)
90 |
91 | s = self.L2Norm(x)
92 | sources.append(s)
93 |
94 | # apply vgg up to fc7
95 | for k in range(23, len(self.vgg)):
96 | x = self.vgg[k](x)
97 |
98 | # just a name so we can later point to return this
99 | sources.append(x)
100 | feats = x
101 |
102 | # apply extra layers and cache source layer outputs
103 | for k, v in enumerate(self.extras):
104 | x = F.relu(v(x), inplace=True)
105 | if k % 2 == 1:
106 | sources.append(x)
107 |
108 | # apply multibox head to source layers
109 | for (x, l, c) in zip(sources, self.loc, self.conf):
110 | loc.append(l(x).permute(0, 2, 3, 1).contiguous())
111 | conf.append(c(x).permute(0, 2, 3, 1).contiguous())
112 |
113 |
114 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
115 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
116 | # zero_mask = torch.cat([o.view(o.size(0), -1) for o in zero_mask], 1)
117 |
118 | if self.phase == "test":
119 | output = self.detect(
120 | loc.view(loc.size(0), -1, 4), # loc preds
121 | self.softmax(conf.view(conf.size(0), -1,
122 | self.num_classes)), # conf preds
123 | self.priors.type(type(x.data)) # default boxes
124 | )
125 | else:
126 | output = (
127 | loc.view(loc.size(0), -1, 4),
128 | conf.view(conf.size(0), -1, self.num_classes),
129 | self.priors
130 | )
131 |
132 | loc = loc.view(loc.size(0), -1, 4)
133 | conf = self.softmax(conf.view(conf.size(0), -1, self.num_classes))
134 | # basic
135 |
136 | sources_flip = list()
137 | loc_flip = list()
138 | conf_flip = list()
139 |
140 | # apply vgg up to conv4_3 relu
141 | for k in range(23):
142 | x_flip = self.vgg[k](x_flip)
143 |
144 | s_flip = self.L2Norm(x_flip)
145 | sources_flip.append(s_flip)
146 |
147 | # apply vgg up to fc7
148 | for k in range(23, len(self.vgg)):
149 | x_flip = self.vgg[k](x_flip)
150 | sources_flip.append(x_flip)
151 |
152 | # apply extra layers and cache source layer outputs
153 | for k, v in enumerate(self.extras):
154 | x_flip = F.relu(v(x_flip), inplace=True)
155 | if k % 2 == 1:
156 | sources_flip.append(x_flip)
157 |
158 | # apply multibox head to source layers
159 | for (x_flip, l, c) in zip(sources_flip, self.loc, self.conf):
160 | append_loc = l(x_flip).permute(0, 2, 3, 1).contiguous()
161 | append_conf = c(x_flip).permute(0, 2, 3, 1).contiguous()
162 | append_loc = flip(append_loc,2)
163 | append_conf = flip(append_conf,2)
164 | loc_flip.append(append_loc)
165 | conf_flip.append(append_conf)
166 |
167 | loc_flip = torch.cat([o.view(o.size(0), -1) for o in loc_flip], 1)
168 | conf_flip = torch.cat([o.view(o.size(0), -1) for o in conf_flip], 1)
169 |
170 | loc_flip = loc_flip.view(loc_flip.size(0), -1, 4)
171 |
172 | conf_flip = self.softmax(conf_flip.view(conf.size(0), -1, self.num_classes))
173 |
174 |
175 | if self.phase == "test":
176 | return output
177 | else:
178 | return output, conf, conf_flip, loc, loc_flip, feats
179 |
180 | def load_weights(self, base_file):
181 | other, ext = os.path.splitext(base_file)
182 | if ext == '.pkl' or '.pth':
183 | print('Loading weights into state dict...')
184 | self.load_state_dict(torch.load(base_file,
185 | map_location=lambda storage, loc: storage))
186 | print('Finished!')
187 | else:
188 | print('Sorry only .pth and .pkl files supported.')
189 |
190 |
191 | # This function is derived from torchvision VGG make_layers()
192 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
193 | def vgg(cfg, i, batch_norm=False):
194 | layers = []
195 | in_channels = i
196 | for v in cfg:
197 | if v == 'M':
198 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
199 | elif v == 'C':
200 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
201 | else:
202 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
203 | if batch_norm:
204 | print('batch_norm')
205 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
206 | else:
207 | layers += [conv2d, nn.ReLU(inplace=True)]
208 | in_channels = v
209 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
210 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
211 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
212 | layers += [pool5, conv6,
213 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
214 | return layers
215 |
216 |
217 | def add_extras(cfg, i, batch_norm=False):
218 | # Extra layers added to VGG for feature scaling
219 | layers = []
220 | in_channels = i
221 | flag = False
222 | for k, v in enumerate(cfg):
223 | if in_channels != 'S':
224 | if v == 'S':
225 | layers += [nn.Conv2d(in_channels, cfg[k + 1],
226 | kernel_size=(1, 3)[flag], stride=2, padding=1)]
227 | elif v=='K':
228 | layers += [nn.Conv2d(in_channels, 256,
229 | kernel_size=4, stride=1, padding=1)]
230 | else:
231 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
232 | flag = not flag
233 | in_channels = v
234 | return layers
235 |
236 |
237 |
238 | def multibox(vgg, extra_layers, cfg, num_classes):
239 | loc_layers = []
240 | conf_layers = []
241 | vgg_source = [21, -2]
242 | for k, v in enumerate(vgg_source):
243 | loc_layers += [nn.Conv2d(vgg[v].out_channels,
244 | cfg[k] * 4, kernel_size=3, padding=1)]
245 | conf_layers += [nn.Conv2d(vgg[v].out_channels,
246 | cfg[k] * num_classes, kernel_size=3, padding=1)]
247 | for k, v in enumerate(extra_layers[1::2], 2):
248 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
249 | * 4, kernel_size=3, padding=1)]
250 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
251 | * num_classes, kernel_size=3, padding=1)]
252 | return vgg, extra_layers, (loc_layers, conf_layers)
253 |
254 |
255 | base = {
256 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
257 | 512, 512, 512],
258 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
259 | 512, 512, 512],
260 | }
261 | extras = {
262 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
263 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'],
264 | }
265 | mbox = {
266 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location
267 | '512': [4, 6, 6, 6, 6, 4, 4],
268 | }
269 |
270 | def flip(x, dim):
271 | dim = x.dim() + dim if dim < 0 else dim
272 | return x[tuple(slice(None, None) if i != dim
273 | else torch.arange(x.size(i)-1, -1, -1).long()
274 | for i in range(x.dim()))]
275 |
276 | class GaussianNoise(nn.Module):
277 | def __init__(self, batch_size, input_size=(3, 300, 300), mean=0, std=0.15):
278 | super(GaussianNoise, self).__init__()
279 | self.shape = (batch_size, ) + input_size
280 | self.noise = Variable(torch.zeros(self.shape).cuda())
281 | self.mean = mean
282 | self.std = std
283 |
284 | def forward(self, x):
285 | self.noise.data.normal_(self.mean, std=self.std)
286 | if x.size(0) == self.noise.size(0):
287 | return x + self.noise
288 | else:
289 | #print('---- Noise Size ')
290 | return x + self.noise[:x.size(0)]
291 |
292 |
293 | def build_ssd_con(phase, size=300, num_classes=21):
294 | if phase != "test" and phase != "train":
295 | print("ERROR: Phase: " + phase + " not recognized")
296 | return
297 | if size != 300:
298 | print("ERROR: You specified size " + repr(size) + ". However, " +
299 | "currently only SSD300 (size=300) is supported!")
300 | return
301 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
302 | add_extras(extras[str(size)], 1024),
303 | mbox[str(size)], num_classes)
304 | return SSD_CON(phase, size, base_, extras_, head_, num_classes)
305 |
306 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 |
9 | from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
10 | from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
11 |
12 | from .config import *
13 | import torch
14 | import cv2
15 | import numpy as np
16 |
17 | def detection_collate(batch):
18 | """Custom collate fn for dealing with batches of images that have a different
19 | number of associated object annotations (bounding boxes).
20 |
21 | Arguments:
22 | batch: (tuple) A tuple of tensor images and lists of annotations
23 |
24 | Return:
25 | A tuple containing:
26 | 1) (tensor) batch of images stacked on their 0 dim
27 | 2) (list of tensors) annotations for a given image are stacked on
28 | 0 dim
29 | """
30 | ### changed when semi-supervised
31 | targets = []
32 | imgs = []
33 | semis = []
34 | for sample in batch:
35 | imgs.append(sample[0])
36 | targets.append(torch.FloatTensor(sample[1]))
37 | if(len(sample)==3):
38 | semis.append(torch.FloatTensor(sample[2]))
39 | if(len(sample)==2):
40 | return torch.stack(imgs, 0), targets
41 | else:
42 | return torch.stack(imgs, 0), targets, semis
43 | # return torch.stack(imgs, 0), targets
44 |
45 |
46 | def base_transform(image, size, mean):
47 | x = cv2.resize(image, (size, size)).astype(np.float32)
48 | x -= mean
49 | x = x.astype(np.float32)
50 | return x
51 |
52 |
53 | class BaseTransform:
54 | def __init__(self, size, mean):
55 | self.size = size
56 | self.mean = np.array(mean, dtype=np.float32)
57 |
58 | def __call__(self, image, boxes=None, labels=None):
59 | return base_transform(image, self.size, self.mean), boxes, labels
60 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/coco_new.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/coco_new.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc0712.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc0712.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc0712.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc0712_backup.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712_backup.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc0712_consistency.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc0712_consistency.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc07_consistency.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc07_consistency.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc07_consistency_init.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency_init.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc07_consistency_init.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_consistency_init.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/voc07_voc12coco.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/__pycache__/voc07_voc12coco.cpython-36.pyc
--------------------------------------------------------------------------------
/data/coco.py:
--------------------------------------------------------------------------------
1 | from .config import HOME
2 | import cv2
3 | import numpy as np
4 | import os
5 | import os.path as osp
6 | import sys
7 | import torch
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 |
11 |
12 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
13 | #
14 | # This work is made available under the Nvidia Source Code License-NC.
15 | # To view a copy of this license, visit
16 | # https://github.com/NVlabs/AL-SSL/
17 |
18 |
19 | COCO_ROOT = '/usr/wiss/elezi/data/coco'
20 | IMAGES = 'images'
21 | ANNOTATIONS = 'annotations'
22 | COCO_API = 'PythonAPI'
23 | INSTANCES_SET = 'instances_{}.json'
24 | COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
25 | 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
26 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
27 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
28 | 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
29 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
30 | 'kite', 'baseball bat', 'baseball glove', 'skateboard',
31 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
32 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
33 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
34 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
35 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
36 | 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
37 | 'refrigerator', 'book', 'clock', 'vase', 'scissors',
38 | 'teddy bear', 'hair drier', 'toothbrush')
39 |
40 | VOC_CLASSES = ( # always index 0
41 | 'aeroplane', 'bicycle', 'bird', 'boat',
42 | 'bottle', 'bus', 'car', 'cat', 'chair',
43 | 'cow', 'diningtable', 'dog', 'horse',
44 | 'motorbike', 'person', 'pottedplant',
45 | 'sheep', 'sofa', 'train', 'tvmonitor')
46 |
47 |
48 | def get_label_map(label_file):
49 | label_map = {}
50 | labels = open(label_file, 'r')
51 | for line in labels:
52 | ids = line.split(',')
53 | label_map[int(ids[0])] = int(ids[1])
54 | return label_map
55 |
56 |
57 | class COCOAnnotationTransform(object):
58 | """Transforms a COCO annotation into a Tensor of bbox coords and label index
59 | Initilized with a dictionary lookup of classnames to indexes
60 | """
61 | def __init__(self):
62 | self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt'))
63 |
64 | def __call__(self, target, width, height):
65 | """
66 | Args:
67 | target (dict): COCO target json annotation as a python dict
68 | height (int): height
69 | width (int): width
70 | Returns:
71 | a list containing lists of bounding boxes [bbox coords, class idx]
72 | """
73 | scale = np.array([width, height, width, height])
74 | res = []
75 | for obj in target:
76 | if 'bbox' in obj:
77 | bbox = obj['bbox']
78 | bbox[2] += bbox[0]
79 | bbox[3] += bbox[1]
80 | label_idx = self.label_map[obj['category_id']] - 1
81 | final_box = list(np.array(bbox)/scale)
82 | final_box.append(label_idx)
83 | res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
84 | else:
85 | print("no bbox problem!")
86 |
87 | return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
88 |
89 |
90 | class COCODetection(data.Dataset):
91 | """`MS Coco Detection `_ Dataset.
92 | Args:
93 | root (string): Root directory where images are downloaded to.
94 | set_name (string): Name of the specific set of COCO images.
95 | transform (callable, optional): A function/transform that augments the
96 | raw images`
97 | target_transform (callable, optional): A function/transform that takes
98 | in the target (bbox) and transforms it.
99 | """
100 |
101 | def __init__(self, root, supervised_indices=None, image_set='train2014', transform=None,
102 | target_transform=COCOAnnotationTransform(), dataset_name='MS COCO', pseudo_labels={}):
103 | sys.path.append(osp.join(root, COCO_API))
104 | self.supervised_indices = supervised_indices
105 | from pycocotools.coco import COCO
106 | self.root = osp.join(root, IMAGES, image_set)
107 | self.coco = COCO(osp.join(root, ANNOTATIONS,
108 | INSTANCES_SET.format(image_set)))
109 | self.ids = list(self.coco.imgToAnns.keys())
110 | self.supervised_indices = supervised_indices
111 | self.pseudo_labels = pseudo_labels
112 | self.pseudo_labels_indices = self.get_pseudo_label_indices()
113 | self.transform = transform
114 | self.target_transform = target_transform
115 | self.name = dataset_name
116 | self.class_to_ind = dict(zip(COCO_CLASSES, range(len(COCO_CLASSES))))
117 | self.class_to_ind_voc = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
118 | self.dict_coco_to_voc = {0: 0, 1: 1, 2: 6, 3: 13, 4: 0, 5: 5, 6: 18, 8: 3, 15: 2, 16: 7, 17: 11, 18: 12, 19: 16,
119 | 20: 9, 40: 4, 57: 8, 58: 17, 59: 15, 61: 10, 63: 19}
120 |
121 | def contain_voc(self):
122 | return self.coco_contain_voc
123 |
124 | def __getitem__(self, index):
125 | """
126 | Args:
127 | index (int): Index
128 | Returns:
129 | tuple: Tuple (image, target).
130 | target is the object returned by ``coco.loadAnns``.
131 | """
132 | im, gt, h, w, semi = self.pull_item(index)
133 | return im, gt, semi
134 |
135 | def __len__(self):
136 | return len(self.ids)
137 |
138 | def pull_item(self, index):
139 | """
140 | Args:
141 | index (int): Index
142 | Returns:
143 | tuple: Tuple (image, target, height, width).
144 | target is the object returned by ``coco.loadAnns``.
145 | """
146 | img_id = self.ids[index]
147 | target = self.coco.imgToAnns[img_id]
148 | ann_ids = self.coco.getAnnIds(imgIds=img_id)
149 |
150 | target = self.coco.loadAnns(ann_ids)
151 | path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name'])
152 | split_path = path.split("/")
153 | split_path = os.path.join(split_path[-2], split_path[-1])
154 | # if split_path in self.coco_for_voc_valid:
155 | # print("Inside")
156 | # self.coco_contain_voc.append(index)
157 | assert osp.exists(path), 'Image path does not exist: {}'.format(path)
158 | # img = cv2.imread(osp.join(self.root, path))
159 | img = cv2.imread(path)
160 |
161 | height, width, _ = img.shape
162 |
163 | if self.target_transform is not None:
164 | if index not in self.pseudo_labels_indices:
165 | target = self.target_transform(target, width, height)
166 | else:
167 | target = self.target_transform_pseudo_label(self.pseudo_labels, index)
168 |
169 | if self.transform is not None:
170 | target = np.array(target)
171 | img, boxes, labels = self.transform(img, target[:, :4],
172 | target[:, 4])
173 | # to rgb
174 | img = img[:, :, (2, 1, 0)]
175 |
176 | target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
177 |
178 | if self.supervised_indices != None:
179 | if index in self.supervised_indices:
180 | semi = np.array([1])
181 | elif index in self.pseudo_labels_indices:
182 | semi = np.array([2])
183 | else:
184 | semi = np.array([0])
185 | target = np.zeros([1,5])
186 | else:
187 | # it does not matter
188 | semi = np.array([0])
189 |
190 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi
191 |
192 | def pull_image(self, index):
193 | '''Returns the original image object at index in PIL form
194 |
195 | Note: not using self.__getitem__(), as any transformations passed in
196 | could mess up this functionality.
197 |
198 | Argument:
199 | index (int): index of img to show
200 | Return:
201 | cv2 img
202 | '''
203 | img_id = self.ids[index]
204 | path = self.coco.loadImgs(img_id)[0]['file_name']
205 | return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR)
206 |
207 | def pull_anno(self, index):
208 | '''Returns the original annotation of image at index
209 |
210 | Note: not using self.__getitem__(), as any transformations passed in
211 | could mess up this functionality.
212 |
213 | Argument:
214 | index (int): index of img to get annotation of
215 | Return:
216 | list: [img_id, [(label, bbox coords),...]]
217 | eg: ('001718', [('dog', (96, 13, 438, 332))])
218 | '''
219 | img_id = self.ids[index]
220 | ann_ids = self.coco.getAnnIds(imgIds=img_id)
221 | return self.coco.loadAnns(ann_ids)
222 |
223 | def pull_pseudo_anno(self, index):
224 | '''Returns the original annotation of image at index
225 |
226 | Note: not using self.__getitem__(), as any transformations passed in
227 | could mess up this functionality.
228 |
229 | Argument:
230 | index (int): index of img to get annotation of
231 | Return:
232 | list: [img_id, [(label, bbox coords),...]]
233 | eg: ('001718', [('dog', (96, 13, 438, 332))])
234 | '''
235 | return self.pseudo_labels[index]
236 |
237 | def get_pseudo_label_indices(self):
238 | pseudo_label_indices = []
239 | for key in self.pseudo_labels:
240 | pseudo_label_indices.append(key)
241 | return pseudo_label_indices
242 |
243 | def target_transform_pseudo_label(self, pseudo_labels, index):
244 | # height, width = pseudo_labels[index][0][3], pseudo_labels[index][0][4]
245 | height, width = pseudo_labels[index][0][4], pseudo_labels[index][0][3]
246 | # height, width = 300, 300
247 | res = []
248 | for i in range(len(pseudo_labels[index])):
249 | pts = [pseudo_labels[index][i][5], pseudo_labels[index][i][6], pseudo_labels[index][i][7], pseudo_labels[index][i][8]]
250 | name = pseudo_labels[index][i][2]
251 | pts[0] /= height
252 | pts[2] /= height
253 | pts[1] /= width
254 | pts[3] /= width
255 | for i in range(len(pts)):
256 | if pts[i] < 0:
257 | pts[i] = 0
258 | if pts[0] > height:
259 | pts[0] = height
260 | if pts[2] > height:
261 | pts[2] = height
262 | label_idx = float(self.class_to_ind[name])
263 | # label_idx = float(self.class_to_ind_voc[name])
264 | # transform to voc standard
265 | # label_idx = self.class_to_ind[name]
266 | # label_idx = float(self.dict_coco_to_voc[label_idx])
267 | pts.append(label_idx)
268 | res.append(pts)
269 | return res
270 |
271 | def __repr__(self):
272 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
273 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
274 | fmt_str += ' Root Location: {}\n'.format(self.root)
275 | tmp = ' Transforms (if any): '
276 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
277 | tmp = ' Target Transforms (if any): '
278 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
279 | return fmt_str
280 |
--------------------------------------------------------------------------------
/data/coco/coco_labels.txt:
--------------------------------------------------------------------------------
1 | 1,1,person
2 | 2,2,bicycle
3 | 3,3,car
4 | 4,4,motorcycle
5 | 5,5,airplane
6 | 6,6,bus
7 | 7,7,train
8 | 8,8,truck
9 | 9,9,boat
10 | 10,10,traffic light
11 | 11,11,fire hydrant
12 | 13,12,stop sign
13 | 14,13,parking meter
14 | 15,14,bench
15 | 16,15,bird
16 | 17,16,cat
17 | 18,17,dog
18 | 19,18,horse
19 | 20,19,sheep
20 | 21,20,cow
21 | 22,21,elephant
22 | 23,22,bear
23 | 24,23,zebra
24 | 25,24,giraffe
25 | 27,25,backpack
26 | 28,26,umbrella
27 | 31,27,handbag
28 | 32,28,tie
29 | 33,29,suitcase
30 | 34,30,frisbee
31 | 35,31,skis
32 | 36,32,snowboard
33 | 37,33,sports ball
34 | 38,34,kite
35 | 39,35,baseball bat
36 | 40,36,baseball glove
37 | 41,37,skateboard
38 | 42,38,surfboard
39 | 43,39,tennis racket
40 | 44,40,bottle
41 | 46,41,wine glass
42 | 47,42,cup
43 | 48,43,fork
44 | 49,44,knife
45 | 50,45,spoon
46 | 51,46,bowl
47 | 52,47,banana
48 | 53,48,apple
49 | 54,49,sandwich
50 | 55,50,orange
51 | 56,51,broccoli
52 | 57,52,carrot
53 | 58,53,hot dog
54 | 59,54,pizza
55 | 60,55,donut
56 | 61,56,cake
57 | 62,57,chair
58 | 63,58,couch
59 | 64,59,potted plant
60 | 65,60,bed
61 | 67,61,dining table
62 | 70,62,toilet
63 | 72,63,tv
64 | 73,64,laptop
65 | 74,65,mouse
66 | 75,66,remote
67 | 76,67,keyboard
68 | 77,68,cell phone
69 | 78,69,microwave
70 | 79,70,oven
71 | 80,71,toaster
72 | 81,72,sink
73 | 82,73,refrigerator
74 | 84,74,book
75 | 85,75,clock
76 | 86,76,vase
77 | 87,77,scissors
78 | 88,78,teddy bear
79 | 89,79,hair drier
80 | 90,80,toothbrush
81 |
--------------------------------------------------------------------------------
/data/config.py:
--------------------------------------------------------------------------------
1 | # config.py
2 | import os.path
3 |
4 | # gets home dir cross platform
5 | HOME = os.path.expanduser("~")
6 |
7 | # for making bounding boxes pretty
8 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
9 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
10 |
11 | MEANS = (104, 117, 123)
12 |
13 |
14 | voc300 = {
15 | 'num_classes': 21,
16 | 'lr_steps': (80000, 100000, 120000),
17 | 'max_iter': 120000,
18 | 'feature_maps': [38, 19, 10, 5, 3, 1],
19 | 'min_dim': 300,
20 | 'steps': [8, 16, 32, 64, 100, 300],
21 | 'min_sizes': [30, 60, 111, 162, 213, 264],
22 | 'max_sizes': [60, 111, 162, 213, 264, 315],
23 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
24 | 'variance': [0.1, 0.2],
25 | 'clip': True,
26 | 'name': 'VOC',
27 | }
28 | voc512 = {
29 | 'num_classes': 21,
30 | 'lr_steps': (80000, 100000, 120000),
31 | 'max_iter': 120000,
32 | 'feature_maps': [64, 32, 16, 8, 4, 2, 1],
33 | 'min_dim': 512,
34 | 'steps': [8, 16, 32, 64, 128, 256, 512],
35 | 'min_sizes': [35.84, 76.8, 153.6, 230.4, 307.2, 384.0, 460.8],
36 | 'max_sizes': [76.8, 153.6, 230.4, 307.2, 384.0, 460.8, 537.6],
37 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2], [2]],
38 | 'variance': [0.1, 0.2],
39 | 'clip': True,
40 | 'name': 'VOC',
41 | }
42 |
43 |
44 | coco = {
45 | 'num_classes': 81,
46 | 'lr_steps': (80000, 100000, 120000),
47 | 'max_iter': 120000,
48 | 'feature_maps': [38, 19, 10, 5, 3, 1],
49 | 'min_dim': 300,
50 | 'steps': [8, 16, 32, 64, 100, 300],
51 | 'min_sizes': [21, 45, 99, 153, 207, 261],
52 | 'max_sizes': [45, 99, 153, 207, 261, 315],
53 | 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
54 | 'variance': [0.1, 0.2],
55 | 'clip': True,
56 | 'name': 'COCO',
57 | }
58 |
59 |
--------------------------------------------------------------------------------
/data/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/data/example.jpg
--------------------------------------------------------------------------------
/data/scripts/COCO2014.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | start=`date +%s`
4 |
5 | # handle optional download dir
6 | if [ -z "$1" ]
7 | then
8 | # navigate to ~/data
9 | echo "navigating to ~/data/ ..."
10 | mkdir -p ~/data
11 | cd ~/data/
12 | mkdir -p ./coco
13 | cd ./coco
14 | mkdir -p ./images
15 | mkdir -p ./annotations
16 | else
17 | # check if specified dir is valid
18 | if [ ! -d $1 ]; then
19 | echo $1 " is not a valid directory"
20 | exit 0
21 | fi
22 | echo "navigating to " $1 " ..."
23 | cd $1
24 | fi
25 |
26 | if [ ! -d images ]
27 | then
28 | mkdir -p ./images
29 | fi
30 |
31 | # Download the image data.
32 | cd ./images
33 | echo "Downloading MSCOCO train images ..."
34 | curl -LO http://images.cocodataset.org/zips/train2014.zip
35 | echo "Downloading MSCOCO val images ..."
36 | curl -LO http://images.cocodataset.org/zips/val2014.zip
37 |
38 | cd ../
39 | if [ ! -d annotations]
40 | then
41 | mkdir -p ./annotations
42 | fi
43 |
44 | # Download the annotation data.
45 | cd ./annotations
46 | echo "Downloading MSCOCO train/val annotations ..."
47 | curl -LO http://images.cocodataset.org/annotations/annotations_trainval2014.zip
48 | echo "Finished downloading. Now extracting ..."
49 |
50 | # Unzip data
51 | echo "Extracting train images ..."
52 | unzip ../images/train2014.zip -d ../images
53 | echo "Extracting val images ..."
54 | unzip ../images/val2014.zip -d ../images
55 | echo "Extracting annotations ..."
56 | unzip ./annotations_trainval2014.zip
57 |
58 | echo "Removing zip files ..."
59 | rm ../images/train2014.zip
60 | rm ../images/val2014.zip
61 | rm ./annotations_trainval2014.zip
62 |
63 | echo "Creating trainval35k dataset..."
64 |
65 | # Download annotations json
66 | echo "Downloading trainval35k annotations from S3"
67 | curl -LO https://s3.amazonaws.com/amdegroot-datasets/instances_trainval35k.json.zip
68 |
69 | # combine train and val
70 | echo "Combining train and val images"
71 | mkdir ../images/trainval35k
72 | cd ../images/train2014
73 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} + # dir too large for cp
74 | cd ../val2014
75 | find -maxdepth 1 -name '*.jpg' -exec cp -t ../trainval35k {} +
76 |
77 |
78 | end=`date +%s`
79 | runtime=$((end-start))
80 |
81 | echo "Completed in " $runtime " seconds"
82 |
--------------------------------------------------------------------------------
/data/scripts/VOC2007.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ellis Brown
3 |
4 | start=`date +%s`
5 |
6 | # handle optional download dir
7 | if [ -z "$1" ]
8 | then
9 | # navigate to ~/data
10 | echo "navigating to ~/data/ ..."
11 | mkdir -p ~/data
12 | cd ~/data/
13 | else
14 | # check if is valid directory
15 | if [ ! -d $1 ]; then
16 | echo $1 "is not a valid directory"
17 | exit 0
18 | fi
19 | echo "navigating to" $1 "..."
20 | cd $1
21 | fi
22 |
23 | echo "Downloading VOC2007 trainval ..."
24 | # Download the data.
25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
26 | echo "Downloading VOC2007 test data ..."
27 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
28 | echo "Done downloading."
29 |
30 | # Extract data
31 | echo "Extracting trainval ..."
32 | tar -xvf VOCtrainval_06-Nov-2007.tar
33 | echo "Extracting test ..."
34 | tar -xvf VOCtest_06-Nov-2007.tar
35 | echo "removing tars ..."
36 | rm VOCtrainval_06-Nov-2007.tar
37 | rm VOCtest_06-Nov-2007.tar
38 |
39 | end=`date +%s`
40 | runtime=$((end-start))
41 |
42 | echo "Completed in" $runtime "seconds"
--------------------------------------------------------------------------------
/data/scripts/VOC2012.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Ellis Brown
3 |
4 | start=`date +%s`
5 |
6 | # handle optional download dir
7 | if [ -z "$1" ]
8 | then
9 | # navigate to ~/data
10 | echo "navigating to ~/data/ ..."
11 | mkdir -p ~/data
12 | cd ~/data/
13 | else
14 | # check if is valid directory
15 | if [ ! -d $1 ]; then
16 | echo $1 "is not a valid directory"
17 | exit 0
18 | fi
19 | echo "navigating to" $1 "..."
20 | cd $1
21 | fi
22 |
23 | echo "Downloading VOC2012 trainval ..."
24 | # Download the data.
25 | curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
26 | echo "Done downloading."
27 |
28 |
29 | # Extract data
30 | echo "Extracting trainval ..."
31 | tar -xvf VOCtrainval_11-May-2012.tar
32 | echo "removing tar ..."
33 | rm VOCtrainval_11-May-2012.tar
34 |
35 | end=`date +%s`
36 | runtime=$((end-start))
37 |
38 | echo "Completed in" $runtime "seconds"
--------------------------------------------------------------------------------
/data/voc0712.py:
--------------------------------------------------------------------------------
1 | """VOC Dataset Classes
2 |
3 | Original author: Francisco Massa
4 | https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
5 |
6 | Updated by: Ellis Brown, Max deGroot
7 | """
8 |
9 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
10 | #
11 | # This work is made available under the Nvidia Source Code License-NC.
12 | # To view a copy of this license, visit
13 | # https://github.com/NVlabs/AL-SSL/
14 |
15 |
16 | from .config import HOME
17 | import os.path as osp
18 | import sys
19 | import torch
20 | import torch.utils.data as data
21 | import cv2
22 | import numpy as np
23 | import matplotlib.pyplot as plt
24 | from utils.augmentations import jaccard_numpy
25 |
26 | if sys.version_info[0] == 2:
27 | import xml.etree.cElementTree as ET
28 | else:
29 | import xml.etree.ElementTree as ET
30 |
31 | import matplotlib.pyplot as plt
32 |
33 | VOC_CLASSES = ( # always index 0
34 | 'aeroplane', 'bicycle', 'bird', 'boat',
35 | 'bottle', 'bus', 'car', 'cat', 'chair',
36 | 'cow', 'diningtable', 'dog', 'horse',
37 | 'motorbike', 'person', 'pottedplant',
38 | 'sheep', 'sofa', 'train', 'tvmonitor')
39 |
40 | # note: if you used our download scripts, this should be right
41 | # VOC_ROOT = osp.join(HOME, "tmp/VOC0712/")
42 | VOC_ROOT = '/usr/wiss/elezi/data/VOC0712'
43 |
44 | class VOCAnnotationTransform(object):
45 | """Transforms a VOC annotation into a Tensor of bbox coords and label index
46 | Initilized with a dictionary lookup of classnames to indexes
47 |
48 | Arguments:
49 | class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
50 | (default: alphabetic indexing of VOC's 20 classes)
51 | keep_difficult (bool, optional): keep difficult instances or not
52 | (default: False)
53 | height (int): height
54 | width (int): width
55 | """
56 |
57 | def __init__(self, class_to_ind=None, keep_difficult=False):
58 | self.class_to_ind = class_to_ind or dict(
59 | zip(VOC_CLASSES, range(len(VOC_CLASSES))))
60 | self.keep_difficult = keep_difficult
61 |
62 | def __call__(self, target, width, height):
63 | """
64 | Arguments:
65 | target (annotation) : the target annotation to be made usable
66 | will be an ET.Element
67 | Returns:
68 | a list containing lists of bounding boxes [bbox coords, class name]
69 | """
70 | res = []
71 | for obj in target.iter('object'):
72 | difficult = int(obj.find('difficult').text) == 1
73 | if not self.keep_difficult and difficult:
74 | continue
75 | name = obj.find('name').text.lower().strip()
76 | bbox = obj.find('bndbox')
77 |
78 | pts = ['xmin', 'ymin', 'xmax', 'ymax']
79 | bndbox = []
80 | for i, pt in enumerate(pts):
81 | cur_pt = int(bbox.find(pt).text) - 1
82 | # scale height or width
83 | cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
84 | bndbox.append(cur_pt)
85 | label_idx = self.class_to_ind[name]
86 | bndbox.append(label_idx)
87 | res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
88 | # img_id = target.find('filename').text[:-4]
89 |
90 | return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
91 |
92 |
93 | class VOCDetection(data.Dataset):
94 | """VOC Detection Dataset Object
95 |
96 | input is image, target is annotation
97 |
98 | Arguments:
99 | root (string): filepath to VOCdevkit folder.
100 | image_set (string): imageset to use (eg. 'train', 'val', 'test')
101 | transform (callable, optional): transformation to perform on the
102 | input image
103 | target_transform (callable, optional): transformation to perform on the
104 | target `annotation`
105 | (eg: take in caption string, return tensor of word indices)
106 | dataset_name (string, optional): which dataset to load
107 | (default: 'VOC2007')
108 | """
109 | def __init__(self, root, supervised_indices=None,
110 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
111 | transform=None, target_transform=VOCAnnotationTransform(),
112 | dataset_name='VOC0712', pseudo_labels={}, bounding_box_dict={}):
113 | self.root = root
114 | self.image_set = image_sets
115 | self.transform = transform
116 | self.target_transform = target_transform
117 | self.name = dataset_name
118 | self._annopath = osp.join('%s', 'Annotations', '%s.xml')
119 | self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
120 | self.ids = list()
121 | self.supervised_indices = supervised_indices
122 | self.pseudo_labels = pseudo_labels
123 | self.pseudo_labels_indices = self.get_pseudo_label_indices()
124 | self.bounding_box_dict = bounding_box_dict
125 | self.bounding_box_indices = self.get_bounding_box_indices()
126 | for (year, name) in image_sets:
127 | rootpath = osp.join(self.root, 'VOC' + year)
128 | for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
129 | self.ids.append((rootpath, line.strip()))
130 | self.class_to_ind = dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
131 |
132 | def __getitem__(self, index):
133 | im, gt, h, w, semi = self.pull_item(index)
134 | return im, gt, semi
135 |
136 | def __len__(self):
137 | return len(self.ids)
138 |
139 | def pull_item(self, index):
140 | img_id = self.ids[index]
141 | colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
142 |
143 | target = ET.parse(self._annopath % img_id).getroot()
144 | img = cv2.imread(self._imgpath % img_id)
145 | height, width, channels = img.shape
146 |
147 | if self.target_transform is not None:
148 | if index in self.bounding_box_indices:
149 | target_real = self.target_transform(target, width, height)
150 | target = self.target_transform_bounding_box(self.bounding_box_dict, index, target_real, width, height)
151 |
152 | else:
153 | target = self.target_transform(target, width, height)
154 |
155 | if self.transform is not None:
156 | target = np.array(target)
157 | img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
158 |
159 | # to rgb
160 | img = img[:, :, (2, 1, 0)]
161 | # img = img.transpose(2, 0, 1)
162 | target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
163 |
164 | if self.supervised_indices != None:
165 | if index in self.supervised_indices:
166 | semi = np.array([1])
167 | elif index in self.pseudo_labels_indices or index in self.bounding_box_indices:
168 | semi = np.array([2])
169 | # elif index in self.bounding_box_indices:
170 | # semi = np.array([2])
171 | # print(semi)
172 | else:
173 | semi = np.array([0])
174 | target = np.zeros([1,5])
175 | else:
176 | # it does not matter
177 | semi = np.array([0])
178 |
179 | return torch.from_numpy(img).permute(2, 0, 1), target, height, width, semi
180 |
181 | def pull_image(self, index):
182 | '''Returns the original image object at index in PIL form
183 |
184 | Note: not using self.__getitem__(), as any transformations passed in
185 | could mess up this functionality.
186 |
187 | Argument:
188 | index (int): index of img to show
189 | Return:
190 | PIL img
191 | '''
192 | img_id = self.ids[index]
193 | return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
194 |
195 | def pull_anno(self, index):
196 | '''Returns the original annotation of image at index
197 |
198 | Note: not using self.__getitem__(), as any transformations passed in
199 | could mess up this functionality.
200 |
201 | Argument:
202 | index (int): index of img to get annotation of
203 | Return:
204 | list: [img_id, [(label, bbox coords),...]]
205 | eg: ('001718', [('dog', (96, 13, 438, 332))])
206 | '''
207 | img_id = self.ids[index]
208 | anno = ET.parse(self._annopath % img_id).getroot()
209 | gt = self.target_transform(anno, 1, 1)
210 | return img_id[1], gt
211 |
212 | def pull_pseudo_anno(self, index):
213 | '''Returns the original annotation of image at index
214 |
215 | Note: not using self.__getitem__(), as any transformations passed in
216 | could mess up this functionality.
217 |
218 | Argument:
219 | index (int): index of img to get annotation of
220 | Return:
221 | list: [img_id, [(label, bbox coords),...]]
222 | eg: ('001718', [('dog', (96, 13, 438, 332))])
223 | '''
224 | return self.pseudo_labels[index]
225 |
226 | def pull_tensor(self, index):
227 | '''Returns the original image at an index in tensor form
228 |
229 | Note: not using self.__getitem__(), as any transformations passed in
230 | could mess up this functionality.
231 |
232 | Argument:
233 | index (int): index of img to show
234 | Return:
235 | tensorized version of img, squeezed
236 | '''
237 | return torch.Tensor(self.pull_image(index)).unsqueeze_(0)
238 |
239 | def get_pseudo_label_indices(self):
240 | pseudo_label_indices = []
241 | if self.pseudo_labels is not None:
242 | for key in self.pseudo_labels:
243 | pseudo_label_indices.append(key)
244 | return pseudo_label_indices
245 | else:
246 | return None
247 |
248 | def get_bounding_box_indices(self):
249 | bounding_box_indices = []
250 | for key in self.bounding_box_dict:
251 | bounding_box_indices.append(key)
252 | return bounding_box_indices
253 |
254 | def target_transform_pseudo_label(self, pseudo_labels, index):
255 | # height, width = pseudo_labels[index][0][3], pseudo_labels[index][0][4]
256 | height, width = pseudo_labels[index][0][4], pseudo_labels[index][0][3]
257 | # height, width = 300, 300
258 | res = []
259 | for i in range(len(pseudo_labels[index])):
260 | pts = [pseudo_labels[index][i][5], pseudo_labels[index][i][6], pseudo_labels[index][i][7], pseudo_labels[index][i][8]]
261 | name = pseudo_labels[index][i][2]
262 | pts[0] /= height
263 | pts[2] /= height
264 | pts[1] /= width
265 | pts[3] /= width
266 | for i in range(len(pts)):
267 | if pts[i] < 0:
268 | pts[i] = 0
269 | if pts[0] > height:
270 | pts[0] = height
271 | if pts[2] > height:
272 | pts[2] = height
273 | label_idx = float(self.class_to_ind[name])
274 | pts.append(label_idx)
275 | res.append(pts)
276 | return res
277 |
278 | def target_transform_bounding_box(self, bounding_box_dict, index, target_real, width, height):
279 | predictions = np.array(bounding_box_dict[index])
280 | predictions = predictions[:, :-1]
281 | target_real_numpy = np.array(target_real)
282 | target_real_numpy = target_real_numpy[:, :-1]
283 | iou_all = np.zeros((predictions.shape[0], len(target_real)))
284 | for i in range(len(target_real)):
285 | iou_all[:, i] = jaccard_numpy(predictions, target_real_numpy[i])
286 | max_val = 2
287 | print()
288 | print(iou_all)
289 | print()
290 |
291 | targets_intersected = []
292 |
293 | # otherwise get the value
294 | while max_val > 0:
295 | max_val = np.max(iou_all)
296 | if max_val > 0:
297 | argmax_val = np.where(iou_all == np.amax(iou_all))
298 | iou_all[argmax_val[0], :] = np.zeros((1, iou_all.shape[1]))
299 | iou_all[:, argmax_val[1]] = np.zeros((iou_all.shape[0], 1))
300 | # get the GT from target_real
301 | targets_intersected.append(target_real[int(argmax_val[1])])
302 |
303 | else:
304 | # get a random one from target real
305 | targets_intersected.append(target_real[0])
306 |
307 | return targets_intersected
308 |
309 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """Adapted from:
2 | @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch
3 | @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn
4 | Licensed under The MIT License [see LICENSE for details]
5 | """
6 |
7 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
8 | #
9 | # This work is made available under the Nvidia Source Code License-NC.
10 | # To view a copy of this license, visit
11 | # https://github.com/NVlabs/AL-SSL/
12 |
13 | from __future__ import print_function
14 | import torch
15 | import torch.nn as nn
16 | import torch.backends.cudnn as cudnn
17 | from torch.autograd import Variable
18 | from data import VOC_ROOT, VOCAnnotationTransform, VOCDetection, BaseTransform
19 | from utils.augmentations import SSDAugmentation
20 | from data import VOC_CLASSES as labelmap
21 | import torch.utils.data as data
22 |
23 | from ssd import build_ssd
24 |
25 | import sys
26 | import os
27 | import time
28 | import argparse
29 | import numpy as np
30 | import pickle
31 | import cv2
32 | from collections import OrderedDict
33 |
34 | if sys.version_info[0] == 2:
35 | import xml.etree.cElementTree as ET
36 | else:
37 | import xml.etree.ElementTree as ET
38 |
39 |
40 | def str2bool(v):
41 | return v.lower() in ("yes", "true", "t", "1")
42 |
43 |
44 | parser = argparse.ArgumentParser(
45 | description='Single Shot MultiBox Detector Evaluation')
46 | parser.add_argument('--trained_model', default='weights_voc_code/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth',
47 | type=str, help='Trained state_dict file path to open')
48 | parser.add_argument('--save_folder', default='eval/', type=str,
49 | help='File path to save results')
50 | parser.add_argument('--confidence_threshold', default=0.01, type=float,
51 | help='Detection confidence threshold')
52 | parser.add_argument('--top_k', default=5, type=int,
53 | help='Further restrict the number of predictions to parse')
54 | parser.add_argument('--cuda', default=True, type=str2bool,
55 | help='Use cuda to train model')
56 | parser.add_argument('--voc_root', default='../../data/VOC0712/',
57 | help='Location of VOC root directory')
58 | parser.add_argument('--cleanup', default=True, type=str2bool,
59 | help='Cleanup and remove results files following eval')
60 |
61 | args = parser.parse_args()
62 |
63 | if not os.path.exists(args.save_folder):
64 | os.mkdir(args.save_folder)
65 |
66 | if torch.cuda.is_available():
67 | if args.cuda:
68 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
69 | if not args.cuda:
70 | print("WARNING: It looks like you have a CUDA device, but aren't using \
71 | CUDA. Run with --cuda for optimal eval speed.")
72 | torch.set_default_tensor_type('torch.FloatTensor')
73 | else:
74 | torch.set_default_tensor_type('torch.FloatTensor')
75 |
76 | annopath = os.path.join(args.voc_root, 'VOC2007', 'Annotations', '%s.xml')
77 | imgpath = os.path.join(args.voc_root, 'VOC2007', 'JPEGImages', '%s.jpg')
78 | imgsetpath = os.path.join(args.voc_root, 'VOC2007', 'ImageSets',
79 | 'Main', '{:s}.txt')
80 | YEAR = '2007'
81 | devkit_path = args.voc_root + 'VOC' + YEAR
82 | dataset_mean = (104, 117, 123)
83 | set_type = 'test'
84 |
85 |
86 | class Timer(object):
87 | """A simple timer."""
88 | def __init__(self):
89 | self.total_time = 0.
90 | self.calls = 0
91 | self.start_time = 0.
92 | self.diff = 0.
93 | self.average_time = 0.
94 |
95 | def tic(self):
96 | # using time.time instead of time.clock because time time.clock
97 | # does not normalize for multithreading
98 | self.start_time = time.time()
99 |
100 | def toc(self, average=True):
101 | self.diff = time.time() - self.start_time
102 | self.total_time += self.diff
103 | self.calls += 1
104 | self.average_time = self.total_time / self.calls
105 | if average:
106 | return self.average_time
107 | else:
108 | return self.diff
109 |
110 |
111 | def parse_rec(filename):
112 | """ Parse a PASCAL VOC xml file """
113 | tree = ET.parse(filename)
114 | objects = []
115 | for obj in tree.findall('object'):
116 | obj_struct = {}
117 | obj_struct['name'] = obj.find('name').text
118 | obj_struct['pose'] = obj.find('pose').text
119 | obj_struct['truncated'] = int(obj.find('truncated').text)
120 | obj_struct['difficult'] = int(obj.find('difficult').text)
121 | bbox = obj.find('bndbox')
122 | obj_struct['bbox'] = [int(bbox.find('xmin').text) - 1,
123 | int(bbox.find('ymin').text) - 1,
124 | int(bbox.find('xmax').text) - 1,
125 | int(bbox.find('ymax').text) - 1]
126 | objects.append(obj_struct)
127 |
128 | return objects
129 |
130 |
131 | def get_output_dir(name, phase):
132 | """Return the directory where experimental artifacts are placed.
133 | If the directory does not exist, it is created.
134 | A canonical path is built using the name from an imdb and a network
135 | (if not None).
136 | """
137 | filedir = os.path.join(name, phase)
138 | if not os.path.exists(filedir):
139 | os.makedirs(filedir)
140 | return filedir
141 |
142 |
143 | def get_voc_results_file_template(image_set, cls):
144 | # VOCdevkit/VOC2007/results/det_test_aeroplane.txt
145 | filename = 'det_' + image_set + '_%s.txt' % (cls)
146 | filedir = os.path.join(devkit_path, 'results')
147 | if not os.path.exists(filedir):
148 | os.makedirs(filedir)
149 | path = os.path.join(filedir, filename)
150 | return path
151 |
152 |
153 | def write_voc_results_file(all_boxes, dataset):
154 | for cls_ind, cls in enumerate(labelmap):
155 | print('Writing {:s} VOC results file'.format(cls))
156 | filename = get_voc_results_file_template(set_type, cls)
157 | with open(filename, 'wt') as f:
158 | for im_ind, index in enumerate(dataset.ids):
159 | dets = all_boxes[cls_ind+1][im_ind]
160 | if dets == []:
161 | continue
162 | # the VOCdevkit expects 1-based indices
163 | for k in range(dets.shape[0]):
164 | f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
165 | format(index[1], dets[k, -1],
166 | dets[k, 0] + 1, dets[k, 1] + 1,
167 | dets[k, 2] + 1, dets[k, 3] + 1))
168 |
169 |
170 | def do_python_eval(output_dir='output', use_07=True):
171 | cachedir = os.path.join(devkit_path, 'annotations_cache')
172 | aps = []
173 | # The PASCAL VOC metric changed in 2010
174 | use_07_metric = use_07
175 | print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
176 | if not os.path.isdir(output_dir):
177 | os.mkdir(output_dir)
178 | for i, cls in enumerate(labelmap):
179 | filename = get_voc_results_file_template(set_type, cls)
180 | rec, prec, ap = voc_eval(
181 | filename, annopath, imgsetpath.format(set_type), cls, cachedir,
182 | ovthresh=0.5, use_07_metric=use_07_metric)
183 | aps += [ap]
184 | print('AP for {} = {:.4f}'.format(cls, ap))
185 | with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
186 | pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
187 | print('Mean AP = {:.4f}'.format(np.mean(aps)))
188 | print('~~~~~~~~')
189 | print('Results:')
190 | for ap in aps:
191 | print('{:.3f}'.format(ap))
192 | print('{:.3f}'.format(np.mean(aps)))
193 | print('~~~~~~~~')
194 | print('')
195 | print('--------------------------------------------------------------')
196 | print('Results computed with the **unofficial** Python eval code.')
197 | print('Results should be very close to the official MATLAB eval code.')
198 | print('--------------------------------------------------------------')
199 | return aps
200 |
201 |
202 | def voc_ap(rec, prec, use_07_metric=True):
203 | """ ap = voc_ap(rec, prec, [use_07_metric])
204 | Compute VOC AP given precision and recall.
205 | If use_07_metric is true, uses the
206 | VOC 07 11 point method (default:True).
207 | """
208 | if use_07_metric:
209 | # 11 point metric
210 | ap = 0.
211 | for t in np.arange(0., 1.1, 0.1):
212 | if np.sum(rec >= t) == 0:
213 | p = 0
214 | else:
215 | p = np.max(prec[rec >= t])
216 | ap = ap + p / 11.
217 | else:
218 | # correct AP calculation
219 | # first append sentinel values at the end
220 | mrec = np.concatenate(([0.], rec, [1.]))
221 | mpre = np.concatenate(([0.], prec, [0.]))
222 |
223 | # compute the precision envelope
224 | for i in range(mpre.size - 1, 0, -1):
225 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
226 |
227 | # to calculate area under PR curve, look for points
228 | # where X axis (recall) changes value
229 | i = np.where(mrec[1:] != mrec[:-1])[0]
230 |
231 | # and sum (\Delta recall) * prec
232 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
233 | return ap
234 |
235 |
236 | def voc_eval(detpath,
237 | annopath,
238 | imagesetfile,
239 | classname,
240 | cachedir,
241 | ovthresh=0.5,
242 | use_07_metric=True):
243 | """rec, prec, ap = voc_eval(detpath,
244 | annopath,
245 | imagesetfile,
246 | classname,
247 | [ovthresh],
248 | [use_07_metric])
249 | Top level function that does the PASCAL VOC evaluation.
250 | detpath: Path to detections
251 | detpath.format(classname) should produce the detection results file.
252 | annopath: Path to annotations
253 | annopath.format(imagename) should be the xml annotations file.
254 | imagesetfile: Text file containing the list of images, one image per line.
255 | classname: Category name (duh)
256 | cachedir: Directory for caching the annotations
257 | [ovthresh]: Overlap threshold (default = 0.5)
258 | [use_07_metric]: Whether to use VOC07's 11 point AP computation
259 | (default True)
260 | """
261 | # assumes detections are in detpath.format(classname)
262 | # assumes annotations are in annopath.format(imagename)
263 | # assumes imagesetfile is a text file with each line an image name
264 | # cachedir caches the annotations in a pickle file
265 | # first load gt
266 | if not os.path.isdir(cachedir):
267 | os.mkdir(cachedir)
268 | cachefile = os.path.join(cachedir, 'annots.pkl')
269 | # read list of images
270 | with open(imagesetfile, 'r') as f:
271 | lines = f.readlines()
272 | imagenames = [x.strip() for x in lines]
273 | if not os.path.isfile(cachefile):
274 | # load annots
275 | recs = {}
276 | for i, imagename in enumerate(imagenames):
277 | recs[imagename] = parse_rec(annopath % (imagename))
278 | if i % 100 == 0:
279 | print('Reading annotation for {:d}/{:d}'.format(
280 | i + 1, len(imagenames)))
281 | # save
282 | print('Saving cached annotations to {:s}'.format(cachefile))
283 | with open(cachefile, 'wb') as f:
284 | pickle.dump(recs, f)
285 | else:
286 | # load
287 | with open(cachefile, 'rb') as f:
288 | recs = pickle.load(f)
289 |
290 | # extract gt objects for this class
291 | class_recs = {}
292 | npos = 0
293 | for imagename in imagenames:
294 | R = [obj for obj in recs[imagename] if obj['name'] == classname]
295 | bbox = np.array([x['bbox'] for x in R])
296 | difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
297 | det = [False] * len(R)
298 | npos = npos + sum(~difficult)
299 | class_recs[imagename] = {'bbox': bbox,
300 | 'difficult': difficult,
301 | 'det': det}
302 |
303 | # read dets
304 | detfile = detpath.format(classname)
305 | with open(detfile, 'r') as f:
306 | lines = f.readlines()
307 | if any(lines) == 1:
308 |
309 | splitlines = [x.strip().split(' ') for x in lines]
310 | image_ids = [x[0] for x in splitlines]
311 | confidence = np.array([float(x[1]) for x in splitlines])
312 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
313 |
314 | # sort by confidence
315 | sorted_ind = np.argsort(-confidence)
316 | sorted_scores = np.sort(-confidence)
317 | BB = BB[sorted_ind, :]
318 | image_ids = [image_ids[x] for x in sorted_ind]
319 |
320 | # go down dets and mark TPs and FPs
321 | nd = len(image_ids)
322 | tp = np.zeros(nd)
323 | fp = np.zeros(nd)
324 | for d in range(nd):
325 | R = class_recs[image_ids[d]]
326 | bb = BB[d, :].astype(float)
327 | ovmax = -np.inf
328 | BBGT = R['bbox'].astype(float)
329 | if BBGT.size > 0:
330 | # compute overlaps
331 | # intersection
332 | ixmin = np.maximum(BBGT[:, 0], bb[0])
333 | iymin = np.maximum(BBGT[:, 1], bb[1])
334 | ixmax = np.minimum(BBGT[:, 2], bb[2])
335 | iymax = np.minimum(BBGT[:, 3], bb[3])
336 | iw = np.maximum(ixmax - ixmin, 0.)
337 | ih = np.maximum(iymax - iymin, 0.)
338 | inters = iw * ih
339 | uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +
340 | (BBGT[:, 2] - BBGT[:, 0]) *
341 | (BBGT[:, 3] - BBGT[:, 1]) - inters)
342 | overlaps = inters / uni
343 | ovmax = np.max(overlaps)
344 | jmax = np.argmax(overlaps)
345 |
346 | if ovmax > ovthresh:
347 | if not R['difficult'][jmax]:
348 | if not R['det'][jmax]:
349 | tp[d] = 1.
350 | R['det'][jmax] = 1
351 | else:
352 | fp[d] = 1.
353 | else:
354 | fp[d] = 1.
355 |
356 | # compute precision recall
357 | fp = np.cumsum(fp)
358 | tp = np.cumsum(tp)
359 | rec = tp / float(npos)
360 | # avoid divide by zero in case the first detection matches a difficult
361 | # ground truth
362 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
363 | ap = voc_ap(rec, prec, use_07_metric)
364 | else:
365 | rec = -1.
366 | prec = -1.
367 | ap = -1.
368 |
369 | return rec, prec, ap
370 |
371 |
372 | def test_net(save_folder, net, cuda, dataset, transform, top_k,
373 | im_size=300, thresh=0.05):
374 | num_images = len(dataset)
375 |
376 | all_boxes = [[[] for _ in range(num_images)]
377 | for _ in range(len(labelmap)+1)]
378 |
379 | # timers
380 | _t = {'im_detect': Timer(), 'misc': Timer()}
381 | output_dir = get_output_dir('ssd300_120000', set_type)
382 | det_file = os.path.join(output_dir, 'detections.pkl')
383 |
384 | for i in range(num_images):
385 | im, gt, h, w, _ = dataset.pull_item(i)
386 |
387 | x = Variable(im.unsqueeze(0))
388 | if args.cuda:
389 | x = x.cuda()
390 | _t['im_detect'].tic()
391 | detections = net(x).data
392 | detect_time = _t['im_detect'].toc(average=False)
393 |
394 | # skip j = 0, because it's the background class
395 | for j in range(1, detections.size(1)):
396 | dets = detections[0, j, :]
397 | mask = dets[:, 0].gt(0.).expand(5, dets.size(0)).t()
398 | dets = torch.masked_select(dets, mask).view(-1, 5)
399 | if dets.dim() == 0:
400 | continue
401 | boxes = dets[:, 1:]
402 | boxes[:, 0] *= w
403 | boxes[:, 2] *= w
404 | boxes[:, 1] *= h
405 | boxes[:, 3] *= h
406 | scores = dets[:, 0].cpu().numpy()
407 | cls_dets = np.hstack((boxes.cpu().numpy(),
408 | scores[:, np.newaxis])).astype(np.float32,
409 | copy=False)
410 | all_boxes[j][i] = cls_dets
411 |
412 | print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1,
413 | num_images, detect_time))
414 |
415 | with open(det_file, 'wb') as f:
416 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
417 |
418 | print('Evaluating detections')
419 | res = evaluate_detections(all_boxes, output_dir, dataset)
420 | return res
421 |
422 |
423 | def evaluate_detections(box_list, output_dir, dataset):
424 | write_voc_results_file(box_list, dataset)
425 | res = do_python_eval(output_dir)
426 | return res
427 |
428 |
429 | if __name__ == '__main__':
430 | """
431 | student_model_dict = {}
432 | for key, value in model_student.state_dict().items():
433 | student_model_dict[key] = value
434 |
435 | new_teacher_dict = OrderedDict()
436 | for key, value in model_teacher.state_dict().items():
437 | if key in student_model_dict.keys():
438 | new_teacher_dict[key] = (
439 | student_model_dict[key] * (1 - keep_rate) + value * keep_rate)"""
440 |
441 | # load net
442 | num_classes = len(labelmap) + 1 # +1 for background
443 | net = build_ssd('test', 300, num_classes) # initialize SSD
444 | net = nn.DataParallel(net)
445 | fi_write = open("results.txt", "a")
446 | for key in net.state_dict():
447 | print(key)
448 |
449 | list_of_folders = ['weights'] # folder where the saved networks are
450 | for folder in list_of_folders:
451 | list_nets = os.listdir(folder)
452 | for nnn in sorted(list_nets):
453 |
454 | net.load_state_dict(torch.load(os.path.join(folder, nnn)))
455 | net.eval()
456 | print('Finished loading model!')
457 | # load data
458 | dataset = VOCDetection(args.voc_root, image_sets=[('2007', 'test')],
459 | transform=BaseTransform(300, dataset_mean),
460 | target_transform=VOCAnnotationTransform())
461 | if args.cuda:
462 | net = net.cuda()
463 | cudnn.benchmark = True
464 | # evaluation
465 | m_ap = test_net(args.save_folder, net, args.cuda, dataset,
466 | BaseTransform(300, dataset_mean), args.top_k, 300,
467 | thresh=args.confidence_threshold)
468 | print(m_ap)
469 | fi_write.write(folder + '_____' + nnn + ": " +str(np.mean(m_ap)))
470 | fi_write.write('\n')
471 | fi_write.close()
472 |
473 |
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import *
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/box_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/box_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/__pycache__/box_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/__pycache__/box_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/box_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 |
5 | def point_form(boxes):
6 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
7 | representation for comparison to point form ground truth data.
8 | Args:
9 | boxes: (tensor) center-size default boxes from priorbox layers.
10 | Return:
11 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
12 | """
13 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
14 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
15 |
16 |
17 | def center_size(boxes):
18 | """ Convert prior_boxes to (cx, cy, w, h)
19 | representation for comparison to center-size form ground truth data.
20 | Args:
21 | boxes: (tensor) point_form boxes
22 | Return:
23 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
24 | """
25 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
26 | boxes[:, 2:] - boxes[:, :2], 1) # w, h
27 |
28 |
29 | def intersect(box_a, box_b):
30 | """ We resize both tensors to [A,B,2] without new malloc:
31 | [A,2] -> [A,1,2] -> [A,B,2]
32 | [B,2] -> [1,B,2] -> [A,B,2]
33 | Then we compute the area of intersect between box_a and box_b.
34 | Args:
35 | box_a: (tensor) bounding boxes, Shape: [A,4].
36 | box_b: (tensor) bounding boxes, Shape: [B,4].
37 | Return:
38 | (tensor) intersection area, Shape: [A,B].
39 | """
40 | A = box_a.size(0)
41 | B = box_b.size(0)
42 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
43 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
44 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
45 | box_b[:, :2].unsqueeze(0).expand(A, B, 2))
46 | inter = torch.clamp((max_xy - min_xy), min=0)
47 | return inter[:, :, 0] * inter[:, :, 1]
48 |
49 |
50 | def jaccard(box_a, box_b):
51 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
52 | is simply the intersection over union of two boxes. Here we operate on
53 | ground truth boxes and default boxes.
54 | E.g.:
55 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
56 | Args:
57 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
58 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
59 | Return:
60 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
61 | """
62 | inter = intersect(box_a, box_b)
63 | area_a = ((box_a[:, 2]-box_a[:, 0]) *
64 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
65 | area_b = ((box_b[:, 2]-box_b[:, 0]) *
66 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
67 | union = area_a + area_b - inter
68 | return inter / union # [A,B]
69 |
70 |
71 | def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx):
72 | """Match each prior box with the ground truth box of the highest jaccard
73 | overlap, encode the bounding boxes, then return the matched indices
74 | corresponding to both confidence and location preds.
75 | Args:
76 | threshold: (float) The overlap threshold used when mathing boxes.
77 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors].
78 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
79 | variances: (tensor) Variances corresponding to each prior coord,
80 | Shape: [num_priors, 4].
81 | labels: (tensor) All the class labels for the image, Shape: [num_obj].
82 | loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
83 | conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
84 | idx: (int) current batch index
85 | Return:
86 | The matched indices corresponding to 1)location and 2)confidence preds.
87 | """
88 | # jaccard index
89 | overlaps = jaccard(
90 | truths,
91 | point_form(priors)
92 | )
93 | # (Bipartite Matching)
94 | # [1,num_objects] best prior for each ground truth
95 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
96 | # [1,num_priors] best ground truth for each prior
97 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
98 | best_truth_idx.squeeze_(0)
99 | best_truth_overlap.squeeze_(0)
100 | best_prior_idx.squeeze_(1)
101 | best_prior_overlap.squeeze_(1)
102 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior
103 | # TODO refactor: index best_prior_idx with long tensor
104 | # ensure every gt matches with its prior of max overlap
105 | for j in range(best_prior_idx.size(0)):
106 | best_truth_idx[best_prior_idx[j]] = j
107 | matches = truths[best_truth_idx] # Shape: [num_priors,4]
108 | conf = labels[best_truth_idx] + 1 # Shape: [num_priors]
109 | conf[best_truth_overlap < threshold] = 0 # label as background
110 | loc = encode(matches, priors, variances)
111 | loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
112 | conf_t[idx] = conf # [num_priors] top class label for each prior
113 |
114 |
115 | def encode(matched, priors, variances):
116 | """Encode the variances from the priorbox layers into the ground truth boxes
117 | we have matched (based on jaccard overlap) with the prior boxes.
118 | Args:
119 | matched: (tensor) Coords of ground truth for each prior in point-form
120 | Shape: [num_priors, 4].
121 | priors: (tensor) Prior boxes in center-offset form
122 | Shape: [num_priors,4].
123 | variances: (list[float]) Variances of priorboxes
124 | Return:
125 | encoded boxes (tensor), Shape: [num_priors, 4]
126 | """
127 |
128 | # dist b/t match center and prior's center
129 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
130 | # encode variance
131 | g_cxcy /= (variances[0] * priors[:, 2:])
132 | # match wh / prior wh
133 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
134 | g_wh = torch.log(g_wh) / variances[1]
135 | # return target for smooth_l1_loss
136 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
137 |
138 |
139 | # Adapted from https://github.com/Hakuyume/chainer-ssd
140 | def decode(loc, priors, variances):
141 | """Decode locations from predictions using priors to undo
142 | the encoding we did for offset regression at train time.
143 | Args:
144 | loc (tensor): location predictions for loc layers,
145 | Shape: [num_priors,4]
146 | priors (tensor): Prior boxes in center-offset form.
147 | Shape: [num_priors,4].
148 | variances: (list[float]) Variances of priorboxes
149 | Return:
150 | decoded bounding box predictions
151 | """
152 |
153 | boxes = torch.cat((
154 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
155 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
156 | boxes[:, :2] -= boxes[:, 2:] / 2
157 | boxes[:, 2:] += boxes[:, :2]
158 | return boxes
159 |
160 |
161 | def log_sum_exp(x):
162 | """Utility function for computing log_sum_exp while determining
163 | This will be used to determine unaveraged confidence loss across
164 | all examples in a batch.
165 | Args:
166 | x (Variable(tensor)): conf_preds from conf layers
167 | """
168 | x_max = x.data.max()
169 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
170 |
171 |
172 | # Original author: Francisco Massa:
173 | # https://github.com/fmassa/object-detection.torch
174 | # Ported to PyTorch by Max deGroot (02/01/2017)
175 | def nms(boxes, scores, overlap=0.5, top_k=200):
176 | """Apply non-maximum suppression at test time to avoid detecting too many
177 | overlapping bounding boxes for a given object.
178 | Args:
179 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
180 | scores: (tensor) The class predscores for the img, Shape:[num_priors].
181 | overlap: (float) The overlap thresh for suppressing unnecessary boxes.
182 | top_k: (int) The Maximum number of box preds to consider.
183 | Return:
184 | The indices of the kept boxes with respect to num_priors.
185 | """
186 |
187 | keep = scores.new(scores.size(0)).zero_().long()
188 | if boxes.numel() == 0:
189 | return keep
190 | x1 = boxes[:, 0]
191 | y1 = boxes[:, 1]
192 | x2 = boxes[:, 2]
193 | y2 = boxes[:, 3]
194 | area = torch.mul(x2 - x1, y2 - y1)
195 | v, idx = scores.sort(0) # sort in ascending order
196 | # I = I[v >= 0.01]
197 | idx = idx[-top_k:] # indices of the top-k largest vals
198 | xx1 = boxes.new()
199 | yy1 = boxes.new()
200 | xx2 = boxes.new()
201 | yy2 = boxes.new()
202 | w = boxes.new()
203 | h = boxes.new()
204 |
205 | # keep = torch.Tensor()
206 | count = 0
207 | while idx.numel() > 0:
208 | i = idx[-1] # index of current largest val
209 | # keep.append(i)
210 | keep[count] = i
211 | count += 1
212 | if idx.size(0) == 1:
213 | break
214 | idx = idx[:-1] # remove kept element from view
215 | # load bboxes of next highest vals
216 | torch.index_select(x1, 0, idx, out=xx1)
217 | torch.index_select(y1, 0, idx, out=yy1)
218 | torch.index_select(x2, 0, idx, out=xx2)
219 | torch.index_select(y2, 0, idx, out=yy2)
220 | # store element-wise max with next highest score
221 | xx1 = torch.clamp(xx1, min=x1[i])
222 | yy1 = torch.clamp(yy1, min=y1[i])
223 | xx2 = torch.clamp(xx2, max=x2[i])
224 | yy2 = torch.clamp(yy2, max=y2[i])
225 | w.resize_as_(xx2)
226 | h.resize_as_(yy2)
227 | w = xx2 - xx1
228 | h = yy2 - yy1
229 | # check sizes of xx1 and xx2.. after each iteration
230 | w = torch.clamp(w, min=0.0)
231 | h = torch.clamp(h, min=0.0)
232 | inter = w*h
233 | # IoU = i / (area(a) + area(b) - i)
234 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
235 | union = (rem_areas - inter) + area[i]
236 | IoU = inter/union # store result in iou
237 | # keep only elements with an IoU <= overlap
238 | idx = idx[IoU.le(overlap)]
239 | return keep, count
240 |
--------------------------------------------------------------------------------
/layers/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .detection import Detect
2 | from .prior_box import PriorBox
3 |
4 |
5 | __all__ = ['Detect', 'PriorBox']
6 |
--------------------------------------------------------------------------------
/layers/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/detection.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/detection.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/detection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/detection.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/prior_box.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/prior_box.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/functions/__pycache__/prior_box.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/functions/__pycache__/prior_box.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/functions/detection.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from ..box_utils import decode, nms
4 | from data import voc300 as cfg
5 | from data import voc512 as cfg
6 |
7 |
8 | class Detect(Function):
9 | """At test time, Detect is the final layer of SSD. Decode location preds,
10 | apply non-maximum suppression to location predictions based on conf
11 | scores and threshold to a top_k number of output predictions for both
12 | confidence score and locations.
13 | """
14 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
15 | self.num_classes = num_classes
16 | self.background_label = bkg_label
17 | self.top_k = top_k
18 | # Parameters used in nms.
19 | self.nms_thresh = nms_thresh
20 | if nms_thresh <= 0:
21 | raise ValueError('nms_threshold must be non negative.')
22 | self.conf_thresh = conf_thresh
23 | self.variance = [0.1, 0.2] #cfg['variance']
24 |
25 | def forward(self, loc_data, conf_data, prior_data):
26 | """
27 | Args:
28 | loc_data: (tensor) Loc preds from loc layers
29 | Shape: [batch,num_priors*4]
30 | conf_data: (tensor) Shape: Conf preds from conf layers
31 | Shape: [batch*num_priors,num_classes]
32 | prior_data: (tensor) Prior boxes and variances from priorbox layers
33 | Shape: [1,num_priors,4]
34 | """
35 | num = loc_data.size(0) # batch size
36 | num_priors = prior_data.size(0)
37 | output = torch.zeros(num, self.num_classes, self.top_k, 5)
38 | conf_preds = conf_data.view(num, num_priors,
39 | self.num_classes).transpose(2, 1)
40 |
41 | # Decode predictions into bboxes.
42 | for i in range(num):
43 | decoded_boxes = decode(loc_data[i], prior_data, self.variance)
44 | # For each class, perform nms
45 | conf_scores = conf_preds[i].clone()
46 |
47 | for cl in range(1, self.num_classes):
48 | c_mask = conf_scores[cl].gt(self.conf_thresh)
49 | scores = conf_scores[cl][c_mask]
50 | if scores.size(0) == 0:
51 | continue
52 | # if scores.dim() == 0:
53 | # continue
54 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
55 | boxes = decoded_boxes[l_mask].view(-1, 4)
56 | # idx of highest scoring and non-overlapping boxes per class
57 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
58 | output[i, cl, :count] = \
59 | torch.cat((scores[ids[:count]].unsqueeze(1),
60 | boxes[ids[:count]]), 1)
61 | flt = output.contiguous().view(num, -1, 5)
62 | _, idx = flt[:, :, 0].sort(1, descending=True)
63 | _, rank = idx.sort(1)
64 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
65 | return output
66 |
--------------------------------------------------------------------------------
/layers/functions/prior_box.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from math import sqrt as sqrt
3 | from itertools import product as product
4 | import torch
5 |
6 |
7 | class PriorBox(object):
8 | """Compute priorbox coordinates in center-offset form for each source
9 | feature map.
10 | """
11 | def __init__(self, cfg):
12 | super(PriorBox, self).__init__()
13 | self.image_size = cfg['min_dim']
14 | # number of priors for feature map location (either 4 or 6)
15 | self.num_priors = len(cfg['aspect_ratios'])
16 | self.variance = cfg['variance'] or [0.1]
17 | self.feature_maps = cfg['feature_maps']
18 | self.min_sizes = cfg['min_sizes']
19 | self.max_sizes = cfg['max_sizes']
20 | self.steps = cfg['steps']
21 | self.aspect_ratios = cfg['aspect_ratios']
22 | self.clip = cfg['clip']
23 | self.version = cfg['name']
24 | for v in self.variance:
25 | if v <= 0:
26 | raise ValueError('Variances must be greater than 0')
27 |
28 | def forward(self):
29 | mean = []
30 | for k, f in enumerate(self.feature_maps):
31 | for i, j in product(range(f), repeat=2):
32 | f_k = self.image_size / self.steps[k]
33 | # unit center x,y
34 | cx = (j + 0.5) / f_k
35 | cy = (i + 0.5) / f_k
36 |
37 | # aspect_ratio: 1
38 | # rel size: min_size
39 | s_k = self.min_sizes[k]/self.image_size
40 | mean += [cx, cy, s_k, s_k]
41 |
42 | # aspect_ratio: 1
43 | # rel size: sqrt(s_k * s_(k+1))
44 | s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size))
45 | mean += [cx, cy, s_k_prime, s_k_prime]
46 |
47 | # rest of aspect ratios
48 | for ar in self.aspect_ratios[k]:
49 | mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)]
50 | mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)]
51 | # back to torch land
52 | output = torch.Tensor(mean).view(-1, 4)
53 | if self.clip:
54 | output.clamp_(max=1, min=0)
55 | return output
56 |
--------------------------------------------------------------------------------
/layers/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .l2norm import L2Norm
2 | from .multibox_loss import MultiBoxLoss
3 |
4 | __all__ = ['L2Norm', 'MultiBoxLoss']
5 |
--------------------------------------------------------------------------------
/layers/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/l2norm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/l2norm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/l2norm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/l2norm.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss_gmm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_gmm.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss_new.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-36.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss_new.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-37.pyc
--------------------------------------------------------------------------------
/layers/modules/__pycache__/multibox_loss_new.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/layers/modules/__pycache__/multibox_loss_new.cpython-38.pyc
--------------------------------------------------------------------------------
/layers/modules/l2norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Function
4 | from torch.autograd import Variable
5 | import torch.nn.init as init
6 |
7 | class L2Norm(nn.Module):
8 | def __init__(self,n_channels, scale):
9 | super(L2Norm,self).__init__()
10 | self.n_channels = n_channels
11 | self.gamma = scale or None
12 | self.eps = 1e-10
13 | self.weight = nn.Parameter(torch.Tensor(self.n_channels))
14 | self.reset_parameters()
15 |
16 | def reset_parameters(self):
17 | init.constant(self.weight,self.gamma)
18 |
19 | def forward(self, x):
20 | norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps
21 | #x /= norm
22 | x = torch.div(x,norm)
23 | out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x
24 | return out
25 |
--------------------------------------------------------------------------------
/layers/modules/multibox_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | # -*- coding: utf-8 -*-
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from data import voc300 as cfg
14 | from ..box_utils import match, log_sum_exp
15 | import numpy as np
16 |
17 |
18 | class MultiBoxLoss(nn.Module):
19 | """SSD Weighted Loss Function
20 | Compute Targets:
21 | 1) Produce Confidence Target Indices by matching ground truth boxes
22 | with (default) 'priorboxes' that have jaccard index > threshold parameter
23 | (default threshold: 0.5).
24 | 2) Produce localization target by 'encoding' variance into offsets of ground
25 | truth boxes and their matched 'priorboxes'.
26 | 3) Hard negative mining to filter the excessive number of negative examples
27 | that comes with using a large number of default bounding boxes.
28 | (default negative:positive ratio 3:1)
29 | Objective Loss:
30 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
31 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
32 | weighted by α which is set to 1 by cross val.
33 | Args:
34 | c: class confidences,
35 | l: predicted boxes,
36 | g: ground truth boxes
37 | N: number of matched default boxes
38 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
39 | """
40 |
41 | def __init__(self, num_classes, overlap_thresh, prior_for_matching,
42 | bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
43 | use_gpu=True):
44 | super(MultiBoxLoss, self).__init__()
45 | self.use_gpu = use_gpu
46 | self.num_classes = num_classes
47 | self.threshold = overlap_thresh
48 | self.background_label = bkg_label
49 | self.encode_target = encode_target
50 | self.use_prior_for_matching = prior_for_matching
51 | self.do_neg_mining = neg_mining
52 | self.negpos_ratio = neg_pos
53 | self.neg_overlap = neg_overlap
54 | self.variance = cfg['variance']
55 |
56 | def forward(self, predictions, targets, semis=[]):
57 | """Multibox Loss
58 | Args:
59 | predictions (tuple): A tuple containing loc preds, conf preds,
60 | and prior boxes from SSD net.
61 | conf shape: torch.size(batch_size,num_priors,num_classes)
62 | loc shape: torch.size(batch_size,num_priors,4)
63 | priors shape: torch.size(num_priors,4)
64 |
65 | targets (tensor): Ground truth boxes and labels for a batch,
66 | shape: [batch_size,num_objs,5] (last idx is the label).
67 | """
68 | loc_data, conf_data, priors = predictions
69 | num = loc_data.size(0)
70 | priors = priors[:loc_data.size(1), :]
71 | num_priors = (priors.size(0))
72 | num_classes = self.num_classes
73 |
74 | # match priors (default boxes) and ground truth boxes
75 | loc_t = torch.Tensor(num, num_priors, 4)
76 | conf_t = torch.LongTensor(num, num_priors)
77 | for idx in range(num):
78 | truths = targets[idx][:, :-1].data
79 | labels = targets[idx][:, -1].data
80 | defaults = priors.data
81 | match(self.threshold, truths, defaults, self.variance, labels,
82 | loc_t, conf_t, idx)
83 | if self.use_gpu:
84 | loc_t = loc_t.cuda()
85 | conf_t = conf_t.cuda()
86 | # wrap targets
87 | loc_t = Variable(loc_t, requires_grad=False)
88 | conf_t = Variable(conf_t, requires_grad=False)
89 |
90 | pos = conf_t > 0
91 | num_pos = pos.sum(dim=1, keepdim=True)
92 |
93 | # Localization Loss (Smooth L1)
94 | # Shape: [batch,num_priors,4]
95 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
96 | loc_p = loc_data[pos_idx].view(-1, 4)
97 | loc_t = loc_t[pos_idx].view(-1, 4)
98 | loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)
99 |
100 | # Compute max conf across batch for hard negative mining
101 | batch_conf = conf_data.view(-1, self.num_classes)
102 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
103 |
104 | # Hard Negative Mining
105 | loss_c = loss_c.view(pos.size()[0], pos.size()[1])
106 | loss_c[pos] = 0 # filter out pos boxes for now
107 | loss_c = loss_c.view(num, -1)
108 | _, loss_idx = loss_c.sort(1, descending=True)
109 | _, idx_rank = loss_idx.sort(1)
110 | num_pos = pos.long().sum(1, keepdim=True)
111 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
112 | neg = idx_rank < num_neg.expand_as(idx_rank)
113 |
114 | # Confidence Loss Including Positive and Negative Examples
115 | pos_idx = pos.unsqueeze(2).expand_as(conf_data)
116 | neg_idx = neg.unsqueeze(2).expand_as(conf_data)
117 |
118 | real_labels, pseudo_labels = [False] * len(semis), [False] * len(semis)
119 | for i, el in enumerate(semis):
120 | if el == torch.FloatTensor([1.]):
121 | real_labels[i] = True
122 | elif el == torch.FloatTensor([2.]):
123 | pseudo_labels[i] = True
124 |
125 | pos_idx_labels = pos_idx[real_labels]
126 | neg_idx = neg_idx[real_labels]
127 | conf_data_real = conf_data[real_labels]
128 | conf_t_real = conf_t[real_labels]
129 | pos_real = pos[real_labels]
130 | neg_real = neg[real_labels]
131 | pos_idx_pseudo = pos_idx[pseudo_labels]
132 | conf_data_pseudo = conf_data[pseudo_labels]
133 | conf_t_pseudo = conf_t[pseudo_labels]
134 | pos_pseudo = pos[pseudo_labels]
135 |
136 | conf_p_labels = conf_data_real[(pos_idx_labels+neg_idx).gt(0)].view(-1, self.num_classes)
137 | conf_p_pseudo = conf_data_pseudo[(pos_idx_pseudo).gt(0)].view(-1, self.num_classes)
138 | targets_weighted_labels = conf_t_real[(pos_real + neg_real).gt(0)]
139 | targets_weighted_pseudo = conf_t_pseudo[(pos_pseudo).gt(0)]
140 |
141 | loss_c_real = F.cross_entropy(conf_p_labels, targets_weighted_labels, size_average=False)
142 | if targets_weighted_pseudo.shape[0] > 0:
143 | loss_c_pseudo = F.cross_entropy(conf_p_pseudo, targets_weighted_pseudo, size_average=False)
144 | loss_c = loss_c_real + loss_c_pseudo
145 | else:
146 | loss_c = loss_c_real
147 |
148 | N = num_pos.data.sum()
149 | loss_l /= N
150 | loss_c /= N
151 | return loss_l, loss_c
152 |
--------------------------------------------------------------------------------
/loaders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | import random
9 | import torch.utils.data as data
10 |
11 | from data import *
12 | from pseudo_labels import predict_pseudo_labels
13 | from subset_sequential_sampler import SubsetSequentialSampler, BalancedSampler
14 | from utils.augmentations import SSDAugmentation
15 |
16 | random.seed(314)
17 |
18 |
19 | def create_loaders(args):
20 | indices_1 = list(range(args.num_total_images))
21 | random.shuffle(indices_1)
22 | labeled_set = indices_1[:args.num_initial_labeled_set]
23 | indices = list(range(args.num_total_images))
24 | unlabeled_set = set(indices) - set(labeled_set)
25 | labeled_set = list(labeled_set)
26 | unlabeled_set = list(unlabeled_set)
27 | random.shuffle(indices)
28 |
29 |
30 | print(len(indices))
31 |
32 | if args.dataset_name == 'voc':
33 | supervised_dataset = VOCDetection(root=args.dataset_root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
34 | supervised_indices=labeled_set,
35 | transform=SSDAugmentation(args.cfg['min_dim'], MEANS))
36 |
37 | unsupervised_dataset = VOCDetection(args.dataset_root, supervised_indices=None,
38 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
39 | transform=BaseTransform(300, MEANS),
40 | target_transform=VOCAnnotationTransform())
41 |
42 | else:
43 | supervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=labeled_set, image_set='train2014',
44 | transform=SSDAugmentation(args.cfg['min_dim']),
45 | target_transform=COCOAnnotationTransform(),
46 | dataset_name='MS COCO')
47 |
48 | unsupervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=None, image_set='val2014',
49 | transform=SSDAugmentation(args.cfg['min_dim']),
50 | target_transform=COCOAnnotationTransform(),
51 | dataset_name='MS COCO')
52 |
53 |
54 | supervised_data_loader = data.DataLoader(supervised_dataset, batch_size=args.batch_size,
55 | num_workers=args.num_workers,
56 | sampler=BalancedSampler(indices, labeled_set, unlabeled_set, ratio=1),
57 | collate_fn=detection_collate,
58 | pin_memory=True)
59 |
60 | unsupervised_data_loader = data.DataLoader(unsupervised_dataset, batch_size=1,
61 | sampler=SubsetSequentialSampler(unlabeled_set),
62 | num_workers=args.num_workers,
63 | collate_fn=detection_collate,
64 | pin_memory=True)
65 |
66 | return supervised_dataset, supervised_data_loader, unsupervised_dataset, unsupervised_data_loader, indices, labeled_set, unlabeled_set
67 |
68 |
69 | def change_loaders(args, supervised_dataset, unsupervised_dataset, labeled_set,
70 | unlabeled_set, indices, net, pseudo=True):
71 | print("Labeled set size: " + str(len(labeled_set)))
72 | unsupervised_data_loader = data.DataLoader(unsupervised_dataset, batch_size=1,
73 | sampler=SubsetSequentialSampler(unlabeled_set),
74 | num_workers=args.num_workers,
75 | collate_fn=detection_collate,
76 | pin_memory=True)
77 |
78 | if pseudo:
79 | if args.dataset_name == 'voc':
80 | voc = True
81 | num_classes = 21
82 | else:
83 | voc = False
84 | num_classes = 81
85 | pseudo_labels = predict_pseudo_labels(unlabeled_set=unlabeled_set, net_name=net,
86 | threshold=args.pseudo_threshold, root=args.dataset_root,
87 | voc=voc, num_classes=num_classes)
88 | else:
89 | pseudo_labels = {}
90 |
91 | if args.dataset_name == 'voc':
92 | supervised_dataset = VOCDetection(root=args.dataset_root,
93 | image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
94 | supervised_indices=labeled_set,
95 | transform=SSDAugmentation(args.cfg['min_dim'], MEANS),
96 | pseudo_labels=pseudo_labels)
97 | else:
98 | supervised_dataset = COCODetection(root=args.dataset_root, supervised_indices=labeled_set, image_set='train2014',
99 | transform=SSDAugmentation(args.cfg['min_dim']),
100 | target_transform=COCOAnnotationTransform(),
101 | dataset_name='MS COCO',
102 | pseudo_labels=pseudo_labels)
103 |
104 | print("Changing the loaders")
105 |
106 | supervised_data_loader = data.DataLoader(supervised_dataset, batch_size=args.batch_size,
107 | num_workers=args.num_workers,
108 | sampler=BalancedSampler(indices, labeled_set, unlabeled_set, ratio=1),
109 | collate_fn=detection_collate,
110 | pin_memory=True)
111 |
112 | return supervised_data_loader, unsupervised_data_loader
113 |
--------------------------------------------------------------------------------
/pseudo_labels.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | import os
9 | import sys
10 | module_path = os.path.abspath(os.path.join('..'))
11 | if module_path not in sys.path:
12 | sys.path.append(module_path)
13 | import random
14 | import pickle
15 | from torch.autograd import Variable
16 | from ssd import build_ssd
17 | import torch.nn as nn
18 | from data import *
19 | from data import VOC_CLASSES as labels1
20 | from data import COCO_CLASSES as labels_2
21 | from collections import defaultdict
22 | import random
23 | random.seed(314)
24 | torch.manual_seed(314)
25 |
26 |
27 | def predict_pseudo_labels(unlabeled_set, net_name, threshold=0.5, root='../tmp/VOC0712/', voc=1, num_classes=21):
28 | labels = labels1 if voc else labels_2
29 | if voc:
30 | testset = VOCDetection(root=root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
31 | supervised_indices=None,
32 | transform=None)
33 | else:
34 | testset = COCODetection(root=root, image_set='train2014',
35 | supervised_indices=None,
36 | transform=None)
37 |
38 | print("Doing PL")
39 | net = build_ssd('test', 300, num_classes)
40 | net = nn.DataParallel(net)
41 | net.load_state_dict(torch.load(net_name))
42 | boxes = get_pseudo_labels(testset, net, labels, unlabeled_set=unlabeled_set, threshold=threshold, voc=voc)
43 | return boxes
44 |
45 |
46 | def get_pseudo_labels(testset, net, labels, unlabeled_set=None, threshold=0.99, voc=1):
47 |
48 | boxes = defaultdict(list)
49 | for ii, img_id in enumerate(unlabeled_set):
50 | print(ii)
51 | image = testset.pull_image(img_id)
52 | x = cv2.resize(image, (300, 300)).astype(np.float32)
53 | x -= (104.0, 117.0, 123.0)
54 |
55 | x = x.astype(np.float32)
56 | x = x[:, :, ::-1].copy()
57 | x = torch.from_numpy(x).permute(2, 0, 1)
58 |
59 | xx = Variable(x.unsqueeze(0)) # wrap tensor in Variable
60 | if torch.cuda.is_available():
61 | xx = xx.cuda()
62 | y = net(xx)
63 |
64 | detections = y.data
65 | # scale each detection back up to the image
66 | scale = torch.Tensor(image.shape[1::-1]).repeat(2)
67 |
68 | for i in range(detections.size(1)):
69 | j = 0
70 | while detections[0, i, j, 0] >= threshold:
71 | score = detections[0, i, j, 0]
72 | if voc == 1:
73 | label_name = labels[i - 1]
74 | else:
75 | label_name = labels[i - 1]
76 | pt = (detections[0, i, j, 1:5] * scale).cpu().numpy()
77 | j += 1
78 | # sore as [prediction_confidence, label_id, label_name, height, width, bbox_coordiates in range [0, inf]
79 | boxes[img_id].append([score.cpu().detach().item(), (i-1), label_name, image.shape[0], image.shape[1], int(pt[0]), int(pt[1]), int(pt[2]), int(pt[3])])
80 |
81 | return boxes
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
--------------------------------------------------------------------------------
/ssd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from layers import *
6 | from data import voc300, voc512, coco
7 | import os
8 |
9 |
10 | class SSD(nn.Module):
11 | """Single Shot Multibox Architecture
12 | The network is composed of a base VGG network followed by the
13 | added multibox conv layers. Each multibox layer branches into
14 | 1) conv2d for class conf scores
15 | 2) conv2d for localization predictions
16 | 3) associated priorbox layer to produce default bounding
17 | boxes specific to the layer's feature map size.
18 | See: https://arxiv.org/pdf/1512.02325.pdf for more details.
19 |
20 | Args:
21 | phase: (string) Can be "test" or "train"
22 | size: input image size
23 | base: VGG16 layers for input, size of either 300 or 500
24 | extras: extra layers that feed to multibox loc and conf layers
25 | head: "multibox head" consists of loc and conf conv layers
26 | """
27 |
28 | def __init__(self, phase, size, base, extras, head, num_classes):
29 | super(SSD, self).__init__()
30 | self.phase = phase
31 | self.num_classes = num_classes
32 | if(size==300):
33 | self.cfg = (coco, voc300)[num_classes == 21]
34 | else:
35 | self.cfg = (coco, voc512)[num_classes == 21]
36 | self.priorbox = PriorBox(self.cfg)
37 | self.priors = Variable(self.priorbox.forward(), volatile=True)
38 | self.size = size
39 |
40 | # SSD network
41 | self.vgg = nn.ModuleList(base)
42 | # Layer learns to scale the l2 normalized features from conv4_3
43 | self.L2Norm = L2Norm(512, 20)
44 | self.extras = nn.ModuleList(extras)
45 |
46 | self.loc = nn.ModuleList(head[0])
47 | self.conf = nn.ModuleList(head[1])
48 |
49 | if phase == 'test':
50 | self.softmax = nn.Softmax(dim=-1)
51 | self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
52 |
53 | def forward(self, x):
54 | """Applies network layers and ops on input image(s) x.
55 |
56 | Args:
57 | x: input image or batch of images. Shape: [batch,3,300,300].
58 |
59 | Return:
60 | Depending on phase:
61 | test:
62 | Variable(tensor) of output class label predictions,
63 | confidence score, and corresponding location predictions for
64 | each object detected. Shape: [batch,topk,7]
65 |
66 | train:
67 | list of concat outputs from:
68 | 1: confidence layers, Shape: [batch*num_priors,num_classes]
69 | 2: localization layers, Shape: [batch,num_priors*4]
70 | 3: priorbox layers, Shape: [2,num_priors*4]
71 | """
72 | sources = list()
73 | loc = list()
74 | conf = list()
75 |
76 | # apply vgg up to conv4_3 relu
77 | for k in range(23):
78 | x = self.vgg[k](x)
79 |
80 | s = self.L2Norm(x)
81 | sources.append(s)
82 |
83 | # apply vgg up to fc7
84 | for k in range(23, len(self.vgg)):
85 | x = self.vgg[k](x)
86 | sources.append(x)
87 |
88 | # apply extra layers and cache source layer outputs
89 | for k, v in enumerate(self.extras):
90 | x = F.relu(v(x), inplace=True)
91 | if k % 2 == 1:
92 | sources.append(x)
93 |
94 | # apply multibox head to source layers
95 | for (x, l, c) in zip(sources, self.loc, self.conf):
96 | loc.append(l(x).permute(0, 2, 3, 1).contiguous())
97 | conf.append(c(x).permute(0, 2, 3, 1).contiguous())
98 |
99 | loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
100 | conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
101 | if self.phase == "test":
102 | output = self.detect(
103 | loc.view(loc.size(0), -1, 4),
104 | self.softmax(conf.view(conf.size(0), -1,
105 | self.num_classes)),
106 | self.priors.type(type(x.data))
107 | )
108 | else:
109 | output = (
110 | loc.view(loc.size(0), -1, 4),
111 | conf.view(conf.size(0), -1, self.num_classes),
112 | self.priors
113 | )
114 | return output
115 |
116 | def load_weights(self, base_file):
117 | other, ext = os.path.splitext(base_file)
118 | if ext == '.pkl' or '.pth':
119 | print('Loading weights into state dict...')
120 | self.load_state_dict(torch.load(base_file,
121 | map_location=lambda storage, loc: storage))
122 | print('Finished!')
123 | else:
124 | print('Sorry only .pth and .pkl files supported.')
125 |
126 |
127 | # This function is derived from torchvision VGG make_layers()
128 | # https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
129 | def vgg(cfg, i, batch_norm=False):
130 | layers = []
131 | in_channels = i
132 | for v in cfg:
133 | if v == 'M':
134 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
135 | elif v == 'C':
136 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
137 | else:
138 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
139 | if batch_norm:
140 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
141 | else:
142 | layers += [conv2d, nn.ReLU(inplace=True)]
143 | in_channels = v
144 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
145 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
146 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
147 | layers += [pool5, conv6,
148 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
149 | return layers
150 |
151 |
152 | def add_extras(cfg, i, batch_norm=False):
153 | # Extra layers added to VGG for feature scaling
154 | layers = []
155 | in_channels = i
156 | flag = False
157 | for k, v in enumerate(cfg):
158 | if in_channels != 'S':
159 | if v == 'S':
160 | layers += [nn.Conv2d(in_channels, cfg[k + 1],
161 | kernel_size=(1, 3)[flag], stride=2, padding=1)]
162 | elif v=='K':
163 | layers += [nn.Conv2d(in_channels, 256,
164 | kernel_size=4, stride=1, padding=1)]
165 | else:
166 | layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
167 | flag = not flag
168 | in_channels = v
169 | return layers
170 |
171 |
172 |
173 | def multibox(vgg, extra_layers, cfg, num_classes):
174 | loc_layers = []
175 | conf_layers = []
176 | vgg_source = [21, -2]
177 | for k, v in enumerate(vgg_source):
178 | loc_layers += [nn.Conv2d(vgg[v].out_channels,
179 | cfg[k] * 4, kernel_size=3, padding=1)]
180 | conf_layers += [nn.Conv2d(vgg[v].out_channels,
181 | cfg[k] * num_classes, kernel_size=3, padding=1)]
182 | for k, v in enumerate(extra_layers[1::2], 2):
183 | loc_layers += [nn.Conv2d(v.out_channels, cfg[k]
184 | * 4, kernel_size=3, padding=1)]
185 | conf_layers += [nn.Conv2d(v.out_channels, cfg[k]
186 | * num_classes, kernel_size=3, padding=1)]
187 | return vgg, extra_layers, (loc_layers, conf_layers)
188 |
189 |
190 | base = {
191 | '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
192 | 512, 512, 512],
193 | '512': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M',
194 | 512, 512, 512],
195 | }
196 | extras = {
197 | '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],
198 | '512': [256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128, 'K'],
199 | }
200 | mbox = {
201 | '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location
202 | '512': [4, 6, 6, 6, 6, 4, 4],
203 | }
204 |
205 |
206 |
207 | def build_ssd(phase, size=300, num_classes=21):
208 | if phase != "test" and phase != "train":
209 | print("ERROR: Phase: " + phase + " not recognized")
210 | return
211 | base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
212 | add_extras(extras[str(size)], 1024),
213 | mbox[str(size)], num_classes)
214 | return SSD(phase, size, base_, extras_, head_, num_classes)
215 |
--------------------------------------------------------------------------------
/subset_sequential_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | from torch.utils.data.sampler import Sampler
9 | import random
10 | import sys
11 | from copy import deepcopy
12 |
13 |
14 | class SubsetSequentialSampler(Sampler):
15 | """Samples elements randomly from a given list of indices, without replacement.
16 | Arguments:
17 | indices (list): a list of indices
18 | """
19 |
20 | def __init__(self, indices):
21 | self.indices = indices
22 |
23 | def __iter__(self):
24 | return (self.indices[i] for i in range(len(self.indices)))
25 |
26 | def __len__(self):
27 | return len(self.indices)
28 |
29 |
30 | class BalancedSubsetSequentialSampler(Sampler):
31 | """Samples in a balanced way images ensuring that exactly half of the images have labels.
32 | Arguments:
33 | indices (list): a list of indices
34 | """
35 |
36 |
37 | def __init__(self, indices, supervised, unsupervised):
38 |
39 | self.indices = indices
40 | random.shuffle(supervised)
41 | random.shuffle(unsupervised)
42 | self.len_supervised = len(supervised)
43 | self.len_unsupervised = len(unsupervised)
44 | if self.len_supervised > self.len_unsupervised:
45 | ratio = self.len_supervised // self.len_unsupervised
46 | module = self.len_supervised % self.len_unsupervised
47 | self.supervised = supervised
48 | self.unsupervised = ratio * unsupervised + unsupervised[:module]
49 | else:
50 | ratio = self.len_unsupervised // self.len_supervised
51 | module = self.len_unsupervised % self.len_supervised
52 | self.unsupervised = unsupervised
53 | self.supervised = ratio * supervised + supervised[:module]
54 | print(len(self.supervised), len(self.unsupervised))
55 |
56 | def _add_lists_alternatively(self, lst1, lst2):
57 | random.shuffle(lst1)
58 | random.shuffle(lst2)
59 | return [sub[item] for item in range(len(lst2))
60 | for sub in [lst1, lst2]]
61 |
62 | def __iter__(self):
63 | all_indices = self._add_lists_alternatively(self.supervised, self.unsupervised)
64 | return (all_indices[i] for i in range(len(all_indices)))
65 |
66 | def __len__(self):
67 | return len(self.len_supervised) + len(self.unsupervised)
68 |
69 |
70 | class BalancedSampler(Sampler):
71 | """Samples elements randomly from a given list of indices, without replacement.
72 | Arguments:
73 | indices (list): a list of indices
74 | """
75 | def __init__(self, indices, supervised, unsupervised, ratio=1):
76 |
77 | self.indices = indices
78 | self.supervised = supervised
79 | self.unsupervised = unsupervised
80 | self.new_list = []
81 | self.ratio = ratio
82 |
83 | def _create_balanced_batch(self):
84 | copy_supervised = deepcopy(self.supervised)
85 | copy_unsupervised = deepcopy(self.unsupervised)
86 | random.shuffle(copy_supervised)
87 | random.shuffle(copy_unsupervised)
88 | self.new_list = []
89 | while copy_supervised:
90 | self.new_list.append(copy_supervised.pop())
91 | if len(copy_unsupervised) < self.ratio:
92 | copy_unsupervised = deepcopy(self.unsupervised)
93 | random.shuffle(copy_unsupervised)
94 | for i in range(self.ratio):
95 | self.new_list.append(copy_unsupervised.pop())
96 | return self.new_list
97 |
98 | def __iter__(self):
99 | all_indices = self._create_balanced_batch()
100 | return (all_indices[i] for i in range(len(all_indices)))
101 |
102 | def __len__(self):
103 | return len(self.new_list)
104 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | from __future__ import print_function
9 | import sys
10 | import os
11 | import argparse
12 | import torch
13 | import torch.nn as nn
14 | import torch.backends.cudnn as cudnn
15 | import torchvision.transforms as transforms
16 | from torch.autograd import Variable
17 | from data import VOC_ROOT, VOC_CLASSES as labelmap
18 | from PIL import Image
19 | from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
20 | import torch.utils.data as data
21 | from ssd import build_ssd
22 |
23 | parser = argparse.ArgumentParser(description='Single Shot MultiBox Detection')
24 | parser.add_argument('--trained_model', default='weights_voc_code/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth',
25 | type=str, help='Trained state_dict file path to open')
26 | # parser.add_argument('--trained_model', default='weights/ssd_300_VOC0712.pth',
27 | # type=str, help='Trained state_dict file path to open')
28 | parser.add_argument('--save_folder', default='eval/', type=str,
29 | help='Dir to save results')
30 | parser.add_argument('--visual_threshold', default=0.6, type=float,
31 | help='Final confidence threshold')
32 | parser.add_argument('--cuda', default=True, type=bool,
33 | help='Use cuda to train model')
34 | parser.add_argument('--voc_root', default='/usr/wiss/elezi/data/VOC0712', help='Location of VOC root directory')
35 | parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks")
36 | args = parser.parse_args()
37 |
38 | if args.cuda and torch.cuda.is_available():
39 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
40 | else:
41 | torch.set_default_tensor_type('torch.FloatTensor')
42 |
43 | if not os.path.exists(args.save_folder):
44 | os.mkdir(args.save_folder)
45 |
46 |
47 | def test_net(save_folder, net, cuda, testset, transform, thresh):
48 | # dump predictions and assoc. ground truth to text file for now
49 | filename = save_folder+'test1.txt'
50 | num_images = len(testset)
51 | for i in range(num_images):
52 | print('Testing image {:d}/{:d}....'.format(i+1, num_images))
53 | img = testset.pull_image(i)
54 | img_id, annotation = testset.pull_anno(i)
55 | x = torch.from_numpy(transform(img)[0]).permute(2, 0, 1)
56 | x = Variable(x.unsqueeze(0))
57 |
58 | with open(filename, mode='a') as f:
59 | f.write('\nGROUND TRUTH FOR: '+img_id+'\n')
60 | for box in annotation:
61 | f.write('label: '+' || '.join(str(b) for b in box)+'\n')
62 | if cuda:
63 | x = x.cuda()
64 |
65 | y = net(x) # forward pass
66 | detections = y.data
67 | # scale each detection back up to the image
68 | scale = torch.Tensor([img.shape[1], img.shape[0],
69 | img.shape[1], img.shape[0]])
70 | pred_num = 0
71 | for i in range(detections.size(1)):
72 | j = 0
73 | while detections[0, i, j, 0] >= 0.6:
74 | if pred_num == 0:
75 | with open(filename, mode='a') as f:
76 | f.write('PREDICTIONS: '+'\n')
77 | score = detections[0, i, j, 0]
78 | label_name = labelmap[i-1]
79 | pt = (detections[0, i, j, 1:]*scale).cpu().numpy()
80 | coords = (pt[0], pt[1], pt[2], pt[3])
81 | pred_num += 1
82 | with open(filename, mode='a') as f:
83 | f.write(str(pred_num)+' label: '+label_name+' score: ' +
84 | str(score) + ' '+' || '.join(str(c) for c in coords) + '\n')
85 | j += 1
86 |
87 |
88 | def test_voc():
89 | # load net
90 | num_classes = len(VOC_CLASSES) + 1 # +1 background
91 | net = build_ssd('test', 300, num_classes) # initialize SSD
92 | net = nn.DataParallel(net)
93 | net.load_state_dict(torch.load(args.trained_model))
94 | net.eval()
95 | print('Finished loading model!')
96 | # load data
97 | testset = VOCDetection(args.voc_root, image_sets=[('2007', 'test')],
98 | transform=None, target_transform=VOCAnnotationTransform())
99 | if args.cuda:
100 | net = net.cuda()
101 | cudnn.benchmark = True
102 | # evaluation
103 | test_net(args.save_folder, net, args.cuda, testset,
104 | BaseTransform(300, (104, 117, 123)),
105 | thresh=args.visual_threshold)
106 |
107 | if __name__ == '__main__':
108 | test_voc()
109 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022-2023, NVIDIA Corporation & Affiliates. All rights reserved.
2 | #
3 | # This work is made available under the Nvidia Source Code License-NC.
4 | # To view a copy of this license, visit
5 | # https://github.com/NVlabs/AL-SSL/
6 |
7 |
8 | import warnings
9 | import argparse
10 | import math
11 | import numpy as np
12 | import random
13 | import os
14 | import time
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | import torch.backends.cudnn as cudnn
19 | import torch.nn.init as init
20 |
21 | from active_learning import combined_score, active_learning_inconsistency, active_learning_entropy
22 | from csd import build_ssd_con
23 | from data import *
24 | from layers.modules import MultiBoxLoss
25 | from loaders import create_loaders, change_loaders
26 |
27 | random.seed(314)
28 | torch.manual_seed(314)
29 |
30 | warnings.filterwarnings("ignore")
31 |
32 |
33 | def str2bool(v):
34 | return v.lower() in ("yes", "true", "t", "1")
35 |
36 |
37 | class get_al_hyperparams():
38 | def __init__(self, dataset_name='voc'):
39 | self.dataset_name = dataset_name
40 | self.dataset_path = {'voc': '/usr/wiss/elezi/data/VOC0712',
41 | 'coco': '/usr/wiss/elezi/data/coco'}
42 |
43 | self.num_ims = {'voc': 16551, 'coco': 82081}
44 | self.num_init = {'voc': 2011, 'coco': 5000}
45 | self.pseudo_threshold = {'voc': 0.99, 'coco': 0.75}
46 | self.config = {'voc': voc300, 'coco': coco}
47 | self.batch_size = {'voc': 16, 'coco': 32}
48 |
49 | def get_dataset_path(self):
50 | return self.dataset_path[self.dataset_name]
51 |
52 | def get_num_ims(self):
53 | return self.num_ims[self.dataset_name]
54 |
55 | def get_num_init(self):
56 | return self.num_init[self.dataset_name]
57 |
58 | def get_pseudo_threshold(self):
59 | return self.pseudo_threshold[self.dataset_name]
60 |
61 | def get_config(self):
62 | return self.config[self.dataset_name]
63 |
64 | def get_dataset_name(self):
65 | return self.dataset_name
66 |
67 | def get_batch_size(self):
68 | return self.batch_size[self.dataset_name]
69 |
70 |
71 |
72 | parser = argparse.ArgumentParser(
73 | description='Single Shot MultiBox Detector Training With Pytorch')
74 | train_set = parser.add_mutually_exclusive_group()
75 | al_hyperparams = get_al_hyperparams('voc') # 'voc' for voc, 'coco' for coco
76 | parser.add_argument('--dataset_name', default=al_hyperparams.get_dataset_name(), type=str,
77 | help='Dataset name')
78 | parser.add_argument('--cfg', default=al_hyperparams.get_config(), type=dict,
79 | help='configurer for the specific dataset')
80 | parser.add_argument('--dataset', default='VOC300', choices=['VOC300', 'VOC512'],
81 | type=str, help='VOC300 or VOC512')
82 | parser.add_argument('--dataset_root', default=al_hyperparams.get_dataset_path(), type=str,
83 | help='Dataset root directory path')
84 | parser.add_argument('--num_total_images', default=al_hyperparams.get_num_ims(), type=int,
85 | help='Number of images in the dataset')
86 | parser.add_argument('--num_initial_labeled_set', default=al_hyperparams.get_num_init(), type=int,
87 | help='Number of initially labeled images')
88 | parser.add_argument('--acquisition_budget', default=1000, type=int,
89 | help='Active labeling cycle budget')
90 | parser.add_argument('--num_cycles', default=5, type=int,
91 | help='Number of active learning cycles')
92 | parser.add_argument('--criterion_select', default='combined',
93 | choices=['random', 'entropy', 'consistency', 'combined'],
94 | help='Active learning acquisition score')
95 | parser.add_argument('--filter_entropy_num', default=3000, type=int,
96 | help='How many samples to pre-filer with entropy')
97 | parser.add_argument('--id', default=2, type=int,
98 | help='the id of the experiment')
99 | parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
100 | help='Pretrained base model')
101 | parser.add_argument('--batch_size', default=al_hyperparams.get_batch_size(), type=int,
102 | help='Batch size for training')
103 | parser.add_argument('--num_workers', default=8, type=int,
104 | help='Number of workers used in dataloading')
105 | parser.add_argument('--cuda', default=True, type=str2bool,
106 | help='Use CUDA to train model')
107 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
108 | help='initial learning rate')
109 | parser.add_argument('--momentum', default=0.9, type=float,
110 | help='Momentum value for optim')
111 | parser.add_argument('--weight_decay', default=5e-4, type=float,
112 | help='Weight decay for SGD')
113 | parser.add_argument('--gamma', default=0.1, type=float,
114 | help='Gamma update for SGD')
115 | parser.add_argument('--save_folder', default='../al_ssl/weights/',
116 | help='Directory for saving checkpoint models')
117 | parser.add_argument('--net_name', default=None, type=str,
118 | help='the net checkpoint we need to load')
119 | parser.add_argument('--is_apex', default=0, type=int,
120 | help='if 1 use apex to do mixed precision training')
121 | parser.add_argument('--is_cluster', default=1, type=int,
122 | help='if 1 use GPU cluster, otherwise do the computations on the local PC')
123 | parser.add_argument('--do_PL', default=1, type=int,
124 | help='if 1 use pseudo-labels, otherwise do not use them')
125 | parser.add_argument('--pseudo_threshold', default=al_hyperparams.get_pseudo_threshold(), type=float,
126 | help='pseudo label confidence threshold for voc dataset')
127 | parser.add_argument('--thresh', default=0.5, type=float,
128 | help='we define an object if the probability of one class is above thresh')
129 | parser.add_argument('--do_AL', default=1, type=int, help='if 0 skip AL')
130 | args = parser.parse_args()
131 |
132 | if torch.cuda.is_available():
133 | if args.cuda:
134 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
135 | if not args.cuda:
136 | print("WARNING: It looks like you have a CUDA device, but aren't " +
137 | "using CUDA.\nRun with --cuda for optimal training speed.")
138 | torch.set_default_tensor_type('torch.FloatTensor')
139 | cudnn.benchmark = True
140 | else:
141 | torch.set_default_tensor_type('torch.FloatTensor')
142 |
143 | if not os.path.exists(args.save_folder):
144 | os.mkdir(args.save_folder)
145 |
146 |
147 | def load_net_optimizer_multi(cfg):
148 | net = build_ssd_con('train', cfg['min_dim'], cfg['num_classes'])
149 | vgg_weights = torch.load(args.save_folder + args.basenet)
150 | print('Loading the backbone pretrained in Imagenet...')
151 | net.vgg.load_state_dict(vgg_weights)
152 | net.extras.apply(weights_init)
153 | net.loc.apply(weights_init)
154 | net.conf.apply(weights_init)
155 | if args.is_cluster:
156 | net = nn.DataParallel(net)
157 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum,
158 | weight_decay=args.weight_decay)
159 | return net, optimizer
160 |
161 |
162 | def compute_consistency_loss(conf, loc, conf_flip, loc_flip, conf_consistency_criterion):
163 | conf_class = conf[:, :, 1:].clone()
164 | background_score = conf[:, :, 0].clone()
165 | each_val, each_index = torch.max(conf_class, dim=2)
166 | mask_val = each_val > background_score
167 | mask_val = mask_val.data
168 |
169 | mask_conf_index = mask_val.unsqueeze(2).expand_as(conf)
170 | mask_loc_index = mask_val.unsqueeze(2).expand_as(loc)
171 |
172 | conf_mask_sample = conf.clone()
173 | loc_mask_sample = loc.clone()
174 | conf_sampled = conf_mask_sample[mask_conf_index].view(-1, args.cfg['num_classes'])
175 | loc_sampled = loc_mask_sample[mask_loc_index].view(-1, 4)
176 |
177 | conf_mask_sample_flip = conf_flip.clone()
178 | loc_mask_sample_flip = loc_flip.clone()
179 | conf_sampled_flip = conf_mask_sample_flip[mask_conf_index].view(-1, args.cfg['num_classes'])
180 | loc_sampled_flip = loc_mask_sample_flip[mask_loc_index].view(-1, 4)
181 |
182 | if (mask_val.sum() > 0):
183 | # Compute Jenson-Shannon divergence (symmetric KL actually)
184 | conf_sampled_flip = conf_sampled_flip + 1e-7
185 | conf_sampled = conf_sampled + 1e-7
186 | consistency_conf_loss_a = conf_consistency_criterion(conf_sampled.log(),
187 | conf_sampled_flip.detach()).sum(-1).mean()
188 | consistency_conf_loss_b = conf_consistency_criterion(conf_sampled_flip.log(),
189 | conf_sampled.detach()).sum(-1).mean()
190 | consistency_conf_loss = consistency_conf_loss_a + consistency_conf_loss_b
191 |
192 | # Compute location consistency loss
193 | consistency_loc_loss_x = torch.mean(torch.pow(loc_sampled[:, 0] + loc_sampled_flip[:, 0], exponent=2))
194 | consistency_loc_loss_y = torch.mean(torch.pow(loc_sampled[:, 1] - loc_sampled_flip[:, 1], exponent=2))
195 | consistency_loc_loss_w = torch.mean(torch.pow(loc_sampled[:, 2] - loc_sampled_flip[:, 2], exponent=2))
196 | consistency_loc_loss_h = torch.mean(torch.pow(loc_sampled[:, 3] - loc_sampled_flip[:, 3], exponent=2))
197 |
198 | consistency_loc_loss = torch.div(
199 | consistency_loc_loss_x + consistency_loc_loss_y + consistency_loc_loss_w + consistency_loc_loss_h,
200 | 4)
201 |
202 | else:
203 | consistency_conf_loss = torch.cuda.FloatTensor([0])
204 | consistency_loc_loss = torch.cuda.FloatTensor([0])
205 |
206 | consistency_loss = torch.div(consistency_conf_loss, 2) + consistency_loc_loss
207 | return consistency_loss
208 |
209 |
210 | def rampweight(iteration):
211 | ramp_up_end = 32000
212 | ramp_down_start = 100000
213 |
214 | if (iteration < ramp_up_end):
215 | ramp_weight = math.exp(-5 * math.pow((1 - iteration / ramp_up_end), 2))
216 | elif (iteration > ramp_down_start):
217 | ramp_weight = math.exp(-12.5 * math.pow((1 - (120000 - iteration) / 20000), 2))
218 | else:
219 | ramp_weight = 1
220 |
221 | if (iteration == 0):
222 | ramp_weight = 0
223 |
224 | return ramp_weight
225 |
226 |
227 | def train(dataset, data_loader, cfg, labeled_set, supervised_dataset, indices):
228 | # net, optimizer = load_net_optimizer_multi(cfg)
229 | criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
230 | False, args.cuda)
231 | conf_consistency_criterion = torch.nn.KLDivLoss(size_average=False, reduce=False).cuda()
232 |
233 | # loss counters
234 | print('Loading the dataset...')
235 | print('Training SSD on:', dataset.name)
236 | print('Using the specified args:')
237 |
238 | step_index = 0
239 |
240 | # create batch iterator
241 | batch_iterator = iter(data_loader)
242 |
243 | finish_flag = True
244 |
245 | while finish_flag:
246 | net, optimizer = load_net_optimizer_multi(cfg)
247 | net.train()
248 | for iteration in range(cfg['max_iter']):
249 | print(iteration)
250 |
251 | if iteration in cfg['lr_steps']:
252 | step_index += 1
253 | adjust_learning_rate(optimizer, args.gamma, step_index)
254 | try:
255 | images, targets, semis = next(batch_iterator)
256 | except StopIteration:
257 | batch_iterator = iter(data_loader)
258 | images, targets, semis = next(batch_iterator)
259 |
260 | images = images.cuda()
261 | targets = [ann.cuda() for ann in targets]
262 |
263 | # forward
264 | t0 = time.time()
265 | out, conf, conf_flip, loc, loc_flip, _ = net(images)
266 | sup_image_binary_index = np.zeros([len(semis), 1])
267 |
268 | semis_index, new_semis = [], []
269 | for iii, super_image in enumerate(range(len(semis))):
270 | new_semis.append(int(semis[super_image]))
271 | if (int(semis[super_image]) > 0):
272 | sup_image_binary_index[super_image] = 1
273 | semis_index.append(super_image)
274 | else:
275 | sup_image_binary_index[super_image] = 0
276 |
277 | if (int(semis[len(semis) - 1 - super_image]) == 0):
278 | del targets[len(semis) - 1 - super_image]
279 |
280 | sup_image_index = np.where(sup_image_binary_index == 1)[0]
281 | loc_data, conf_data, priors = out
282 |
283 | if (len(sup_image_index) != 0):
284 | loc_data = loc_data[sup_image_index, :, :]
285 | conf_data = conf_data[sup_image_index, :, :]
286 | output = (loc_data, conf_data, priors)
287 |
288 | consistency_loss = compute_consistency_loss(conf, loc, conf_flip, loc_flip, conf_consistency_criterion)
289 | ramp_weight = rampweight(iteration)
290 | consistency_loss = torch.mul(consistency_loss, ramp_weight)
291 |
292 | if (len(sup_image_index) == 0):
293 | loss_l, loss_c = torch.cuda.FloatTensor([0]), torch.cuda.FloatTensor([0])
294 | else:
295 | loss_l, loss_c = criterion(output, targets, np.array(new_semis)[semis_index])
296 | loss = loss_l + loss_c + consistency_loss
297 | print(loss)
298 |
299 | if (loss.data > 0):
300 | optimizer.zero_grad()
301 | loss.backward()
302 | optimizer.step()
303 | else:
304 | print("Loss is 0")
305 |
306 | if (float(loss) > 100 or torch.isnan(loss)):
307 | # if the net diverges, go back to point 0 and train from scratch
308 | break
309 | t1 = time.time()
310 |
311 | if iteration % 10 == 0:
312 | print('timer: %.4f sec.' % (t1 - t0))
313 | print('iter ' + repr(
314 | iteration) + ': loss: %.4f , loss_c: %.4f , loss_l: %.4f , loss_con: %.4f, lr : %.4f, super_len : %d\n' % (
315 | loss.data, loss_c.data, loss_l.data, consistency_loss.data,
316 | float(optimizer.param_groups[0]['lr']),
317 | len(sup_image_index)))
318 |
319 | if iteration != 0 and (iteration + 1) % 120000 == 0:
320 | print('Saving state, iter:', iteration)
321 | net_name = 'weights/' + repr(iteration + 1) + args.criterion_select + '_id_' + str(args.id) + \
322 | '_pl_threshold_' + str(args.pseudo_threshold) + '_labeled_set_' + str(len(labeled_set)) + '_.pth'
323 | torch.save(net.state_dict(), net_name)
324 |
325 | if iteration >= 119000:
326 | finish_flag = False
327 | return net, net_name
328 |
329 |
330 | def adjust_learning_rate(optimizer, gamma, step):
331 | """Sets the learning rate to the initial LR decayed by 10 at every
332 | specified step
333 | # Adapted from PyTorch Imagenet example:
334 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py
335 | """
336 | lr = args.lr * (gamma ** (step))
337 | for param_group in optimizer.param_groups:
338 | param_group['lr'] = lr
339 |
340 |
341 | def xavier(param):
342 | init.xavier_uniform(param)
343 |
344 |
345 | def weights_init(m):
346 | if isinstance(m, nn.Conv2d):
347 | xavier(m.weight.data)
348 | m.bias.data.zero_()
349 |
350 |
351 | def main():
352 | print(os.path.abspath(os.getcwd()))
353 | if args.cuda: cudnn.benchmark = True
354 | supervised_dataset, supervised_data_loader, unsupervised_dataset, unsupervised_data_loader, indices, labeled_set, unlabeled_set = create_loaders(args)
355 | net, net_name = train(supervised_dataset, supervised_data_loader, args.cfg, labeled_set, supervised_dataset, indices)
356 |
357 | net, _ = load_net_optimizer_multi(args.cfg)
358 | if not args.is_cluster:
359 | net = nn.DataParallel(net)
360 |
361 | # net_name = os.path.join('/usr/wiss/elezi/PycharmProjects/al_ssl/weights_good/120000combined_id_2_pl_threshold_0.99_labeled_set_3011_.pth')
362 | # net.load_state_dict(torch.load(net_name))
363 |
364 | # do active learning cycles
365 | for i in range(args.num_cycles):
366 | net.eval()
367 |
368 | if args.do_AL:
369 | if args.criterion_select in ['Max_aver', 'entropy', 'random']:
370 | batch_iterator = iter(unsupervised_data_loader)
371 | labeled_set, unlabeled_set = active_learning_entropy(args, batch_iterator, labeled_set, unlabeled_set, net,
372 | args.cfg['num_classes'],
373 | args.criterion_select,
374 | loader=unsupervised_data_loader)
375 | elif args.criterion_select == 'consistency':
376 | batch_iterator = iter(unsupervised_data_loader)
377 | labeled_set, unlabeled_set = active_learning_inconsistency(args, batch_iterator, labeled_set, unlabeled_set, net,
378 | args.cfg['num_classes'],
379 | args.criterion_select,
380 | loader=unsupervised_data_loader)
381 | elif args.criterion_select == 'combined':
382 | print("Combined")
383 | batch_iterator = iter(unsupervised_data_loader)
384 | labeled_set, unlabeled_set = combined_score(args, batch_iterator, labeled_set, unlabeled_set, net,
385 | unsupervised_data_loader)
386 |
387 | supervised_data_loader, unsupervised_data_loader = change_loaders(args, supervised_dataset,
388 | unsupervised_dataset, labeled_set, unlabeled_set, indices, net_name, pseudo=args.do_PL)
389 | net, net_name = train(supervised_dataset, supervised_data_loader, args.cfg, labeled_set, supervised_dataset,
390 | indices)
391 |
392 |
393 | if __name__ == '__main__':
394 | main()
395 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .augmentations import SSDAugmentation, jaccard_numpy
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/augmentations.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/augmentations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/augmentations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/AL-SSL/949476f019abe8d713c96e60105a4f7c832d3820/utils/__pycache__/augmentations.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import transforms
3 | import cv2
4 | import numpy as np
5 | import types
6 | from numpy import random
7 |
8 |
9 | def intersect(box_a, box_b):
10 | max_xy = np.minimum(box_a[:, 2:], box_b[2:])
11 | min_xy = np.maximum(box_a[:, :2], box_b[:2])
12 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
13 | return inter[:, 0] * inter[:, 1]
14 |
15 |
16 | def jaccard_numpy(box_a, box_b):
17 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
18 | is simply the intersection over union of two boxes.
19 | E.g.:
20 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
21 | Args:
22 | box_a: Multiple bounding boxes, Shape: [num_boxes,4]
23 | box_b: Single bounding box, Shape: [4]
24 | Return:
25 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
26 | """
27 | inter = intersect(box_a, box_b)
28 | area_a = ((box_a[:, 2]-box_a[:, 0]) *
29 | (box_a[:, 3]-box_a[:, 1])) # [A,B]
30 | area_b = ((box_b[2]-box_b[0]) *
31 | (box_b[3]-box_b[1])) # [A,B]
32 | union = area_a + area_b - inter
33 | return inter / union # [A,B]
34 |
35 |
36 | class Compose(object):
37 | """Composes several augmentations together.
38 | Args:
39 | transforms (List[Transform]): list of transforms to compose.
40 | Example:
41 | >>> augmentations.Compose([
42 | >>> transforms.CenterCrop(10),
43 | >>> transforms.ToTensor(),
44 | >>> ])
45 | """
46 |
47 | def __init__(self, transforms):
48 | self.transforms = transforms
49 |
50 | def __call__(self, img, boxes=None, labels=None):
51 | for t in self.transforms:
52 | img, boxes, labels = t(img, boxes, labels)
53 | return img, boxes, labels
54 |
55 |
56 | class Lambda(object):
57 | """Applies a lambda as a transform."""
58 |
59 | def __init__(self, lambd):
60 | assert isinstance(lambd, types.LambdaType)
61 | self.lambd = lambd
62 |
63 | def __call__(self, img, boxes=None, labels=None):
64 | return self.lambd(img, boxes, labels)
65 |
66 |
67 | class ConvertFromInts(object):
68 | def __call__(self, image, boxes=None, labels=None):
69 | return image.astype(np.float32), boxes, labels
70 |
71 |
72 | class SubtractMeans(object):
73 | def __init__(self, mean):
74 | self.mean = np.array(mean, dtype=np.float32)
75 |
76 | def __call__(self, image, boxes=None, labels=None):
77 | image = image.astype(np.float32)
78 | image -= self.mean
79 | return image.astype(np.float32), boxes, labels
80 |
81 |
82 | class ToAbsoluteCoords(object):
83 | def __call__(self, image, boxes=None, labels=None):
84 | height, width, channels = image.shape
85 | boxes[:, 0] *= width
86 | boxes[:, 2] *= width
87 | boxes[:, 1] *= height
88 | boxes[:, 3] *= height
89 |
90 | return image, boxes, labels
91 |
92 |
93 | class ToPercentCoords(object):
94 | def __call__(self, image, boxes=None, labels=None):
95 | height, width, channels = image.shape
96 | boxes[:, 0] /= width
97 | boxes[:, 2] /= width
98 | boxes[:, 1] /= height
99 | boxes[:, 3] /= height
100 |
101 | return image, boxes, labels
102 |
103 |
104 | class Resize(object):
105 | def __init__(self, size=300):
106 | self.size = size
107 |
108 | def __call__(self, image, boxes=None, labels=None):
109 | image = cv2.resize(image, (self.size,
110 | self.size))
111 | return image, boxes, labels
112 |
113 |
114 | class RandomSaturation(object):
115 | def __init__(self, lower=0.5, upper=1.5):
116 | self.lower = lower
117 | self.upper = upper
118 | assert self.upper >= self.lower, "contrast upper must be >= lower."
119 | assert self.lower >= 0, "contrast lower must be non-negative."
120 |
121 | def __call__(self, image, boxes=None, labels=None):
122 | if random.randint(2):
123 | image[:, :, 1] *= random.uniform(self.lower, self.upper)
124 |
125 | return image, boxes, labels
126 |
127 |
128 | class RandomHue(object):
129 | def __init__(self, delta=18.0):
130 | assert delta >= 0.0 and delta <= 360.0
131 | self.delta = delta
132 |
133 | def __call__(self, image, boxes=None, labels=None):
134 | if random.randint(2):
135 | image[:, :, 0] += random.uniform(-self.delta, self.delta)
136 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
137 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
138 | return image, boxes, labels
139 |
140 |
141 | class RandomLightingNoise(object):
142 | def __init__(self):
143 | self.perms = ((0, 1, 2), (0, 2, 1),
144 | (1, 0, 2), (1, 2, 0),
145 | (2, 0, 1), (2, 1, 0))
146 |
147 | def __call__(self, image, boxes=None, labels=None):
148 | if random.randint(2):
149 | swap = self.perms[random.randint(len(self.perms))]
150 | shuffle = SwapChannels(swap) # shuffle channels
151 | image = shuffle(image)
152 | return image, boxes, labels
153 |
154 |
155 | class ConvertColor(object):
156 | def __init__(self, current='BGR', transform='HSV'):
157 | self.transform = transform
158 | self.current = current
159 |
160 | def __call__(self, image, boxes=None, labels=None):
161 | if self.current == 'BGR' and self.transform == 'HSV':
162 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
163 | elif self.current == 'HSV' and self.transform == 'BGR':
164 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
165 | else:
166 | raise NotImplementedError
167 | return image, boxes, labels
168 |
169 |
170 | class RandomContrast(object):
171 | def __init__(self, lower=0.5, upper=1.5):
172 | self.lower = lower
173 | self.upper = upper
174 | assert self.upper >= self.lower, "contrast upper must be >= lower."
175 | assert self.lower >= 0, "contrast lower must be non-negative."
176 |
177 | # expects float image
178 | def __call__(self, image, boxes=None, labels=None):
179 | if random.randint(2):
180 | alpha = random.uniform(self.lower, self.upper)
181 | image *= alpha
182 | return image, boxes, labels
183 |
184 |
185 | class RandomBrightness(object):
186 | def __init__(self, delta=32):
187 | assert delta >= 0.0
188 | assert delta <= 255.0
189 | self.delta = delta
190 |
191 | def __call__(self, image, boxes=None, labels=None):
192 | if random.randint(2):
193 | delta = random.uniform(-self.delta, self.delta)
194 | image += delta
195 | return image, boxes, labels
196 |
197 |
198 | class ToCV2Image(object):
199 | def __call__(self, tensor, boxes=None, labels=None):
200 | return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels
201 |
202 |
203 | class ToTensor(object):
204 | def __call__(self, cvimage, boxes=None, labels=None):
205 | return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels
206 |
207 |
208 | class RandomSampleCrop(object):
209 | """Crop
210 | Arguments:
211 | img (Image): the image being input during training
212 | boxes (Tensor): the original bounding boxes in pt form
213 | labels (Tensor): the class labels for each bbox
214 | mode (float tuple): the min and max jaccard overlaps
215 | Return:
216 | (img, boxes, classes)
217 | img (Image): the cropped image
218 | boxes (Tensor): the adjusted bounding boxes in pt form
219 | labels (Tensor): the class labels for each bbox
220 | """
221 | def __init__(self):
222 | self.sample_options = (
223 | # using entire original input image
224 | None,
225 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
226 | (0.1, None),
227 | (0.3, None),
228 | (0.7, None),
229 | (0.9, None),
230 | # randomly sample a patch
231 | (None, None),
232 | )
233 |
234 | def __call__(self, image, boxes=None, labels=None):
235 | height, width, _ = image.shape
236 | while True:
237 | # randomly choose a mode
238 | mode = random.choice(self.sample_options)
239 | if mode is None:
240 | return image, boxes, labels
241 |
242 | min_iou, max_iou = mode
243 | if min_iou is None:
244 | min_iou = float('-inf')
245 | if max_iou is None:
246 | max_iou = float('inf')
247 |
248 | # max trails (50)
249 | for _ in range(50):
250 | current_image = image
251 |
252 | w = random.uniform(0.3 * width, width)
253 | h = random.uniform(0.3 * height, height)
254 |
255 | # aspect ratio constraint b/t .5 & 2
256 | if h / w < 0.5 or h / w > 2:
257 | continue
258 |
259 | left = random.uniform(width - w)
260 | top = random.uniform(height - h)
261 |
262 | # convert to integer rect x1,y1,x2,y2
263 | rect = np.array([int(left), int(top), int(left+w), int(top+h)])
264 |
265 | # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
266 | overlap = jaccard_numpy(boxes, rect)
267 |
268 | # is min and max overlap constraint satisfied? if not try again
269 | if overlap.min() < min_iou and max_iou < overlap.max():
270 | continue
271 |
272 | # cut the crop from the image
273 | current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
274 | :]
275 |
276 | # keep overlap with gt box IF center in sampled patch
277 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
278 |
279 | # mask in all gt boxes that above and to the left of centers
280 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
281 |
282 | # mask in all gt boxes that under and to the right of centers
283 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
284 |
285 | # mask in that both m1 and m2 are true
286 | mask = m1 * m2
287 |
288 | # have any valid boxes? try again if not
289 | if not mask.any():
290 | continue
291 |
292 | # take only matching gt boxes
293 | current_boxes = boxes[mask, :].copy()
294 |
295 | # take only matching gt labels
296 | current_labels = labels[mask]
297 |
298 | # should we use the box left and top corner or the crop's
299 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
300 | rect[:2])
301 | # adjust to crop (by substracting crop's left,top)
302 | current_boxes[:, :2] -= rect[:2]
303 |
304 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
305 | rect[2:])
306 | # adjust to crop (by substracting crop's left,top)
307 | current_boxes[:, 2:] -= rect[:2]
308 |
309 | return current_image, current_boxes, current_labels
310 |
311 |
312 | class Expand(object):
313 | def __init__(self, mean):
314 | self.mean = mean
315 |
316 | def __call__(self, image, boxes, labels):
317 | if random.randint(2):
318 | return image, boxes, labels
319 |
320 | height, width, depth = image.shape
321 | ratio = random.uniform(1, 4)
322 | left = random.uniform(0, width*ratio - width)
323 | top = random.uniform(0, height*ratio - height)
324 |
325 | expand_image = np.zeros(
326 | (int(height*ratio), int(width*ratio), depth),
327 | dtype=image.dtype)
328 | expand_image[:, :, :] = self.mean
329 | expand_image[int(top):int(top + height),
330 | int(left):int(left + width)] = image
331 | image = expand_image
332 |
333 | boxes = boxes.copy()
334 | boxes[:, :2] += (int(left), int(top))
335 | boxes[:, 2:] += (int(left), int(top))
336 |
337 | return image, boxes, labels
338 |
339 |
340 | class RandomMirror(object):
341 | def __call__(self, image, boxes, classes):
342 | _, width, _ = image.shape
343 | if random.randint(2):
344 | image = image[:, ::-1]
345 | boxes = boxes.copy()
346 | boxes[:, 0::2] = width - boxes[:, 2::-2]
347 | return image, boxes, classes
348 |
349 |
350 | class SwapChannels(object):
351 | """Transforms a tensorized image by swapping the channels in the order
352 | specified in the swap tuple.
353 | Args:
354 | swaps (int triple): final order of channels
355 | eg: (2, 1, 0)
356 | """
357 |
358 | def __init__(self, swaps):
359 | self.swaps = swaps
360 |
361 | def __call__(self, image):
362 | """
363 | Args:
364 | image (Tensor): image tensor to be transformed
365 | Return:
366 | a tensor with channels swapped according to swap
367 | """
368 | # if torch.is_tensor(image):
369 | # image = image.data.cpu().numpy()
370 | # else:
371 | # image = np.array(image)
372 | image = image[:, :, self.swaps]
373 | return image
374 |
375 |
376 | class PhotometricDistort(object):
377 | def __init__(self):
378 | self.pd = [
379 | RandomContrast(),
380 | ConvertColor(transform='HSV'),
381 | RandomSaturation(),
382 | RandomHue(),
383 | ConvertColor(current='HSV', transform='BGR'),
384 | RandomContrast()
385 | ]
386 | self.rand_brightness = RandomBrightness()
387 | self.rand_light_noise = RandomLightingNoise()
388 |
389 | def __call__(self, image, boxes, labels):
390 | im = image.copy()
391 | im, boxes, labels = self.rand_brightness(im, boxes, labels)
392 | if random.randint(2):
393 | distort = Compose(self.pd[:-1])
394 | else:
395 | distort = Compose(self.pd[1:])
396 | im, boxes, labels = distort(im, boxes, labels)
397 | return self.rand_light_noise(im, boxes, labels)
398 |
399 |
400 | class SSDAugmentation(object):
401 | def __init__(self, size=300, mean=(104, 117, 123)):
402 | self.mean = mean
403 | self.size = size
404 | self.augment = Compose([
405 | ConvertFromInts(),
406 | ToAbsoluteCoords(),
407 | PhotometricDistort(),
408 | Expand(self.mean),
409 | RandomSampleCrop(),
410 | RandomMirror(),
411 | ToPercentCoords(),
412 | Resize(self.size),
413 | SubtractMeans(self.mean)
414 | ])
415 |
416 | # self.augment = Compose([
417 | # ToAbsoluteCoords(),
418 | # # Expand(self.mean),
419 | # # RandomSampleCrop(),
420 | # ToPercentCoords(),
421 | # Resize(self.size),
422 | # ])
423 |
424 | def __call__(self, img, boxes, labels):
425 | return self.augment(img, boxes, labels)
426 |
--------------------------------------------------------------------------------