├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── cifar.py ├── mydataset.py └── randaugment.py ├── eval.py ├── files.zip ├── images └── consistency.png ├── main.py ├── models ├── ema.py ├── resnet_imagenet.py ├── resnext.py └── wideresnet.py ├── run_cifar10.sh ├── run_cifar100.sh ├── run_eval_cifar10.sh ├── run_imagenet.sh ├── trainer.py └── utils ├── __init__.py ├── default.py ├── misc.py └── parser.py /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vision and Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers (NeurIPS 2021)](https://arxiv.org/pdf/2105.14148.pdf) 2 | 3 | ![OpenMatch Overview](images/consistency.png) 4 | 5 | 6 | This is an PyTorch implementation of OpenMatch. 7 | This implementation is based on [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch). 8 | 9 | 10 | 11 | ## Requirements 12 | - python 3.6+ 13 | - torch 1.4 14 | - torchvision 0.5 15 | - tensorboard 16 | - numpy 17 | - tqdm 18 | - sklearn 19 | - apex (optional) 20 | 21 | See [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch) for the details. 22 | 23 | ## Usage 24 | 25 | ### Dataset Preparation 26 | This repository needs CIFAR10, CIFAR100, or ImageNet-30 to train a model. 27 | 28 | To fully reproduce the results in evaluation, we also need SVHN, LSUN, ImageNet 29 | for CIFAR10, 100, and LSUN, DTD, CUB, Flowers, Caltech_256, Stanford Dogs for ImageNet-30. 30 | To prepare the datasets above, follow [CSI](https://github.com/alinlab/CSI). 31 | 32 | 33 | ``` 34 | mkdir data 35 | ln -s path_to_each_dataset ./data/. 36 | 37 | ## unzip filelist for imagenet_30 experiments. 38 | unzip files.zip 39 | ``` 40 | 41 | All datasets are supposed to be under ./data. 42 | 43 | ### Train 44 | Train the model by 50 labeled data per class of CIFAR-10 dataset: 45 | 46 | ``` 47 | sh run_cifar10.sh 50 save_directory 48 | ``` 49 | 50 | Train the model by 50 labeled data per class of CIFAR-100 dataset, 55 known classes: 51 | 52 | ``` 53 | sh run_cifar100.sh 50 10 save_directory 54 | ``` 55 | 56 | 57 | Train the model by 50 labeled data per class of CIFAR-100 dataset, 80 known classes: 58 | 59 | ``` 60 | sh run_cifar100.sh 50 15 save_directory 61 | ``` 62 | 63 | 64 | Run experiments on ImageNet-30: 65 | 66 | ``` 67 | sh run_imagenet.sh save_directory 68 | ``` 69 | 70 | 71 | ### Evaluation 72 | Evaluate a model trained on cifar10 73 | 74 | ``` 75 | sh run_eval_cifar10.sh trained_model.pth 76 | ``` 77 | 78 | ### Trained models 79 | Coming soon. 80 | 81 | - [CIFAR10-50-labeled](https://drive.google.com/file/d/1oNWAR8jVlxQXH0TMql1P-c7_i5-taU2T/view?usp=sharing) 82 | - [CIFAR100-50-labeled-55class](https://drive.google.com/file/d/1T5a_p4XUEOexEnjLWpGd-3pme4OzJ2pP/view?usp=sharing) 83 | - ImageNet-30 84 | 85 | ### Acknowledgement 86 | This repository depends a lot on [Pytorch-FixMatch](https://github.com/kekmodel/FixMatch-pytorch) for FixMatch implementation, and [CSI](https://github.com/alinlab/CSI) for anomaly detection evaluation. 87 | Thanks for sharing the great code bases! 88 | 89 | ### Reference 90 | This repository is contributed by [Kuniaki Saito](http://cs-people.bu.edu/keisaito/). 91 | If you consider using this code or its derivatives, please consider citing: 92 | 93 | ``` 94 | @article{saito2021openmatch, 95 | title={OpenMatch: Open-set Consistency Regularization for Semi-supervised Learning with Outliers}, 96 | author={Saito, Kuniaki and Kim, Donghyun and Saenko, Kate}, 97 | journal={arXiv preprint arXiv:2105.14148}, 98 | year={2021} 99 | } 100 | ``` 101 | 102 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import * 2 | -------------------------------------------------------------------------------- /dataset/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import numpy as np 5 | from PIL import Image 6 | from torchvision import datasets 7 | from torchvision import transforms 8 | 9 | from .randaugment import RandAugmentMC 10 | from .mydataset import ImageFolder, ImageFolder_fix 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | __all__ = ['TransformOpenMatch', 'TransformFixMatch', 'cifar10_mean', 15 | 'cifar10_std', 'cifar100_mean', 'cifar100_std', 'normal_mean', 16 | 'normal_std', 'TransformFixMatch_Imagenet', 17 | 'TransformFixMatch_Imagenet_Weak'] 18 | ### Enter Path of the data directory. 19 | DATA_PATH = './data' 20 | 21 | cifar10_mean = (0.4914, 0.4822, 0.4465) 22 | cifar10_std = (0.2471, 0.2435, 0.2616) 23 | cifar100_mean = (0.5071, 0.4867, 0.4408) 24 | cifar100_std = (0.2675, 0.2565, 0.2761) 25 | normal_mean = (0.5, 0.5, 0.5) 26 | normal_std = (0.5, 0.5, 0.5) 27 | 28 | 29 | def get_cifar(args, norm=True): 30 | root = args.root 31 | name = args.dataset 32 | if name == "cifar10": 33 | data_folder = datasets.CIFAR10 34 | data_folder_main = CIFAR10SSL 35 | mean = cifar10_mean 36 | std = cifar10_std 37 | num_class = 10 38 | elif name == "cifar100": 39 | data_folder = CIFAR100FIX 40 | data_folder_main = CIFAR100SSL 41 | mean = cifar100_mean 42 | std = cifar100_std 43 | num_class = 100 44 | num_super = args.num_super 45 | 46 | else: 47 | raise NotImplementedError() 48 | assert num_class > args.num_classes 49 | 50 | if name == "cifar10": 51 | base_dataset = data_folder(root, train=True, download=True) 52 | args.num_classes = 6 53 | elif name == 'cifar100': 54 | base_dataset = data_folder(root, train=True, 55 | download=True, num_super=num_super) 56 | args.num_classes = base_dataset.num_known_class 57 | 58 | base_dataset.targets = np.array(base_dataset.targets) 59 | if name == 'cifar10': 60 | base_dataset.targets -= 2 61 | base_dataset.targets[np.where(base_dataset.targets == -2)[0]] = 8 62 | base_dataset.targets[np.where(base_dataset.targets == -1)[0]] = 9 63 | 64 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = \ 65 | x_u_split(args, base_dataset.targets) 66 | 67 | ## This function will be overwritten in trainer.py 68 | norm_func = TransformFixMatch(mean=mean, std=std, norm=norm) 69 | if norm: 70 | norm_func_test = transforms.Compose([ 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean=mean, std=std) 73 | ]) 74 | else: 75 | norm_func_test = transforms.Compose([ 76 | transforms.ToTensor(), 77 | ]) 78 | 79 | if name == 'cifar10': 80 | train_labeled_dataset = data_folder_main( 81 | root, train_labeled_idxs, train=True, 82 | transform=norm_func) 83 | train_unlabeled_dataset = data_folder_main( 84 | root, train_unlabeled_idxs, train=True, 85 | transform=norm_func, return_idx=False) 86 | val_dataset = data_folder_main( 87 | root, val_idxs, train=True, 88 | transform=norm_func_test) 89 | elif name == 'cifar100': 90 | train_labeled_dataset = data_folder_main( 91 | root, train_labeled_idxs, num_super = num_super, train=True, 92 | transform=norm_func) 93 | train_unlabeled_dataset = data_folder_main( 94 | root, train_unlabeled_idxs, num_super = num_super, train=True, 95 | transform=norm_func, return_idx=False) 96 | val_dataset = data_folder_main( 97 | root, val_idxs, num_super = num_super,train=True, 98 | transform=norm_func_test) 99 | 100 | if name == 'cifar10': 101 | train_labeled_dataset.targets -= 2 102 | train_unlabeled_dataset.targets -= 2 103 | val_dataset.targets -= 2 104 | 105 | 106 | if name == 'cifar10': 107 | test_dataset = data_folder( 108 | root, train=False, transform=norm_func_test, download=False) 109 | elif name == 'cifar100': 110 | test_dataset = data_folder( 111 | root, train=False, transform=norm_func_test, 112 | download=False, num_super=num_super) 113 | test_dataset.targets = np.array(test_dataset.targets) 114 | 115 | if name == 'cifar10': 116 | test_dataset.targets -= 2 117 | test_dataset.targets[np.where(test_dataset.targets == -2)[0]] = 8 118 | test_dataset.targets[np.where(test_dataset.targets == -1)[0]] = 9 119 | 120 | target_ind = np.where(test_dataset.targets >= args.num_classes)[0] 121 | test_dataset.targets[target_ind] = args.num_classes 122 | 123 | 124 | unique_labeled = np.unique(train_labeled_idxs) 125 | val_labeled = np.unique(val_idxs) 126 | logger.info("Dataset: %s"%name) 127 | logger.info(f"Labeled examples: {len(unique_labeled)}" 128 | f"Unlabeled examples: {len(train_unlabeled_idxs)}" 129 | f"Valdation samples: {len(val_labeled)}") 130 | return train_labeled_dataset, train_unlabeled_dataset, \ 131 | test_dataset, val_dataset 132 | 133 | 134 | 135 | def get_imagenet(args, norm=True): 136 | mean = normal_mean 137 | std = normal_std 138 | txt_labeled = "filelist/imagenet_train_labeled.txt" 139 | txt_unlabeled = "filelist/imagenet_train_unlabeled.txt" 140 | txt_val = "filelist/imagenet_val.txt" 141 | txt_test = "filelist/imagenet_test.txt" 142 | ## This function will be overwritten in trainer.py 143 | norm_func = TransformFixMatch_Imagenet(mean=mean, std=std, 144 | norm=norm, size_image=224) 145 | dataset_labeled = ImageFolder(txt_labeled, transform=norm_func) 146 | dataset_unlabeled = ImageFolder_fix(txt_unlabeled, transform=norm_func) 147 | 148 | test_transform = transforms.Compose([ 149 | transforms.Resize(256), 150 | transforms.CenterCrop(224), 151 | transforms.ToTensor(), 152 | transforms.Normalize(mean=mean, std=std) 153 | ]) 154 | dataset_val = ImageFolder(txt_val, transform=test_transform) 155 | dataset_test = ImageFolder(txt_test, transform=test_transform) 156 | logger.info(f"Labeled examples: {len(dataset_labeled)}" 157 | f"Unlabeled examples: {len(dataset_unlabeled)}" 158 | f"Valdation samples: {len(dataset_val)}") 159 | return dataset_labeled, dataset_unlabeled, dataset_test, dataset_val 160 | 161 | 162 | def x_u_split(args, labels): 163 | label_per_class = args.num_labeled #// args.num_classes 164 | val_per_class = args.num_val #// args.num_classes 165 | labels = np.array(labels) 166 | labeled_idx = [] 167 | val_idx = [] 168 | unlabeled_idx = [] 169 | # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10) 170 | for i in range(args.num_classes): 171 | idx = np.where(labels == i)[0] 172 | unlabeled_idx.extend(idx) 173 | idx = np.random.choice(idx, label_per_class+val_per_class, False) 174 | labeled_idx.extend(idx[:label_per_class]) 175 | val_idx.extend(idx[label_per_class:]) 176 | 177 | labeled_idx = np.array(labeled_idx) 178 | 179 | assert len(labeled_idx) == args.num_labeled * args.num_classes 180 | if args.expand_labels or args.num_labeled < args.batch_size: 181 | num_expand_x = math.ceil( 182 | args.batch_size * args.eval_step / args.num_labeled) 183 | labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)]) 184 | np.random.shuffle(labeled_idx) 185 | 186 | #if not args.no_out: 187 | unlabeled_idx = np.array(range(len(labels))) 188 | unlabeled_idx = [idx for idx in unlabeled_idx if idx not in labeled_idx] 189 | unlabeled_idx = [idx for idx in unlabeled_idx if idx not in val_idx] 190 | return labeled_idx, unlabeled_idx, val_idx 191 | 192 | 193 | class TransformFixMatch(object): 194 | def __init__(self, mean, std, norm=True, size_image=32): 195 | self.weak = transforms.Compose([ 196 | transforms.RandomHorizontalFlip(), 197 | transforms.RandomCrop(size=size_image, 198 | padding=int(size_image*0.125), 199 | padding_mode='reflect')]) 200 | self.weak2 = transforms.Compose([ 201 | transforms.RandomHorizontalFlip(),]) 202 | self.strong = transforms.Compose([ 203 | transforms.RandomHorizontalFlip(), 204 | transforms.RandomCrop(size=size_image, 205 | padding=int(size_image*0.125), 206 | padding_mode='reflect'), 207 | RandAugmentMC(n=2, m=10)]) 208 | self.normalize = transforms.Compose([ 209 | transforms.ToTensor(), 210 | transforms.Normalize(mean=mean, std=std)]) 211 | self.norm = norm 212 | 213 | def __call__(self, x): 214 | weak = self.weak(x) 215 | strong = self.strong(x) 216 | if self.norm: 217 | return self.normalize(weak), self.normalize(strong), self.normalize(self.weak2(x)) 218 | else: 219 | return weak, strong 220 | 221 | class TransformOpenMatch(object): 222 | def __init__(self, mean, std, norm=True, size_image=32): 223 | self.weak = transforms.Compose([ 224 | transforms.RandomHorizontalFlip(), 225 | transforms.RandomCrop(size=size_image, 226 | padding=int(size_image*0.125), 227 | padding_mode='reflect')]) 228 | self.weak2 = transforms.Compose([ 229 | transforms.RandomHorizontalFlip(),]) 230 | self.normalize = transforms.Compose([ 231 | transforms.ToTensor(), 232 | transforms.Normalize(mean=mean, std=std)]) 233 | self.norm = norm 234 | 235 | def __call__(self, x): 236 | weak = self.weak(x) 237 | strong = self.weak(x) 238 | 239 | if self.norm: 240 | return self.normalize(weak), self.normalize(strong), self.normalize(self.weak2(x)) 241 | else: 242 | return weak, strong 243 | 244 | 245 | 246 | 247 | class TransformFixMatch_Imagenet(object): 248 | def __init__(self, mean, std, norm=True, size_image=224): 249 | self.weak = transforms.Compose([ 250 | transforms.Scale((256, 256)), 251 | transforms.RandomHorizontalFlip(), 252 | transforms.RandomCrop(size=size_image, 253 | padding=int(size_image*0.125), 254 | padding_mode='reflect')]) 255 | self.weak2 = transforms.Compose([ 256 | transforms.Scale((256, 256)), 257 | transforms.RandomHorizontalFlip(), 258 | transforms.CenterCrop(size=size_image), 259 | ]) 260 | self.strong = transforms.Compose([ 261 | transforms.Scale((256, 256)), 262 | transforms.RandomHorizontalFlip(), 263 | transforms.RandomCrop(size=size_image, 264 | padding=int(size_image*0.125), 265 | padding_mode='reflect'), 266 | RandAugmentMC(n=2, m=10)]) 267 | self.normalize = transforms.Compose([ 268 | transforms.ToTensor(), 269 | transforms.Normalize(mean=mean, std=std)]) 270 | self.norm = norm 271 | 272 | def __call__(self, x): 273 | weak = self.weak(x) 274 | weak2 = self.weak2(x) 275 | strong = self.strong(x) 276 | if self.norm: 277 | return self.normalize(weak), self.normalize(strong), self.normalize(weak2) 278 | else: 279 | return weak, strong 280 | 281 | 282 | 283 | class TransformFixMatch_Imagenet_Weak(object): 284 | def __init__(self, mean, std, norm=True, size_image=224): 285 | self.weak = transforms.Compose([ 286 | transforms.Scale((256, 256)), 287 | transforms.RandomHorizontalFlip(), 288 | transforms.RandomCrop(size=size_image, 289 | padding=int(size_image*0.125), 290 | padding_mode='reflect')]) 291 | self.weak2 = transforms.Compose([ 292 | transforms.Scale((256, 256)), 293 | transforms.RandomHorizontalFlip(), 294 | transforms.CenterCrop(size=size_image), 295 | ]) 296 | self.strong = transforms.Compose([ 297 | transforms.Scale((256, 256)), 298 | transforms.RandomHorizontalFlip(), 299 | transforms.RandomCrop(size=size_image, 300 | padding=int(size_image*0.125), 301 | padding_mode='reflect'), 302 | RandAugmentMC(n=2, m=10)]) 303 | self.normalize = transforms.Compose([ 304 | transforms.ToTensor(), 305 | transforms.Normalize(mean=mean, std=std)]) 306 | self.norm = norm 307 | 308 | def __call__(self, x): 309 | weak = self.weak2(x) 310 | weak2 = self.weak2(x) 311 | strong = self.strong(x) 312 | if self.norm: 313 | return self.normalize(weak), self.normalize(strong), self.normalize(weak2) 314 | else: 315 | return weak, strong 316 | 317 | 318 | 319 | 320 | class CIFAR10SSL(datasets.CIFAR10): 321 | def __init__(self, root, indexs, train=True, 322 | transform=None, target_transform=None, 323 | download=False, return_idx=False): 324 | super().__init__(root, train=train, 325 | transform=transform, 326 | target_transform=target_transform, 327 | download=download) 328 | if indexs is not None: 329 | self.data = self.data[indexs] 330 | self.targets = np.array(self.targets)[indexs] 331 | self.return_idx = return_idx 332 | self.set_index() 333 | 334 | def set_index(self, indexes=None): 335 | if indexes is not None: 336 | self.data_index = self.data[indexes] 337 | self.targets_index = self.targets[indexes] 338 | else: 339 | self.data_index = self.data 340 | self.targets_index = self.targets 341 | 342 | def init_index(self): 343 | self.data_index = self.data 344 | self.targets_index = self.targets 345 | 346 | def __getitem__(self, index): 347 | img, target = self.data_index[index], self.targets_index[index] 348 | img = Image.fromarray(img) 349 | 350 | if self.transform is not None: 351 | img = self.transform(img) 352 | 353 | if self.target_transform is not None: 354 | target = self.target_transform(target) 355 | 356 | if not self.return_idx: 357 | return img, target 358 | else: 359 | return img, target, index 360 | 361 | def __len__(self): 362 | return len(self.data_index) 363 | 364 | 365 | 366 | 367 | 368 | 369 | class CIFAR100FIX(datasets.CIFAR100): 370 | def __init__(self, root, num_super=10, train=True, transform=None, 371 | target_transform=None, download=False, return_idx=False): 372 | super().__init__(root, train=train, transform=transform, 373 | target_transform=target_transform, download=download) 374 | 375 | coarse_labels = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 376 | 3, 14, 9, 18, 7, 11, 3, 9, 7, 11, 377 | 6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 378 | 0, 11, 1, 10, 12, 14, 16, 9, 11, 5, 379 | 5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 380 | 16, 4, 17, 4, 2, 0, 17, 4, 18, 17, 381 | 10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 382 | 2, 10, 0, 1, 16, 12, 9, 13, 15, 13, 383 | 16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 384 | 18, 1, 2, 15, 6, 0, 17, 8, 14, 13]) 385 | self.course_labels = coarse_labels[self.targets] 386 | self.targets = np.array(self.targets) 387 | labels_unknown = self.targets[np.where(self.course_labels > num_super)[0]] 388 | labels_known = self.targets[np.where(self.course_labels <= num_super)[0]] 389 | unknown_categories = np.unique(labels_unknown) 390 | known_categories = np.unique(labels_known) 391 | 392 | num_unknown = len(unknown_categories) 393 | num_known = len(known_categories) 394 | print("number of unknown categories %s"%num_unknown) 395 | print("number of known categories %s"%num_known) 396 | assert num_known + num_unknown == 100 397 | #new_category_labels = list(range(num_known)) 398 | self.targets_new = np.zeros_like(self.targets) 399 | for i, known in enumerate(known_categories): 400 | ind_known = np.where(self.targets==known)[0] 401 | self.targets_new[ind_known] = i 402 | for i, unknown in enumerate(unknown_categories): 403 | ind_unknown = np.where(self.targets == unknown)[0] 404 | self.targets_new[ind_unknown] = num_known 405 | 406 | self.targets = self.targets_new 407 | assert len(np.where(self.targets == num_known)[0]) == len(labels_unknown) 408 | assert len(np.where(self.targets < num_known)[0]) == len(labels_known) 409 | self.num_known_class = num_known 410 | 411 | 412 | def __getitem__(self, index): 413 | 414 | img, target = self.data[index], self.targets[index] 415 | img = Image.fromarray(img) 416 | 417 | if self.transform is not None: 418 | img = self.transform(img) 419 | 420 | if self.target_transform is not None: 421 | target = self.target_transform(target) 422 | 423 | return img, target 424 | 425 | 426 | class CIFAR100SSL(CIFAR100FIX): 427 | def __init__(self, root, indexs, num_super=10, train=True, 428 | transform=None, target_transform=None, 429 | download=False, return_idx=False): 430 | super().__init__(root, num_super=num_super,train=train, 431 | transform=transform, 432 | target_transform=target_transform, 433 | download=download) 434 | self.return_idx = return_idx 435 | if indexs is not None: 436 | self.data = self.data[indexs] 437 | self.targets = np.array(self.targets)[indexs] 438 | 439 | self.set_index() 440 | def set_index(self, indexes=None): 441 | if indexes is not None: 442 | self.data_index = self.data[indexes] 443 | self.targets_index = self.targets[indexes] 444 | else: 445 | self.data_index = self.data 446 | self.targets_index = self.targets 447 | 448 | def init_index(self): 449 | self.data_index = self.data 450 | self.targets_index = self.targets 451 | 452 | 453 | def __getitem__(self, index): 454 | img, target = self.data_index[index], self.targets_index[index] 455 | img = Image.fromarray(img) 456 | 457 | if self.transform is not None: 458 | img = self.transform(img) 459 | 460 | if self.target_transform is not None: 461 | target = self.target_transform(target) 462 | if not self.return_idx: 463 | return img, target 464 | else: 465 | return img, target, index 466 | 467 | def __len__(self): 468 | return len(self.data_index) 469 | 470 | def get_transform(mean, std, image_size=None): 471 | # Note: data augmentation is implemented in the layers 472 | # Hence, we only define the identity transformation here 473 | if image_size: # use pre-specified image size 474 | train_transform = transforms.Compose([ 475 | transforms.Resize((image_size[0], image_size[1])), 476 | transforms.RandomHorizontalFlip(), 477 | transforms.ToTensor(), 478 | transforms.Normalize(mean=mean, std=std), 479 | ]) 480 | test_transform = transforms.Compose([ 481 | transforms.Resize((image_size[0], image_size[1])), 482 | transforms.ToTensor(), 483 | transforms.Normalize(mean=mean, std=std), 484 | ]) 485 | else: # use default image size 486 | train_transform = transforms.Compose([ 487 | transforms.ToTensor(), 488 | transforms.Normalize(mean=mean, std=std), 489 | ]) 490 | test_transform = transforms.ToTensor() 491 | 492 | return train_transform, test_transform 493 | 494 | 495 | def get_ood(dataset, id, test_only=False, image_size=None): 496 | image_size = (32, 32, 3) if image_size is None else image_size 497 | if id == "cifar10": 498 | mean = cifar10_mean 499 | std = cifar10_std 500 | elif id == "cifar100": 501 | mean = cifar100_mean 502 | std = cifar100_std 503 | elif "imagenet" in id or id == "tiny": 504 | mean = normal_mean 505 | std = normal_std 506 | 507 | _, test_transform = get_transform(mean, std, image_size=image_size) 508 | 509 | if dataset == 'cifar10': 510 | test_set = datasets.CIFAR10(DATA_PATH, train=False, download=False, 511 | transform=test_transform) 512 | 513 | elif dataset == 'cifar100': 514 | test_set = datasets.CIFAR100(DATA_PATH, train=False, download=False, 515 | transform=test_transform) 516 | 517 | elif dataset == 'svhn': 518 | test_set = datasets.SVHN(DATA_PATH, split='test', download=True, 519 | transform=test_transform) 520 | 521 | elif dataset == 'lsun': 522 | test_dir = os.path.join(DATA_PATH, 'LSUN_fix') 523 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 524 | 525 | elif dataset == 'imagenet': 526 | test_dir = os.path.join(DATA_PATH, 'Imagenet_fix') 527 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 528 | elif dataset == 'stanford_dogs': 529 | test_dir = os.path.join(DATA_PATH, 'stanford_dogs') 530 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 531 | 532 | elif dataset == 'cub': 533 | test_dir = os.path.join(DATA_PATH, 'cub') 534 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 535 | 536 | elif dataset == 'flowers102': 537 | test_dir = os.path.join(DATA_PATH, 'flowers102') 538 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 539 | 540 | elif dataset == 'food_101': 541 | test_dir = os.path.join(DATA_PATH, 'food-101', 'images') 542 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 543 | 544 | elif dataset == 'caltech_256': 545 | test_dir = os.path.join(DATA_PATH, 'caltech-256') 546 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 547 | 548 | elif dataset == 'dtd': 549 | test_dir = os.path.join(DATA_PATH, 'dtd') 550 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 551 | 552 | elif dataset == 'pets': 553 | test_dir = os.path.join(DATA_PATH, 'pets') 554 | test_set = datasets.ImageFolder(test_dir, transform=test_transform) 555 | 556 | return test_set 557 | 558 | DATASET_GETTERS = {'cifar10': get_cifar, 559 | 'cifar100': get_cifar, 560 | 'imagenet': get_imagenet, 561 | } 562 | -------------------------------------------------------------------------------- /dataset/mydataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | 13 | def find_classes(dir): 14 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 15 | classes.sort() 16 | class_to_idx = {classes[i]: i for i in range(len(classes))} 17 | return classes, class_to_idx 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, class_to_idx): 25 | images = [] 26 | dir = os.path.expanduser(dir) 27 | for target in os.listdir(dir): 28 | d = os.path.join(dir, target) 29 | if not os.path.isdir(d): 30 | continue 31 | 32 | for root, _, fnames in sorted(os.walk(d)): 33 | for fname in fnames: 34 | if is_image_file(fname): 35 | path = os.path.join(root, fname) 36 | item = (path, class_to_idx[target]) 37 | images.append(item) 38 | 39 | return images 40 | 41 | 42 | def default_flist_reader(flist): 43 | """ 44 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 45 | """ 46 | imlist = [] 47 | with open(flist, 'r') as rf: 48 | for line in rf.readlines(): 49 | impath, imlabel = line.strip().split() 50 | imlist.append((impath, int(imlabel))) 51 | 52 | return imlist 53 | 54 | 55 | def default_loader(path): 56 | return Image.open(path).convert('RGB') 57 | 58 | 59 | def make_dataset_nolist(image_list): 60 | with open(image_list) as f: 61 | image_index = [x.split(' ')[0] for x in f.readlines()] 62 | with open(image_list) as f: 63 | label_list = [] 64 | selected_list = [] 65 | for ind, x in enumerate(f.readlines()): 66 | label = x.split(' ')[1].strip() 67 | label_list.append(int(label)) 68 | selected_list.append(ind) 69 | image_index = np.array(image_index) 70 | label_list = np.array(label_list) 71 | image_index = image_index[selected_list] 72 | return image_index, label_list 73 | 74 | 75 | class ImageFolder(data.Dataset): 76 | """A generic data loader where the images are arranged in this way: :: 77 | root/dog/xxx.png 78 | root/dog/xxy.png 79 | root/dog/xxz.png 80 | root/cat/123.png 81 | root/cat/nsdf3.png 82 | root/cat/asd932_.png 83 | Args: 84 | root (string): Root directory path. 85 | transform (callable, optional): A function/transform that takes in an PIL image 86 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 87 | target_transform (callable, optional): A function/transform that takes in the 88 | target and transforms it. 89 | loader (callable, optional): A function to load an image given its path. 90 | Attributes: 91 | classes (list): List of the class names. 92 | class_to_idx (dict): Dict with items (class_name, class_index). 93 | imgs (list): List of (image path, class_index) tuples 94 | """ 95 | 96 | def __init__(self, image_list, transform=None, target_transform=None, return_paths=False, 97 | loader=default_loader,train=False, return_id=False): 98 | imgs, labels = make_dataset_nolist(image_list) 99 | self.imgs = imgs 100 | self.labels= labels 101 | self.transform = transform 102 | self.target_transform = target_transform 103 | self.loader = loader 104 | self.return_paths = return_paths 105 | self.return_id = return_id 106 | self.train = train 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index 112 | Returns: 113 | tuple: (image, target) where target is class_index of the target class. 114 | """ 115 | 116 | path = self.imgs[index] 117 | target = self.labels[index] 118 | img = self.loader(path) 119 | img = self.transform(img) 120 | 121 | if self.target_transform is not None: 122 | target = self.target_transform(target) 123 | if self.return_paths: 124 | return img, target, path 125 | elif self.return_id: 126 | return img, target ,index 127 | else: 128 | return img, target 129 | 130 | def __len__(self): 131 | return len(self.imgs) 132 | 133 | 134 | class ImageFolder_fix(data.Dataset): 135 | """A generic data loader where the images are arranged in this way: :: 136 | root/dog/xxx.png 137 | root/dog/xxy.png 138 | root/dog/xxz.png 139 | root/cat/123.png 140 | root/cat/nsdf3.png 141 | root/cat/asd932_.png 142 | Args: 143 | root (string): Root directory path. 144 | transform (callable, optional): A function/transform that takes in an PIL image 145 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 146 | target_transform (callable, optional): A function/transform that takes in the 147 | target and transforms it. 148 | loader (callable, optional): A function to load an image given its path. 149 | Attributes: 150 | classes (list): List of the class names. 151 | class_to_idx (dict): Dict with items (class_name, class_index). 152 | imgs (list): List of (image path, class_index) tuples 153 | """ 154 | 155 | def __init__(self, image_list, transform=None, target_transform=None, return_paths=False, 156 | loader=default_loader,train=False, return_id=False): 157 | imgs, labels = make_dataset_nolist(image_list) 158 | self.imgs = imgs 159 | self.labels= labels 160 | self.transform = transform 161 | self.target_transform = target_transform 162 | self.loader = loader 163 | self.return_paths = return_paths 164 | self.return_id = return_id 165 | self.train = train 166 | self.set_index() 167 | 168 | def set_index(self, indexes=None): 169 | if indexes is not None: 170 | self.imgs_index = self.imgs[indexes] 171 | self.targets_index = self.labels[indexes] 172 | else: 173 | self.imgs_index = self.imgs 174 | self.targets_index = self.labels 175 | 176 | def init_index(self): 177 | self.imgs_index = self.imgs 178 | self.targets_index = self.labels 179 | 180 | 181 | 182 | def __getitem__(self, index): 183 | """ 184 | Args: 185 | index (int): Index 186 | Returns: 187 | tuple: (image, target) where target is class_index of the target class. 188 | """ 189 | 190 | path = self.imgs_index[index] 191 | target = self.targets_index[index] 192 | img = self.loader(path) 193 | img = self.transform(img) 194 | 195 | if self.target_transform is not None: 196 | target = self.target_transform(target) 197 | if self.return_paths: 198 | return img, target, path 199 | elif self.return_id: 200 | return img, target ,index 201 | else: 202 | return img, target 203 | 204 | def __len__(self): 205 | return len(self.imgs_index) 206 | 207 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, int(32*0.5)) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, int(32*0.5)) 220 | return img 221 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from utils import test, test_ood 3 | 4 | logger = logging.getLogger(__name__) 5 | best_acc = 0 6 | best_acc_val = 0 7 | def eval_model(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader, 8 | ood_loaders, model, ema_model): 9 | if args.amp: 10 | from apex import amp 11 | global best_acc 12 | global best_acc_val 13 | 14 | model.eval() 15 | if args.use_ema: 16 | test_model = ema_model.ema 17 | else: 18 | test_model = model 19 | epoch = 0 20 | if args.local_rank in [-1, 0]: 21 | val_acc = test(args, val_loader, test_model, epoch, val=True) 22 | test_loss, close_valid, test_overall, \ 23 | test_unk, test_roc, test_roc_softm, test_id \ 24 | = test(args, test_loader, test_model, epoch) 25 | for ood in ood_loaders.keys(): 26 | roc_ood = test_ood(args, test_id, ood_loaders[ood], test_model) 27 | logger.info("ROC vs {ood}: {roc}".format(ood=ood, roc=roc_ood)) 28 | 29 | overall_valid = test_overall 30 | unk_valid = test_unk 31 | roc_valid = test_roc 32 | roc_softm_valid = test_roc_softm 33 | logger.info('validation closed acc: {:.3f}'.format(val_acc)) 34 | logger.info('test closed acc: {:.3f}'.format(close_valid)) 35 | logger.info('test overall acc: {:.3f}'.format(overall_valid)) 36 | logger.info('test unk acc: {:.3f}'.format(unk_valid)) 37 | logger.info('test roc: {:.3f}'.format(roc_valid)) 38 | logger.info('test roc soft: {:.3f}'.format(roc_softm_valid)) 39 | if args.local_rank in [-1, 0]: 40 | args.writer.close() 41 | -------------------------------------------------------------------------------- /files.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionLearningGroup/OP_Match/ba1a59cf42ad8c2920cba428991a6cc717901d52/files.zip -------------------------------------------------------------------------------- /images/consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionLearningGroup/OP_Match/ba1a59cf42ad8c2920cba428991a6cc717901d52/images/consistency.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | from utils import set_model_config, \ 6 | set_dataset, set_models, set_parser, \ 7 | set_seed 8 | from eval import eval_model 9 | from trainer import train 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def main(): 15 | args = set_parser() 16 | global best_acc 17 | global best_acc_val 18 | 19 | if args.local_rank == -1: 20 | device = torch.device('cuda', args.gpu_id) 21 | args.world_size = 1 22 | args.n_gpu = torch.cuda.device_count() 23 | else: 24 | torch.cuda.set_device(args.local_rank) 25 | device = torch.device('cuda', args.local_rank) 26 | torch.distributed.init_process_group(backend='nccl') 27 | args.world_size = torch.distributed.get_world_size() 28 | args.n_gpu = 1 29 | args.device = device 30 | logging.basicConfig( 31 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 32 | datefmt="%m/%d/%Y %H:%M:%S", 33 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 34 | logger.warning( 35 | f"Process rank: {args.local_rank}, " 36 | f"device: {args.device}, " 37 | f"n_gpu: {args.n_gpu}, " 38 | f"distributed training: {bool(args.local_rank != -1)}, " 39 | f"16-bits training: {args.amp}",) 40 | logger.info(dict(args._get_kwargs())) 41 | if args.seed is not None: 42 | set_seed(args) 43 | if args.local_rank in [-1, 0]: 44 | os.makedirs(args.out, exist_ok=True) 45 | args.writer = SummaryWriter(args.out) 46 | set_model_config(args) 47 | 48 | if args.local_rank not in [-1, 0]: 49 | torch.distributed.barrier() 50 | 51 | labeled_trainloader, unlabeled_dataset, test_loader, val_loader, ood_loaders \ 52 | = set_dataset(args) 53 | 54 | model, optimizer, scheduler = set_models(args) 55 | logger.info("Total params: {:.2f}M".format( 56 | sum(p.numel() for p in model.parameters()) / 1e6)) 57 | 58 | if args.use_ema: 59 | from models.ema import ModelEMA 60 | ema_model = ModelEMA(args, model, args.ema_decay) 61 | args.start_epoch = 0 62 | if args.resume: 63 | logger.info("==> Resuming from checkpoint..") 64 | assert os.path.isfile( 65 | args.resume), "Error: no checkpoint directory found!" 66 | args.out = os.path.dirname(args.resume) 67 | checkpoint = torch.load(args.resume) 68 | best_acc = checkpoint['best_acc'] 69 | args.start_epoch = checkpoint['epoch'] 70 | model.load_state_dict(checkpoint['state_dict']) 71 | if args.use_ema: 72 | ema_model.ema.load_state_dict(checkpoint['ema_state_dict']) 73 | optimizer.load_state_dict(checkpoint['optimizer']) 74 | scheduler.load_state_dict(checkpoint['scheduler']) 75 | 76 | if args.amp: 77 | from apex import amp 78 | model, optimizer = amp.initialize( 79 | model, optimizer, opt_level=args.opt_level) 80 | 81 | if args.local_rank != -1: 82 | model = torch.nn.parallel.DistributedDataParallel( 83 | model, device_ids=[args.local_rank], 84 | output_device=args.local_rank, find_unused_parameters=True) 85 | 86 | 87 | model.zero_grad() 88 | if not args.eval_only: 89 | logger.info("***** Running training *****") 90 | logger.info(f" Task = {args.dataset}@{args.num_labeled}") 91 | logger.info(f" Num Epochs = {args.epochs}") 92 | logger.info(f" Batch size per GPU = {args.batch_size}") 93 | logger.info(f" Total train batch size = {args.batch_size*args.world_size}") 94 | logger.info(f" Total optimization steps = {args.total_steps}") 95 | train(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader, 96 | ood_loaders, model, optimizer, ema_model, scheduler) 97 | else: 98 | logger.info("***** Running Evaluation *****") 99 | logger.info(f" Task = {args.dataset}@{args.num_labeled}") 100 | eval_model(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader, 101 | ood_loaders, model, ema_model) 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import torch 4 | 5 | 6 | class ModelEMA(object): 7 | def __init__(self, args, model, decay): 8 | self.ema = deepcopy(model) 9 | self.ema.to(args.device) 10 | self.ema.eval() 11 | self.decay = decay 12 | self.ema_has_module = hasattr(self.ema, 'module') 13 | self.param_keys = [k for k, _ in self.ema.named_parameters()] 14 | self.buffer_keys = [k for k, _ in self.ema.named_buffers()] 15 | for p in self.ema.parameters(): 16 | p.requires_grad_(False) 17 | 18 | def update(self, model): 19 | needs_module = hasattr(model, 'module') and not self.ema_has_module 20 | with torch.no_grad(): 21 | msd = model.state_dict() 22 | esd = self.ema.state_dict() 23 | for k in self.param_keys: 24 | if needs_module: 25 | j = 'module.' + k 26 | else: 27 | j = k 28 | model_v = msd[j].detach() 29 | ema_v = esd[k] 30 | esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v) 31 | 32 | for k in self.buffer_keys: 33 | if needs_module: 34 | j = 'module.' + k 35 | else: 36 | j = k 37 | esd[k].copy_(msd[j]) 38 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=dilation, groups=groups, bias=False, dilation=dilation) 10 | 11 | 12 | def conv1x1(in_planes, out_planes, stride=1): 13 | """1x1 convolution""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 21 | base_width=64, dilation=1, norm_layer=None): 22 | super(BasicBlock, self).__init__() 23 | if norm_layer is None: 24 | norm_layer = nn.BatchNorm2d 25 | if groups != 1 or base_width != 64: 26 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 27 | if dilation > 1: 28 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 29 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = norm_layer(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = norm_layer(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | identity = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 59 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 60 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 61 | # This variant is also known as ResNet V1.5 and improves accuracy according to 62 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 63 | 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 67 | base_width=64, dilation=1, norm_layer=None): 68 | super(Bottleneck, self).__init__() 69 | if norm_layer is None: 70 | norm_layer = nn.BatchNorm2d 71 | width = int(planes * (base_width / 64.)) * groups 72 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 73 | self.conv1 = conv1x1(inplanes, width) 74 | self.bn1 = norm_layer(width) 75 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 76 | self.bn2 = norm_layer(width) 77 | self.conv3 = conv1x1(width, planes * self.expansion) 78 | self.bn3 = norm_layer(planes * self.expansion) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | identity = self.downsample(x) 99 | 100 | out += identity 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | def __init__(self, block, layers, num_classes=10, 108 | zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, 109 | norm_layer=None): 110 | super(ResNet, self).__init__() 111 | last_dim = 512 * block.expansion 112 | 113 | if norm_layer is None: 114 | norm_layer = nn.BatchNorm2d 115 | self._norm_layer = norm_layer 116 | 117 | self.inplanes = 64 118 | self.dilation = 1 119 | if replace_stride_with_dilation is None: 120 | # each element in the tuple indicates if we should replace 121 | # the 2x2 stride with a dilated convolution instead 122 | replace_stride_with_dilation = [False, False, False] 123 | if len(replace_stride_with_dilation) != 3: 124 | raise ValueError("replace_stride_with_dilation should be None " 125 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 126 | self.groups = groups 127 | self.base_width = width_per_group 128 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 129 | bias=False) 130 | self.bn1 = norm_layer(self.inplanes) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(block, 64, layers[0]) 134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 135 | dilate=replace_stride_with_dilation[0]) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 137 | dilate=replace_stride_with_dilation[1]) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 139 | dilate=replace_stride_with_dilation[2]) 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | #self.normalize = NormalizeLayer() 142 | self.last_dim = 512 * block.expansion 143 | self.fc = nn.Linear(last_dim, num_classes) 144 | self.fc_open = nn.Linear(last_dim, num_classes * 2, bias=False) 145 | self.simclr_layer = nn.Sequential( 146 | nn.Linear(last_dim, 128), 147 | nn.ReLU(), 148 | nn.Linear(128, 128), 149 | ) 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x, feature=False, feat_only=False): 192 | # See note [TorchScript super()] 193 | 194 | #x = self.normalize(x) 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | x = self.avgpool(x) 204 | x = torch.flatten(x, 1) 205 | if feat_only: 206 | return self.simclr_layer(x) 207 | if feature: 208 | return self.fc(x), self.fc_open(x), self.simclr_layer(x) 209 | else: 210 | return self.fc(x), self.fc_open(x) 211 | 212 | 213 | def _resnet(arch, block, layers, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | return model 216 | 217 | 218 | def resnet18(**kwargs): 219 | r"""ResNet-18 model from 220 | `"Deep Residual Learning for Image Recognition" `_ 221 | """ 222 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) 223 | 224 | 225 | def resnet50(**kwargs): 226 | r"""ResNet-50 model from 227 | `"Deep Residual Learning for Image Recognition" `_ 228 | """ 229 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs) 230 | 231 | 232 | 233 | def _tresnet(arch, block, layers, **kwargs): 234 | model = TResNet(block, layers, **kwargs) 235 | return model 236 | 237 | 238 | def tresnet18(**kwargs): 239 | r"""ResNet-18 model from 240 | `"Deep Residual Learning for Image Recognition" `_ 241 | """ 242 | return _tresnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) 243 | 244 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class nn.BatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class ResNeXtBottleneck(nn.Module): 27 | """ 28 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 29 | """ 30 | 31 | def __init__(self, in_channels, out_channels, stride, 32 | cardinality, base_width, widen_factor): 33 | """ Constructor 34 | Args: 35 | in_channels: input channel dimensionality 36 | out_channels: output channel dimensionality 37 | stride: conv stride. Replaces pooling layer. 38 | cardinality: num of convolution groups. 39 | base_width: base number of channels in each group. 40 | widen_factor: factor to reduce the input dimensionality before convolution. 41 | """ 42 | super().__init__() 43 | width_ratio = out_channels / (widen_factor * 64.) 44 | D = cardinality * int(base_width * width_ratio) 45 | self.conv_reduce = nn.Conv2d( 46 | in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 47 | self.bn_reduce = nn.BatchNorm2d(D, momentum=0.001) 48 | self.conv_conv = nn.Conv2d(D, D, 49 | kernel_size=3, stride=stride, padding=1, 50 | groups=cardinality, bias=False) 51 | self.bn = nn.BatchNorm2d(D, momentum=0.001) 52 | self.act = mish 53 | self.conv_expand = nn.Conv2d( 54 | D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 55 | self.bn_expand = nn.BatchNorm2d(out_channels, momentum=0.001) 56 | 57 | self.shortcut = nn.Sequential() 58 | if in_channels != out_channels: 59 | self.shortcut.add_module('shortcut_conv', 60 | nn.Conv2d(in_channels, out_channels, 61 | kernel_size=1, 62 | stride=stride, 63 | padding=0, 64 | bias=False)) 65 | self.shortcut.add_module( 66 | 'shortcut_bn', nn.BatchNorm2d(out_channels, momentum=0.001)) 67 | 68 | def forward(self, x): 69 | bottleneck = self.conv_reduce.forward(x) 70 | bottleneck = self.act(self.bn_reduce.forward(bottleneck)) 71 | bottleneck = self.conv_conv.forward(bottleneck) 72 | bottleneck = self.act(self.bn.forward(bottleneck)) 73 | bottleneck = self.conv_expand.forward(bottleneck) 74 | bottleneck = self.bn_expand.forward(bottleneck) 75 | residual = self.shortcut.forward(x) 76 | return self.act(residual + bottleneck) 77 | 78 | 79 | class CifarResNeXt(nn.Module): 80 | """ 81 | ResNext optimized for the Cifar dataset, as specified in 82 | https://arxiv.org/pdf/1611.05431.pdf 83 | """ 84 | 85 | def __init__(self, cardinality, depth, num_classes, 86 | base_width, widen_factor=4): 87 | """ Constructor 88 | Args: 89 | cardinality: number of convolution groups. 90 | depth: number of layers. 91 | nlabels: number of classes 92 | base_width: base number of channels in each group. 93 | widen_factor: factor to adjust the channel dimensionality 94 | """ 95 | super().__init__() 96 | self.cardinality = cardinality 97 | self.depth = depth 98 | self.block_depth = (self.depth - 2) // 9 99 | self.base_width = base_width 100 | self.widen_factor = widen_factor 101 | self.nlabels = num_classes 102 | self.output_size = 64 103 | self.stages = [64, 64 * self.widen_factor, 128 * 104 | self.widen_factor, 256 * self.widen_factor] 105 | 106 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 107 | self.bn_1 = nn.BatchNorm2d(64, momentum=0.001) 108 | self.act = mish 109 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 110 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 111 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 112 | self.classifier = nn.Linear(self.stages[3], num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, 117 | mode='fan_out', 118 | nonlinearity='leaky_relu') 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1.0) 121 | nn.init.constant_(m.bias, 0.0) 122 | elif isinstance(m, nn.Linear): 123 | nn.init.xavier_normal_(m.weight) 124 | nn.init.constant_(m.bias, 0.0) 125 | 126 | def block(self, name, in_channels, out_channels, pool_stride=2): 127 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 128 | Args: 129 | name: string name of the current block. 130 | in_channels: number of input channels 131 | out_channels: number of output channels 132 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 133 | Returns: a Module consisting of n sequential bottlenecks. 134 | """ 135 | block = nn.Sequential() 136 | for bottleneck in range(self.block_depth): 137 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 138 | if bottleneck == 0: 139 | block.add_module(name_, ResNeXtBottleneck(in_channels, 140 | out_channels, 141 | pool_stride, 142 | self.cardinality, 143 | self.base_width, 144 | self.widen_factor)) 145 | else: 146 | block.add_module(name_, 147 | ResNeXtBottleneck(out_channels, 148 | out_channels, 149 | 1, 150 | self.cardinality, 151 | self.base_width, 152 | self.widen_factor)) 153 | return block 154 | 155 | def forward(self, x): 156 | x = self.conv_1_3x3.forward(x) 157 | x = self.act(self.bn_1.forward(x)) 158 | x = self.stage_1.forward(x) 159 | x = self.stage_2.forward(x) 160 | x = self.stage_3.forward(x) 161 | x = F.adaptive_avg_pool2d(x, 1) 162 | x = x.view(-1, self.stages[3]) 163 | return self.classifier(x) 164 | 165 | 166 | def build_resnext(cardinality, depth, width, num_classes): 167 | logger.info(f"Model: ResNeXt {depth+1}x{width}") 168 | return CifarResNeXt(cardinality=cardinality, 169 | depth=depth, 170 | base_width=width, 171 | num_classes=num_classes) 172 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def mish(x): 11 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function (https://arxiv.org/abs/1908.08681)""" 12 | return x * torch.tanh(F.softplus(x)) 13 | 14 | 15 | class PSBatchNorm2d(nn.BatchNorm2d): 16 | """How Does BN Increase Collapsed Neural Network Filters? (https://arxiv.org/abs/2001.11216)""" 17 | 18 | def __init__(self, num_features, alpha=0.1, eps=1e-05, momentum=0.001, affine=True, track_running_stats=True): 19 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 20 | self.alpha = alpha 21 | 22 | def forward(self, x): 23 | return super().forward(x) + self.alpha 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False): 28 | super(BasicBlock, self).__init__() 29 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 30 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 31 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 32 | padding=1, bias=False) 33 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 34 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 36 | padding=1, bias=False) 37 | self.drop_rate = drop_rate 38 | self.equalInOut = (in_planes == out_planes) 39 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 40 | padding=0, bias=False) or None 41 | self.activate_before_residual = activate_before_residual 42 | 43 | def forward(self, x): 44 | if not self.equalInOut and self.activate_before_residual == True: 45 | x = self.relu1(self.bn1(x)) 46 | else: 47 | out = self.relu1(self.bn1(x)) 48 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 49 | if self.drop_rate > 0: 50 | out = F.dropout(out, p=self.drop_rate, training=self.training) 51 | out = self.conv2(out) 52 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 53 | 54 | 55 | class NetworkBlock(nn.Module): 56 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 57 | super(NetworkBlock, self).__init__() 58 | self.layer = self._make_layer( 59 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 60 | 61 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 62 | layers = [] 63 | for i in range(int(nb_layers)): 64 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 65 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | return self.layer(x) 70 | 71 | 72 | class WideResNet(nn.Module): 73 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0): 74 | super(WideResNet, self).__init__() 75 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 76 | assert((depth - 4) % 6 == 0) 77 | n = (depth - 4) / 6 78 | block = BasicBlock 79 | # 1st conv before any network block 80 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = NetworkBlock( 84 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True) 85 | # 2nd block 86 | self.block2 = NetworkBlock( 87 | n, channels[1], channels[2], block, 2, drop_rate) 88 | # 3rd block 89 | self.block3 = NetworkBlock( 90 | n, channels[2], channels[3], block, 2, drop_rate) 91 | # global average pooling and classifier 92 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001) 93 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 94 | self.fc = nn.Linear(channels[3], num_classes) 95 | self.channels = channels[3] 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, 100 | mode='fan_out', 101 | nonlinearity='leaky_relu') 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.constant_(m.weight, 1.0) 104 | nn.init.constant_(m.bias, 0.0) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.xavier_normal_(m.weight) 107 | nn.init.constant_(m.bias, 0.0) 108 | 109 | def forward(self, x): 110 | out = self.conv1(x) 111 | out = self.block1(out) 112 | out = self.block2(out) 113 | out = self.block3(out) 114 | out = self.relu(self.bn1(out)) 115 | out = F.adaptive_avg_pool2d(out, 1) 116 | out = out.view(-1, self.channels) 117 | return self.fc(out) 118 | 119 | 120 | class WideResNet_Open(nn.Module): 121 | def __init__(self, num_classes, depth=28, widen_factor=2, drop_rate=0.0): 122 | super(WideResNet_Open, self).__init__() 123 | channels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 124 | assert((depth - 4) % 6 == 0) 125 | n = (depth - 4) / 6 126 | block = BasicBlock 127 | # 1st conv before any network block 128 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, 129 | padding=1, bias=False) 130 | # 1st block 131 | self.block1 = NetworkBlock( 132 | n, channels[0], channels[1], block, 1, drop_rate, activate_before_residual=True) 133 | # 2nd block 134 | self.block2 = NetworkBlock( 135 | n, channels[1], channels[2], block, 2, drop_rate) 136 | # 3rd block 137 | self.block3 = NetworkBlock( 138 | n, channels[2], channels[3], block, 2, drop_rate) 139 | # global average pooling and classifier 140 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001) 141 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 142 | self.simclr_layer = nn.Sequential( 143 | nn.Linear(channels[3], 128), 144 | nn.ReLU(), 145 | nn.Linear(128, 128), 146 | ) 147 | self.fc = nn.Linear(channels[3], num_classes) 148 | out_open = 2 * num_classes 149 | self.fc_open = nn.Linear(channels[3], out_open, bias=False) 150 | self.channels = channels[3] 151 | 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, 155 | mode='fan_out', 156 | nonlinearity='leaky_relu') 157 | elif isinstance(m, nn.BatchNorm2d): 158 | nn.init.constant_(m.weight, 1.0) 159 | nn.init.constant_(m.bias, 0.0) 160 | elif isinstance(m, nn.Linear): 161 | nn.init.xavier_normal_(m.weight) 162 | if m.bias is not None: 163 | nn.init.constant_(m.bias, 0.0) 164 | 165 | def forward(self, x, feature=False, feat_only=False): 166 | #self.weight_norm() 167 | out = self.conv1(x) 168 | out = self.block1(out) 169 | out = self.block2(out) 170 | out = self.block3(out) 171 | out = self.relu(self.bn1(out)) 172 | out = F.adaptive_avg_pool2d(out, 1) 173 | out = out.view(-1, self.channels) 174 | 175 | 176 | if feat_only: 177 | return self.simclr_layer(out) 178 | out_open = self.fc_open(out) 179 | if feature: 180 | return self.fc(out), out_open, out 181 | else: 182 | return self.fc(out), out_open 183 | 184 | def weight_norm(self): 185 | w = self.fc_open.weight.data 186 | norm = w.norm(p=2, dim=1, keepdim=True) 187 | self.fc_open.weight.data = w.div(norm.expand_as(w)) 188 | 189 | 190 | # 191 | # 192 | class ResBasicBlock(nn.Module): 193 | expansion = 1 194 | 195 | def __init__(self, in_planes, planes, stride=1): 196 | super(ResBasicBlock, self).__init__() 197 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 198 | self.bn1 = nn.BatchNorm2d(planes) 199 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 200 | self.bn2 = nn.BatchNorm2d(planes) 201 | 202 | self.shortcut = nn.Sequential() 203 | if stride != 1 or in_planes != self.expansion*planes: 204 | self.shortcut = nn.Sequential( 205 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 206 | nn.BatchNorm2d(self.expansion*planes) 207 | ) 208 | 209 | def forward(self, x): 210 | out = F.relu(self.bn1(self.conv1(x))) 211 | out = self.bn2(self.conv2(out)) 212 | out += self.shortcut(x) 213 | out = F.relu(out) 214 | return out 215 | 216 | # 217 | class ResNet_Open(nn.Module): 218 | def __init__(self, block, num_blocks, low_dim=128, num_classes=10): 219 | super(ResNet_Open, self).__init__() 220 | self.in_planes = 64 221 | 222 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 223 | self.bn1 = nn.BatchNorm2d(64) 224 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 225 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 226 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 227 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 228 | self.linear = nn.Linear(512*block.expansion, low_dim) 229 | self.simclr_layer = nn.Sequential( 230 | nn.Linear(512*block.expansion, 128), 231 | nn.ReLU(), 232 | nn.Linear(128, 128), 233 | ) 234 | self.fc1 = nn.Linear(512*block.expansion, num_classes) 235 | self.fc_open = nn.Linear(512*block.expansion, num_classes*2, bias=False) 236 | 237 | 238 | #self.l2norm = Normalize(2) 239 | 240 | def _make_layer(self, block, planes, num_blocks, stride): 241 | strides = [stride] + [1]*(num_blocks-1) 242 | layers = [] 243 | for stride in strides: 244 | layers.append(block(self.in_planes, planes, stride)) 245 | self.in_planes = planes * block.expansion 246 | return nn.Sequential(*layers) 247 | 248 | def forward(self, x, feature=False): 249 | out = F.relu(self.bn1(self.conv1(x))) 250 | out = self.layer1(out) 251 | out = self.layer2(out) 252 | out = self.layer3(out) 253 | out = self.layer4(out) 254 | out = F.avg_pool2d(out, 4) 255 | out = out.view(out.size(0), -1) 256 | out_open = self.fc_open(out) 257 | if feature: 258 | return self.fc1(out), out_open, self.simclr_layer(out) 259 | else: 260 | return self.fc1(out), out_open 261 | 262 | 263 | 264 | def ResNet18(low_dim=128, num_classes=10): 265 | return ResNet_Open(ResBasicBlock, [2,2,2,2], low_dim, num_classes) 266 | 267 | 268 | def build_wideresnet(depth, widen_factor, dropout, num_classes, open=False): 269 | logger.info(f"Model: WideResNet {depth}x{widen_factor}") 270 | build_func = WideResNet_Open if open else WideResNet 271 | return build_func(depth=depth, 272 | widen_factor=widen_factor, 273 | drop_rate=dropout, 274 | num_classes=num_classes) 275 | -------------------------------------------------------------------------------- /run_cifar10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$3 python main.py --dataset cifar10 --num-labeled $1 --out $2 --arch wideresnet --lambda_oem 0.1 --lambda_socr 0.5 \ 2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /run_cifar100.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset cifar100 --num-labeled $1 --out $2 --num-super $3 --arch wideresnet --lambda_oem 0.1 --lambda_socr 1.0 \ 2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /run_eval_cifar10.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset cifar10 --resume $1 --arch wideresnet --eval_only 1 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /run_imagenet.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset imagenet --out $1 --arch resnet_imagenet --lambda_oem 0.1 --lambda_socr 0.5 \ 2 | --batch-size 64 --lr 0.03 --expand-labels --seed 0 --opt_level O2 --amp --mu 2 --epochs 100 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import copy 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 8 | from torch.utils.data.distributed import DistributedSampler 9 | from dataset import TransformOpenMatch, cifar10_mean, cifar10_std, \ 10 | cifar100_std, cifar100_mean, normal_mean, \ 11 | normal_std, TransformFixMatch_Imagenet_Weak 12 | from tqdm import tqdm 13 | from utils import AverageMeter, ova_loss,\ 14 | save_checkpoint, ova_ent, \ 15 | test, test_ood, exclude_dataset 16 | 17 | logger = logging.getLogger(__name__) 18 | best_acc = 0 19 | best_acc_val = 0 20 | 21 | def train(args, labeled_trainloader, unlabeled_dataset, test_loader, val_loader, 22 | ood_loaders, model, optimizer, ema_model, scheduler): 23 | if args.amp: 24 | from apex import amp 25 | 26 | global best_acc 27 | global best_acc_val 28 | 29 | test_accs = [] 30 | batch_time = AverageMeter() 31 | data_time = AverageMeter() 32 | losses = AverageMeter() 33 | losses_x = AverageMeter() 34 | losses_o = AverageMeter() 35 | losses_oem = AverageMeter() 36 | losses_socr = AverageMeter() 37 | losses_fix = AverageMeter() 38 | mask_probs = AverageMeter() 39 | end = time.time() 40 | 41 | 42 | if args.world_size > 1: 43 | labeled_epoch = 0 44 | unlabeled_epoch = 0 45 | labeled_iter = iter(labeled_trainloader) 46 | default_out = "Epoch: {epoch}/{epochs:4}. " \ 47 | "LR: {lr:.6f}. " \ 48 | "Lab: {loss_x:.4f}. " \ 49 | "Open: {loss_o:.4f}" 50 | output_args = vars(args) 51 | default_out += " OEM {loss_oem:.4f}" 52 | default_out += " SOCR {loss_socr:.4f}" 53 | default_out += " Fix {loss_fix:.4f}" 54 | 55 | model.train() 56 | unlabeled_dataset_all = copy.deepcopy(unlabeled_dataset) 57 | if args.dataset == 'cifar10': 58 | mean = cifar10_mean 59 | std = cifar10_std 60 | func_trans = TransformOpenMatch 61 | elif args.dataset == 'cifar100': 62 | mean = cifar100_mean 63 | std = cifar100_std 64 | func_trans = TransformOpenMatch 65 | elif 'imagenet' in args.dataset: 66 | mean = normal_mean 67 | std = normal_std 68 | func_trans = TransformFixMatch_Imagenet_Weak 69 | 70 | 71 | unlabeled_dataset_all.transform = func_trans(mean=mean, std=std) 72 | labeled_dataset = copy.deepcopy(labeled_trainloader.dataset) 73 | labeled_dataset.transform = func_trans(mean=mean, std=std) 74 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 75 | labeled_trainloader = DataLoader( 76 | labeled_dataset, 77 | sampler=train_sampler(labeled_dataset), 78 | batch_size=args.batch_size, 79 | num_workers=args.num_workers, 80 | drop_last=True) 81 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 82 | 83 | 84 | for epoch in range(args.start_epoch, args.epochs): 85 | output_args["epoch"] = epoch 86 | if not args.no_progress: 87 | p_bar = tqdm(range(args.eval_step), 88 | disable=args.local_rank not in [-1, 0]) 89 | 90 | if epoch >= args.start_fix: 91 | ## pick pseudo-inliers 92 | exclude_dataset(args, unlabeled_dataset, ema_model.ema) 93 | 94 | 95 | unlabeled_trainloader = DataLoader(unlabeled_dataset, 96 | sampler = train_sampler(unlabeled_dataset), 97 | batch_size = args.batch_size * args.mu, 98 | num_workers = args.num_workers, 99 | drop_last = True) 100 | unlabeled_trainloader_all = DataLoader(unlabeled_dataset_all, 101 | sampler=train_sampler(unlabeled_dataset_all), 102 | batch_size=args.batch_size * args.mu, 103 | num_workers=args.num_workers, 104 | drop_last=True) 105 | 106 | unlabeled_iter = iter(unlabeled_trainloader) 107 | unlabeled_all_iter = iter(unlabeled_trainloader_all) 108 | 109 | for batch_idx in range(args.eval_step): 110 | ## Data loading 111 | 112 | try: 113 | (_, inputs_x_s, inputs_x), targets_x = labeled_iter.next() 114 | except: 115 | if args.world_size > 1: 116 | labeled_epoch += 1 117 | labeled_trainloader.sampler.set_epoch(labeled_epoch) 118 | labeled_iter = iter(labeled_trainloader) 119 | (_, inputs_x_s, inputs_x), targets_x = labeled_iter.next() 120 | try: 121 | (inputs_u_w, inputs_u_s, _), _ = unlabeled_iter.next() 122 | except: 123 | if args.world_size > 1: 124 | unlabeled_epoch += 1 125 | unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) 126 | unlabeled_iter = iter(unlabeled_trainloader) 127 | (inputs_u_w, inputs_u_s, _), _ = unlabeled_iter.next() 128 | try: 129 | (inputs_all_w, inputs_all_s, _), _ = unlabeled_all_iter.next() 130 | except: 131 | unlabeled_all_iter = iter(unlabeled_trainloader_all) 132 | (inputs_all_w, inputs_all_s, _), _ = unlabeled_all_iter.next() 133 | data_time.update(time.time() - end) 134 | 135 | b_size = inputs_x.shape[0] 136 | 137 | inputs_all = torch.cat([inputs_all_w, inputs_all_s], 0) 138 | inputs = torch.cat([inputs_x, inputs_x_s, 139 | inputs_all], 0).to(args.device) 140 | targets_x = targets_x.to(args.device) 141 | ## Feed data 142 | logits, logits_open = model(inputs) 143 | logits_open_u1, logits_open_u2 = logits_open[2*b_size:].chunk(2) 144 | 145 | ## Loss for labeled samples 146 | Lx = F.cross_entropy(logits[:2*b_size], 147 | targets_x.repeat(2), reduction='mean') 148 | Lo = ova_loss(logits_open[:2*b_size], targets_x.repeat(2)) 149 | 150 | ## Open-set entropy minimization 151 | L_oem = ova_ent(logits_open_u1) / 2. 152 | L_oem += ova_ent(logits_open_u2) / 2. 153 | 154 | ## Soft consistenty regularization 155 | logits_open_u1 = logits_open_u1.view(logits_open_u1.size(0), 2, -1) 156 | logits_open_u2 = logits_open_u2.view(logits_open_u2.size(0), 2, -1) 157 | logits_open_u1 = F.softmax(logits_open_u1, 1) 158 | logits_open_u2 = F.softmax(logits_open_u2, 1) 159 | L_socr = torch.mean(torch.sum(torch.sum(torch.abs( 160 | logits_open_u1 - logits_open_u2)**2, 1), 1)) 161 | 162 | if epoch >= args.start_fix: 163 | inputs_ws = torch.cat([inputs_u_w, inputs_u_s], 0).to(args.device) 164 | logits, logits_open_fix = model(inputs_ws) 165 | logits_u_w, logits_u_s = logits.chunk(2) 166 | pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1) 167 | max_probs, targets_u = torch.max(pseudo_label, dim=-1) 168 | mask = max_probs.ge(args.threshold).float() 169 | L_fix = (F.cross_entropy(logits_u_s, 170 | targets_u, 171 | reduction='none') * mask).mean() 172 | mask_probs.update(mask.mean().item()) 173 | 174 | else: 175 | L_fix = torch.zeros(1).to(args.device).mean() 176 | loss = Lx + Lo + args.lambda_oem * L_oem \ 177 | + args.lambda_socr * L_socr + L_fix 178 | if args.amp: 179 | with amp.scale_loss(loss, optimizer) as scaled_loss: 180 | scaled_loss.backward() 181 | else: 182 | loss.backward() 183 | 184 | losses.update(loss.item()) 185 | losses_x.update(Lx.item()) 186 | losses_o.update(Lo.item()) 187 | losses_oem.update(L_oem.item()) 188 | losses_socr.update(L_socr.item()) 189 | losses_fix.update(L_fix.item()) 190 | 191 | output_args["batch"] = batch_idx 192 | output_args["loss_x"] = losses_x.avg 193 | output_args["loss_o"] = losses_o.avg 194 | output_args["loss_oem"] = losses_oem.avg 195 | output_args["loss_socr"] = losses_socr.avg 196 | output_args["loss_fix"] = losses_fix.avg 197 | output_args["lr"] = [group["lr"] for group in optimizer.param_groups][0] 198 | 199 | 200 | optimizer.step() 201 | if args.opt != 'adam': 202 | scheduler.step() 203 | if args.use_ema: 204 | ema_model.update(model) 205 | model.zero_grad() 206 | batch_time.update(time.time() - end) 207 | end = time.time() 208 | 209 | if not args.no_progress: 210 | p_bar.set_description(default_out.format(**output_args)) 211 | p_bar.update() 212 | 213 | if not args.no_progress: 214 | p_bar.close() 215 | 216 | if args.use_ema: 217 | test_model = ema_model.ema 218 | else: 219 | test_model = model 220 | 221 | if args.local_rank in [-1, 0]: 222 | 223 | val_acc = test(args, val_loader, test_model, epoch, val=True) 224 | test_loss, test_acc_close, test_overall, \ 225 | test_unk, test_roc, test_roc_softm, test_id \ 226 | = test(args, test_loader, test_model, epoch) 227 | 228 | for ood in ood_loaders.keys(): 229 | roc_ood = test_ood(args, test_id, ood_loaders[ood], test_model) 230 | logger.info("ROC vs {ood}: {roc}".format(ood=ood, roc=roc_ood)) 231 | 232 | args.writer.add_scalar('train/1.train_loss', losses.avg, epoch) 233 | args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch) 234 | args.writer.add_scalar('train/3.train_loss_o', losses_o.avg, epoch) 235 | args.writer.add_scalar('train/4.train_loss_oem', losses_oem.avg, epoch) 236 | args.writer.add_scalar('train/5.train_loss_socr', losses_socr.avg, epoch) 237 | args.writer.add_scalar('train/5.train_loss_fix', losses_fix.avg, epoch) 238 | args.writer.add_scalar('train/6.mask', mask_probs.avg, epoch) 239 | args.writer.add_scalar('test/1.test_acc', test_acc_close, epoch) 240 | args.writer.add_scalar('test/2.test_loss', test_loss, epoch) 241 | 242 | is_best = val_acc > best_acc_val 243 | best_acc_val = max(val_acc, best_acc_val) 244 | if is_best: 245 | overall_valid = test_overall 246 | close_valid = test_acc_close 247 | unk_valid = test_unk 248 | roc_valid = test_roc 249 | roc_softm_valid = test_roc_softm 250 | model_to_save = model.module if hasattr(model, "module") else model 251 | if args.use_ema: 252 | ema_to_save = ema_model.ema.module if hasattr( 253 | ema_model.ema, "module") else ema_model.ema 254 | 255 | save_checkpoint({ 256 | 'epoch': epoch + 1, 257 | 'state_dict': model_to_save.state_dict(), 258 | 'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None, 259 | 'acc close': test_acc_close, 260 | 'acc overall': test_overall, 261 | 'unk': test_unk, 262 | 'best_acc': best_acc, 263 | 'optimizer': optimizer.state_dict(), 264 | 'scheduler': scheduler.state_dict(), 265 | }, is_best, args.out) 266 | test_accs.append(test_acc_close) 267 | logger.info('Best val closed acc: {:.3f}'.format(best_acc_val)) 268 | logger.info('Valid closed acc: {:.3f}'.format(close_valid)) 269 | logger.info('Valid overall acc: {:.3f}'.format(overall_valid)) 270 | logger.info('Valid unk acc: {:.3f}'.format(unk_valid)) 271 | logger.info('Valid roc: {:.3f}'.format(roc_valid)) 272 | logger.info('Valid roc soft: {:.3f}'.format(roc_softm_valid)) 273 | logger.info('Mean top-1 acc: {:.3f}\n'.format( 274 | np.mean(test_accs[-20:]))) 275 | if args.local_rank in [-1, 0]: 276 | args.writer.close() 277 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | from .default import * 3 | from .parser import * -------------------------------------------------------------------------------- /utils/default.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import math 5 | import random 6 | import shutil 7 | import numpy as np 8 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 9 | from torch.utils.data.distributed import DistributedSampler 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import LambdaLR 12 | 13 | from dataset.cifar import DATASET_GETTERS, get_ood 14 | 15 | __all__ = ['create_model', 'set_model_config', 16 | 'set_dataset', 'set_models', 17 | 'save_checkpoint', 'set_seed'] 18 | 19 | 20 | def create_model(args): 21 | if 'wideresnet' in args.arch: 22 | import models.wideresnet as models 23 | model = models.build_wideresnet(depth=args.model_depth, 24 | widen_factor=args.model_width, 25 | dropout=0, 26 | num_classes=args.num_classes, 27 | open=True) 28 | elif args.arch == 'resnext': 29 | import models.resnext as models 30 | model = models.build_resnext(cardinality=args.model_cardinality, 31 | depth=args.model_depth, 32 | width=args.model_width, 33 | num_classes=args.num_classes) 34 | elif args.arch == 'resnet_imagenet': 35 | import models.resnet_imagenet as models 36 | model = models.resnet18(num_classes=args.num_classes) 37 | 38 | return model 39 | 40 | 41 | 42 | def set_model_config(args): 43 | 44 | if args.dataset == 'cifar10': 45 | if args.arch == 'wideresnet': 46 | args.model_depth = 28 47 | args.model_width = 2 48 | elif args.arch == 'resnext': 49 | args.model_cardinality = 4 50 | args.model_depth = 28 51 | args.model_width = 4 52 | 53 | elif args.dataset == 'cifar100': 54 | args.num_classes = 55 55 | if args.arch == 'wideresnet': 56 | args.model_depth = 28 57 | args.model_width = 2 58 | elif args.arch == 'wideresnet_10': 59 | args.model_depth = 28 60 | args.model_width = 8 61 | elif args.arch == 'resnext': 62 | args.model_cardinality = 8 63 | args.model_depth = 29 64 | args.model_width = 64 65 | 66 | elif args.dataset == "imagenet": 67 | args.num_classes = 20 68 | 69 | args.image_size = (32, 32, 3) 70 | if args.dataset == 'cifar10': 71 | args.ood_data = ["svhn", 'cifar100', 'lsun', 'imagenet'] 72 | 73 | elif args.dataset == 'cifar100': 74 | args.ood_data = ['cifar10', "svhn", 'lsun', 'imagenet'] 75 | 76 | elif 'imagenet' in args.dataset: 77 | args.ood_data = ['lsun', 'dtd', 'cub', 'flowers102', 78 | 'caltech_256', 'stanford_dogs'] 79 | args.image_size = (224, 224, 3) 80 | 81 | def set_dataset(args): 82 | labeled_dataset, unlabeled_dataset, test_dataset, val_dataset = \ 83 | DATASET_GETTERS[args.dataset](args) 84 | 85 | ood_loaders = {} 86 | for ood in args.ood_data: 87 | ood_dataset = get_ood(ood, args.dataset, image_size=args.image_size) 88 | ood_loaders[ood] = DataLoader(ood_dataset, 89 | batch_size=args.batch_size, 90 | num_workers=args.num_workers) 91 | 92 | if args.local_rank == 0: 93 | torch.distributed.barrier() 94 | 95 | train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler 96 | 97 | labeled_trainloader = DataLoader( 98 | labeled_dataset, 99 | sampler=train_sampler(labeled_dataset), 100 | batch_size=args.batch_size, 101 | num_workers=args.num_workers, 102 | drop_last=True) 103 | 104 | test_loader = DataLoader( 105 | test_dataset, 106 | sampler=SequentialSampler(test_dataset), 107 | batch_size=args.batch_size, 108 | num_workers=args.num_workers) 109 | val_loader = DataLoader( 110 | val_dataset, 111 | sampler=SequentialSampler(val_dataset), 112 | batch_size=args.batch_size, 113 | num_workers=args.num_workers) 114 | if args.local_rank not in [-1, 0]: 115 | torch.distributed.barrier() 116 | 117 | return labeled_trainloader, unlabeled_dataset, \ 118 | test_loader, val_loader, ood_loaders 119 | 120 | 121 | def get_cosine_schedule_with_warmup(optimizer, 122 | num_warmup_steps, 123 | num_training_steps, 124 | num_cycles=7./16., 125 | last_epoch=-1): 126 | def _lr_lambda(current_step): 127 | if current_step < num_warmup_steps: 128 | return float(current_step) / float(max(1, num_warmup_steps)) 129 | no_progress = float(current_step - num_warmup_steps) / \ 130 | float(max(1, num_training_steps - num_warmup_steps)) 131 | return max(0., math.cos(math.pi * num_cycles * no_progress)) 132 | 133 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 134 | 135 | 136 | def set_models(args): 137 | model = create_model(args) 138 | if args.local_rank == 0: 139 | torch.distributed.barrier() 140 | model.to(args.device) 141 | 142 | no_decay = ['bias', 'bn'] 143 | grouped_parameters = [ 144 | {'params': [p for n, p in model.named_parameters() if not any( 145 | nd in n for nd in no_decay)], 'weight_decay': args.wdecay}, 146 | {'params': [p for n, p in model.named_parameters() if any( 147 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 148 | ] 149 | if args.opt == 'sgd': 150 | optimizer = optim.SGD(grouped_parameters, lr=args.lr, 151 | momentum=0.9, nesterov=args.nesterov) 152 | elif args.opt == 'adam': 153 | optimizer = optim.Adam(grouped_parameters, lr=2e-3) 154 | 155 | # args.epochs = math.ceil(args.total_steps / args.eval_step) 156 | scheduler = get_cosine_schedule_with_warmup( 157 | optimizer, args.warmup, args.total_steps) 158 | 159 | return model, optimizer, scheduler 160 | 161 | 162 | def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'): 163 | filepath = os.path.join(checkpoint, filename) 164 | torch.save(state, filepath) 165 | if is_best: 166 | shutil.copyfile(filepath, os.path.join(checkpoint, 167 | 'model_best.pth.tar')) 168 | 169 | 170 | def set_seed(args): 171 | random.seed(args.seed) 172 | np.random.seed(args.seed) 173 | torch.manual_seed(args.seed) 174 | if args.n_gpu > 0: 175 | torch.cuda.manual_seed_all(args.seed) 176 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | ''' 4 | import logging 5 | import time 6 | from tqdm import tqdm 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 12 | from sklearn.metrics import roc_auc_score 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | __all__ = ['get_mean_and_std', 'accuracy', 'AverageMeter', 17 | 'accuracy_open', 'ova_loss', 'compute_roc', 18 | 'roc_id_ood', 'ova_ent', 'exclude_dataset', 19 | 'test_ood', 'test'] 20 | 21 | 22 | def get_mean_and_std(dataset): 23 | '''Compute the mean and std value of dataset.''' 24 | dataloader = torch.utils.data.DataLoader( 25 | dataset, batch_size=1, shuffle=False, num_workers=4) 26 | 27 | mean = torch.zeros(3) 28 | std = torch.zeros(3) 29 | logger.info('==> Computing mean and std..') 30 | for inputs, targets in dataloader: 31 | for i in range(3): 32 | mean[i] += inputs[:, i, :, :].mean() 33 | std[i] += inputs[:, i, :, :].std() 34 | mean.div_(len(dataset)) 35 | std.div_(len(dataset)) 36 | return mean, std 37 | 38 | 39 | def accuracy(output, target, topk=(1,)): 40 | """Computes the precision@k for the specified values of k""" 41 | maxk = max(topk) 42 | batch_size = target.size(0) 43 | 44 | _, pred = output.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(target.reshape(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | 50 | for k in topk: 51 | correct_k = correct[:k].reshape(-1).float().sum(0) 52 | res.append(correct_k.mul_(100.0 / batch_size)) 53 | return res 54 | 55 | 56 | def accuracy_open(pred, target, topk=(1,), num_classes=5): 57 | """Computes the precision@k for the specified values of k, 58 | num_classes are the number of known classes. 59 | This function returns overall accuracy, 60 | accuracy to reject unknown samples, 61 | the size of unknown samples in this batch.""" 62 | maxk = max(topk) 63 | batch_size = target.size(0) 64 | pred = pred.view(-1, 1) 65 | pred = pred.t() 66 | ind = (target == num_classes) 67 | unknown_size = len(ind) 68 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 69 | if ind.sum() > 0: 70 | unk_corr = pred.eq(target).view(-1)[ind] 71 | acc = torch.sum(unk_corr).item() / unk_corr.size(0) 72 | else: 73 | acc = 0 74 | 75 | res = [] 76 | for k in topk: 77 | correct_k = correct[:k].view(-1).float().sum(0) 78 | res.append(correct_k.mul_(100.0 / batch_size)) 79 | return res[0], acc, unknown_size 80 | 81 | 82 | def compute_roc(unk_all, label_all, num_known): 83 | Y_test = np.zeros(unk_all.shape[0]) 84 | unk_pos = np.where(label_all >= num_known)[0] 85 | Y_test[unk_pos] = 1 86 | return roc_auc_score(Y_test, unk_all) 87 | 88 | 89 | def roc_id_ood(score_id, score_ood): 90 | id_all = np.r_[score_id, score_ood] 91 | Y_test = np.zeros(score_id.shape[0]+score_ood.shape[0]) 92 | Y_test[score_id.shape[0]:] = 1 93 | return roc_auc_score(Y_test, id_all) 94 | 95 | 96 | def ova_loss(logits_open, label): 97 | logits_open = logits_open.view(logits_open.size(0), 2, -1) 98 | logits_open = F.softmax(logits_open, 1) 99 | label_s_sp = torch.zeros((logits_open.size(0), 100 | logits_open.size(2))).long().to(label.device) 101 | label_range = torch.range(0, logits_open.size(0) - 1).long() 102 | label_s_sp[label_range, label] = 1 103 | label_sp_neg = 1 - label_s_sp 104 | open_loss = torch.mean(torch.sum(-torch.log(logits_open[:, 1, :] 105 | + 1e-8) * label_s_sp, 1)) 106 | open_loss_neg = torch.mean(torch.max(-torch.log(logits_open[:, 0, :] 107 | + 1e-8) * label_sp_neg, 1)[0]) 108 | Lo = open_loss_neg + open_loss 109 | return Lo 110 | 111 | 112 | def ova_ent(logits_open): 113 | logits_open = logits_open.view(logits_open.size(0), 2, -1) 114 | logits_open = F.softmax(logits_open, 1) 115 | Le = torch.mean(torch.mean(torch.sum(-logits_open * 116 | torch.log(logits_open + 1e-8), 1), 1)) 117 | return Le 118 | 119 | 120 | class AverageMeter(object): 121 | """Computes and stores the average and current value 122 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 123 | """ 124 | 125 | def __init__(self): 126 | self.reset() 127 | 128 | def reset(self): 129 | self.val = 0 130 | self.avg = 0 131 | self.sum = 0 132 | self.count = 0 133 | 134 | def update(self, val, n=1): 135 | self.val = val 136 | self.sum += val * n 137 | self.count += n 138 | self.avg = self.sum / self.count 139 | 140 | 141 | 142 | def exclude_dataset(args, dataset, model, exclude_known=False): 143 | data_time = AverageMeter() 144 | end = time.time() 145 | dataset.init_index() 146 | test_loader = DataLoader( 147 | dataset, 148 | batch_size=args.batch_size, 149 | num_workers=args.num_workers, 150 | drop_last=False, 151 | shuffle=False) 152 | if not args.no_progress: 153 | test_loader = tqdm(test_loader, 154 | disable=args.local_rank not in [-1, 0]) 155 | model.eval() 156 | with torch.no_grad(): 157 | for batch_idx, ((_, _, inputs), targets) in enumerate(test_loader): 158 | data_time.update(time.time() - end) 159 | 160 | inputs = inputs.to(args.device) 161 | outputs, outputs_open = model(inputs) 162 | outputs = F.softmax(outputs, 1) 163 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1) 164 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda() 165 | pred_close = outputs.data.max(1)[1] 166 | unk_score = out_open[tmp_range, 0, pred_close] 167 | known_ind = unk_score < 0.5 168 | if batch_idx == 0: 169 | known_all = known_ind 170 | else: 171 | known_all = torch.cat([known_all, known_ind], 0) 172 | if not args.no_progress: 173 | test_loader.close() 174 | known_all = known_all.data.cpu().numpy() 175 | if exclude_known: 176 | ind_selected = np.where(known_all == 0)[0] 177 | else: 178 | ind_selected = np.where(known_all != 0)[0] 179 | print("selected ratio %s"%( (len(ind_selected)/ len(known_all)))) 180 | model.train() 181 | dataset.set_index(ind_selected) 182 | 183 | def test(args, test_loader, model, epoch, val=False): 184 | batch_time = AverageMeter() 185 | data_time = AverageMeter() 186 | losses = AverageMeter() 187 | top1 = AverageMeter() 188 | acc = AverageMeter() 189 | unk = AverageMeter() 190 | top5 = AverageMeter() 191 | end = time.time() 192 | 193 | if not args.no_progress: 194 | test_loader = tqdm(test_loader, 195 | disable=args.local_rank not in [-1, 0]) 196 | with torch.no_grad(): 197 | for batch_idx, (inputs, targets) in enumerate(test_loader): 198 | data_time.update(time.time() - end) 199 | model.eval() 200 | inputs = inputs.to(args.device) 201 | targets = targets.to(args.device) 202 | outputs, outputs_open = model(inputs) 203 | outputs = F.softmax(outputs, 1) 204 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1) 205 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda() 206 | pred_close = outputs.data.max(1)[1] 207 | unk_score = out_open[tmp_range, 0, pred_close] 208 | known_score = outputs.max(1)[0] 209 | targets_unk = targets >= int(outputs.size(1)) 210 | targets[targets_unk] = int(outputs.size(1)) 211 | known_targets = targets < int(outputs.size(1))#[0] 212 | known_pred = outputs[known_targets] 213 | known_targets = targets[known_targets] 214 | 215 | if len(known_pred) > 0: 216 | prec1, prec5 = accuracy(known_pred, known_targets, topk=(1, 5)) 217 | top1.update(prec1.item(), known_pred.shape[0]) 218 | top5.update(prec5.item(), known_pred.shape[0]) 219 | 220 | ind_unk = unk_score > 0.5 221 | pred_close[ind_unk] = int(outputs.size(1)) 222 | acc_all, unk_acc, size_unk = accuracy_open(pred_close, 223 | targets, 224 | num_classes=int(outputs.size(1))) 225 | acc.update(acc_all.item(), inputs.shape[0]) 226 | unk.update(unk_acc, size_unk) 227 | 228 | batch_time.update(time.time() - end) 229 | end = time.time() 230 | if batch_idx == 0: 231 | unk_all = unk_score 232 | known_all = known_score 233 | label_all = targets 234 | else: 235 | unk_all = torch.cat([unk_all, unk_score], 0) 236 | known_all = torch.cat([known_all, known_score], 0) 237 | label_all = torch.cat([label_all, targets], 0) 238 | 239 | if not args.no_progress: 240 | test_loader.set_description("Test Iter: {batch:4}/{iter:4}. " 241 | "Data: {data:.3f}s." 242 | "Batch: {bt:.3f}s. " 243 | "Loss: {loss:.4f}. " 244 | "Closed t1: {top1:.3f} " 245 | "t5: {top5:.3f} " 246 | "acc: {acc:.3f}. " 247 | "unk: {unk:.3f}. ".format( 248 | batch=batch_idx + 1, 249 | iter=len(test_loader), 250 | data=data_time.avg, 251 | bt=batch_time.avg, 252 | loss=losses.avg, 253 | top1=top1.avg, 254 | top5=top5.avg, 255 | acc=acc.avg, 256 | unk=unk.avg, 257 | )) 258 | if not args.no_progress: 259 | test_loader.close() 260 | ## ROC calculation 261 | #import pdb 262 | #pdb.set_trace() 263 | unk_all = unk_all.data.cpu().numpy() 264 | known_all = known_all.data.cpu().numpy() 265 | label_all = label_all.data.cpu().numpy() 266 | if not val: 267 | roc = compute_roc(unk_all, label_all, 268 | num_known=int(outputs.size(1))) 269 | roc_soft = compute_roc(-known_all, label_all, 270 | num_known=int(outputs.size(1))) 271 | ind_known = np.where(label_all < int(outputs.size(1)))[0] 272 | id_score = unk_all[ind_known] 273 | logger.info("Closed acc: {:.3f}".format(top1.avg)) 274 | logger.info("Overall acc: {:.3f}".format(acc.avg)) 275 | logger.info("Unk acc: {:.3f}".format(unk.avg)) 276 | logger.info("ROC: {:.3f}".format(roc)) 277 | logger.info("ROC Softmax: {:.3f}".format(roc_soft)) 278 | return losses.avg, top1.avg, acc.avg, \ 279 | unk.avg, roc, roc_soft, id_score 280 | else: 281 | logger.info("Closed acc: {:.3f}".format(top1.avg)) 282 | return top1.avg 283 | 284 | 285 | def test_ood(args, test_id, test_loader, model): 286 | batch_time = AverageMeter() 287 | data_time = AverageMeter() 288 | end = time.time() 289 | 290 | if not args.no_progress: 291 | test_loader = tqdm(test_loader, 292 | disable=args.local_rank not in [-1, 0]) 293 | with torch.no_grad(): 294 | for batch_idx, (inputs, targets) in enumerate(test_loader): 295 | data_time.update(time.time() - end) 296 | model.eval() 297 | inputs = inputs.to(args.device) 298 | outputs, outputs_open = model(inputs) 299 | out_open = F.softmax(outputs_open.view(outputs_open.size(0), 2, -1), 1) 300 | tmp_range = torch.range(0, out_open.size(0) - 1).long().cuda() 301 | pred_close = outputs.data.max(1)[1] 302 | unk_score = out_open[tmp_range, 0, pred_close] 303 | batch_time.update(time.time() - end) 304 | end = time.time() 305 | if batch_idx == 0: 306 | unk_all = unk_score 307 | else: 308 | unk_all = torch.cat([unk_all, unk_score], 0) 309 | if not args.no_progress: 310 | test_loader.close() 311 | ## ROC calculation 312 | unk_all = unk_all.data.cpu().numpy() 313 | roc = roc_id_ood(test_id, unk_all) 314 | 315 | return roc 316 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | __all__ = ['set_parser'] 4 | 5 | def set_parser(): 6 | parser = argparse.ArgumentParser(description='PyTorch OpenMatch Training') 7 | ## Computational Configurations 8 | parser.add_argument('--gpu-id', default='0', type=int, 9 | help='id(s) for CUDA_VISIBLE_DEVICES') 10 | parser.add_argument('--num-workers', type=int, default=4, 11 | help='number of workers') 12 | parser.add_argument('--seed', default=None, type=int, 13 | help="random seed") 14 | parser.add_argument("--amp", action="store_true", 15 | help="use 16-bit (mixed) precision through NVIDIA apex AMP") 16 | parser.add_argument("--opt_level", type=str, default="O1", 17 | help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 18 | "See details at https://nvidia.github.io/apex/amp.html") 19 | parser.add_argument("--local_rank", type=int, default=-1, 20 | help="For distributed training: local_rank") 21 | parser.add_argument('--no-progress', action='store_true', 22 | help="don't use progress bar") 23 | parser.add_argument('--eval_only', type=int, default=0, 24 | help='1 if evaluation mode ') 25 | parser.add_argument('--num_classes', type=int, default=6, 26 | help='for cifar10') 27 | 28 | parser.add_argument('--out', default='result', 29 | help='directory to output the result') 30 | parser.add_argument('--resume', default='', type=str, 31 | help='path to latest checkpoint (default: none)') 32 | parser.add_argument('--root', default='./data', type=str, 33 | help='path to data directory') 34 | parser.add_argument('--dataset', default='cifar10', type=str, 35 | choices=['cifar10', 'cifar100', 'imagenet'], 36 | help='dataset name') 37 | ## Hyper-parameters 38 | parser.add_argument('--opt', default='sgd', type=str, 39 | choices=['sgd', 'adam'], 40 | help='optimize name') 41 | parser.add_argument('--num-labeled', type=int, default=400, 42 | choices=[25, 50, 100, 400], 43 | help='number of labeled data per each class') 44 | parser.add_argument('--num_val', type=int, default=50, 45 | help='number of validation data per each class') 46 | parser.add_argument('--num-super', type=int, default=10, 47 | help='number of super-class known classes cifar100: 10 or 15') 48 | parser.add_argument("--expand-labels", action="store_true", 49 | help="expand labels to fit eval steps") 50 | parser.add_argument('--arch', default='wideresnet', type=str, 51 | choices=['wideresnet', 'resnext', 52 | 'resnet_imagenet'], 53 | help='dataset name') 54 | ## HP unique to OpenMatch (Some are changed from FixMatch) 55 | parser.add_argument('--lambda_oem', default=0.1, type=float, 56 | help='coefficient of OEM loss') 57 | parser.add_argument('--lambda_socr', default=0.5, type=float, 58 | help='coefficient of SOCR loss, 0.5 for CIFAR10, ImageNet, ' 59 | '1.0 for CIFAR100') 60 | parser.add_argument('--start_fix', default=10, type=int, 61 | help='epoch to start fixmatch training') 62 | parser.add_argument('--mu', default=2, type=int, 63 | help='coefficient of unlabeled batch size') 64 | parser.add_argument('--total-steps', default=2 ** 19, type=int, 65 | help='number of total steps to run') 66 | parser.add_argument('--epochs', default=512, type=int, 67 | help='number of epochs to run') 68 | parser.add_argument('--threshold', default=0.0, type=float, 69 | help='pseudo label threshold') 70 | ## 71 | parser.add_argument('--eval-step', default=1024, type=int, 72 | help='number of eval steps to run') 73 | 74 | parser.add_argument('--start-epoch', default=0, type=int, 75 | help='manual epoch number (useful on restarts)') 76 | parser.add_argument('--batch-size', default=64, type=int, 77 | help='train batchsize') 78 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 79 | help='initial learning rate') 80 | parser.add_argument('--warmup', default=0, type=float, 81 | help='warmup epochs (unlabeled data based)') 82 | parser.add_argument('--wdecay', default=5e-4, type=float, 83 | help='weight decay') 84 | parser.add_argument('--nesterov', action='store_true', default=True, 85 | help='use nesterov momentum') 86 | parser.add_argument('--use-ema', action='store_true', default=True, 87 | help='use EMA model') 88 | parser.add_argument('--ema-decay', default=0.999, type=float, 89 | help='EMA decay rate') 90 | parser.add_argument('--T', default=1, type=float, 91 | help='pseudo label temperature') 92 | 93 | 94 | args = parser.parse_args() 95 | return args --------------------------------------------------------------------------------