├── .idea
└── vcs.xml
├── LICENSE
├── README.md
├── dataset
├── __init__.py
├── cifar.py
├── mydataset.py
└── randaugment.py
├── eval.py
├── files.zip
├── images
└── consistency.png
├── main.py
├── models
├── ema.py
├── resnet_imagenet.py
├── resnext.py
└── wideresnet.py
├── run_cifar10.sh
├── run_cifar100.sh
├── run_eval_cifar10.sh
├── run_imagenet.sh
├── trainer.py
└── utils
├── __init__.py
├── default.py
├── misc.py
└── parser.py
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Vision and Learning Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## [OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021)](https://arxiv.org/pdf/2105.14148.pdf)
2 |
3 | 
4 |
5 |
6 | This is an PyTorch implementation of OpenMatch.
7 | This implementation is based on [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch).
8 |
9 |
10 |
11 | ## Requirements
12 | - python 3.6+
13 | - torch 1.4
14 | - torchvision 0.5
15 | - tensorboard
16 | - numpy
17 | - tqdm
18 | - sklearn
19 | - apex (optional)
20 |
21 | See [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch) for the details.
22 |
23 | ## Usage
24 |
25 | ### Dataset Preparation
26 | This repository needs CIFAR10, CIFAR100, or ImageNet-30 to train a model.
27 |
28 | To fully reproduce the results in evaluation, we also need SVHN, LSUN, ImageNet
29 | for CIFAR10, 100, and LSUN, DTD, CUB, Flowers, Caltech_256, Stanford Dogs for ImageNet-30.
30 | To prepare the datasets above, follow [CSI](https://github.com/alinlab/CSI).
31 |
32 |
33 | ```
34 | mkdir data
35 | ln -s path_to_each_dataset ./data/.
36 |
37 | ## unzip filelist for imagenet_30 experiments.
38 | unzip files.zip
39 | ```
40 |
41 | All datasets are supposed to be under ./data.
42 |
43 | ### Train
44 | Train the model by 50 labeled data per class of CIFAR-10 dataset:
45 |
46 | ```
47 | sh run_cifar10.sh 50 save_directory
48 | ```
49 |
50 | Train the model by 50 labeled data per class of CIFAR-100 dataset, 55 known classes:
51 |
52 | ```
53 | sh run_cifar100.sh 50 10 save_directory
54 | ```
55 |
56 |
57 | Train the model by 50 labeled data per class of CIFAR-100 dataset, 80 known classes:
58 |
59 | ```
60 | sh run_cifar100.sh 50 15 save_directory
61 | ```
62 |
63 |
64 | Run experiments on ImageNet-30:
65 |
66 | ```
67 | sh run_imagenet.sh save_directory
68 | ```
69 |
70 |
71 | ### Evaluation
72 | Evaluate a model trained on cifar10
73 |
74 | ```
75 | sh run_eval_cifar10.sh trained_model.pth
76 | ```
77 |
78 | ### Trained models
79 | Coming soon.
80 |
81 | - [CIFAR10-50-labeled](https://drive.google.com/file/d/1oNWAR8jVlxQXH0TMql1P-c7_i5-taU2T/view?usp=sharing)
82 | - [CIFAR100-50-labeled-55class](https://drive.google.com/file/d/1T5a_p4XUEOexEnjLWpGd-3pme4OzJ2pP/view?usp=sharing)
83 | - ImageNet-30
84 |
85 | ### Acknowledgement
86 | This repository depends a lot on [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch) for FixMatch implementation, and [CSI](https://github.com/alinlab/CSI) for anomaly detection evaluation.
87 | Thanks for sharing the great code bases!
88 |
89 | ### Reference
90 | This repository is contributed by [Kuniaki Saito](http://cs-people.bu.edu/keisaito/).
91 | If you consider using this code or its derivatives, please consider citing:
92 |
93 | ```
94 | @article{saito2021openmatch,
95 | title={OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers},
96 | author={Saito, Kuniaki and Kim, Donghyun and Saenko, Kate},
97 | journal={arXiv preprint arXiv:2105.14148},
98 | year={2021}
99 | }
100 | ```
101 |
102 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .cifar import *
2 |
--------------------------------------------------------------------------------
/dataset/cifar.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | from torchvision import datasets
7 | from torchvision import transforms
8 |
9 | from .randaugment import RandAugmentMC
10 | from .mydataset import ImageFolder, ImageFolder_fix
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | __all__ = ['TransformOpenMatch', 'TransformFixMatch', 'cifar10_mean',
15 | 'cifar10_std', 'cifar100_mean', 'cifar100_std', 'normal_mean',
16 | 'normal_std', 'TransformFixMatch_Imagenet',
17 | 'TransformFixMatch_Imagenet_Weak']
18 | ### Enter Path of the data directory.
19 | DATA_PATH = './data'
20 |
21 | cifar10_mean = (0.4914, 0.4822, 0.4465)
22 | cifar10_std = (0.2471, 0.2435, 0.2616)
23 | cifar100_mean = (0.5071, 0.4867, 0.4408)
24 | cifar100_std = (0.2675, 0.2565, 0.2761)
25 | normal_mean = (0.5, 0.5, 0.5)
26 | normal_std = (0.5, 0.5, 0.5)
27 |
28 |
29 | def get_cifar(args, norm=True):
30 | root = args.root
31 | name = args.dataset
32 | if name == "cifar10":
33 | data_folder = datasets.CIFAR10
34 | data_folder_main = CIFAR10SSL
35 | mean = cifar10_mean
36 | std = cifar10_std
37 | num_class = 10
38 | elif name == "cifar100":
39 | data_folder = CIFAR100FIX
40 | data_folder_main = CIFAR100SSL
41 | mean = cifar100_mean
42 | std = cifar100_std
43 | num_class = 100
44 | num_super = args.num_super
45 |
46 | else:
47 | raise NotImplementedError()
48 | assert num_class > args.num_classes
49 |
50 | if name == "cifar10":
51 | base_dataset = data_folder(root, train=True, download=True)
52 | args.num_classes = 6
53 | elif name == 'cifar100':
54 | base_dataset = data_folder(root, train=True,
55 | download=True, num_super=num_super)
56 | args.num_classes = base_dataset.num_known_class
57 |
58 | base_dataset.targets = np.array(base_dataset.targets)
59 | if name == 'cifar10':
60 | base_dataset.targets -= 2
61 | base_dataset.targets[np.where(base_dataset.targets == -2)[0]] = 8
62 | base_dataset.targets[np.where(base_dataset.targets == -1)[0]] = 9
63 |
64 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = \
65 | x_u_split(args, base_dataset.targets)
66 |
67 | ## This function will be overwritten in trainer.py
68 | norm_func = TransformFixMatch(mean=mean, std=std, norm=norm)
69 | if norm:
70 | norm_func_test = transforms.Compose([
71 | transforms.ToTensor(),
72 | transforms.Normalize(mean=mean, std=std)
73 | ])
74 | else:
75 | norm_func_test = transforms.Compose([
76 | transforms.ToTensor(),
77 | ])
78 |
79 | if name == 'cifar10':
80 | train_labeled_dataset = data_folder_main(
81 | root, train_labeled_idxs, train=True,
82 | transform=norm_func)
83 | train_unlabeled_dataset = data_folder_main(
84 | root, train_unlabeled_idxs, train=True,
85 | transform=norm_func, return_idx=False)
86 | val_dataset = data_folder_main(
87 | root, val_idxs, train=True,
88 | transform=norm_func_test)
89 | elif name == 'cifar100':
90 | train_labeled_dataset = data_folder_main(
91 | root, train_labeled_idxs, num_super = num_super, train=True,
92 | transform=norm_func)
93 | train_unlabeled_dataset = data_folder_main(
94 | root, train_unlabeled_idxs, num_super = num_super, train=True,
95 | transform=norm_func, return_idx=False)
96 | val_dataset = data_folder_main(
97 | root, val_idxs, num_super = num_super,train=True,
98 | transform=norm_func_test)
99 |
100 | if name == 'cifar10':
101 | train_labeled_dataset.targets -= 2
102 | train_unlabeled_dataset.targets -= 2
103 | val_dataset.targets -= 2
104 |
105 |
106 | if name == 'cifar10':
107 | test_dataset = data_folder(
108 | root, train=False, transform=norm_func_test, download=False)
109 | elif name == 'cifar100':
110 | test_dataset = data_folder(
111 | root, train=False, transform=norm_func_test,
112 | download=False, num_super=num_super)
113 | test_dataset.targets = np.array(test_dataset.targets)
114 |
115 | if name == 'cifar10':
116 | test_dataset.targets -= 2
117 | test_dataset.targets[np.where(test_dataset.targets == -2)[0]] = 8
118 | test_dataset.targets[np.where(test_dataset.targets == -1)[0]] = 9
119 |
120 | target_ind = np.where(test_dataset.targets >= args.num_classes)[0]
121 | test_dataset.targets[target_ind] = args.num_classes
122 |
123 |
124 | unique_labeled = np.unique(train_labeled_idxs)
125 | val_labeled = np.unique(val_idxs)
126 | logger.info("Dataset: %s"%name)
127 | logger.info(f"Labeled examples: {len(unique_labeled)}"
128 | f"Unlabeled examples: {len(train_unlabeled_idxs)}"
129 | f"Valdation samples: {len(val_labeled)}")
130 | return train_labeled_dataset, train_unlabeled_dataset, \
131 | test_dataset, val_dataset
132 |
133 |
134 |
135 | def get_imagenet(args, norm=True):
136 | mean = normal_mean
137 | std = normal_std
138 | txt_labeled = "filelist/imagenet_train_labeled.txt"
139 | txt_unlabeled = "filelist/imagenet_train_unlabeled.txt"
140 | txt_val = "filelist/imagenet_val.txt"
141 | txt_test = "filelist/imagenet_test.txt"
142 | ## This function will be overwritten in trainer.py
143 | norm_func = TransformFixMatch_Imagenet(mean=mean, std=std,
144 | norm=norm, size_image=224)
145 | dataset_labeled = ImageFolder(txt_labeled, transform=norm_func)
146 | dataset_unlabeled = ImageFolder_fix(txt_unlabeled, transform=norm_func)
147 |
148 | test_transform = transforms.Compose([
149 | transforms.Resize(256),
150 | transforms.CenterCrop(224),
151 | transforms.ToTensor(),
152 | transforms.Normalize(mean=mean, std=std)
153 | ])
154 | dataset_val = ImageFolder(txt_val, transform=test_transform)
155 | dataset_test = ImageFolder(txt_test, transform=test_transform)
156 | logger.info(f"Labeled examples: {len(dataset_labeled)}"
157 | f"Unlabeled examples: {len(dataset_unlabeled)}"
158 | f"Valdation samples: {len(dataset_val)}")
159 | return dataset_labeled, dataset_unlabeled, dataset_test, dataset_val
160 |
161 |
162 | def x_u_split(args, labels):
163 | label_per_class = args.num_labeled #// args.num_classes
164 | val_per_class = args.num_val #// args.num_classes
165 | labels = np.array(labels)
166 | labeled_idx = []
167 | val_idx = []
168 | unlabeled_idx = []
169 | # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
170 | for i in range(args.num_classes):
171 | idx = np.where(labels == i)[0]
172 | unlabeled_idx.extend(idx)
173 | idx = np.random.choice(idx, label_per_class+val_per_class, False)
174 | labeled_idx.extend(idx[:label_per_class])
175 | val_idx.extend(idx[label_per_class:])
176 |
177 | labeled_idx = np.array(labeled_idx)
178 |
179 | assert len(labeled_idx) == args.num_labeled * args.num_classes
180 | if args.expand_labels or args.num_labeled < args.batch_size:
181 | num_expand_x = math.ceil(
182 | args.batch_size * args.eval_step / args.num_labeled)
183 | labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
184 | np.random.shuffle(labeled_idx)
185 |
186 | #if not args.no_out:
187 | unlabeled_idx = np.array(range(len(labels)))
188 | unlabeled_idx = [idx for idx in unlabeled_idx if idx not in labeled_idx]
189 | unlabeled_idx = [idx for idx in unlabeled_idx if idx not in val_idx]
190 | return labeled_idx, unlabeled_idx, val_idx
191 |
192 |
193 | class TransformFixMatch(object):
194 | def __init__(self, mean, std, norm=True, size_image=32):
195 | self.weak = transforms.Compose([
196 | transforms.RandomHorizontalFlip(),
197 | transforms.RandomCrop(size=size_image,
198 | padding=int(size_image*0.125),
199 | padding_mode='reflect')])
200 | self.weak2 = transforms.Compose([
201 | transforms.RandomHorizontalFlip(),])
202 | self.strong = transforms.Compose([
203 | transforms.RandomHorizontalFlip(),
204 | transforms.RandomCrop(size=size_image,
205 | padding=int(size_image*0.125),
206 | padding_mode='reflect'),
207 | RandAugmentMC(n=2, m=10)])
208 | self.normalize = transforms.Compose([
209 | transforms.ToTensor(),
210 | transforms.Normalize(mean=mean, std=std)])
211 | self.norm = norm
212 |
213 | def __call__(self, x):
214 | weak = self.weak(x)
215 | strong = self.strong(x)
216 | if self.norm:
217 | return self.normalize(weak), self.normalize(strong), self.normalize(self.weak2(x))
218 | else:
219 | return weak, strong
220 |
221 | class TransformOpenMatch(object):
222 | def __init__(self, mean, std, norm=True, size_image=32):
223 | self.weak = transforms.Compose([
224 | transforms.RandomHorizontalFlip(),
225 | transforms.RandomCrop(size=size_image,
226 | padding=int(size_image*0.125),
227 | padding_mode='reflect')])
228 | self.weak2 = transforms.Compose([
229 | transforms.RandomHorizontalFlip(),])
230 | self.normalize = transforms.Compose([
231 | transforms.ToTensor(),
232 | transforms.Normalize(mean=mean, std=std)])
233 | self.norm = norm
234 |
235 | def __call__(self, x):
236 | weak = self.weak(x)
237 | strong = self.weak(x)
238 |
239 | if self.norm:
240 | return self.normalize(weak), self.normalize(strong), self.normalize(self.weak2(x))
241 | else:
242 | return weak, strong
243 |
244 |
245 |
246 |
247 | class TransformFixMatch_Imagenet(object):
248 | def __init__(self, mean, std, norm=True, size_image=224):
249 | self.weak = transforms.Compose([
250 | transforms.Scale((256, 256)),
251 | transforms.RandomHorizontalFlip(),
252 | transforms.RandomCrop(size=size_image,
253 | padding=int(size_image*0.125),
254 | padding_mode='reflect')])
255 | self.weak2 = transforms.Compose([
256 | transforms.Scale((256, 256)),
257 | transforms.RandomHorizontalFlip(),
258 | transforms.CenterCrop(size=size_image),
259 | ])
260 | self.strong = transforms.Compose([
261 | transforms.Scale((256, 256)),
262 | transforms.RandomHorizontalFlip(),
263 | transforms.RandomCrop(size=size_image,
264 | padding=int(size_image*0.125),
265 | padding_mode='reflect'),
266 | RandAugmentMC(n=2, m=10)])
267 | self.normalize = transforms.Compose([
268 | transforms.ToTensor(),
269 | transforms.Normalize(mean=mean, std=std)])
270 | self.norm = norm
271 |
272 | def __call__(self, x):
273 | weak = self.weak(x)
274 | weak2 = self.weak2(x)
275 | strong = self.strong(x)
276 | if self.norm:
277 | return self.normalize(weak), self.normalize(strong), self.normalize(weak2)
278 | else:
279 | return weak, strong
280 |
281 |
282 |
283 | class TransformFixMatch_Imagenet_Weak(object):
284 | def __init__(self, mean, std, norm=True, size_image=224):
285 | self.weak = transforms.Compose([
286 | transforms.Scale((256, 256)),
287 | transforms.RandomHorizontalFlip(),
288 | transforms.RandomCrop(size=size_image,
289 | padding=int(size_image*0.125),
290 | padding_mode='reflect')])
291 | self.weak2 = transforms.Compose([
292 | transforms.Scale((256, 256)),
293 | transforms.RandomHorizontalFlip(),
294 | transforms.CenterCrop(size=size_image),
295 | ])
296 | self.strong = transforms.Compose([
297 | transforms.Scale((256, 256)),
298 | transforms.RandomHorizontalFlip(),
299 | transforms.RandomCrop(size=size_image,
300 | padding=int(size_image*0.125),
301 | padding_mode='reflect'),
302 | RandAugmentMC(n=2, m=10)])
303 | self.normalize = transforms.Compose([
304 | transforms.ToTensor(),
305 | transforms.Normalize(mean=mean, std=std)])
306 | self.norm = norm
307 |
308 | def __call__(self, x):
309 | weak = self.weak2(x)
310 | weak2 = self.weak2(x)
311 | strong = self.strong(x)
312 | if self.norm:
313 | return self.normalize(weak), self.normalize(strong), self.normalize(weak2)
314 | else:
315 | return weak, strong
316 |
317 |
318 |
319 |
320 | class CIFAR10SSL(datasets.CIFAR10):
321 | def __init__(self, root, indexs, train=True,
322 | transform=None, target_transform=None,
323 | download=False, return_idx=False):
324 | super().__init__(root, train=train,
325 | transform=transform,
326 | target_transform=target_transform,
327 | download=download)
328 | if indexs is not None:
329 | self.data = self.data[indexs]
330 | self.targets = np.array(self.targets)[indexs]
331 | self.return_idx = return_idx
332 | self.set_index()
333 |
334 | def set_index(self, indexes=None):
335 | if indexes is not None:
336 | self.data_index = self.data[indexes]
337 | self.targets_index = self.targets[indexes]
338 | else:
339 | self.data_index = self.data
340 | self.targets_index = self.targets
341 |
342 | def init_index(self):
343 | self.data_index = self.data
344 | self.targets_index = self.targets
345 |
346 | def __getitem__(self, index):
347 | img, target = self.data_index[index], self.targets_index[index]
348 | img = Image.fromarray(img)
349 |
350 | if self.transform is not None:
351 | img = self.transform(img)
352 |
353 | if self.target_transform is not None:
354 | target = self.target_transform(target)
355 |
356 | if not self.return_idx:
357 | return img, target
358 | else:
359 | return img, target, index
360 |
361 | def __len__(self):
362 | return len(self.data_index)
363 |
364 |
365 |
366 |
367 |
368 |
369 | class CIFAR100FIX(datasets.CIFAR100):
370 | def __init__(self, root, num_super=10, train=True, transform=None,
371 | target_transform=None, download=False, return_idx=False):
372 | super().__init__(root, train=train, transform=transform,
373 | target_transform=target_transform, download=download)
374 |
375 | coarse_labels = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
376 | 3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
377 | 6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
378 | 0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
379 | 5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
380 | 16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
381 | 10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
382 | 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
383 | 16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
384 | 18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
385 | self.course_labels = coarse_labels[self.targets]
386 | self.targets = np.array(self.targets)
387 | labels_unknown = self.targets[np.where(self.course_labels > num_super)[0]]
388 | labels_known = self.targets[np.where(self.course_labels <= num_super)[0]]
389 | unknown_categories = np.unique(labels_unknown)
390 | known_categories = np.unique(labels_known)
391 |
392 | num_unknown = len(unknown_categories)
393 | num_known = len(known_categories)
394 | print("number of unknown categories %s"%num_unknown)
395 | print("number of known categories %s"%num_known)
396 | assert num_known + num_unknown == 100
397 | #new_category_labels = list(range(num_known))
398 | self.targets_new = np.zeros_like(self.targets)
399 | for i, known in enumerate(known_categories):
400 | ind_known = np.where(self.targets==known)[0]
401 | self.targets_new[ind_known] = i
402 | for i, unknown in enumerate(unknown_categories):
403 | ind_unknown = np.where(self.targets == unknown)[0]
404 | self.targets_new[ind_unknown] = num_known
405 |
406 | self.targets = self.targets_new
407 | assert len(np.where(self.targets == num_known)[0]) == len(labels_unknown)
408 | assert len(np.where(self.targets < num_known)[0]) == len(labels_known)
409 | self.num_known_class = num_known
410 |
411 |
412 | def __getitem__(self, index):
413 |
414 | img, target = self.data[index], self.targets[index]
415 | img = Image.fromarray(img)
416 |
417 | if self.transform is not None:
418 | img = self.transform(img)
419 |
420 | if self.target_transform is not None:
421 | target = self.target_transform(target)
422 |
423 | return img, target
424 |
425 |
426 | class CIFAR100SSL(CIFAR100FIX):
427 | def __init__(self, root, indexs, num_super=10, train=True,
428 | transform=None, target_transform=None,
429 | download=False, return_idx=False):
430 | super().__init__(root, num_super=num_super,train=train,
431 | transform=transform,
432 | target_transform=target_transform,
433 | download=download)
434 | self.return_idx = return_idx
435 | if indexs is not None:
436 | self.data = self.data[indexs]
437 | self.targets = np.array(self.targets)[indexs]
438 |
439 | self.set_index()
440 | def set_index(self, indexes=None):
441 | if indexes is not None:
442 | self.data_index = self.data[indexes]
443 | self.targets_index = self.targets[indexes]
444 | else:
445 | self.data_index = self.data
446 | self.targets_index = self.targets
447 |
448 | def init_index(self):
449 | self.data_index = self.data
450 | self.targets_index = self.targets
451 |
452 |
453 | def __getitem__(self, index):
454 | img, target = self.data_index[index], self.targets_index[index]
455 | img = Image.fromarray(img)
456 |
457 | if self.transform is not None:
458 | img = self.transform(img)
459 |
460 | if self.target_transform is not None:
461 | target = self.target_transform(target)
462 | if not self.return_idx:
463 | return img, target
464 | else:
465 | return img, target, index
466 |
467 | def __len__(self):
468 | return len(self.data_index)
469 |
470 | def get_transform(mean, std, image_size=None):
471 | # Note: data augmentation is implemented in the layers
472 | # Hence, we only define the identity transformation here
473 | if image_size: # use pre-specified image size
474 | train_transform = transforms.Compose([
475 | transforms.Resize((image_size[0], image_size[1])),
476 | transforms.RandomHorizontalFlip(),
477 | transforms.ToTensor(),
478 | transforms.Normalize(mean=mean, std=std),
479 | ])
480 | test_transform = transforms.Compose([
481 | transforms.Resize((image_size[0], image_size[1])),
482 | transforms.ToTensor(),
483 | transforms.Normalize(mean=mean, std=std),
484 | ])
485 | else: # use default image size
486 | train_transform = transforms.Compose([
487 | transforms.ToTensor(),
488 | transforms.Normalize(mean=mean, std=std),
489 | ])
490 | test_transform = transforms.ToTensor()
491 |
492 | return train_transform, test_transform
493 |
494 |
495 | def get_ood(dataset, id, test_only=False, image_size=None):
496 | image_size = (32, 32, 3) if image_size is None else image_size
497 | if id == "cifar10":
498 | mean = cifar10_mean
499 | std = cifar10_std
500 | elif id == "cifar100":
501 | mean = cifar100_mean
502 | std = cifar100_std
503 | elif "imagenet" in id or id == "tiny":
504 | mean = normal_mean
505 | std = normal_std
506 |
507 | _, test_transform = get_transform(mean, std, image_size=image_size)
508 |
509 | if dataset == 'cifar10':
510 | test_set = datasets.CIFAR10(DATA_PATH, train=False, download=False,
511 | transform=test_transform)
512 |
513 | elif dataset == 'cifar100':
514 | test_set = datasets.CIFAR100(DATA_PATH, train=False, download=False,
515 | transform=test_transform)
516 |
517 | elif dataset == 'svhn':
518 | test_set = datasets.SVHN(DATA_PATH, split='test', download=True,
519 | transform=test_transform)
520 |
521 | elif dataset == 'lsun':
522 | test_dir = os.path.join(DATA_PATH, 'LSUN_fix')
523 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
524 |
525 | elif dataset == 'imagenet':
526 | test_dir = os.path.join(DATA_PATH, 'Imagenet_fix')
527 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
528 | elif dataset == 'stanford_dogs':
529 | test_dir = os.path.join(DATA_PATH, 'stanford_dogs')
530 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
531 |
532 | elif dataset == 'cub':
533 | test_dir = os.path.join(DATA_PATH, 'cub')
534 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
535 |
536 | elif dataset == 'flowers102':
537 | test_dir = os.path.join(DATA_PATH, 'flowers102')
538 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
539 |
540 | elif dataset == 'food_101':
541 | test_dir = os.path.join(DATA_PATH, 'food-101', 'images')
542 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
543 |
544 | elif dataset == 'caltech_256':
545 | test_dir = os.path.join(DATA_PATH, 'caltech-256')
546 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
547 |
548 | elif dataset == 'dtd':
549 | test_dir = os.path.join(DATA_PATH, 'dtd')
550 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
551 |
552 | elif dataset == 'pets':
553 | test_dir = os.path.join(DATA_PATH, 'pets')
554 | test_set = datasets.ImageFolder(test_dir, transform=test_transform)
555 |
556 | return test_set
557 |
558 | DATASET_GETTERS = {'cifar10': get_cifar,
559 | 'cifar100': get_cifar,
560 | 'imagenet': get_imagenet,
561 | }
562 |
--------------------------------------------------------------------------------
/dataset/mydataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 |
7 | IMG_EXTENSIONS = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | ]
11 |
12 |
13 | def find_classes(dir):
14 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
15 | classes.sort()
16 | class_to_idx = {classes[i]: i for i in range(len(classes))}
17 | return classes, class_to_idx
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir, class_to_idx):
25 | images = []
26 | dir = os.path.expanduser(dir)
27 | for target in os.listdir(dir):
28 | d = os.path.join(dir, target)
29 | if not os.path.isdir(d):
30 | continue
31 |
32 | for root, _, fnames in sorted(os.walk(d)):
33 | for fname in fnames:
34 | if is_image_file(fname):
35 | path = os.path.join(root, fname)
36 | item = (path, class_to_idx[target])
37 | images.append(item)
38 |
39 | return images
40 |
41 |
42 | def default_flist_reader(flist):
43 | """
44 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
45 | """
46 | imlist = []
47 | with open(flist, 'r') as rf:
48 | for line in rf.readlines():
49 | impath, imlabel = line.strip().split()
50 | imlist.append((impath, int(imlabel)))
51 |
52 | return imlist
53 |
54 |
55 | def default_loader(path):
56 | return Image.open(path).convert('RGB')
57 |
58 |
59 | def make_dataset_nolist(image_list):
60 | with open(image_list) as f:
61 | image_index = [x.split(' ')[0] for x in f.readlines()]
62 | with open(image_list) as f:
63 | label_list = []
64 | selected_list = []
65 | for ind, x in enumerate(f.readlines()):
66 | label = x.split(' ')[1].strip()
67 | label_list.append(int(label))
68 | selected_list.append(ind)
69 | image_index = np.array(image_index)
70 | label_list = np.array(label_list)
71 | image_index = image_index[selected_list]
72 | return image_index, label_list
73 |
74 |
75 | class ImageFolder(data.Dataset):
76 | """A generic data loader where the images are arranged in this way: ::
77 | root/dog/xxx.png
78 | root/dog/xxy.png
79 | root/dog/xxz.png
80 | root/cat/123.png
81 | root/cat/nsdf3.png
82 | root/cat/asd932_.png
83 | Args:
84 | root (string): Root directory path.
85 | transform (callable, optional): A function/transform that takes in an PIL image
86 | and returns a transformed version. E.g, ``transforms.RandomCrop``
87 | target_transform (callable, optional): A function/transform that takes in the
88 | target and transforms it.
89 | loader (callable, optional): A function to load an image given its path.
90 | Attributes:
91 | classes (list): List of the class names.
92 | class_to_idx (dict): Dict with items (class_name, class_index).
93 | imgs (list): List of (image path, class_index) tuples
94 | """
95 |
96 | def __init__(self, image_list, transform=None, target_transform=None, return_paths=False,
97 | loader=default_loader,train=False, return_id=False):
98 | imgs, labels = make_dataset_nolist(image_list)
99 | self.imgs = imgs
100 | self.labels= labels
101 | self.transform = transform
102 | self.target_transform = target_transform
103 | self.loader = loader
104 | self.return_paths = return_paths
105 | self.return_id = return_id
106 | self.train = train
107 |
108 | def __getitem__(self, index):
109 | """
110 | Args:
111 | index (int): Index
112 | Returns:
113 | tuple: (image, target) where target is class_index of the target class.
114 | """
115 |
116 | path = self.imgs[index]
117 | target = self.labels[index]
118 | img = self.loader(path)
119 | img = self.transform(img)
120 |
121 | if self.target_transform is not None:
122 | target = self.target_transform(target)
123 | if self.return_paths:
124 | return img, target, path
125 | elif self.return_id:
126 | return img, target ,index
127 | else:
128 | return img, target
129 |
130 | def __len__(self):
131 | return len(self.imgs)
132 |
133 |
134 | class ImageFolder_fix(data.Dataset):
135 | """A generic data loader where the images are arranged in this way: ::
136 | root/dog/xxx.png
137 | root/dog/xxy.png
138 | root/dog/xxz.png
139 | root/cat/123.png
140 | root/cat/nsdf3.png
141 | root/cat/asd932_.png
142 | Args:
143 | root (string): Root directory path.
144 | transform (callable, optional): A function/transform that takes in an PIL image
145 | and returns a transformed version. E.g, ``transforms.RandomCrop``
146 | target_transform (callable, optional): A function/transform that takes in the
147 | target and transforms it.
148 | loader (callable, optional): A function to load an image given its path.
149 | Attributes:
150 | classes (list): List of the class names.
151 | class_to_idx (dict): Dict with items (class_name, class_index).
152 | imgs (list): List of (image path, class_index) tuples
153 | """
154 |
155 | def __init__(self, image_list, transform=None, target_transform=None, return_paths=False,
156 | loader=default_loader,train=False, return_id=False):
157 | imgs, labels = make_dataset_nolist(image_list)
158 | self.imgs = imgs
159 | self.labels= labels
160 | self.transform = transform
161 | self.target_transform = target_transform
162 | self.loader = loader
163 | self.return_paths = return_paths
164 | self.return_id = return_id
165 | self.train = train
166 | self.set_index()
167 |
168 | def set_index(self, indexes=None):
169 | if indexes is not None:
170 | self.imgs_index = self.imgs[indexes]
171 | self.targets_index = self.labels[indexes]
172 | else:
173 | self.imgs_index = self.imgs
174 | self.targets_index = self.labels
175 |
176 | def init_index(self):
177 | self.imgs_index = self.imgs
178 | self.targets_index = self.labels
179 |
180 |
181 |
182 | def __getitem__(self, index):
183 | """
184 | Args:
185 | index (int): Index
186 | Returns:
187 | tuple: (image, target) where target is class_index of the target class.
188 | """
189 |
190 | path = self.imgs_index[index]
191 | target = self.targets_index[index]
192 | img = self.loader(path)
193 | img = self.transform(img)
194 |
195 | if self.target_transform is not None:
196 | target = self.target_transform(target)
197 | if self.return_paths:
198 | return img, target, path
199 | elif self.return_id:
200 | return img, target ,index
201 | else:
202 | return img, target
203 |
204 | def __len__(self):
205 | return len(self.imgs_index)
206 |
207 |
--------------------------------------------------------------------------------
/dataset/randaugment.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from
2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
5 | import logging
6 | import random
7 |
8 | import numpy as np
9 | import PIL
10 | import PIL.ImageOps
11 | import PIL.ImageEnhance
12 | import PIL.ImageDraw
13 | from PIL import Image
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | PARAMETER_MAX = 10
18 |
19 |
20 | def AutoContrast(img, **kwarg):
21 | return PIL.ImageOps.autocontrast(img)
22 |
23 |
24 | def Brightness(img, v, max_v, bias=0):
25 | v = _float_parameter(v, max_v) + bias
26 | return PIL.ImageEnhance.Brightness(img).enhance(v)
27 |
28 |
29 | def Color(img, v, max_v, bias=0):
30 | v = _float_parameter(v, max_v) + bias
31 | return PIL.ImageEnhance.Color(img).enhance(v)
32 |
33 |
34 | def Contrast(img, v, max_v, bias=0):
35 | v = _float_parameter(v, max_v) + bias
36 | return PIL.ImageEnhance.Contrast(img).enhance(v)
37 |
38 |
39 | def Cutout(img, v, max_v, bias=0):
40 | if v == 0:
41 | return img
42 | v = _float_parameter(v, max_v) + bias
43 | v = int(v * min(img.size))
44 | return CutoutAbs(img, v)
45 |
46 |
47 | def CutoutAbs(img, v, **kwarg):
48 | w, h = img.size
49 | x0 = np.random.uniform(0, w)
50 | y0 = np.random.uniform(0, h)
51 | x0 = int(max(0, x0 - v / 2.))
52 | y0 = int(max(0, y0 - v / 2.))
53 | x1 = int(min(w, x0 + v))
54 | y1 = int(min(h, y0 + v))
55 | xy = (x0, y0, x1, y1)
56 | # gray
57 | color = (127, 127, 127)
58 | img = img.copy()
59 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
60 | return img
61 |
62 |
63 | def Equalize(img, **kwarg):
64 | return PIL.ImageOps.equalize(img)
65 |
66 |
67 | def Identity(img, **kwarg):
68 | return img
69 |
70 |
71 | def Invert(img, **kwarg):
72 | return PIL.ImageOps.invert(img)
73 |
74 |
75 | def Posterize(img, v, max_v, bias=0):
76 | v = _int_parameter(v, max_v) + bias
77 | return PIL.ImageOps.posterize(img, v)
78 |
79 |
80 | def Rotate(img, v, max_v, bias=0):
81 | v = _int_parameter(v, max_v) + bias
82 | if random.random() < 0.5:
83 | v = -v
84 | return img.rotate(v)
85 |
86 |
87 | def Sharpness(img, v, max_v, bias=0):
88 | v = _float_parameter(v, max_v) + bias
89 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
90 |
91 |
92 | def ShearX(img, v, max_v, bias=0):
93 | v = _float_parameter(v, max_v) + bias
94 | if random.random() < 0.5:
95 | v = -v
96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
97 |
98 |
99 | def ShearY(img, v, max_v, bias=0):
100 | v = _float_parameter(v, max_v) + bias
101 | if random.random() < 0.5:
102 | v = -v
103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
104 |
105 |
106 | def Solarize(img, v, max_v, bias=0):
107 | v = _int_parameter(v, max_v) + bias
108 | return PIL.ImageOps.solarize(img, 256 - v)
109 |
110 |
111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
112 | v = _int_parameter(v, max_v) + bias
113 | if random.random() < 0.5:
114 | v = -v
115 | img_np = np.array(img).astype(np.int)
116 | img_np = img_np + v
117 | img_np = np.clip(img_np, 0, 255)
118 | img_np = img_np.astype(np.uint8)
119 | img = Image.fromarray(img_np)
120 | return PIL.ImageOps.solarize(img, threshold)
121 |
122 |
123 | def TranslateX(img, v, max_v, bias=0):
124 | v = _float_parameter(v, max_v) + bias
125 | if random.random() < 0.5:
126 | v = -v
127 | v = int(v * img.size[0])
128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
129 |
130 |
131 | def TranslateY(img, v, max_v, bias=0):
132 | v = _float_parameter(v, max_v) + bias
133 | if random.random() < 0.5:
134 | v = -v
135 | v = int(v * img.size[1])
136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
137 |
138 |
139 | def _float_parameter(v, max_v):
140 | return float(v) * max_v / PARAMETER_MAX
141 |
142 |
143 | def _int_parameter(v, max_v):
144 | return int(v * max_v / PARAMETER_MAX)
145 |
146 |
147 | def fixmatch_augment_pool():
148 | # FixMatch paper
149 | augs = [(AutoContrast, None, None),
150 | (Brightness, 0.9, 0.05),
151 | (Color, 0.9, 0.05),
152 | (Contrast, 0.9, 0.05),
153 | (Equalize, None, None),
154 | (Identity, None, None),
155 | (Posterize, 4, 4),
156 | (Rotate, 30, 0),
157 | (Sharpness, 0.9, 0.05),
158 | (ShearX, 0.3, 0),
159 | (ShearY, 0.3, 0),
160 | (Solarize, 256, 0),
161 | (TranslateX, 0.3, 0),
162 | (TranslateY, 0.3, 0)]
163 | return augs
164 |
165 |
166 | def my_augment_pool():
167 | # Test
168 | augs = [(AutoContrast, None, None),
169 | (Brightness, 1.8, 0.1),
170 | (Color, 1.8, 0.1),
171 | (Contrast, 1.8, 0.1),
172 | (Cutout, 0.2, 0),
173 | (Equalize, None, None),
174 | (Invert, None, None),
175 | (Posterize, 4, 4),
176 | (Rotate, 30, 0),
177 | (Sharpness, 1.8, 0.1),
178 | (ShearX, 0.3, 0),
179 | (ShearY, 0.3, 0),
180 | (Solarize, 256, 0),
181 | (SolarizeAdd, 110, 0),
182 | (TranslateX, 0.45, 0),
183 | (TranslateY, 0.45, 0)]
184 | return augs
185 |
186 |
187 | class RandAugmentPC(object):
188 | def __init__(self, n, m):
189 | assert n >= 1
190 | assert 1 <= m <= 10
191 | self.n = n
192 | self.m = m
193 | self.augment_pool = my_augment_pool()
194 |
195 | def __call__(self, img):
196 | ops = random.choices(self.augment_pool, k=self.n)
197 | for op, max_v, bias in ops:
198 | prob = np.random.uniform(0.2, 0.8)
199 | if random.random() + prob >= 1:
200 | img = op(img, v=self.m, max_v=max_v, bias=bias)
201 | img = CutoutAbs(img, int(32*0.5))
202 | return img
203 |
204 |
205 | class RandAugmentMC(object):
206 | def __init__(self, n, m):
207 | assert n >= 1
208 | assert 1 <= m <= 10
209 | self.n = n
210 | self.m = m
211 | self.augment_pool = fixmatch_augment_pool()
212 |
213 | def __call__(self, img):
214 | ops = random.choices(self.augment_pool, k=self.n)
215 | for op, max_v, bias in ops:
216 | v = np.random.randint(1, self.m)
217 | if random.random() < 0.5:
218 | img = op(img, v=v, max_v=max_v, bias=bias)
219 | img = CutoutAbs(img, int(32*0.5))
220 | return img
221 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from utils import test, test_ood
3 |
4 | logger = logging.getLogger(__name__)
5 | best_acc = 0
6 | best_acc_val = 0
7 | def eval_model(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader,
8 | ood_loaders, model, ema_model):
9 | if args.amp:
10 | from apex import amp
11 | global best_acc
12 | global best_acc_val
13 |
14 | model.eval()
15 | if args.use_ema:
16 | test_model = ema_model.ema
17 | else:
18 | test_model = model
19 | epoch = 0
20 | if args.local_rank in [-1, 0]:
21 | val_acc = test(args, val_loader, test_model, epoch, val=True)
22 | test_loss, close_valid, test_overall, \
23 | test_unk, test_roc, test_roc_softm, test_id \
24 | = test(args, test_loader, test_model, epoch)
25 | for ood in ood_loaders.keys():
26 | roc_ood = test_ood(args, test_id, ood_loaders[ood], test_model)
27 | logger.info("ROC vs {ood}: {roc}".format(ood=ood, roc=roc_ood))
28 |
29 | overall_valid = test_overall
30 | unk_valid = test_unk
31 | roc_valid = test_roc
32 | roc_softm_valid = test_roc_softm
33 | logger.info('validation closed acc: {:.3f}'.format(val_acc))
34 | logger.info('test closed acc: {:.3f}'.format(close_valid))
35 | logger.info('test overall acc: {:.3f}'.format(overall_valid))
36 | logger.info('test unk acc: {:.3f}'.format(unk_valid))
37 | logger.info('test roc: {:.3f}'.format(roc_valid))
38 | logger.info('test roc soft: {:.3f}'.format(roc_softm_valid))
39 | if args.local_rank in [-1, 0]:
40 | args.writer.close()
41 |
--------------------------------------------------------------------------------
/files.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionLearningGroup/OP_Match/ba1a59cf42ad8c2920cba428991a6cc717901d52/files.zip
--------------------------------------------------------------------------------
/images/consistency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisionLearningGroup/OP_Match/ba1a59cf42ad8c2920cba428991a6cc717901d52/images/consistency.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | from torch.utils.tensorboard import SummaryWriter
5 | from utils import set_model_config, \
6 | set_dataset, set_models, set_parser, \
7 | set_seed
8 | from eval import eval_model
9 | from trainer import train
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | def main():
15 | args = set_parser()
16 | global best_acc
17 | global best_acc_val
18 |
19 | if args.local_rank == -1:
20 | device = torch.device('cuda', args.gpu_id)
21 | args.world_size = 1
22 | args.n_gpu = torch.cuda.device_count()
23 | else:
24 | torch.cuda.set_device(args.local_rank)
25 | device = torch.device('cuda', args.local_rank)
26 | torch.distributed.init_process_group(backend='nccl')
27 | args.world_size = torch.distributed.get_world_size()
28 | args.n_gpu = 1
29 | args.device = device
30 | logging.basicConfig(
31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
32 | datefmt="%m/%d/%Y %H:%M:%S",
33 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
34 | logger.warning(
35 | f"Process rank: {args.local_rank}, "
36 | f"device: {args.device}, "
37 | f"n_gpu: {args.n_gpu}, "
38 | f"distributed training: {bool(args.local_rank != -1)}, "
39 | f"16-bits training: {args.amp}",)
40 | logger.info(dict(args._get_kwargs()))
41 | if args.seed is not None:
42 | set_seed(args)
43 | if args.local_rank in [-1, 0]:
44 | os.makedirs(args.out, exist_ok=True)
45 | args.writer = SummaryWriter(args.out)
46 | set_model_config(args)
47 |
48 | if args.local_rank not in [-1, 0]:
49 | torch.distributed.barrier()
50 |
51 | labeled_trainloader, unlabeled_dataset, test_loader, val_loader, ood_loaders \
52 | = set_dataset(args)
53 |
54 | model, optimizer, scheduler = set_models(args)
55 | logger.info("Total params: {:.2f}M".format(
56 | sum(p.numel() for p in model.parameters()) / 1e6))
57 |
58 | if args.use_ema:
59 | from models.ema import ModelEMA
60 | ema_model = ModelEMA(args, model, args.ema_decay)
61 | args.start_epoch = 0
62 | if args.resume:
63 | logger.info("==> Resuming from checkpoint..")
64 | assert os.path.isfile(
65 | args.resume), "Error: no checkpoint directory found!"
66 | args.out = os.path.dirname(args.resume)
67 | checkpoint = torch.load(args.resume)
68 | best_acc = checkpoint['best_acc']
69 | args.start_epoch = checkpoint['epoch']
70 | model.load_state_dict(checkpoint['state_dict'])
71 | if args.use_ema:
72 | ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
73 | optimizer.load_state_dict(checkpoint['optimizer'])
74 | scheduler.load_state_dict(checkpoint['scheduler'])
75 |
76 | if args.amp:
77 | from apex import amp
78 | model, optimizer = amp.initialize(
79 | model, optimizer, opt_level=args.opt_level)
80 |
81 | if args.local_rank != -1:
82 | model = torch.nn.parallel.DistributedDataParallel(
83 | model, device_ids=[args.local_rank],
84 | output_device=args.local_rank, find_unused_parameters=True)
85 |
86 |
87 | model.zero_grad()
88 | if not args.eval_only:
89 | logger.info("***** Running training *****")
90 | logger.info(f" Task = {args.dataset}@{args.num_labeled}")
91 | logger.info(f" Num Epochs = {args.epochs}")
92 | logger.info(f" Batch size per GPU = {args.batch_size}")
93 | logger.info(f" Total train batch size = {args.batch_size*args.world_size}")
94 | logger.info(f" Total optimization steps = {args.total_steps}")
95 | train(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader,
96 | ood_loaders, model, optimizer, ema_model, scheduler)
97 | else:
98 | logger.info("***** Running Evaluation *****")
99 | logger.info(f" Task = {args.dataset}@{args.num_labeled}")
100 | eval_model(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader,
101 | ood_loaders, model, ema_model)
102 |
103 |
104 | if __name__ == '__main__':
105 | main()
106 |
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | import torch
4 |
5 |
6 | class ModelEMA(object):
7 | def __init__(self, args, model, decay):
8 | self.ema = deepcopy(model)
9 | self.ema.to(args.device)
10 | self.ema.eval()
11 | self.decay = decay
12 | self.ema_has_module = hasattr(self.ema, 'module')
13 | self.param_keys = [k for k, _ in self.ema.named_parameters()]
14 | self.buffer_keys = [k for k, _ in self.ema.named_buffers()]
15 | for p in self.ema.parameters():
16 | p.requires_grad_(False)
17 |
18 | def update(self, model):
19 | needs_module = hasattr(model, 'module') and not self.ema_has_module
20 | with torch.no_grad():
21 | msd = model.state_dict()
22 | esd = self.ema.state_dict()
23 | for k in self.param_keys:
24 | if needs_module:
25 | j = 'module.' + k
26 | else:
27 | j = k
28 | model_v = msd[j].detach()
29 | ema_v = esd[k]
30 | esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v)
31 |
32 | for k in self.buffer_keys:
33 | if needs_module:
34 | j = 'module.' + k
35 | else:
36 | j = k
37 | esd[k].copy_(msd[j])
38 |
--------------------------------------------------------------------------------
/models/resnet_imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=dilation, groups=groups, bias=False, dilation=dilation)
10 |
11 |
12 | def conv1x1(in_planes, out_planes, stride=1):
13 | """1x1 convolution"""
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
21 | base_width=64, dilation=1, norm_layer=None):
22 | super(BasicBlock, self).__init__()
23 | if norm_layer is None:
24 | norm_layer = nn.BatchNorm2d
25 | if groups != 1 or base_width != 64:
26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
27 | if dilation > 1:
28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
29 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = norm_layer(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = norm_layer(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | identity = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
59 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
60 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
61 | # This variant is also known as ResNet V1.5 and improves accuracy according to
62 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
63 |
64 | expansion = 4
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
67 | base_width=64, dilation=1, norm_layer=None):
68 | super(Bottleneck, self).__init__()
69 | if norm_layer is None:
70 | norm_layer = nn.BatchNorm2d
71 | width = int(planes * (base_width / 64.)) * groups
72 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
73 | self.conv1 = conv1x1(inplanes, width)
74 | self.bn1 = norm_layer(width)
75 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
76 | self.bn2 = norm_layer(width)
77 | self.conv3 = conv1x1(width, planes * self.expansion)
78 | self.bn3 = norm_layer(planes * self.expansion)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | identity = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | identity = self.downsample(x)
99 |
100 | out += identity
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 | def __init__(self, block, layers, num_classes=10,
108 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None,
109 | norm_layer=None):
110 | super(ResNet, self).__init__()
111 | last_dim = 512 * block.expansion
112 |
113 | if norm_layer is None:
114 | norm_layer = nn.BatchNorm2d
115 | self._norm_layer = norm_layer
116 |
117 | self.inplanes = 64
118 | self.dilation = 1
119 | if replace_stride_with_dilation is None:
120 | # each element in the tuple indicates if we should replace
121 | # the 2x2 stride with a dilated convolution instead
122 | replace_stride_with_dilation = [False, False, False]
123 | if len(replace_stride_with_dilation) != 3:
124 | raise ValueError("replace_stride_with_dilation should be None "
125 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
126 | self.groups = groups
127 | self.base_width = width_per_group
128 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
129 | bias=False)
130 | self.bn1 = norm_layer(self.inplanes)
131 | self.relu = nn.ReLU(inplace=True)
132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
133 | self.layer1 = self._make_layer(block, 64, layers[0])
134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
135 | dilate=replace_stride_with_dilation[0])
136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
137 | dilate=replace_stride_with_dilation[1])
138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
139 | dilate=replace_stride_with_dilation[2])
140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
141 | #self.normalize = NormalizeLayer()
142 | self.last_dim = 512 * block.expansion
143 | self.fc = nn.Linear(last_dim, num_classes)
144 | self.fc_open = nn.Linear(last_dim, num_classes * 2, bias=False)
145 | self.simclr_layer = nn.Sequential(
146 | nn.Linear(last_dim, 128),
147 | nn.ReLU(),
148 | nn.Linear(128, 128),
149 | )
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
154 | nn.init.constant_(m.weight, 1)
155 | nn.init.constant_(m.bias, 0)
156 |
157 | # Zero-initialize the last BN in each residual branch,
158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
160 | if zero_init_residual:
161 | for m in self.modules():
162 | if isinstance(m, Bottleneck):
163 | nn.init.constant_(m.bn3.weight, 0)
164 | elif isinstance(m, BasicBlock):
165 | nn.init.constant_(m.bn2.weight, 0)
166 |
167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
168 | norm_layer = self._norm_layer
169 | downsample = None
170 | previous_dilation = self.dilation
171 | if dilate:
172 | self.dilation *= stride
173 | stride = 1
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | downsample = nn.Sequential(
176 | conv1x1(self.inplanes, planes * block.expansion, stride),
177 | norm_layer(planes * block.expansion),
178 | )
179 |
180 | layers = []
181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
182 | self.base_width, previous_dilation, norm_layer))
183 | self.inplanes = planes * block.expansion
184 | for _ in range(1, blocks):
185 | layers.append(block(self.inplanes, planes, groups=self.groups,
186 | base_width=self.base_width, dilation=self.dilation,
187 | norm_layer=norm_layer))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def forward(self, x, feature=False, feat_only=False):
192 | # See note [TorchScript super()]
193 |
194 | #x = self.normalize(x)
195 | x = self.conv1(x)
196 | x = self.bn1(x)
197 | x = self.relu(x)
198 | x = self.maxpool(x)
199 | x = self.layer1(x)
200 | x = self.layer2(x)
201 | x = self.layer3(x)
202 | x = self.layer4(x)
203 | x = self.avgpool(x)
204 | x = torch.flatten(x, 1)
205 | if feat_only:
206 | return self.simclr_layer(x)
207 | if feature:
208 | return self.fc(x), self.fc_open(x), self.simclr_layer(x)
209 | else:
210 | return self.fc(x), self.fc_open(x)
211 |
212 |
213 | def _resnet(arch, block, layers, **kwargs):
214 | model = ResNet(block, layers, **kwargs)
215 | return model
216 |
217 |
218 | def resnet18(**kwargs):
219 | r"""ResNet-18 model from
220 | `"Deep Residual Learning for Image Recognition" `_
221 | """
222 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
223 |
224 |
225 | def resnet50(**kwargs):
226 | r"""ResNet-50 model from
227 | `"Deep Residual Learning for Image Recognition" `_
228 | """
229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)
230 |
231 |
232 |
233 | def _tresnet(arch, block, layers, **kwargs):
234 | model = TResNet(block, layers, **kwargs)
235 | return model
236 |
237 |
238 | def tresnet18(**kwargs):
239 | r"""ResNet-18 model from
240 | `"Deep Residual Learning for Image Recognition" `_
241 | """
242 | return _tresnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
243 |
244 |
--------------------------------------------------------------------------------
/models/resnext.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def mish(x):
11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
12 | return x * torch.tanh(F.softplus(x))
13 |
14 |
15 | class nn.BatchNorm2d(nn.BatchNorm2d):
16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
17 |
18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True):
19 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
20 | self.alpha = alpha
21 |
22 | def forward(self, x):
23 | return super().forward(x) + self.alpha
24 |
25 |
26 | class ResNeXtBottleneck(nn.Module):
27 | """
28 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
29 | """
30 |
31 | def __init__(self, in_channels, out_channels, stride,
32 | cardinality, base_width, widen_factor):
33 | """ Constructor
34 | Args:
35 | in_channels: input channel dimensionality
36 | out_channels: output channel dimensionality
37 | stride: conv stride. Replaces pooling layer.
38 | cardinality: num of convolution groups.
39 | base_width: base number of channels in each group.
40 | widen_factor: factor to reduce the input dimensionality before convolution.
41 | """
42 | super().__init__()
43 | width_ratio = out_channels / (widen_factor * 64.)
44 | D = cardinality * int(base_width * width_ratio)
45 | self.conv_reduce = nn.Conv2d(
46 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
47 | self.bn_reduce = nn.BatchNorm2d(D, momentum=0.001)
48 | self.conv_conv = nn.Conv2d(D, D,
49 | kernel_size=3, stride=stride, padding=1,
50 | groups=cardinality, bias=False)
51 | self.bn = nn.BatchNorm2d(D, momentum=0.001)
52 | self.act = mish
53 | self.conv_expand = nn.Conv2d(
54 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
55 | self.bn_expand = nn.BatchNorm2d(out_channels, momentum=0.001)
56 |
57 | self.shortcut = nn.Sequential()
58 | if in_channels != out_channels:
59 | self.shortcut.add_module('shortcut_conv',
60 | nn.Conv2d(in_channels, out_channels,
61 | kernel_size=1,
62 | stride=stride,
63 | padding=0,
64 | bias=False))
65 | self.shortcut.add_module(
66 | 'shortcut_bn', nn.BatchNorm2d(out_channels, momentum=0.001))
67 |
68 | def forward(self, x):
69 | bottleneck = self.conv_reduce.forward(x)
70 | bottleneck = self.act(self.bn_reduce.forward(bottleneck))
71 | bottleneck = self.conv_conv.forward(bottleneck)
72 | bottleneck = self.act(self.bn.forward(bottleneck))
73 | bottleneck = self.conv_expand.forward(bottleneck)
74 | bottleneck = self.bn_expand.forward(bottleneck)
75 | residual = self.shortcut.forward(x)
76 | return self.act(residual + bottleneck)
77 |
78 |
79 | class CifarResNeXt(nn.Module):
80 | """
81 | ResNext optimized for the Cifar dataset, as specified in
82 | https://arxiv.org/pdf/1611.05431.pdf
83 | """
84 |
85 | def __init__(self, cardinality, depth, num_classes,
86 | base_width, widen_factor=4):
87 | """ Constructor
88 | Args:
89 | cardinality: number of convolution groups.
90 | depth: number of layers.
91 | nlabels: number of classes
92 | base_width: base number of channels in each group.
93 | widen_factor: factor to adjust the channel dimensionality
94 | """
95 | super().__init__()
96 | self.cardinality = cardinality
97 | self.depth = depth
98 | self.block_depth = (self.depth - 2) // 9
99 | self.base_width = base_width
100 | self.widen_factor = widen_factor
101 | self.nlabels = num_classes
102 | self.output_size = 64
103 | self.stages = [64, 64 * self.widen_factor, 128 *
104 | self.widen_factor, 256 * self.widen_factor]
105 |
106 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
107 | self.bn_1 = nn.BatchNorm2d(64, momentum=0.001)
108 | self.act = mish
109 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
110 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
111 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
112 | self.classifier = nn.Linear(self.stages[3], num_classes)
113 |
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | nn.init.kaiming_normal_(m.weight,
117 | mode='fan_out',
118 | nonlinearity='leaky_relu')
119 | elif isinstance(m, nn.BatchNorm2d):
120 | nn.init.constant_(m.weight, 1.0)
121 | nn.init.constant_(m.bias, 0.0)
122 | elif isinstance(m, nn.Linear):
123 | nn.init.xavier_normal_(m.weight)
124 | nn.init.constant_(m.bias, 0.0)
125 |
126 | def block(self, name, in_channels, out_channels, pool_stride=2):
127 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
128 | Args:
129 | name: string name of the current block.
130 | in_channels: number of input channels
131 | out_channels: number of output channels
132 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
133 | Returns: a Module consisting of n sequential bottlenecks.
134 | """
135 | block = nn.Sequential()
136 | for bottleneck in range(self.block_depth):
137 | name_ = '%s_bottleneck_%d' % (name, bottleneck)
138 | if bottleneck == 0:
139 | block.add_module(name_, ResNeXtBottleneck(in_channels,
140 | out_channels,
141 | pool_stride,
142 | self.cardinality,
143 | self.base_width,
144 | self.widen_factor))
145 | else:
146 | block.add_module(name_,
147 | ResNeXtBottleneck(out_channels,
148 | out_channels,
149 | 1,
150 | self.cardinality,
151 | self.base_width,
152 | self.widen_factor))
153 | return block
154 |
155 | def forward(self, x):
156 | x = self.conv_1_3x3.forward(x)
157 | x = self.act(self.bn_1.forward(x))
158 | x = self.stage_1.forward(x)
159 | x = self.stage_2.forward(x)
160 | x = self.stage_3.forward(x)
161 | x = F.adaptive_avg_pool2d(x, 1)
162 | x = x.view(-1, self.stages[3])
163 | return self.classifier(x)
164 |
165 |
166 | def build_resnext(cardinality, depth, width, num_classes):
167 | logger.info(f"Model: ResNeXt {depth+1}x{width}")
168 | return CifarResNeXt(cardinality=cardinality,
169 | depth=depth,
170 | base_width=width,
171 | num_classes=num_classes)
172 |
--------------------------------------------------------------------------------
/models/wideresnet.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | def mish(x):
11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)"""
12 | return x * torch.tanh(F.softplus(x))
13 |
14 |
15 | class PSBatchNorm2d(nn.BatchNorm2d):
16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)"""
17 |
18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True):
19 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
20 | self.alpha = alpha
21 |
22 | def forward(self, x):
23 | return super().forward(x) + self.alpha
24 |
25 |
26 | class BasicBlock(nn.Module):
27 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False):
28 | super(BasicBlock, self).__init__()
29 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001)
30 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
31 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
32 | padding=1, bias=False)
33 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001)
34 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
35 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
36 | padding=1, bias=False)
37 | self.drop_rate = drop_rate
38 | self.equalInOut = (in_planes == out_planes)
39 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
40 | padding=0, bias=False) or None
41 | self.activate_before_residual = activate_before_residual
42 |
43 | def forward(self, x):
44 | if not self.equalInOut and self.activate_before_residual == True:
45 | x = self.relu1(self.bn1(x))
46 | else:
47 | out = self.relu1(self.bn1(x))
48 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
49 | if self.drop_rate > 0:
50 | out = F.dropout(out, p=self.drop_rate, training=self.training)
51 | out = self.conv2(out)
52 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
53 |
54 |
55 | class NetworkBlock(nn.Module):
56 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False):
57 | super(NetworkBlock, self).__init__()
58 | self.layer = self._make_layer(
59 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual)
60 |
61 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual):
62 | layers = []
63 | for i in range(int(nb_layers)):
64 | layers.append(block(i == 0 and in_planes or out_planes, out_planes,
65 | i == 0 and stride or 1, drop_rate, activate_before_residual))
66 | return nn.Sequential(*layers)
67 |
68 | def forward(self, x):
69 | return self.layer(x)
70 |
71 |
72 | class WideResNet(nn.Module):
73 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0):
74 | super(WideResNet, self).__init__()
75 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
76 | assert((depth - 4) % 6 == 0)
77 | n = (depth - 4) / 6
78 | block = BasicBlock
79 | # 1st conv before any network block
80 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
81 | padding=1, bias=False)
82 | # 1st block
83 | self.block1 = NetworkBlock(
84 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True)
85 | # 2nd block
86 | self.block2 = NetworkBlock(
87 | n, channels[1], channels[2], block, 2, drop_rate)
88 | # 3rd block
89 | self.block3 = NetworkBlock(
90 | n, channels[2], channels[3], block, 2, drop_rate)
91 | # global average pooling and classifier
92 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001)
93 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
94 | self.fc = nn.Linear(channels[3], num_classes)
95 | self.channels = channels[3]
96 |
97 | for m in self.modules():
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.kaiming_normal_(m.weight,
100 | mode='fan_out',
101 | nonlinearity='leaky_relu')
102 | elif isinstance(m, nn.BatchNorm2d):
103 | nn.init.constant_(m.weight, 1.0)
104 | nn.init.constant_(m.bias, 0.0)
105 | elif isinstance(m, nn.Linear):
106 | nn.init.xavier_normal_(m.weight)
107 | nn.init.constant_(m.bias, 0.0)
108 |
109 | def forward(self, x):
110 | out = self.conv1(x)
111 | out = self.block1(out)
112 | out = self.block2(out)
113 | out = self.block3(out)
114 | out = self.relu(self.bn1(out))
115 | out = F.adaptive_avg_pool2d(out, 1)
116 | out = out.view(-1, self.channels)
117 | return self.fc(out)
118 |
119 |
120 | class WideResNet_Open(nn.Module):
121 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0):
122 | super(WideResNet_Open, self).__init__()
123 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
124 | assert((depth - 4) % 6 == 0)
125 | n = (depth - 4) / 6
126 | block = BasicBlock
127 | # 1st conv before any network block
128 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1,
129 | padding=1, bias=False)
130 | # 1st block
131 | self.block1 = NetworkBlock(
132 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True)
133 | # 2nd block
134 | self.block2 = NetworkBlock(
135 | n, channels[1], channels[2], block, 2, drop_rate)
136 | # 3rd block
137 | self.block3 = NetworkBlock(
138 | n, channels[2], channels[3], block, 2, drop_rate)
139 | # global average pooling and classifier
140 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001)
141 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
142 | self.simclr_layer = nn.Sequential(
143 | nn.Linear(channels[3], 128),
144 | nn.ReLU(),
145 | nn.Linear(128, 128),
146 | )
147 | self.fc = nn.Linear(channels[3], num_classes)
148 | out_open = 2 * num_classes
149 | self.fc_open = nn.Linear(channels[3], out_open, bias=False)
150 | self.channels = channels[3]
151 |
152 | for m in self.modules():
153 | if isinstance(m, nn.Conv2d):
154 | nn.init.kaiming_normal_(m.weight,
155 | mode='fan_out',
156 | nonlinearity='leaky_relu')
157 | elif isinstance(m, nn.BatchNorm2d):
158 | nn.init.constant_(m.weight, 1.0)
159 | nn.init.constant_(m.bias, 0.0)
160 | elif isinstance(m, nn.Linear):
161 | nn.init.xavier_normal_(m.weight)
162 | if m.bias is not None:
163 | nn.init.constant_(m.bias, 0.0)
164 |
165 | def forward(self, x, feature=False, feat_only=False):
166 | #self.weight_norm()
167 | out = self.conv1(x)
168 | out = self.block1(out)
169 | out = self.block2(out)
170 | out = self.block3(out)
171 | out = self.relu(self.bn1(out))
172 | out = F.adaptive_avg_pool2d(out, 1)
173 | out = out.view(-1, self.channels)
174 |
175 |
176 | if feat_only:
177 | return self.simclr_layer(out)
178 | out_open = self.fc_open(out)
179 | if feature:
180 | return self.fc(out), out_open, out
181 | else:
182 | return self.fc(out), out_open
183 |
184 | def weight_norm(self):
185 | w = self.fc_open.weight.data
186 | norm = w.norm(p=2, dim=1, keepdim=True)
187 | self.fc_open.weight.data = w.div(norm.expand_as(w))
188 |
189 |
190 | #
191 | #
192 | class ResBasicBlock(nn.Module):
193 | expansion = 1
194 |
195 | def __init__(self, in_planes, planes, stride=1):
196 | super(ResBasicBlock, self).__init__()
197 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
198 | self.bn1 = nn.BatchNorm2d(planes)
199 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
200 | self.bn2 = nn.BatchNorm2d(planes)
201 |
202 | self.shortcut = nn.Sequential()
203 | if stride != 1 or in_planes != self.expansion*planes:
204 | self.shortcut = nn.Sequential(
205 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
206 | nn.BatchNorm2d(self.expansion*planes)
207 | )
208 |
209 | def forward(self, x):
210 | out = F.relu(self.bn1(self.conv1(x)))
211 | out = self.bn2(self.conv2(out))
212 | out += self.shortcut(x)
213 | out = F.relu(out)
214 | return out
215 |
216 | #
217 | class ResNet_Open(nn.Module):
218 | def __init__(self, block, num_blocks, low_dim=128, num_classes=10):
219 | super(ResNet_Open, self).__init__()
220 | self.in_planes = 64
221 |
222 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
223 | self.bn1 = nn.BatchNorm2d(64)
224 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
225 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
226 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
227 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
228 | self.linear = nn.Linear(512*block.expansion, low_dim)
229 | self.simclr_layer = nn.Sequential(
230 | nn.Linear(512*block.expansion, 128),
231 | nn.ReLU(),
232 | nn.Linear(128, 128),
233 | )
234 | self.fc1 = nn.Linear(512*block.expansion, num_classes)
235 | self.fc_open = nn.Linear(512*block.expansion, num_classes*2, bias=False)
236 |
237 |
238 | #self.l2norm = Normalize(2)
239 |
240 | def _make_layer(self, block, planes, num_blocks, stride):
241 | strides = [stride] + [1]*(num_blocks-1)
242 | layers = []
243 | for stride in strides:
244 | layers.append(block(self.in_planes, planes, stride))
245 | self.in_planes = planes * block.expansion
246 | return nn.Sequential(*layers)
247 |
248 | def forward(self, x, feature=False):
249 | out = F.relu(self.bn1(self.conv1(x)))
250 | out = self.layer1(out)
251 | out = self.layer2(out)
252 | out = self.layer3(out)
253 | out = self.layer4(out)
254 | out = F.avg_pool2d(out, 4)
255 | out = out.view(out.size(0), -1)
256 | out_open = self.fc_open(out)
257 | if feature:
258 | return self.fc1(out), out_open, self.simclr_layer(out)
259 | else:
260 | return self.fc1(out), out_open
261 |
262 |
263 |
264 | def ResNet18(low_dim=128, num_classes=10):
265 | return ResNet_Open(ResBasicBlock, [2,2,2,2], low_dim, num_classes)
266 |
267 |
268 | def build_wideresnet(depth, widen_factor, dropout, num_classes, open=False):
269 | logger.info(f"Model: WideResNet {depth}x{widen_factor}")
270 | build_func = WideResNet_Open if open else WideResNet
271 | return build_func(depth=depth,
272 | widen_factor=widen_factor,
273 | drop_rate=dropout,
274 | num_classes=num_classes)
275 |
--------------------------------------------------------------------------------
/run_cifar10.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=$3 python main.py --dataset cifar10 --num-labeled $1 --out $2 --arch wideresnet --lambda_oem 0.1 --lambda_socr 0.5 \
2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/run_cifar100.sh:
--------------------------------------------------------------------------------
1 | python main.py --dataset cifar100 --num-labeled $1 --out $2 --num-super $3 --arch wideresnet --lambda_oem 0.1 --lambda_socr 1.0 \
2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/run_eval_cifar10.sh:
--------------------------------------------------------------------------------
1 | python main.py --dataset cifar10 --resume $1 --arch wideresnet --eval_only 1
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/run_imagenet.sh:
--------------------------------------------------------------------------------
1 | python main.py --dataset imagenet --out $1 --arch resnet_imagenet --lambda_oem 0.1 --lambda_socr 0.5 \
2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2 --epochs 100
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | import copy
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8 | from torch.utils.data.distributed import DistributedSampler
9 | from dataset import TransformOpenMatch, cifar10_mean, cifar10_std, \
10 | cifar100_std, cifar100_mean, normal_mean, \
11 | normal_std, TransformFixMatch_Imagenet_Weak
12 | from tqdm import tqdm
13 | from utils import AverageMeter, ova_loss,\
14 | save_checkpoint, ova_ent, \
15 | test, test_ood, exclude_dataset
16 |
17 | logger = logging.getLogger(__name__)
18 | best_acc = 0
19 | best_acc_val = 0
20 |
21 | def train(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader,
22 | ood_loaders, model, optimizer, ema_model, scheduler):
23 | if args.amp:
24 | from apex import amp
25 |
26 | global best_acc
27 | global best_acc_val
28 |
29 | test_accs = []
30 | batch_time = AverageMeter()
31 | data_time = AverageMeter()
32 | losses = AverageMeter()
33 | losses_x = AverageMeter()
34 | losses_o = AverageMeter()
35 | losses_oem = AverageMeter()
36 | losses_socr = AverageMeter()
37 | losses_fix = AverageMeter()
38 | mask_probs = AverageMeter()
39 | end = time.time()
40 |
41 |
42 | if args.world_size > 1:
43 | labeled_epoch = 0
44 | unlabeled_epoch = 0
45 | labeled_iter = iter(labeled_trainloader)
46 | default_out = "Epoch: {epoch}/{epochs:4}. " \
47 | "LR: {lr:.6f}. " \
48 | "Lab: {loss_x:.4f}. " \
49 | "Open: {loss_o:.4f}"
50 | output_args = vars(args)
51 | default_out += " OEM {loss_oem:.4f}"
52 | default_out += " SOCR {loss_socr:.4f}"
53 | default_out += " Fix {loss_fix:.4f}"
54 |
55 | model.train()
56 | unlabeled_dataset_all = copy.deepcopy(unlabeled_dataset)
57 | if args.dataset == 'cifar10':
58 | mean = cifar10_mean
59 | std = cifar10_std
60 | func_trans = TransformOpenMatch
61 | elif args.dataset == 'cifar100':
62 | mean = cifar100_mean
63 | std = cifar100_std
64 | func_trans = TransformOpenMatch
65 | elif 'imagenet' in args.dataset:
66 | mean = normal_mean
67 | std = normal_std
68 | func_trans = TransformFixMatch_Imagenet_Weak
69 |
70 |
71 | unlabeled_dataset_all.transform = func_trans(mean=mean, std=std)
72 | labeled_dataset = copy.deepcopy(labeled_trainloader.dataset)
73 | labeled_dataset.transform = func_trans(mean=mean, std=std)
74 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
75 | labeled_trainloader = DataLoader(
76 | labeled_dataset,
77 | sampler=train_sampler(labeled_dataset),
78 | batch_size=args.batch_size,
79 | num_workers=args.num_workers,
80 | drop_last=True)
81 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
82 |
83 |
84 | for epoch in range(args.start_epoch, args.epochs):
85 | output_args["epoch"] = epoch
86 | if not args.no_progress:
87 | p_bar = tqdm(range(args.eval_step),
88 | disable=args.local_rank not in [-1, 0])
89 |
90 | if epoch >= args.start_fix:
91 | ## pick pseudo-inliers
92 | exclude_dataset(args, unlabeled_dataset, ema_model.ema)
93 |
94 |
95 | unlabeled_trainloader = DataLoader(unlabeled_dataset,
96 | sampler = train_sampler(unlabeled_dataset),
97 | batch_size = args.batch_size * args.mu,
98 | num_workers = args.num_workers,
99 | drop_last = True)
100 | unlabeled_trainloader_all = DataLoader(unlabeled_dataset_all,
101 | sampler=train_sampler(unlabeled_dataset_all),
102 | batch_size=args.batch_size * args.mu,
103 | num_workers=args.num_workers,
104 | drop_last=True)
105 |
106 | unlabeled_iter = iter(unlabeled_trainloader)
107 | unlabeled_all_iter = iter(unlabeled_trainloader_all)
108 |
109 | for batch_idx in range(args.eval_step):
110 | ## Data loading
111 |
112 | try:
113 | (_, inputs_x_s, inputs_x), targets_x = labeled_iter.next()
114 | except:
115 | if args.world_size > 1:
116 | labeled_epoch += 1
117 | labeled_trainloader.sampler.set_epoch(labeled_epoch)
118 | labeled_iter = iter(labeled_trainloader)
119 | (_, inputs_x_s, inputs_x), targets_x = labeled_iter.next()
120 | try:
121 | (inputs_u_w, inputs_u_s, _), _ = unlabeled_iter.next()
122 | except:
123 | if args.world_size > 1:
124 | unlabeled_epoch += 1
125 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
126 | unlabeled_iter = iter(unlabeled_trainloader)
127 | (inputs_u_w, inputs_u_s, _), _ = unlabeled_iter.next()
128 | try:
129 | (inputs_all_w, inputs_all_s, _), _ = unlabeled_all_iter.next()
130 | except:
131 | unlabeled_all_iter = iter(unlabeled_trainloader_all)
132 | (inputs_all_w, inputs_all_s, _), _ = unlabeled_all_iter.next()
133 | data_time.update(time.time() - end)
134 |
135 | b_size = inputs_x.shape[0]
136 |
137 | inputs_all = torch.cat([inputs_all_w, inputs_all_s], 0)
138 | inputs = torch.cat([inputs_x, inputs_x_s,
139 | inputs_all], 0).to(args.device)
140 | targets_x = targets_x.to(args.device)
141 | ## Feed data
142 | logits, logits_open = model(inputs)
143 | logits_open_u1, logits_open_u2 = logits_open[2*b_size:].chunk(2)
144 |
145 | ## Loss for labeled samples
146 | Lx = F.cross_entropy(logits[:2*b_size],
147 | targets_x.repeat(2), reduction='mean')
148 | Lo = ova_loss(logits_open[:2*b_size], targets_x.repeat(2))
149 |
150 | ## Open-set entropy minimization
151 | L_oem = ova_ent(logits_open_u1) / 2.
152 | L_oem += ova_ent(logits_open_u2) / 2.
153 |
154 | ## Soft consistenty regularization
155 | logits_open_u1 = logits_open_u1.view(logits_open_u1.size(0), 2, -1)
156 | logits_open_u2 = logits_open_u2.view(logits_open_u2.size(0), 2, -1)
157 | logits_open_u1 = F.softmax(logits_open_u1, 1)
158 | logits_open_u2 = F.softmax(logits_open_u2, 1)
159 | L_socr = torch.mean(torch.sum(torch.sum(torch.abs(
160 | logits_open_u1 - logits_open_u2)**2, 1), 1))
161 |
162 | if epoch >= args.start_fix:
163 | inputs_ws = torch.cat([inputs_u_w, inputs_u_s], 0).to(args.device)
164 | logits, logits_open_fix = model(inputs_ws)
165 | logits_u_w, logits_u_s = logits.chunk(2)
166 | pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1)
167 | max_probs, targets_u = torch.max(pseudo_label, dim=-1)
168 | mask = max_probs.ge(args.threshold).float()
169 | L_fix = (F.cross_entropy(logits_u_s,
170 | targets_u,
171 | reduction='none') * mask).mean()
172 | mask_probs.update(mask.mean().item())
173 |
174 | else:
175 | L_fix = torch.zeros(1).to(args.device).mean()
176 | loss = Lx + Lo + args.lambda_oem * L_oem \
177 | + args.lambda_socr * L_socr + L_fix
178 | if args.amp:
179 | with amp.scale_loss(loss, optimizer) as scaled_loss:
180 | scaled_loss.backward()
181 | else:
182 | loss.backward()
183 |
184 | losses.update(loss.item())
185 | losses_x.update(Lx.item())
186 | losses_o.update(Lo.item())
187 | losses_oem.update(L_oem.item())
188 | losses_socr.update(L_socr.item())
189 | losses_fix.update(L_fix.item())
190 |
191 | output_args["batch"] = batch_idx
192 | output_args["loss_x"] = losses_x.avg
193 | output_args["loss_o"] = losses_o.avg
194 | output_args["loss_oem"] = losses_oem.avg
195 | output_args["loss_socr"] = losses_socr.avg
196 | output_args["loss_fix"] = losses_fix.avg
197 | output_args["lr"] = [group["lr"] for group in optimizer.param_groups][0]
198 |
199 |
200 | optimizer.step()
201 | if args.opt != 'adam':
202 | scheduler.step()
203 | if args.use_ema:
204 | ema_model.update(model)
205 | model.zero_grad()
206 | batch_time.update(time.time() - end)
207 | end = time.time()
208 |
209 | if not args.no_progress:
210 | p_bar.set_description(default_out.format(**output_args))
211 | p_bar.update()
212 |
213 | if not args.no_progress:
214 | p_bar.close()
215 |
216 | if args.use_ema:
217 | test_model = ema_model.ema
218 | else:
219 | test_model = model
220 |
221 | if args.local_rank in [-1, 0]:
222 |
223 | val_acc = test(args, val_loader, test_model, epoch, val=True)
224 | test_loss, test_acc_close, test_overall, \
225 | test_unk, test_roc, test_roc_softm, test_id \
226 | = test(args, test_loader, test_model, epoch)
227 |
228 | for ood in ood_loaders.keys():
229 | roc_ood = test_ood(args, test_id, ood_loaders[ood], test_model)
230 | logger.info("ROC vs {ood}: {roc}".format(ood=ood, roc=roc_ood))
231 |
232 | args.writer.add_scalar('train/1.train_loss', losses.avg, epoch)
233 | args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch)
234 | args.writer.add_scalar('train/3.train_loss_o', losses_o.avg, epoch)
235 | args.writer.add_scalar('train/4.train_loss_oem', losses_oem.avg, epoch)
236 | args.writer.add_scalar('train/5.train_loss_socr', losses_socr.avg, epoch)
237 | args.writer.add_scalar('train/5.train_loss_fix', losses_fix.avg, epoch)
238 | args.writer.add_scalar('train/6.mask', mask_probs.avg, epoch)
239 | args.writer.add_scalar('test/1.test_acc', test_acc_close, epoch)
240 | args.writer.add_scalar('test/2.test_loss', test_loss, epoch)
241 |
242 | is_best = val_acc > best_acc_val
243 | best_acc_val = max(val_acc, best_acc_val)
244 | if is_best:
245 | overall_valid = test_overall
246 | close_valid = test_acc_close
247 | unk_valid = test_unk
248 | roc_valid = test_roc
249 | roc_softm_valid = test_roc_softm
250 | model_to_save = model.module if hasattr(model, "module") else model
251 | if args.use_ema:
252 | ema_to_save = ema_model.ema.module if hasattr(
253 | ema_model.ema, "module") else ema_model.ema
254 |
255 | save_checkpoint({
256 | 'epoch': epoch + 1,
257 | 'state_dict': model_to_save.state_dict(),
258 | 'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None,
259 | 'acc close': test_acc_close,
260 | 'acc overall': test_overall,
261 | 'unk': test_unk,
262 | 'best_acc': best_acc,
263 | 'optimizer': optimizer.state_dict(),
264 | 'scheduler': scheduler.state_dict(),
265 | }, is_best, args.out)
266 | test_accs.append(test_acc_close)
267 | logger.info('Best val closed acc: {:.3f}'.format(best_acc_val))
268 | logger.info('Valid closed acc: {:.3f}'.format(close_valid))
269 | logger.info('Valid overall acc: {:.3f}'.format(overall_valid))
270 | logger.info('Valid unk acc: {:.3f}'.format(unk_valid))
271 | logger.info('Valid roc: {:.3f}'.format(roc_valid))
272 | logger.info('Valid roc soft: {:.3f}'.format(roc_softm_valid))
273 | logger.info('Mean top-1 acc: {:.3f}\n'.format(
274 | np.mean(test_accs[-20:])))
275 | if args.local_rank in [-1, 0]:
276 | args.writer.close()
277 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 | from .default import *
3 | from .parser import *
--------------------------------------------------------------------------------
/utils/default.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | import math
5 | import random
6 | import shutil
7 | import numpy as np
8 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
9 | from torch.utils.data.distributed import DistributedSampler
10 | import torch.optim as optim
11 | from torch.optim.lr_scheduler import LambdaLR
12 |
13 | from dataset.cifar import DATASET_GETTERS, get_ood
14 |
15 | __all__ = ['create_model', 'set_model_config',
16 | 'set_dataset', 'set_models',
17 | 'save_checkpoint', 'set_seed']
18 |
19 |
20 | def create_model(args):
21 | if 'wideresnet' in args.arch:
22 | import models.wideresnet as models
23 | model = models.build_wideresnet(depth=args.model_depth,
24 | widen_factor=args.model_width,
25 | dropout=0,
26 | num_classes=args.num_classes,
27 | open=True)
28 | elif args.arch == 'resnext':
29 | import models.resnext as models
30 | model = models.build_resnext(cardinality=args.model_cardinality,
31 | depth=args.model_depth,
32 | width=args.model_width,
33 | num_classes=args.num_classes)
34 | elif args.arch == 'resnet_imagenet':
35 | import models.resnet_imagenet as models
36 | model = models.resnet18(num_classes=args.num_classes)
37 |
38 | return model
39 |
40 |
41 |
42 | def set_model_config(args):
43 |
44 | if args.dataset == 'cifar10':
45 | if args.arch == 'wideresnet':
46 | args.model_depth = 28
47 | args.model_width = 2
48 | elif args.arch == 'resnext':
49 | args.model_cardinality = 4
50 | args.model_depth = 28
51 | args.model_width = 4
52 |
53 | elif args.dataset == 'cifar100':
54 | args.num_classes = 55
55 | if args.arch == 'wideresnet':
56 | args.model_depth = 28
57 | args.model_width = 2
58 | elif args.arch == 'wideresnet_10':
59 | args.model_depth = 28
60 | args.model_width = 8
61 | elif args.arch == 'resnext':
62 | args.model_cardinality = 8
63 | args.model_depth = 29
64 | args.model_width = 64
65 |
66 | elif args.dataset == "imagenet":
67 | args.num_classes = 20
68 |
69 | args.image_size = (32, 32, 3)
70 | if args.dataset == 'cifar10':
71 | args.ood_data = ["svhn", 'cifar100', 'lsun', 'imagenet']
72 |
73 | elif args.dataset == 'cifar100':
74 | args.ood_data = ['cifar10', "svhn", 'lsun', 'imagenet']
75 |
76 | elif 'imagenet' in args.dataset:
77 | args.ood_data = ['lsun', 'dtd', 'cub', 'flowers102',
78 | 'caltech_256', 'stanford_dogs']
79 | args.image_size = (224, 224, 3)
80 |
81 | def set_dataset(args):
82 | labeled_dataset, unlabeled_dataset, test_dataset, val_dataset = \
83 | DATASET_GETTERS[args.dataset](args)
84 |
85 | ood_loaders = {}
86 | for ood in args.ood_data:
87 | ood_dataset = get_ood(ood, args.dataset, image_size=args.image_size)
88 | ood_loaders[ood] = DataLoader(ood_dataset,
89 | batch_size=args.batch_size,
90 | num_workers=args.num_workers)
91 |
92 | if args.local_rank == 0:
93 | torch.distributed.barrier()
94 |
95 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
96 |
97 | labeled_trainloader = DataLoader(
98 | labeled_dataset,
99 | sampler=train_sampler(labeled_dataset),
100 | batch_size=args.batch_size,
101 | num_workers=args.num_workers,
102 | drop_last=True)
103 |
104 | test_loader = DataLoader(
105 | test_dataset,
106 | sampler=SequentialSampler(test_dataset),
107 | batch_size=args.batch_size,
108 | num_workers=args.num_workers)
109 | val_loader = DataLoader(
110 | val_dataset,
111 | sampler=SequentialSampler(val_dataset),
112 | batch_size=args.batch_size,
113 | num_workers=args.num_workers)
114 | if args.local_rank not in [-1, 0]:
115 | torch.distributed.barrier()
116 |
117 | return labeled_trainloader, unlabeled_dataset, \
118 | test_loader, val_loader, ood_loaders
119 |
120 |
121 | def get_cosine_schedule_with_warmup(optimizer,
122 | num_warmup_steps,
123 | num_training_steps,
124 | num_cycles=7./16.,
125 | last_epoch=-1):
126 | def _lr_lambda(current_step):
127 | if current_step < num_warmup_steps:
128 | return float(current_step) / float(max(1, num_warmup_steps))
129 | no_progress = float(current_step - num_warmup_steps) / \
130 | float(max(1, num_training_steps - num_warmup_steps))
131 | return max(0., math.cos(math.pi * num_cycles * no_progress))
132 |
133 | return LambdaLR(optimizer, _lr_lambda, last_epoch)
134 |
135 |
136 | def set_models(args):
137 | model = create_model(args)
138 | if args.local_rank == 0:
139 | torch.distributed.barrier()
140 | model.to(args.device)
141 |
142 | no_decay = ['bias', 'bn']
143 | grouped_parameters = [
144 | {'params': [p for n, p in model.named_parameters() if not any(
145 | nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
146 | {'params': [p for n, p in model.named_parameters() if any(
147 | nd in n for nd in no_decay)], 'weight_decay': 0.0}
148 | ]
149 | if args.opt == 'sgd':
150 | optimizer = optim.SGD(grouped_parameters, lr=args.lr,
151 | momentum=0.9, nesterov=args.nesterov)
152 | elif args.opt == 'adam':
153 | optimizer = optim.Adam(grouped_parameters, lr=2e-3)
154 |
155 | # args.epochs = math.ceil(args.total_steps / args.eval_step)
156 | scheduler = get_cosine_schedule_with_warmup(
157 | optimizer, args.warmup, args.total_steps)
158 |
159 | return model, optimizer, scheduler
160 |
161 |
162 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'):
163 | filepath = os.path.join(checkpoint, filename)
164 | torch.save(state, filepath)
165 | if is_best:
166 | shutil.copyfile(filepath, os.path.join(checkpoint,
167 | 'model_best.pth.tar'))
168 |
169 |
170 | def set_seed(args):
171 | random.seed(args.seed)
172 | np.random.seed(args.seed)
173 | torch.manual_seed(args.seed)
174 | if args.n_gpu > 0:
175 | torch.cuda.manual_seed_all(args.seed)
176 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | '''
4 | import logging
5 | import time
6 | from tqdm import tqdm
7 | import torch.nn.functional as F
8 | import numpy as np
9 |
10 | import torch
11 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
12 | from sklearn.metrics import roc_auc_score
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 | __all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter',
17 | 'accuracy_open', 'ova_loss', 'compute_roc',
18 | 'roc_id_ood', 'ova_ent', 'exclude_dataset',
19 | 'test_ood', 'test']
20 |
21 |
22 | def get_mean_and_std(dataset):
23 | '''Compute the mean and std value of dataset.'''
24 | dataloader = torch.utils.data.DataLoader(
25 | dataset, batch_size=1, shuffle=False, num_workers=4)
26 |
27 | mean = torch.zeros(3)
28 | std = torch.zeros(3)
29 | logger.info('==> Computing mean and std..')
30 | for inputs, targets in dataloader:
31 | for i in range(3):
32 | mean[i] += inputs[:, i, :, :].mean()
33 | std[i] += inputs[:, i, :, :].std()
34 | mean.div_(len(dataset))
35 | std.div_(len(dataset))
36 | return mean, std
37 |
38 |
39 | def accuracy(output, target, topk=(1,)):
40 | """Computes the precision@k for the specified values of k"""
41 | maxk = max(topk)
42 | batch_size = target.size(0)
43 |
44 | _, pred = output.topk(maxk, 1, True, True)
45 | pred = pred.t()
46 | correct = pred.eq(target.reshape(1, -1).expand_as(pred))
47 |
48 | res = []
49 |
50 | for k in topk:
51 | correct_k = correct[:k].reshape(-1).float().sum(0)
52 | res.append(correct_k.mul_(100.0 / batch_size))
53 | return res
54 |
55 |
56 | def accuracy_open(pred, target, topk=(1,), num_classes=5):
57 | """Computes the precision@k for the specified values of k,
58 | num_classes are the number of known classes.
59 | This function returns overall accuracy,
60 | accuracy to reject unknown samples,
61 | the size of unknown samples in this batch."""
62 | maxk = max(topk)
63 | batch_size = target.size(0)
64 | pred = pred.view(-1, 1)
65 | pred = pred.t()
66 | ind = (target == num_classes)
67 | unknown_size = len(ind)
68 | correct = pred.eq(target.view(1, -1).expand_as(pred))
69 | if ind.sum() > 0:
70 | unk_corr = pred.eq(target).view(-1)[ind]
71 | acc = torch.sum(unk_corr).item() / unk_corr.size(0)
72 | else:
73 | acc = 0
74 |
75 | res = []
76 | for k in topk:
77 | correct_k = correct[:k].view(-1).float().sum(0)
78 | res.append(correct_k.mul_(100.0 / batch_size))
79 | return res[0], acc, unknown_size
80 |
81 |
82 | def compute_roc(unk_all, label_all, num_known):
83 | Y_test = np.zeros(unk_all.shape[0])
84 | unk_pos = np.where(label_all >= num_known)[0]
85 | Y_test[unk_pos] = 1
86 | return roc_auc_score(Y_test, unk_all)
87 |
88 |
89 | def roc_id_ood(score_id, score_ood):
90 | id_all = np.r_[score_id, score_ood]
91 | Y_test = np.zeros(score_id.shape[0]+score_ood.shape[0])
92 | Y_test[score_id.shape[0]:] = 1
93 | return roc_auc_score(Y_test, id_all)
94 |
95 |
96 | def ova_loss(logits_open, label):
97 | logits_open = logits_open.view(logits_open.size(0), 2, -1)
98 | logits_open = F.softmax(logits_open, 1)
99 | label_s_sp = torch.zeros((logits_open.size(0),
100 | logits_open.size(2))).long().to(label.device)
101 | label_range = torch.range(0, logits_open.size(0) - 1).long()
102 | label_s_sp[label_range, label] = 1
103 | label_sp_neg = 1 - label_s_sp
104 | open_loss = torch.mean(torch.sum(-torch.log(logits_open[:, 1, :]
105 | + 1e-8) * label_s_sp, 1))
106 | open_loss_neg = torch.mean(torch.max(-torch.log(logits_open[:, 0, :]
107 | + 1e-8) * label_sp_neg, 1)[0])
108 | Lo = open_loss_neg + open_loss
109 | return Lo
110 |
111 |
112 | def ova_ent(logits_open):
113 | logits_open = logits_open.view(logits_open.size(0), 2, -1)
114 | logits_open = F.softmax(logits_open, 1)
115 | Le = torch.mean(torch.mean(torch.sum(-logits_open *
116 | torch.log(logits_open + 1e-8), 1), 1))
117 | return Le
118 |
119 |
120 | class AverageMeter(object):
121 | """Computes and stores the average and current value
122 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
123 | """
124 |
125 | def __init__(self):
126 | self.reset()
127 |
128 | def reset(self):
129 | self.val = 0
130 | self.avg = 0
131 | self.sum = 0
132 | self.count = 0
133 |
134 | def update(self, val, n=1):
135 | self.val = val
136 | self.sum += val * n
137 | self.count += n
138 | self.avg = self.sum / self.count
139 |
140 |
141 |
142 | def exclude_dataset(args, dataset, model, exclude_known=False):
143 | data_time = AverageMeter()
144 | end = time.time()
145 | dataset.init_index()
146 | test_loader = DataLoader(
147 | dataset,
148 | batch_size=args.batch_size,
149 | num_workers=args.num_workers,
150 | drop_last=False,
151 | shuffle=False)
152 | if not args.no_progress:
153 | test_loader = tqdm(test_loader,
154 | disable=args.local_rank not in [-1, 0])
155 | model.eval()
156 | with torch.no_grad():
157 | for batch_idx, ((_, _, inputs), targets) in enumerate(test_loader):
158 | data_time.update(time.time() - end)
159 |
160 | inputs = inputs.to(args.device)
161 | outputs, outputs_open = model(inputs)
162 | outputs = F.softmax(outputs, 1)
163 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1)
164 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda()
165 | pred_close = outputs.data.max(1)[1]
166 | unk_score = out_open[tmp_range, 0, pred_close]
167 | known_ind = unk_score < 0.5
168 | if batch_idx == 0:
169 | known_all = known_ind
170 | else:
171 | known_all = torch.cat([known_all, known_ind], 0)
172 | if not args.no_progress:
173 | test_loader.close()
174 | known_all = known_all.data.cpu().numpy()
175 | if exclude_known:
176 | ind_selected = np.where(known_all == 0)[0]
177 | else:
178 | ind_selected = np.where(known_all != 0)[0]
179 | print("selected ratio %s"%( (len(ind_selected)/ len(known_all))))
180 | model.train()
181 | dataset.set_index(ind_selected)
182 |
183 | def test(args, test_loader, model, epoch, val=False):
184 | batch_time = AverageMeter()
185 | data_time = AverageMeter()
186 | losses = AverageMeter()
187 | top1 = AverageMeter()
188 | acc = AverageMeter()
189 | unk = AverageMeter()
190 | top5 = AverageMeter()
191 | end = time.time()
192 |
193 | if not args.no_progress:
194 | test_loader = tqdm(test_loader,
195 | disable=args.local_rank not in [-1, 0])
196 | with torch.no_grad():
197 | for batch_idx, (inputs, targets) in enumerate(test_loader):
198 | data_time.update(time.time() - end)
199 | model.eval()
200 | inputs = inputs.to(args.device)
201 | targets = targets.to(args.device)
202 | outputs, outputs_open = model(inputs)
203 | outputs = F.softmax(outputs, 1)
204 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1)
205 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda()
206 | pred_close = outputs.data.max(1)[1]
207 | unk_score = out_open[tmp_range, 0, pred_close]
208 | known_score = outputs.max(1)[0]
209 | targets_unk = targets >= int(outputs.size(1))
210 | targets[targets_unk] = int(outputs.size(1))
211 | known_targets = targets < int(outputs.size(1))#[0]
212 | known_pred = outputs[known_targets]
213 | known_targets = targets[known_targets]
214 |
215 | if len(known_pred) > 0:
216 | prec1, prec5 = accuracy(known_pred, known_targets, topk=(1, 5))
217 | top1.update(prec1.item(), known_pred.shape[0])
218 | top5.update(prec5.item(), known_pred.shape[0])
219 |
220 | ind_unk = unk_score > 0.5
221 | pred_close[ind_unk] = int(outputs.size(1))
222 | acc_all, unk_acc, size_unk = accuracy_open(pred_close,
223 | targets,
224 | num_classes=int(outputs.size(1)))
225 | acc.update(acc_all.item(), inputs.shape[0])
226 | unk.update(unk_acc, size_unk)
227 |
228 | batch_time.update(time.time() - end)
229 | end = time.time()
230 | if batch_idx == 0:
231 | unk_all = unk_score
232 | known_all = known_score
233 | label_all = targets
234 | else:
235 | unk_all = torch.cat([unk_all, unk_score], 0)
236 | known_all = torch.cat([known_all, known_score], 0)
237 | label_all = torch.cat([label_all, targets], 0)
238 |
239 | if not args.no_progress:
240 | test_loader.set_description("Test Iter: {batch:4}/{iter:4}. "
241 | "Data: {data:.3f}s."
242 | "Batch: {bt:.3f}s. "
243 | "Loss: {loss:.4f}. "
244 | "Closed t1: {top1:.3f} "
245 | "t5: {top5:.3f} "
246 | "acc: {acc:.3f}. "
247 | "unk: {unk:.3f}. ".format(
248 | batch=batch_idx + 1,
249 | iter=len(test_loader),
250 | data=data_time.avg,
251 | bt=batch_time.avg,
252 | loss=losses.avg,
253 | top1=top1.avg,
254 | top5=top5.avg,
255 | acc=acc.avg,
256 | unk=unk.avg,
257 | ))
258 | if not args.no_progress:
259 | test_loader.close()
260 | ## ROC calculation
261 | #import pdb
262 | #pdb.set_trace()
263 | unk_all = unk_all.data.cpu().numpy()
264 | known_all = known_all.data.cpu().numpy()
265 | label_all = label_all.data.cpu().numpy()
266 | if not val:
267 | roc = compute_roc(unk_all, label_all,
268 | num_known=int(outputs.size(1)))
269 | roc_soft = compute_roc(-known_all, label_all,
270 | num_known=int(outputs.size(1)))
271 | ind_known = np.where(label_all < int(outputs.size(1)))[0]
272 | id_score = unk_all[ind_known]
273 | logger.info("Closed acc: {:.3f}".format(top1.avg))
274 | logger.info("Overall acc: {:.3f}".format(acc.avg))
275 | logger.info("Unk acc: {:.3f}".format(unk.avg))
276 | logger.info("ROC: {:.3f}".format(roc))
277 | logger.info("ROC Softmax: {:.3f}".format(roc_soft))
278 | return losses.avg, top1.avg, acc.avg, \
279 | unk.avg, roc, roc_soft, id_score
280 | else:
281 | logger.info("Closed acc: {:.3f}".format(top1.avg))
282 | return top1.avg
283 |
284 |
285 | def test_ood(args, test_id, test_loader, model):
286 | batch_time = AverageMeter()
287 | data_time = AverageMeter()
288 | end = time.time()
289 |
290 | if not args.no_progress:
291 | test_loader = tqdm(test_loader,
292 | disable=args.local_rank not in [-1, 0])
293 | with torch.no_grad():
294 | for batch_idx, (inputs, targets) in enumerate(test_loader):
295 | data_time.update(time.time() - end)
296 | model.eval()
297 | inputs = inputs.to(args.device)
298 | outputs, outputs_open = model(inputs)
299 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1)
300 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda()
301 | pred_close = outputs.data.max(1)[1]
302 | unk_score = out_open[tmp_range, 0, pred_close]
303 | batch_time.update(time.time() - end)
304 | end = time.time()
305 | if batch_idx == 0:
306 | unk_all = unk_score
307 | else:
308 | unk_all = torch.cat([unk_all, unk_score], 0)
309 | if not args.no_progress:
310 | test_loader.close()
311 | ## ROC calculation
312 | unk_all = unk_all.data.cpu().numpy()
313 | roc = roc_id_ood(test_id, unk_all)
314 |
315 | return roc
316 |
--------------------------------------------------------------------------------
/utils/parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | __all__ = ['set_parser']
4 |
5 | def set_parser():
6 | parser = argparse.ArgumentParser(description='PyTorch OpenMatch Training')
7 | ## Computational Configurations
8 | parser.add_argument('--gpu-id', default='0', type=int,
9 | help='id(s) for CUDA_VISIBLE_DEVICES')
10 | parser.add_argument('--num-workers', type=int, default=4,
11 | help='number of workers')
12 | parser.add_argument('--seed', default=None, type=int,
13 | help="random seed")
14 | parser.add_argument("--amp", action="store_true",
15 | help="use 16-bit (mixed) precision through NVIDIA apex AMP")
16 | parser.add_argument("--opt_level", type=str, default="O1",
17 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
18 | "See details at https://nvidia.github.io/apex/amp.html")
19 | parser.add_argument("--local_rank", type=int, default=-1,
20 | help="For distributed training: local_rank")
21 | parser.add_argument('--no-progress', action='store_true',
22 | help="don't use progress bar")
23 | parser.add_argument('--eval_only', type=int, default=0,
24 | help='1 if evaluation mode ')
25 | parser.add_argument('--num_classes', type=int, default=6,
26 | help='for cifar10')
27 |
28 | parser.add_argument('--out', default='result',
29 | help='directory to output the result')
30 | parser.add_argument('--resume', default='', type=str,
31 | help='path to latest checkpoint (default: none)')
32 | parser.add_argument('--root', default='./data', type=str,
33 | help='path to data directory')
34 | parser.add_argument('--dataset', default='cifar10', type=str,
35 | choices=['cifar10', 'cifar100', 'imagenet'],
36 | help='dataset name')
37 | ## Hyper-parameters
38 | parser.add_argument('--opt', default='sgd', type=str,
39 | choices=['sgd', 'adam'],
40 | help='optimize name')
41 | parser.add_argument('--num-labeled', type=int, default=400,
42 | choices=[25, 50, 100, 400],
43 | help='number of labeled data per each class')
44 | parser.add_argument('--num_val', type=int, default=50,
45 | help='number of validation data per each class')
46 | parser.add_argument('--num-super', type=int, default=10,
47 | help='number of super-class known classes cifar100: 10 or 15')
48 | parser.add_argument("--expand-labels", action="store_true",
49 | help="expand labels to fit eval steps")
50 | parser.add_argument('--arch', default='wideresnet', type=str,
51 | choices=['wideresnet', 'resnext',
52 | 'resnet_imagenet'],
53 | help='dataset name')
54 | ## HP unique to OpenMatch (Some are changed from FixMatch)
55 | parser.add_argument('--lambda_oem', default=0.1, type=float,
56 | help='coefficient of OEM loss')
57 | parser.add_argument('--lambda_socr', default=0.5, type=float,
58 | help='coefficient of SOCR loss, 0.5 for CIFAR10, ImageNet, '
59 | '1.0 for CIFAR100')
60 | parser.add_argument('--start_fix', default=10, type=int,
61 | help='epoch to start fixmatch training')
62 | parser.add_argument('--mu', default=2, type=int,
63 | help='coefficient of unlabeled batch size')
64 | parser.add_argument('--total-steps', default=2 ** 19, type=int,
65 | help='number of total steps to run')
66 | parser.add_argument('--epochs', default=512, type=int,
67 | help='number of epochs to run')
68 | parser.add_argument('--threshold', default=0.0, type=float,
69 | help='pseudo label threshold')
70 | ##
71 | parser.add_argument('--eval-step', default=1024, type=int,
72 | help='number of eval steps to run')
73 |
74 | parser.add_argument('--start-epoch', default=0, type=int,
75 | help='manual epoch number (useful on restarts)')
76 | parser.add_argument('--batch-size', default=64, type=int,
77 | help='train batchsize')
78 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
79 | help='initial learning rate')
80 | parser.add_argument('--warmup', default=0, type=float,
81 | help='warmup epochs (unlabeled data based)')
82 | parser.add_argument('--wdecay', default=5e-4, type=float,
83 | help='weight decay')
84 | parser.add_argument('--nesterov', action='store_true', default=True,
85 | help='use nesterov momentum')
86 | parser.add_argument('--use-ema', action='store_true', default=True,
87 | help='use EMA model')
88 | parser.add_argument('--ema-decay', default=0.999, type=float,
89 | help='EMA decay rate')
90 | parser.add_argument('--T', default=1, type=float,
91 | help='pseudo label temperature')
92 |
93 |
94 | args = parser.parse_args()
95 | return args
--------------------------------------------------------------------------------