├── cifar ├── requirements.txt ├── scripts │ ├── test_cifar10.sh │ ├── test_cifar100.sh │ ├── ours_cifar10.sh │ └── ours_cifar100.sh ├── models │ ├── SSHead.py │ ├── ResNet.py │ ├── WideResNet.py │ ├── wide.py │ ├── BigResNet.py │ └── dm.py ├── utils │ ├── misc.py │ ├── offline.py │ ├── contrastive.py │ ├── aug.py │ ├── augmentation.py │ ├── test_helpers.py │ └── prepare_dataset.py ├── README.md ├── TEST.py └── OURS.py ├── imagenet ├── requirements.txt ├── scripts │ ├── test_r.sh │ ├── test_c.sh │ ├── ours_r.sh │ └── ours_c.sh ├── model │ └── resnet.py ├── utils │ ├── create_corruption_dataset.py │ ├── prepare_dataset.py │ ├── test_helpers.py │ └── offline.py ├── README.md ├── TEST.py └── OURS.py ├── imgs └── overview.png ├── LICENSE └── README.md /cifar/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision -------------------------------------------------------------------------------- /imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | imagenet_c -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yushu-Li/OWTTT/HEAD/imgs/overview.png -------------------------------------------------------------------------------- /imagenet/scripts/test_r.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | STRONG_OOD=$1 5 | 6 | 7 | python ./TEST.py \ 8 | --dataset ImageNet-R \ 9 | --dataroot ./data \ 10 | --strong_OOD ${STRONG_OOD} 11 | 12 | -------------------------------------------------------------------------------- /imagenet/scripts/test_c.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | CORRUPT=$1 5 | STRONG_OOD=$2 6 | 7 | 8 | python ./TEST.py \ 9 | --dataset ImageNet-C \ 10 | --dataroot ./data \ 11 | --strong_OOD ${STRONG_OOD} \ 12 | --corruption ${CORRUPT} 13 | 14 | 15 | -------------------------------------------------------------------------------- /imagenet/scripts/ours_r.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | STRONG_OOD=$1 6 | 7 | 8 | python ./OURS.py \ 9 | --dataset ImageNet-R \ 10 | --dataroot ./data \ 11 | --strong_OOD ${STRONG_OOD} \ 12 | --lr 0.001 \ 13 | --delta 0.1 \ 14 | --ce_scale 0.05 \ 15 | --da_scale 0.1 16 | 17 | -------------------------------------------------------------------------------- /cifar/scripts/test_cifar10.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | CORRUPT=$1 6 | STRONG_OOD=$2 7 | 8 | 9 | python TEST.py \ 10 | --dataset cifar10OOD \ 11 | --dataroot ./data \ 12 | --strong_OOD ${STRONG_OOD} \ 13 | --resume ./results/cifar10_joint_resnet50 \ 14 | --corruption ${CORRUPT} 15 | 16 | -------------------------------------------------------------------------------- /cifar/scripts/test_cifar100.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | 6 | CORRUPT=$1 7 | STRONG_OOD=$2 8 | 9 | 10 | 11 | python TEST.py \ 12 | --dataset cifar100OOD \ 13 | --dataroot ./data \ 14 | --strong_OOD ${STRONG_OOD} \ 15 | --resume ./results/cifar100_joint_resnet50 \ 16 | --corruption ${CORRUPT} 17 | 18 | -------------------------------------------------------------------------------- /imagenet/scripts/ours_c.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | CORRUPT=$1 6 | STRONG_OOD=$2 7 | 8 | 9 | python ./OURS.py \ 10 | --dataset ImageNet-C \ 11 | --dataroot ./data \ 12 | --strong_OOD ${STRONG_OOD} \ 13 | --corruption ${CORRUPT} \ 14 | --lr 0.001 \ 15 | --delta 0.1 \ 16 | --ce_scale 0.05 \ 17 | --da_scale 0.1 -------------------------------------------------------------------------------- /cifar/scripts/ours_cifar10.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | CORRUPT=$1 6 | STRONG_OOD=$2 7 | 8 | python OURS.py \ 9 | --dataset cifar10OOD \ 10 | --dataroot ./data \ 11 | --strong_OOD ${STRONG_OOD} \ 12 | --resume ./results/cifar10_joint_resnet50 \ 13 | --corruption ${CORRUPT} \ 14 | --lr 0.01 \ 15 | --delta 0.1 \ 16 | --da_scale 1 \ 17 | --ce_scale 0.2 18 | 19 | -------------------------------------------------------------------------------- /cifar/scripts/ours_cifar100.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | export PYTHONPATH=$PYTHONPATH:$(pwd) 4 | 5 | CORRUPT=$1 6 | STRONG_OOD=$2 7 | 8 | python OURS.py \ 9 | --dataset cifar100OOD \ 10 | --dataroot ./data \ 11 | --strong_OOD ${STRONG_OOD} \ 12 | --resume ./results/cifar100_joint_resnet50 \ 13 | --corruption ${CORRUPT} \ 14 | --lr 0.001 \ 15 | --delta 0.1 \ 16 | --da_scale 1 \ 17 | --ce_scale 0.2 18 | 19 | -------------------------------------------------------------------------------- /imagenet/model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchvision.models as models 3 | 4 | 5 | class SupCEResNet(nn.Module): 6 | """official Resnet for image classification, e.g., ImageNet""" 7 | def __init__(self, name='resnet50'): 8 | super(SupCEResNet, self).__init__() 9 | self.encoder = models.__dict__[name](pretrained=True) 10 | self.fc = self.encoder.fc 11 | self.encoder.fc = nn.Identity() 12 | 13 | def forward(self, x): 14 | return self.fc(self.encoder(x)) 15 | -------------------------------------------------------------------------------- /cifar/models/SSHead.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import math 3 | import copy 4 | 5 | class ViewFlatten(nn.Module): 6 | def __init__(self): 7 | super(ViewFlatten, self).__init__() 8 | 9 | def forward(self, x): 10 | return x.view(x.size(0), -1) 11 | 12 | class ExtractorHead(nn.Module): 13 | def __init__(self, ext, head): 14 | super(ExtractorHead, self).__init__() 15 | self.ext = ext 16 | self.head = head 17 | 18 | def forward(self, x): 19 | return self.head(self.ext(x)) 20 | 21 | def extractor_from_layer3(net): 22 | layers = [net.conv1, net.layer1, net.layer2, net.layer3, net.bn, net.relu, net.avgpool, ViewFlatten()] 23 | return nn.Sequential(*layers) 24 | 25 | def extractor_from_layer2(net): 26 | layers = [net.conv1, net.layer1, net.layer2] 27 | return nn.Sequential(*layers) 28 | 29 | def head_on_layer2(net, width, classes): 30 | head = copy.deepcopy([net.layer3, net.bn, net.relu, net.avgpool]) 31 | head.append(ViewFlatten()) 32 | head.append(nn.Linear(64 * width, classes)) 33 | return nn.Sequential(*head) 34 | 35 | def task_head_on_layer3(net): 36 | layers = [net.fc] 37 | return nn.Sequential(*layers) 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yushu-Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OWTTT 2 | 3 | This repository is an official implementation for our [ICCV 2023 Oral] paper. 4 | 5 | ## On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion 6 | 7 | **[Yushu Li](https://yushu-li.github.io/)1**   **[Xun Xu](https://alex-xun-xu.github.io/)2**   **[Yongyi Su](https://yysu.site/)1**   **[Kui Jia](http://kuijia.site/)1** 8 |
9 | 1South China University of Technology   10 |
2Institute for Infocomm Research (I2R), Agency for Science, Technology and Research (A*STAR) 11 |
12 | 13 | [![arXiv preprint](http://img.shields.io/badge/arXiv-2308.09942-b31b1b)](https://arxiv.org/abs/2308.09942) 14 | [![Project Page](http://img.shields.io/badge/Project%20Page-OWTTT-brightgreen)](https://yushu-li.github.io/owttt-site/) 15 | 16 | 17 | ### Overview 18 | 19 | ![](./imgs/overview.png) 20 | 21 | 22 | ### CIFAR10-C/CIFAR100-C 23 | 24 | The code is released in the [cifar](cifar) folder. 25 | 26 | ### ImageNet-C/ImageNet-R 27 | 28 | The code is released in the [imagenet](imagenet) folder. 29 | 30 | ### Citation 31 | 32 | If you find our work useful in your research, please consider citing: 33 | 34 | ```bibtex 35 | @inproceedings{ 36 | li2023robustness, 37 | title={On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion}, 38 | author={Li, Yushu and Xu, Xun and Su, Yongyi and Jia, Kui}, 39 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 40 | month={October}, 41 | year={2023} 42 | } 43 | ``` -------------------------------------------------------------------------------- /imagenet/utils/create_corruption_dataset.py: -------------------------------------------------------------------------------- 1 | from imagenet_c import * 2 | from torchvision.datasets import ImageNet 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import DataLoader 5 | import os 6 | import torch 7 | import gorilla 8 | 9 | DATA_ROOT = '/cluster/sc_download/li.yushu/imagenet_ttac' 10 | CORRUPTION_PATH = './corruption' 11 | 12 | 13 | corruption_tuple = (gaussian_noise, shot_noise, impulse_noise, defocus_blur, 14 | glass_blur, motion_blur, zoom_blur, snow, frost, fog, 15 | brightness, contrast, elastic_transform, pixelate, jpeg_compression) 16 | 17 | corruption_dict = {corr_func.__name__: corr_func for corr_func in corruption_tuple} 18 | 19 | class corrupt(object): 20 | def __init__(self, corruption_name, severity=5): 21 | self.corruption_name = corruption_name 22 | self.severity = severity 23 | return 24 | 25 | def __call__(self, x): 26 | # x: PIL.Image 27 | x_corrupted = corruption_dict[self.corruption_name](x, self.severity) 28 | return np.uint8(x_corrupted) 29 | 30 | def __repr__(self): 31 | return "Corruption(name=" + self.corruption_name + ", severity=" + str(self.severity) + ")" 32 | 33 | 34 | print(os.path.join(DATA_ROOT, CORRUPTION_PATH)) 35 | if os.path.exists(os.path.join(DATA_ROOT, CORRUPTION_PATH)) is False: 36 | os.mkdir(os.path.join(DATA_ROOT, CORRUPTION_PATH)) 37 | 38 | 39 | 40 | for corruption in corruption_dict.keys(): 41 | if os.path.exists(os.path.join(DATA_ROOT, CORRUPTION_PATH, corruption + '.pth')): 42 | continue 43 | print(corruption) 44 | val_transform = transforms.Compose([ 45 | transforms.Resize(256), 46 | transforms.CenterCrop(224), 47 | corrupt(corruption, 5) 48 | ]) 49 | 50 | target_dataset = ImageNet(DATA_ROOT, 'val', transform=val_transform) 51 | 52 | target_dataloader = DataLoader(target_dataset, batch_size=256, shuffle=False, drop_last=False, num_workers=16) 53 | 54 | datas = [] 55 | for batch in gorilla.track(target_dataloader): 56 | datas.append(batch[0]) 57 | datas = torch.cat(datas) 58 | torch.save(datas, os.path.join(DATA_ROOT, CORRUPTION_PATH, corruption + '.pth')) 59 | 60 | 61 | -------------------------------------------------------------------------------- /cifar/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from colorama import Fore 5 | 6 | def get_grad(params): 7 | if isinstance(params, torch.Tensor): 8 | params = [params] 9 | params = list(filter(lambda p: p.grad is not None, params)) 10 | grad = [p.grad.data.cpu().view(-1) for p in params] 11 | return torch.cat(grad) 12 | 13 | def write_to_txt(name, content): 14 | with open(name, 'w') as text_file: 15 | text_file.write(content) 16 | 17 | def my_makedir(name): 18 | try: 19 | os.makedirs(name) 20 | except OSError: 21 | pass 22 | 23 | def print_args(opt): 24 | for arg in vars(opt): 25 | print('%s %s' % (arg, getattr(opt, arg))) 26 | 27 | def mean(ls): 28 | return sum(ls) / len(ls) 29 | 30 | def normalize(v): 31 | return (v - v.mean()) / v.std() 32 | 33 | def flat_grad(grad_tuple): 34 | return torch.cat([p.view(-1) for p in grad_tuple]) 35 | 36 | def print_nparams(model): 37 | nparams = sum([param.nelement() for param in model.parameters()]) 38 | print('number of parameters: %d' % (nparams)) 39 | 40 | def print_color(color, string): 41 | print(getattr(Fore, color) + string + Fore.RESET) 42 | 43 | def freeze_params(model): 44 | for name, p in model.named_parameters(): 45 | p.requires_grad = False 46 | print("Freeze parameter until", name) 47 | 48 | def print_params(model): 49 | for name, p in model.named_parameters(): 50 | print(name) 51 | 52 | class AverageMeter(object): 53 | """Computes and stores the average and current value""" 54 | def __init__(self, name, fmt=':f'): 55 | self.name = name 56 | self.fmt = fmt 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | 71 | def __str__(self): 72 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 73 | return fmtstr.format(**self.__dict__) 74 | 75 | def adjust_learning_rate(args, optimizer, epoch): 76 | lr = args.lr 77 | 78 | eta_min = lr * (args.lr_decay_rate ** 3) 79 | lr = eta_min + (lr - eta_min) * ( 80 | 1 + math.cos(math.pi * epoch / args.nepoch)) / 2 81 | 82 | for param_group in optimizer.param_groups: 83 | param_group['lr'] = lr 84 | -------------------------------------------------------------------------------- /cifar/utils/offline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import statistics 3 | import os 4 | 5 | def covariance(features): 6 | assert len(features.size()) == 2, "TODO: multi-dimensional feature map covariance" 7 | n = features.shape[0] 8 | tmp = torch.ones((1, n), device=features.device) @ features 9 | cov = (features.t() @ features - (tmp.t() @ tmp) / n) / n 10 | return cov 11 | 12 | def coral(cs, ct): 13 | d = cs.shape[0] 14 | loss = (cs - ct).pow(2).sum() / (4. * d ** 2) 15 | return loss 16 | 17 | 18 | def linear_mmd(ms, mt): 19 | loss = (ms - mt).pow(2).mean() 20 | return loss 21 | 22 | def offline(args,trloader, ext, classifier, head, class_num=10): 23 | if class_num == 10: 24 | if os.path.exists(args.resume+'/offline_cifar10.pth'): 25 | data = torch.load(args.resume+'/offline_cifar10.pth') 26 | return data 27 | elif class_num == 100: 28 | if os.path.exists(args.resume+'/offline_cifar100.pth'): 29 | data = torch.load(args.resume+'/offline_cifar100.pth') 30 | return data 31 | else: 32 | raise Exception("This function only handles CIFAR10 and CIFAR100 datasets.") 33 | ext.eval() 34 | 35 | feat_stack = [[] for i in range(class_num)] 36 | ssh_feat_stack = [[] for i in range(class_num)] 37 | 38 | with torch.no_grad(): 39 | for batch_idx, (inputs, labels) in enumerate(trloader): 40 | 41 | feat = ext(inputs.cuda()) 42 | predict_logit = classifier(feat) 43 | ssh_feat = predict_logit 44 | 45 | pseudo_label = predict_logit.max(dim=1)[1] 46 | 47 | for label in pseudo_label.unique(): 48 | label_mask = pseudo_label == label 49 | feat_stack[label].extend(feat[label_mask, :]) 50 | ssh_feat_stack[label].extend(ssh_feat[label_mask, :]) 51 | ext_mu = [] 52 | ext_cov = [] 53 | ext_all = [] 54 | 55 | ssh_mu = [] 56 | ssh_cov = [] 57 | ssh_all = [] 58 | for feat in feat_stack: 59 | ext_mu.append(torch.stack(feat).mean(dim=0)) 60 | ext_cov.append(covariance(torch.stack(feat))) 61 | ext_all.extend(feat) 62 | 63 | for feat in ssh_feat_stack: 64 | ssh_mu.append(torch.stack(feat).mean(dim=0)) 65 | ssh_cov.append(covariance(torch.stack(feat))) 66 | ssh_all.extend(feat) 67 | 68 | ext_all = torch.stack(ext_all) 69 | ext_all_mu = ext_all.mean(dim=0) 70 | ext_all_cov = covariance(ext_all) 71 | 72 | ssh_all = torch.stack(ssh_all) 73 | ssh_all_mu = ssh_all.mean(dim=0) 74 | ssh_all_cov = covariance(ssh_all) 75 | if class_num == 10: 76 | torch.save((ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov), args.resume+'/offline_cifar10.pth') 77 | if class_num == 100: 78 | torch.save((ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov), args.resume+'/offline_cifar100.pth') 79 | return ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov 80 | 81 | 82 | -------------------------------------------------------------------------------- /cifar/models/ResNet.py: -------------------------------------------------------------------------------- 1 | # Based on the ResNet implementation in torchvision 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | 4 | import math 5 | import torch 6 | from torch import nn 7 | from torchvision.models.resnet import conv3x3 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, inplanes, planes, norm_layer, stride=1, downsample=None): 11 | super(BasicBlock, self).__init__() 12 | self.downsample = downsample 13 | self.stride = stride 14 | 15 | self.bn1 = norm_layer(inplanes) 16 | self.relu1 = nn.ReLU(inplace=True) 17 | self.conv1 = conv3x3(inplanes, planes, stride) 18 | 19 | self.bn2 = norm_layer(planes) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | 23 | def forward(self, x): 24 | residual = x 25 | residual = self.bn1(residual) 26 | residual = self.relu1(residual) 27 | residual = self.conv1(residual) 28 | 29 | residual = self.bn2(residual) 30 | residual = self.relu2(residual) 31 | residual = self.conv2(residual) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | return x + residual 36 | 37 | class Downsample(nn.Module): 38 | def __init__(self, nIn, nOut, stride): 39 | super(Downsample, self).__init__() 40 | self.avg = nn.AvgPool2d(stride) 41 | assert nOut % nIn == 0 42 | self.expand_ratio = nOut // nIn 43 | 44 | def forward(self, x): 45 | x = self.avg(x) 46 | return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1) 47 | 48 | class ResNetCifar(nn.Module): 49 | def __init__(self, depth, width=1, classes=10, channels=3, norm_layer=nn.BatchNorm2d, detach=None): 50 | assert (depth - 2) % 6 == 0 # depth is 6N+2 51 | self.N = (depth - 2) // 6 52 | super(ResNetCifar, self).__init__() 53 | 54 | # Following the Wide ResNet convention, we fix the very first convolution 55 | self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.inplanes = 16 57 | self.layer1 = self._make_layer(norm_layer, 16 * width) 58 | self.layer2 = self._make_layer(norm_layer, 32 * width, stride=2) 59 | self.layer3 = self._make_layer(norm_layer, 64 * width, stride=2) 60 | self.bn = norm_layer(64 * width) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.avgpool = nn.AvgPool2d(8) 63 | self.fc = nn.Linear(64 * width, classes) 64 | 65 | # Task-agnostic encoder 66 | self.detach = detach 67 | 68 | # Initialization 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | 74 | def _make_layer(self, norm_layer, planes, stride=1): 75 | downsample = None 76 | if stride != 1 or self.inplanes != planes: 77 | downsample = Downsample(self.inplanes, planes, stride) 78 | layers = [BasicBlock(self.inplanes, planes, norm_layer, stride, downsample)] 79 | self.inplanes = planes 80 | for i in range(self.N - 1): 81 | layers.append(BasicBlock(self.inplanes, planes, norm_layer)) 82 | return nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | x = self.conv1(x) 86 | x = self.layer1(x) 87 | x = self.layer2(x) 88 | if self.detach == 'layer2': x = x.detach() 89 | x = self.layer3(x) 90 | x = self.bn(x) 91 | x = self.relu(x) 92 | x = self.avgpool(x) 93 | x = x.view(x.size(0), -1) 94 | if self.detach == 'layer3': x = x.detach() 95 | x = self.fc(x) 96 | return x 97 | -------------------------------------------------------------------------------- /cifar/utils/contrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SupConLoss(nn.Module): 6 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 7 | It also supports the unsupervised contrastive loss in SimCLR""" 8 | def __init__(self, temperature=0.07, contrast_mode='all', 9 | base_temperature=0.07): 10 | super(SupConLoss, self).__init__() 11 | self.temperature = temperature 12 | self.contrast_mode = contrast_mode 13 | self.base_temperature = base_temperature 14 | 15 | def forward(self, features, labels=None, mask=None): 16 | """Compute loss for model. If both `labels` and `mask` are None, 17 | it degenerates to SimCLR unsupervised loss: 18 | https://arxiv.org/pdf/2002.05709.pdf 19 | 20 | Args: 21 | features: hidden vector of shape [bsz, n_views, ...]. 22 | labels: ground truth of shape [bsz]. 23 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 24 | has the same class as sample i. Can be asymmetric. 25 | Returns: 26 | A loss scalar. 27 | """ 28 | device = (torch.device('cuda') 29 | if features.is_cuda 30 | else torch.device('cpu')) 31 | 32 | if len(features.shape) < 3: 33 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 34 | 'at least 3 dimensions are required') 35 | if len(features.shape) > 3: 36 | features = features.view(features.shape[0], features.shape[1], -1) 37 | 38 | batch_size = features.shape[0] 39 | if labels is not None and mask is not None: 40 | raise ValueError('Cannot define both `labels` and `mask`') 41 | elif labels is None and mask is None: 42 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 43 | elif labels is not None: 44 | labels = labels.contiguous().view(-1, 1) 45 | if labels.shape[0] != batch_size: 46 | raise ValueError('Num of labels does not match num of features') 47 | mask = torch.eq(labels, labels.T).float().to(device) 48 | else: 49 | mask = mask.float().to(device) 50 | 51 | contrast_count = features.shape[1] 52 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 53 | if self.contrast_mode == 'one': 54 | anchor_feature = features[:, 0] 55 | anchor_count = 1 56 | elif self.contrast_mode == 'all': 57 | anchor_feature = contrast_feature 58 | anchor_count = contrast_count 59 | else: 60 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 61 | 62 | # compute logits 63 | anchor_dot_contrast = torch.div( 64 | torch.matmul(anchor_feature, contrast_feature.T), 65 | self.temperature) 66 | # for numerical stability 67 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 68 | logits = anchor_dot_contrast - logits_max.detach() 69 | 70 | # tile mask 71 | mask = mask.repeat(anchor_count, contrast_count) 72 | # mask-out self-contrast cases 73 | logits_mask = torch.scatter( 74 | torch.ones_like(mask), 75 | 1, 76 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 77 | 0 78 | ) 79 | 80 | mask = mask * logits_mask 81 | 82 | # compute log_prob 83 | exp_logits = torch.exp(logits) * logits_mask 84 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 85 | 86 | # compute mean of log-likelihood over positive 87 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 88 | 89 | # loss 90 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 91 | loss = loss.view(anchor_count, batch_size).mean() 92 | 93 | return loss 94 | -------------------------------------------------------------------------------- /cifar/README.md: -------------------------------------------------------------------------------- 1 | # OWTTT on CIFAR10-C/100-C 2 | 3 | Ours method and the baseline method TEST (direct test without adaptation) on CIFAR-10-C/100-C under common corruptions or natural shifts. Our implementation is based on [repo](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/cifar) and therefore requires some similar preparation processes. 4 | 5 | 6 | ### Requirements 7 | 8 | To install requirements: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | To download datasets: 15 | 16 | ``` 17 | export DATADIR=/data/cifar 18 | mkdir -p ${DATADIR} && cd ${DATADIR} 19 | wget -O CIFAR-10-C.tar https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1 20 | tar -xvf CIFAR-10-C.tar 21 | wget -O CIFAR-100-C.tar https://zenodo.org/record/3555552/files/CIFAR-100-C.tar?download=1 22 | tar -xvf CIFAR-100-C.tar 23 | wget -O tiny-imagenet-200.zip http://cs231n.stanford.edu/tiny-imagenet-200.zip 24 | unzip tiny-imagenet-200.zip 25 | ``` 26 | 27 | ### Pre-trained Models 28 | 29 | The checkpoints of pre-train Resnet-50 can be downloaded (214MB) using the following command: 30 | 31 | ``` 32 | mkdir -p results/cifar10_joint_resnet50 && cd results/cifar10_joint_resnet50 33 | gdown https://drive.google.com/uc?id=1QWyI8UrXJ6_H9lBbrq52qXWpjdpq4PUn && cd ../.. 34 | mkdir -p results/cifar100_joint_resnet50 && cd results/cifar100_joint_resnet50 35 | gdown https://drive.google.com/uc?id=1cau93HVjl4aWuZlrl7cJIMEKBxXXunR9 && cd ../.. 36 | ``` 37 | 38 | These models are obtained by training on the clean CIFAR10/100 images using semi-supervised SimCLR. 39 | 40 | ### Open-World Test-Time Training: 41 | 42 | We present our method and the baseline method TEST (direct test without adaptation) on CIFAR10-C/100-C. 43 | 44 | - run OURS method or the baeline method TEST on CIFAR10-C under the OWTTT protocol. 45 | 46 | ``` 47 | # OURS: 48 | bash scripts/ours_cifar10.sh "corruption_type" "strong_ood_type" 49 | 50 | # TEST: 51 | bash scripts/test_cifar10.sh "corruption_type" "strong_ood_type" 52 | ``` 53 | Where "corruption_type" is the corruption type in CIFAR10-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN, Tiny, cifar100]. 54 | 55 | For example, to run OURS or TEST on CIFAR10-C under the snow corruption with MNIST as strong OOD, we can use the following command: 56 | 57 | ``` 58 | # OURS: 59 | bash scripts/ours_cifar10.sh snow MNIST 60 | 61 | # TEST: 62 | bash scripts/test_cifar10.sh snow MNIST 63 | ``` 64 | 65 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD: 66 | 67 | | Method | ACC_S | ACC_N | ACC_H | 68 | |:------:|:-------:|:-------:|:-------:| 69 | | TEST | 66.36 | 91.56 | 76.95 | 70 | | OURS | 84.05 | 97.46 | 90.26| 71 | 72 | - run OURS method or the baeline method TEST on CIFAR100-C under the OWTTT protocol. 73 | 74 | ``` 75 | # OURS: 76 | bash scripts/ours_cifar100.sh "corruption_type" "strong_ood_type" 77 | 78 | # TEST: 79 | bash scripts/test_cifar100.sh "corruption_type" "strong_ood_type" 80 | ``` 81 | Where "corruption_type" is the corruption type in CIFAR100-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN, Tiny, cifar10]. 82 | 83 | For example, to run OURS or TEST on CIFAR100-C under the snow corruption with MNIST as strong OOD, we can use the following command: 84 | 85 | ``` 86 | # OURS: 87 | bash scripts/ours_cifar100.sh snow MNIST 88 | 89 | # TEST: 90 | bash scripts/test_cifar100.sh snow MNIST 91 | ``` 92 | 93 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD: 94 | 95 | | Method | ACC_S | ACC_N | ACC_H | 96 | |:------:|:-------:|:-------:|:-------:| 97 | | TEST | 29.2 | 53.27 | 37.72 | 98 | | OURS | 44.78 | 93.56 | 60.57 | 99 | 100 | 101 | ### Acknowledgements 102 | 103 | Our code is built upon the public code of the [TTAC](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/cifar). 104 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # OWTTT on ImageNet-C/R 2 | 3 | Ours method and the baseline method TEST (direct test without adaptation) on ImageNet-C/ImageNet-R under the OWTTT protocol. Our implementation is based on [repo](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/imagenet) and therefore requires some similar preparation processes. 4 | 5 | ### Requirements 6 | 7 | - To install requirements: 8 | 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | - To download ImageNet dataset: 14 | 15 | We need to firstly download the validation set and the development kit (Task 1 & 2) of ImageNet-1k on [here](https://image-net.org/challenges/LSVRC/2012/index.php), and put them under `data` folder. 16 | 17 | - To download ImageNet-R dataset: 18 | 19 | To download datasets: 20 | 21 | ``` 22 | export DATADIR=/data 23 | cd ${DATADIR} 24 | wget -O imagenet-r.tar https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar 25 | tar -xvf imagenet-r.tar 26 | ``` 27 | 28 | - To create the corruption dataset 29 | ``` 30 | python utils/create_corruption_dataset.py 31 | ``` 32 | 33 | The issue `Frost missing after pip install` can be solved following [here](https://github.com/hendrycks/robustness/issues/4#issuecomment-427226016). 34 | 35 | Finally, the structure of the `data` folder should be like 36 | ``` 37 | data 38 | |_ ILSVRC2012_devkit_t12.tar 39 | |_ ILSVRC2012_img_val.tar 40 | |_ val 41 | |_ n01440764 42 | |_ ... 43 | |_ imagenet-r 44 | |_ n01443537 45 | |_ ... 46 | |_ corruption 47 | |_ brightness.pth 48 | |_ contrast.pth 49 | |_ ... 50 | |_ meta.bin 51 | ``` 52 | 53 | ### Pre-trained Models 54 | 55 | Here, we use the pretrain model provided by torchvision. 56 | 57 | ### Open-World Test-Time Training: 58 | 59 | We present our method and the baseline method TEST (direct test without adaptation) on ImageNet-C/R. 60 | 61 | - run OURS method or the baseline method TEST on ImageNet-C under the OWTTT protocol. 62 | 63 | ``` 64 | # OURS: 65 | bash scripts/ours_c.sh "corruption_type" "strong_ood_type" 66 | 67 | # TEST: 68 | bash scripts/test_c.sh "corruption_type" "strong_ood_type" 69 | ``` 70 | Where "corruption_type" is the corruption type in ImageNet-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN]. 71 | 72 | For example, to run OURS or TEST on ImageNet-C under the snow corruption with MNIST as strong OOD, we can use the following command: 73 | 74 | ``` 75 | # OURS: 76 | bash scripts/ours_c.sh snow MNIST 77 | 78 | # TEST: 79 | bash scripts/test_c.sh snow MNIST 80 | ``` 81 | 82 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD: 83 | 84 | | Method | ACC_S | ACC_N | ACC_H | 85 | |:------:|:-------:|:-------:|:-------:| 86 | | TEST | 17.30 | 99.35 | 29.47 | 87 | | OURS | 45.34 | 100.00 | 62.39 | 88 | 89 | - run OURS method or the baseline method TEST on ImageNet-R under the OWTTT protocol. 90 | 91 | ``` 92 | # OURS: 93 | bash scripts/ours_cifar100.sh "strong_ood_type" 94 | 95 | # TEST: 96 | bash scripts/test_cifar100.sh "strong_ood_type" 97 | ``` 98 | Where "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN]. 99 | 100 | For example, to run OURS or TEST on ImageNet-R with MNIST as strong OOD, we can use the following command: 101 | 102 | ``` 103 | # OURS: 104 | bash scripts/ours_r.sh MNIST 105 | 106 | # TEST: 107 | bash scripts/test_r.sh MNIST 108 | ``` 109 | 110 | The following results are yielded by the above scripts (%) with MNIST as strong OOD: 111 | 112 | | Method | ACC_S | ACC_N | ACC_H | 113 | |:------:|:-------:|:-------:|:-------:| 114 | | TEST | 35.50 | 99.96 | 52.39 | 115 | | OURS | 41.40 | 100.00 | 58.56 | 116 | 117 | 118 | ### Acknowledgements 119 | 120 | Our code is built upon the public code of the [TTAC](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/imagenet). 121 | -------------------------------------------------------------------------------- /cifar/utils/aug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from PIL import ImageOps, Image 6 | from torchvision import transforms 7 | 8 | 9 | ## https://github.com/google-research/augmix 10 | 11 | def _augmix_aug(x_orig): 12 | x_orig = preaugment(x_orig) 13 | x_processed = preprocess(x_orig) 14 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 15 | m = np.float32(np.random.beta(1.0, 1.0)) 16 | 17 | mix = torch.zeros_like(x_processed) 18 | for i in range(3): 19 | x_aug = x_orig.copy() 20 | for _ in range(np.random.randint(1, 4)): 21 | x_aug = np.random.choice(augmentations)(x_aug) 22 | mix += w[i] * preprocess(x_aug) 23 | mix = m * x_processed + (1 - m) * mix 24 | return mix 25 | 26 | aug = _augmix_aug 27 | 28 | 29 | def autocontrast(pil_img, level=None): 30 | return ImageOps.autocontrast(pil_img) 31 | 32 | def equalize(pil_img, level=None): 33 | return ImageOps.equalize(pil_img) 34 | 35 | def rotate(pil_img, level): 36 | degrees = int_parameter(rand_lvl(level), 30) 37 | if np.random.uniform() > 0.5: 38 | degrees = -degrees 39 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 40 | 41 | def solarize(pil_img, level): 42 | level = int_parameter(rand_lvl(level), 256) 43 | return ImageOps.solarize(pil_img, 256 - level) 44 | 45 | def shear_x(pil_img, level): 46 | level = float_parameter(rand_lvl(level), 0.3) 47 | if np.random.uniform() > 0.5: 48 | level = -level 49 | return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128) 50 | 51 | def shear_y(pil_img, level): 52 | level = float_parameter(rand_lvl(level), 0.3) 53 | if np.random.uniform() > 0.5: 54 | level = -level 55 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR, fillcolor=128) 56 | 57 | def translate_x(pil_img, level): 58 | level = int_parameter(rand_lvl(level), 32 / 3) 59 | if np.random.random() > 0.5: 60 | level = -level 61 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128) 62 | 63 | def translate_y(pil_img, level): 64 | level = int_parameter(rand_lvl(level), 32 / 3) 65 | if np.random.random() > 0.5: 66 | level = -level 67 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR, fillcolor=128) 68 | 69 | def posterize(pil_img, level): 70 | level = int_parameter(rand_lvl(level), 4) 71 | return ImageOps.posterize(pil_img, 4 - level) 72 | 73 | 74 | def int_parameter(level, maxval): 75 | """Helper function to scale `val` between 0 and maxval . 76 | Args: 77 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 78 | maxval: Maximum value that the operation can have. This will be scaled 79 | to level/PARAMETER_MAX. 80 | Returns: 81 | An int that results from scaling `maxval` according to `level`. 82 | """ 83 | return int(level * maxval / 10) 84 | 85 | def float_parameter(level, maxval): 86 | """Helper function to scale `val` between 0 and maxval . 87 | Args: 88 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 89 | maxval: Maximum value that the operation can have. This will be scaled 90 | to level/PARAMETER_MAX. 91 | Returns: 92 | A float that results from scaling `maxval` according to `level`. 93 | """ 94 | return float(level) * maxval / 10. 95 | 96 | def rand_lvl(n): 97 | return np.random.uniform(low=0.1, high=n) 98 | 99 | 100 | augmentations = [ 101 | autocontrast, 102 | equalize, 103 | lambda x: rotate(x, 1), 104 | lambda x: solarize(x, 1), 105 | lambda x: shear_x(x, 1), 106 | lambda x: shear_y(x, 1), 107 | lambda x: translate_x(x, 1), 108 | lambda x: translate_y(x, 1), 109 | lambda x: posterize(x, 1), 110 | ] 111 | 112 | mean = [0.5, 0.5, 0.5] 113 | std = [0.5, 0.5, 0.5] 114 | preprocess = transforms.Compose([ 115 | transforms.ToTensor(), 116 | transforms.Normalize(mean, std) 117 | ]) 118 | preaugment = transforms.Compose([ 119 | transforms.RandomCrop(32, padding=4), 120 | transforms.RandomHorizontalFlip(), 121 | ]) -------------------------------------------------------------------------------- /cifar/models/WideResNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | """ Based on code from https://github.com/yaodongyu/TRADES """ 52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 53 | super(WideResNet, self).__init__() 54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 55 | assert ((depth - 4) % 6 == 0) 56 | n = (depth - 4) / 6 57 | block = BasicBlock 58 | # 1st conv before any network block 59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 60 | padding=1, bias=False) 61 | # 1st block 62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 63 | if sub_block1: 64 | # 1st sub-block 65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 74 | self.nChannels = nChannels[3] 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear) and not m.bias is None: 84 | m.bias.data.zero_() 85 | 86 | def forward(self, x): 87 | out = self.conv1(x) 88 | out = self.block1(out) 89 | out = self.block2(out) 90 | out = self.block3(out) 91 | out = self.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | out = out.view(-1, self.nChannels) 94 | return self.fc(out) 95 | -------------------------------------------------------------------------------- /cifar/models/wide.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | """ Based on code from https://github.com/yaodongyu/TRADES """ 52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 53 | super(WideResNet, self).__init__() 54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 55 | assert ((depth - 4) % 6 == 0) 56 | n = (depth - 4) / 6 57 | block = BasicBlock 58 | # 1st conv before any network block 59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 60 | padding=1, bias=False) 61 | # 1st block 62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 63 | if sub_block1: 64 | # 1st sub-block 65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | # self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 74 | self.nChannels = nChannels[3] 75 | self.num_out = self.nChannels 76 | 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 80 | m.weight.data.normal_(0, math.sqrt(2. / n)) 81 | elif isinstance(m, nn.BatchNorm2d): 82 | m.weight.data.fill_(1) 83 | m.bias.data.zero_() 84 | elif isinstance(m, nn.Linear) and not m.bias is None: 85 | m.bias.data.zero_() 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.block1(out) 90 | out = self.block2(out) 91 | out = self.block3(out) 92 | out = self.relu(self.bn1(out)) 93 | out = F.avg_pool2d(out, 8) 94 | out = out.view(-1, self.nChannels) 95 | # return self.fc(out) 96 | return out 97 | -------------------------------------------------------------------------------- /cifar/utils/augmentation.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, int(32*0.5)) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, int(32*0.5)) 220 | return img -------------------------------------------------------------------------------- /cifar/TEST.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | import torch.utils.data as data 5 | 6 | from utils.misc import * 7 | from utils.test_helpers import * 8 | from utils.prepare_dataset import * 9 | 10 | # ---------------------------------- 11 | import copy 12 | import random 13 | import numpy as np 14 | from utils.contrastive import * 15 | from utils.offline import * 16 | from torch import nn 17 | import torch.nn.functional as F 18 | # ---------------------------------- 19 | 20 | 21 | def compute_os_variance(os, th): 22 | """ 23 | Calculate the area of a rectangle. 24 | 25 | Parameters: 26 | os : OOD score queue. 27 | th : Given threshold to separate weak and strong OOD samples. 28 | 29 | Returns: 30 | float: Weighted variance at the given threshold th. 31 | """ 32 | 33 | thresholded_os = np.zeros(os.shape) 34 | thresholded_os[os >= th] = 1 35 | 36 | # compute weights 37 | nb_pixels = os.size 38 | nb_pixels1 = np.count_nonzero(thresholded_os) 39 | weight1 = nb_pixels1 / nb_pixels 40 | weight0 = 1 - weight1 41 | 42 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered 43 | # in the search for the best threshold 44 | if weight1 == 0 or weight0 == 0: 45 | return np.inf 46 | 47 | # find all pixels belonging to each class 48 | val_pixels1 = os[thresholded_os == 1] 49 | val_pixels0 = os[thresholded_os == 0] 50 | 51 | # compute variance of these classes 52 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0 53 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0 54 | 55 | return weight0 * var0 + weight1 * var1 56 | 57 | 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--dataset', default='cifar10OOD') 60 | parser.add_argument('--strong_OOD', default='noise') 61 | parser.add_argument('--strong_ratio', default=1, type=float) 62 | parser.add_argument('--dataroot', default="./data", help='path to dataset') 63 | parser.add_argument('--batch_size', default=256, type=int) 64 | parser.add_argument('--workers', default=4, type=int) 65 | parser.add_argument('--outf', help='folder to output log') 66 | parser.add_argument('--level', default=5, type=int) 67 | parser.add_argument('--N_m', default=512, type=int, help='queue length') 68 | parser.add_argument('--corruption', default='snow') 69 | parser.add_argument('--resume', default='/cluster/personal/code/TTT/TTAC-master/cifar/results/cifar10_joint_resnet50', help='directory of pretrained model') 70 | parser.add_argument('--model', default='resnet50', help='resnet50') 71 | parser.add_argument('--seed', default=0, type=int) 72 | 73 | 74 | # ----------- Args and Dataloader ------------ 75 | args = parser.parse_args() 76 | 77 | print(args) 78 | print('\n') 79 | 80 | 81 | 82 | 83 | class_num = 10 if args.dataset == 'cifar10OOD' else 100 84 | 85 | net, ext, head, ssh, classifier = build_resnet50(args) 86 | 87 | teset, _ = prepare_test_data(args) 88 | teloader = data.DataLoader(teset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=seed_worker, pin_memory=True, drop_last=False) 89 | 90 | # ------------------------------- 91 | print('Resuming from %s...' %(args.resume)) 92 | 93 | load_resnet50(net, head, ssh, classifier, args) 94 | 95 | # ----------- Offline Feature Summarization ------------ 96 | args_align = copy.deepcopy(args) 97 | 98 | _, offlineloader = prepare_train_data(args_align) 99 | ext_src_mu, ext_src_cov, ssh_src_mu, ssh_src_cov, mu_src_ext, cov_src_ext, mu_src_ssh, cov_src_ssh = offline(args,offlineloader, ext, classifier, head, class_num) 100 | 101 | ext_src_mu = torch.stack(ext_src_mu) 102 | weak_prototype = F.normalize(ext_src_mu.clone()).cuda() 103 | 104 | torch.manual_seed(args.seed) 105 | random.seed(args.seed) 106 | np.random.seed(args.seed) 107 | torch.cuda.manual_seed(args.seed) 108 | torch.cuda.manual_seed_all(args.seed) 109 | 110 | # ----------- Open-World Test-time Training ------------ 111 | 112 | correct = [] 113 | unseen_correct= [] 114 | all_correct=[] 115 | cumulative_error = [] 116 | num_open = 0 117 | predicted_list=[] 118 | label_list=[] 119 | 120 | os_inference_queue = [] 121 | queue_length = args.N_m 122 | 123 | ema_total_n = 0. 124 | 125 | print('\n-----Test-Time Training with TEST-----') 126 | for te_idx, (te_inputs, te_labels) in enumerate(teloader): 127 | 128 | 129 | ####-------------------------- Test ----------------------------#### 130 | 131 | with torch.no_grad(): 132 | if isinstance(te_inputs,list): 133 | inputs = te_inputs[0].cuda() 134 | else: 135 | inputs = te_inputs.cuda() 136 | net.eval() 137 | feat_ext = ext(inputs) #b,2048 138 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) 139 | update = 1 140 | softmax_logit = logit.softmax(dim=-1) 141 | pro, predicted = softmax_logit.max(dim=-1) 142 | 143 | ood_score, max_index = logit.max(1) 144 | ood_score = 1-ood_score 145 | os_inference_queue.extend(ood_score.detach().cpu().tolist()) 146 | os_inference_queue = os_inference_queue[-queue_length:] 147 | 148 | threshold_range = np.arange(0,1,0.01) 149 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range] 150 | best_threshold = threshold_range[np.argmin(criterias)] 151 | unseen_mask = (ood_score > best_threshold) 152 | args.ts = best_threshold 153 | predicted[unseen_mask] = class_num 154 | 155 | one = torch.ones_like(te_labels)*class_num 156 | false = torch.ones_like(te_labels)*-1 157 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted) 158 | all_labels = torch.where(te_labels>class_num-1, one, te_labels) 159 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels) 160 | unseen_labels = torch.where(te_labels>class_num-1, one, false) 161 | correct.append(predicted.cpu().eq(seen_labels)) 162 | unseen_correct.append(predicted.cpu().eq(unseen_labels)) 163 | all_correct.append(predicted.cpu().eq(all_labels)) 164 | num_open += torch.gt(te_labels, 99).sum() 165 | 166 | predicted_list.append(predicted.long().cpu()) 167 | label_list.append(all_labels.long().cpu()) 168 | 169 | 170 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4) 171 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4) 172 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4) 173 | print('Batch:(', te_idx,'/',len(teloader),\ 174 | '\t Cumulative Results: ACC_S:', seen_acc,\ 175 | '\tACC_N:', unseen_acc,\ 176 | '\tACC_H:',h_score\ 177 | ) 178 | 179 | 180 | print('\nTest time training result:',' ACC_S:', seen_acc,\ 181 | '\tACC_N:', unseen_acc,\ 182 | '\tACC_H:',h_score,'\n\n\n\n'\ 183 | ) 184 | 185 | 186 | if args.outf != None: 187 | my_makedir(args.outf) 188 | with open (args.outf+'/results.txt','a') as f: 189 | f.write(str(args)+'\n') 190 | f.write( 191 | 'ACC_S:'+ str(seen_acc)+\ 192 | '\tACC_N:'+ str(unseen_acc)+\ 193 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\ 194 | ) -------------------------------------------------------------------------------- /imagenet/TEST.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | import torch.utils.data as data 5 | 6 | import torch.nn as nn 7 | from utils.test_helpers import * 8 | from utils.prepare_dataset import * 9 | 10 | # ---------------------------------- 11 | import copy 12 | import random 13 | import numpy as np 14 | 15 | from utils.test_helpers import build_model, test 16 | from utils.prepare_dataset import prepare_transforms, create_dataloader, ImageNetCorruption, ImageNet_, prepare_ood_test_data,prepare_ood_test_data_r 17 | from utils.offline import offline, offline_r 18 | import torch.nn.functional as F 19 | # ---------------------------------- 20 | 21 | 22 | def compute_os_variance(os, th): 23 | """ 24 | Calculate the area of a rectangle. 25 | 26 | Parameters: 27 | os : OOD score queue. 28 | th : Given threshold to separate weak and strong OOD samples. 29 | 30 | Returns: 31 | float: Weighted variance at the given threshold th. 32 | """ 33 | 34 | thresholded_os = np.zeros(os.shape) 35 | thresholded_os[os >= th] = 1 36 | 37 | # compute weights 38 | nb_pixels = os.size 39 | nb_pixels1 = np.count_nonzero(thresholded_os) 40 | weight1 = nb_pixels1 / nb_pixels 41 | weight0 = 1 - weight1 42 | 43 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered 44 | # in the search for the best threshold 45 | if weight1 == 0 or weight0 == 0: 46 | return np.inf 47 | 48 | # find all pixels belonging to each class 49 | val_pixels1 = os[thresholded_os == 1] 50 | val_pixels0 = os[thresholded_os == 0] 51 | 52 | # compute variance of these classes 53 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0 54 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0 55 | 56 | return weight0 * var0 + weight1 * var1 57 | 58 | 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--dataset', default='ImageNet-C') 61 | parser.add_argument('--strong_OOD', default='noise') 62 | parser.add_argument('--strong_ratio', default=1, type=float) 63 | parser.add_argument('--dataroot', default='./data') 64 | parser.add_argument('--batch_size', default=128, type=int) 65 | parser.add_argument('--workers', default=8, type=int) 66 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale') 67 | parser.add_argument('--outf', help='folder to output log') 68 | parser.add_argument('--level', default=5, type=int) 69 | parser.add_argument('--N_m', default=512, type=int, help='queue length') 70 | parser.add_argument('--corruption', default='snow') 71 | parser.add_argument('--offline', default='./results/offline/', help='directory of pretrained model') 72 | parser.add_argument('--model', default='resnet50', help='resnet50') 73 | parser.add_argument('--seed', default=0, type=int) 74 | 75 | 76 | # ----------- Args and Dataloader ------------ 77 | args = parser.parse_args() 78 | 79 | print(args) 80 | print('\n') 81 | 82 | 83 | 84 | net, ext, classifier = build_model() 85 | 86 | 87 | train_transform, val_transform, val_corrupt_transform = prepare_transforms() 88 | 89 | source_dataset = ImageNet_(args.dataroot, 'val', transform=val_transform, is_carry_index=True) 90 | 91 | if args.dataset == 'ImageNet-C': 92 | target_dataset_test = prepare_ood_test_data(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform) 93 | class_num = 1000 94 | 95 | elif args.dataset == 'ImageNet-R': 96 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids] 97 | target_dataset_test = prepare_ood_test_data_r(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform) 98 | class_num = 200 99 | else: 100 | raise NotImplementedError 101 | 102 | source_dataloader = create_dataloader(source_dataset, args, True, False) 103 | target_dataloader_test = create_dataloader(target_dataset_test, args, True, False) 104 | 105 | # ----------- Offline Feature Summarization ------------ 106 | if args.dataset == 'ImageNet-C': 107 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline(args, source_dataloader, ext, classifier) 108 | weak_prototype = F.normalize(ext_mean_categories.clone()).cuda() 109 | else: 110 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline_r(args, source_dataloader, ext, classifier) 111 | weak_prototype = F.normalize(ext_mean_categories[indices_in_1k].clone()).cuda() 112 | 113 | torch.manual_seed(args.seed) 114 | random.seed(args.seed) 115 | np.random.seed(args.seed) 116 | torch.cuda.manual_seed(args.seed) 117 | torch.cuda.manual_seed_all(args.seed) 118 | 119 | # ----------- Open-World Test-time Training ------------ 120 | 121 | correct = [] 122 | unseen_correct= [] 123 | all_correct=[] 124 | cumulative_error = [] 125 | num_open = 0 126 | predicted_list=[] 127 | label_list=[] 128 | 129 | os_inference_queue = [] 130 | queue_length = args.N_m 131 | 132 | ema_total_n = 0. 133 | 134 | print('\n-----Test-Time Training with TEST-----') 135 | for te_idx, (te_inputs, te_labels) in enumerate(target_dataloader_test): 136 | 137 | if isinstance(te_inputs,list): 138 | inputs = te_inputs[0].cuda() 139 | else: 140 | inputs = te_inputs.cuda() 141 | 142 | ####-------------------------- Test ----------------------------#### 143 | 144 | with torch.no_grad(): 145 | 146 | net.eval() 147 | feat_ext = ext(inputs) #b,2048 148 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) 149 | 150 | 151 | softmax_logit = logit.softmax(dim=-1) 152 | pro, predicted = softmax_logit.max(dim=-1) 153 | 154 | ood_score, max_index = logit.max(1) 155 | ood_score = 1-ood_score 156 | os_inference_queue.extend(ood_score.detach().cpu().tolist()) 157 | os_inference_queue = os_inference_queue[-queue_length:] 158 | 159 | threshold_range = np.arange(0,1,0.01) 160 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range] 161 | best_threshold = threshold_range[np.argmin(criterias)] 162 | unseen_mask = (ood_score > best_threshold) 163 | args.ts = best_threshold 164 | predicted[unseen_mask] = class_num 165 | 166 | one = torch.ones_like(te_labels)*class_num 167 | false = torch.ones_like(te_labels)*-1 168 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted) 169 | all_labels = torch.where(te_labels>class_num-1, one, te_labels) 170 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels) 171 | unseen_labels = torch.where(te_labels>class_num-1, one, false) 172 | correct.append(predicted.cpu().eq(seen_labels)) 173 | unseen_correct.append(predicted.cpu().eq(unseen_labels)) 174 | all_correct.append(predicted.cpu().eq(all_labels)) 175 | num_open += torch.gt(te_labels, class_num-1).sum() 176 | 177 | predicted_list.append(predicted.long().cpu()) 178 | label_list.append(all_labels.long().cpu()) 179 | 180 | 181 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4) 182 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4) 183 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4) 184 | print('Batch:(', te_idx,'/',len(target_dataloader_test),\ 185 | '\t Cumulative Results: ACC_S:', seen_acc,\ 186 | '\tACC_N:', unseen_acc,\ 187 | '\tACC_H:',h_score\ 188 | ) 189 | 190 | 191 | print('\nTest time training result:',' ACC_S:', seen_acc,\ 192 | '\tACC_N:', unseen_acc,\ 193 | '\tACC_H:',h_score,'\n\n\n\n'\ 194 | ) 195 | 196 | 197 | if args.outf != None: 198 | my_makedir(args.outf) 199 | with open (args.outf+'/results.txt','a') as f: 200 | f.write(str(args)+'\n') 201 | f.write( 202 | 'ACC_S:'+ str(seen_acc)+\ 203 | '\tACC_N:'+ str(unseen_acc)+\ 204 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\ 205 | ) -------------------------------------------------------------------------------- /cifar/models/BigResNet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | ImageNet-Style ResNet 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | ResNet adapted from: https://github.com/bearpaw/pytorch-classification 6 | SupConResNet adpated from https://github.com/HobbitLong/SupContrast 7 | """ 8 | from functools import partial 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1, is_last=False): 17 | super(BasicBlock, self).__init__() 18 | self.is_last = is_last 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion * planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion * planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | preact = out 36 | out = F.relu(out) 37 | if self.is_last: 38 | return out, preact 39 | else: 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, in_planes, planes, stride=1, is_last=False): 47 | super(Bottleneck, self).__init__() 48 | self.is_last = is_last 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion * planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(self.expansion * planes) 61 | ) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = F.relu(self.bn2(self.conv2(out))) 66 | out = self.bn3(self.conv3(out)) 67 | out += self.shortcut(x) 68 | preact = out 69 | out = F.relu(out) 70 | if self.is_last: 71 | return out, preact 72 | else: 73 | return out 74 | 75 | 76 | class ResNet(nn.Module): 77 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 78 | super(ResNet, self).__init__() 79 | self.in_planes = 64 80 | 81 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 82 | bias=False) 83 | self.bn1 = nn.BatchNorm2d(64) 84 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 85 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 86 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 87 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, is_last=True) 88 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 93 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 94 | nn.init.constant_(m.weight, 1) 95 | nn.init.constant_(m.bias, 0) 96 | 97 | # Zero-initialize the last BN in each residual branch, 98 | # so that the residual branch starts with zeros, and each residual block behaves 99 | # like an identity. This improves the model by 0.2~0.3% according to: 100 | # https://arxiv.org/abs/1706.02677 101 | if zero_init_residual: 102 | for m in self.modules(): 103 | if isinstance(m, Bottleneck): 104 | nn.init.constant_(m.bn3.weight, 0) 105 | elif isinstance(m, BasicBlock): 106 | nn.init.constant_(m.bn2.weight, 0) 107 | 108 | def _make_layer(self, block, planes, num_blocks, stride, is_last=False): 109 | strides = [stride] + [1] * (num_blocks - 1) 110 | layers = [] 111 | for i in range(num_blocks): 112 | stride = strides[i] 113 | layers.append(block(self.in_planes, planes, stride)) 114 | self.in_planes = planes * block.expansion 115 | return nn.Sequential(*layers) 116 | 117 | def forward(self, x, layer=100): 118 | out = F.relu(self.bn1(self.conv1(x))) 119 | out = self.layer1(out) 120 | out = self.layer2(out) 121 | out = self.layer3(out) 122 | out = self.layer4(out) 123 | out = self.avgpool(out) 124 | out = torch.flatten(out, 1) 125 | return out 126 | 127 | 128 | def resnet18(**kwargs): 129 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 130 | 131 | 132 | def resnet34(**kwargs): 133 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 134 | 135 | 136 | def resnet50(**kwargs): 137 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 138 | 139 | 140 | def resnet101(**kwargs): 141 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 142 | 143 | 144 | model_dict = { 145 | 'resnet18': [resnet18, 512], 146 | 'resnet34': [resnet34, 512], 147 | 'resnet50': [resnet50, 2048], 148 | 'resnet101': [resnet101, 2048], 149 | } 150 | 151 | 152 | class LinearBatchNorm(nn.Module): 153 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 154 | def __init__(self, dim, affine=True): 155 | super(LinearBatchNorm, self).__init__() 156 | self.dim = dim 157 | self.bn = nn.BatchNorm2d(dim, affine=affine) 158 | 159 | def forward(self, x): 160 | x = x.view(-1, self.dim, 1, 1) 161 | x = self.bn(x) 162 | x = x.view(-1, self.dim) 163 | return x 164 | 165 | 166 | class SupConResNet(nn.Module): 167 | """backbone + projection head""" 168 | def __init__(self, name='resnet50', head='mlp', feat_dim=128): 169 | super(SupConResNet, self).__init__() 170 | model_fun, dim_in = model_dict[name] 171 | self.encoder = model_fun() 172 | if head == 'linear': 173 | self.head = nn.Linear(dim_in, feat_dim) 174 | elif head == 'mlp': 175 | self.head = nn.Sequential( 176 | nn.Linear(dim_in, dim_in), 177 | nn.ReLU(inplace=True), 178 | nn.Linear(dim_in, feat_dim) 179 | ) 180 | else: 181 | raise NotImplementedError( 182 | 'head not supported: {}'.format(head)) 183 | 184 | def forward(self, x): 185 | feat = self.encoder(x) 186 | feat = F.normalize(self.head(feat), dim=1) 187 | return feat 188 | 189 | 190 | class LinearClassifier(nn.Module): 191 | """Linear classifier""" 192 | def __init__(self, name='resnet50', num_classes=10,num_dim=None): 193 | super(LinearClassifier, self).__init__() 194 | if num_dim==None: 195 | _, feat_dim = model_dict[name] 196 | else: 197 | feat_dim=num_dim 198 | self.fc = nn.Linear(feat_dim, num_classes) 199 | 200 | def forward(self, features,norm=False): 201 | if not norm: 202 | return self.fc(features) 203 | self.weight_norm() 204 | 205 | return self.fc(F.normalize(features))- self.fc.bias 206 | 207 | def weight_norm(self): 208 | # print(self.fc.bias.data)-+ 209 | w = self.fc.weight.data 210 | norm = w.norm(p=2, dim=1, keepdim=True) 211 | self.fc.weight.data = w.div(norm.expand_as(w)) 212 | -------------------------------------------------------------------------------- /imagenet/utils/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | from torchvision.datasets import ImageNet 6 | import os 7 | import torchvision 8 | 9 | def prepare_transforms(): 10 | train_transform = transforms.Compose([ 11 | transforms.RandomResizedCrop(224), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 15 | ]) 16 | val_transform = transforms.Compose([ 17 | transforms.Resize(256), 18 | transforms.CenterCrop(224), 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 21 | ]) 22 | val_corrupt_transform = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 25 | ]) 26 | return train_transform, val_transform, val_corrupt_transform 27 | 28 | def seed_worker(worker_id): 29 | worker_seed = torch.initial_seed() % 2**32 30 | np.random.seed(worker_seed) 31 | random.seed(worker_seed) 32 | 33 | def create_dataloader(dataset, args, shuffle=False, drop_last=False): 34 | return torch.utils.data.DataLoader(dataset, 35 | batch_size=args.batch_size, 36 | shuffle=shuffle, 37 | num_workers=args.workers, 38 | worker_init_fn=seed_worker, 39 | pin_memory=True, 40 | drop_last=drop_last) 41 | 42 | 43 | class ImageNetCorruption(ImageNet): 44 | def __init__(self, root, corruption_name="gaussian_noise", transform=None, is_carry_index=False): 45 | super().__init__(root, 'val', transform=transform) 46 | self.root = root 47 | self.corruption_name = corruption_name 48 | self.transform = transform 49 | self.is_carry_index = is_carry_index 50 | self.load_data() 51 | 52 | def load_data(self): 53 | self.data = torch.load(os.path.join(self.root, 'corruption', self.corruption_name + '.pth')).numpy() 54 | self.target = [i[1] for i in self.imgs] 55 | return 56 | 57 | def __getitem__(self, index): 58 | img = self.data[index, :, :, :] 59 | target = self.target[index] 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.is_carry_index: 63 | img = [img, index] 64 | return img, target 65 | 66 | def __len__(self): 67 | return self.data.shape[0] 68 | 69 | class ImageNet_(ImageNet): 70 | def __init__(self, *args, is_carry_index=False, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | self.is_carry_index = is_carry_index 73 | 74 | def __getitem__(self, index: int): 75 | img, target = super().__getitem__(index) 76 | if self.is_carry_index: 77 | if type(img) == list: 78 | img.append(index) 79 | else: 80 | img = [img, index] 81 | return img, target 82 | 83 | 84 | class noise_dataset(torch.utils.data.Dataset):#需要继承data.Dataset 85 | def __init__(self, transform,ratio=1): 86 | #定义好 image 的路径 87 | self.number = int(50000*ratio) 88 | self.transform = transform 89 | 90 | def __getitem__(self, index:int): 91 | image = torch.randn(3,224,224) 92 | target = 1000 93 | # if self.transform is not None: 94 | # image = self.transform(image) 95 | if type(image) == list: 96 | image.append(index) 97 | else: 98 | image = [image, index] 99 | 100 | return image, target 101 | 102 | def __len__(self): 103 | 104 | return self.number 105 | 106 | class imageneta(torchvision.datasets.ImageFolder):#需要继承data.Dataset 107 | def __init__(self, *args, **kwargs): 108 | super().__init__(*args, **kwargs) 109 | # self.is_carry_index = is_carry_index 110 | 111 | def __getitem__(self, index: int): 112 | img, target = super().__getitem__(index) 113 | 114 | if type(img) == list: 115 | img.append(index) 116 | else: 117 | img = [img, index] 118 | return img, target 119 | 120 | 121 | class MNIST_openset(torchvision.datasets.MNIST): 122 | def __init__(self, *args, ratio = 1 , **kwargs): 123 | super().__init__(*args, **kwargs) 124 | self.data, self.targets = self.data[:int(50000*ratio)], self.targets[:int(50000*ratio)] 125 | print(ratio) 126 | print(len(self.data)) 127 | return 128 | 129 | def __getitem__(self, index: int): 130 | image, target = super().__getitem__(index) 131 | target = target + 1000 132 | if type(image) == list: 133 | image.append(index) 134 | else: 135 | image = [image, index] 136 | return image, target 137 | 138 | 139 | class SVHN_openset(torchvision.datasets.SVHN): 140 | def __init__(self, *args, ratio = 1 , **kwargs): 141 | super().__init__(*args, **kwargs) 142 | self.data, self.labels = self.data[:int(50000*ratio)], self.labels[:int(50000*ratio)] 143 | print(ratio) 144 | print(len(self.data)) 145 | return 146 | 147 | def __getitem__(self, index: int): 148 | image, target = super().__getitem__(index) 149 | target = target + 1000 150 | if type(image) == list: 151 | image.append(index) 152 | else: 153 | image = [image, index] 154 | return image, target 155 | 156 | 157 | def prepare_ood_test_data_a(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None): 158 | teset_seen = imageneta(root='/cluster/personal/dataset/imagenet-a', transform=OOD_transform) 159 | print(len(teset_seen)) 160 | if OOD =='noise': 161 | teset_unseen = noise_dataset(transform,ratio=0.15) 162 | elif OOD=='SVHN': 163 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C", 164 | split='train', download=True, transform=OOD_transform, ratio=0.15) 165 | elif OOD=='MNIST': 166 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ]) 167 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C", 168 | train=True, download=True, transform=te_rize, ratio=0.15) 169 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen]) 170 | return teset 171 | 172 | def prepare_ood_test_data_r(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None): 173 | teset_seen = imageneta(root='/cluster/personal/dataset/imagenet-r', transform=OOD_transform) 174 | print(len(teset_seen)) 175 | if OOD =='noise': 176 | teset_unseen = noise_dataset(transform,ratio=0.6) 177 | elif OOD=='SVHN': 178 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C", 179 | split='train', download=True, transform=OOD_transform, ratio=0.6) 180 | elif OOD=='MNIST': 181 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ]) 182 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C", 183 | train=True, download=True, transform=te_rize, ratio=0.6) 184 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen]) 185 | return teset 186 | 187 | def prepare_test_data_r(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None): 188 | teset_seen = imageneta(root=root+'/imagenet-r', transform=OOD_transform) 189 | return teset_seen 190 | 191 | def prepare_ood_test_data(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None): 192 | teset_seen = ImageNetCorruption(root, corruption_name, transform=transform, is_carry_index=is_carry_index) 193 | if OOD =='noise': 194 | teset_unseen = noise_dataset(transform) 195 | elif OOD=='SVHN': 196 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C", 197 | split='train', download=True, transform=OOD_transform, ratio=1) 198 | elif OOD=='MNIST': 199 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ]) 200 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C", 201 | train=True, download=True, transform=te_rize, ratio=1) 202 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen]) 203 | return teset 204 | -------------------------------------------------------------------------------- /cifar/utils/test_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from utils.misc import * 5 | 6 | 7 | def load_resnet50(net, head, ssh, classifier, args): 8 | 9 | filename = args.resume + '/ckpt.pth' 10 | 11 | ckpt = torch.load(filename) 12 | state_dict = ckpt['model'] 13 | 14 | net_dict = {} 15 | head_dict = {} 16 | for k, v in state_dict.items(): 17 | if k[:4] == "head": 18 | k = k.replace("head.", "") 19 | head_dict[k] = v 20 | else: 21 | k = k.replace("encoder.", "ext.") 22 | k = k.replace("fc.", "head.fc.") 23 | net_dict[k] = v 24 | 25 | net.load_state_dict(net_dict) 26 | head.load_state_dict(head_dict) 27 | 28 | print('Loaded model trained jointly on Classification and SimCLR:', filename) 29 | 30 | def load_resnet50_sne(net, head, ssh, classifier, args): 31 | 32 | filename = args.resume 33 | ckpt = torch.load(filename) 34 | net.load_state_dict(ckpt) 35 | 36 | print('Loaded model trained jointly on Classification and SimCLR:', filename) 37 | 38 | def load_robust_resnet50(net, head, ssh, classifier, args): 39 | 40 | filename = args.resume 41 | try: 42 | sd = torch.load(filename,map_location='cuda:0')['state_dict'] 43 | except: 44 | sd = torch.load(filename) 45 | model_dict=net.state_dict() 46 | 47 | ckpt={} 48 | for k, v in sd.items(): 49 | print(k) 50 | k = k.replace("backbone.",'') 51 | if k[:12] == "module.model": 52 | if not 'linear' in k: 53 | k = k.replace("module.model.", "ext.") 54 | ckpt[k] = v 55 | else: 56 | k = k.replace("module.model.linear",'head.fc') 57 | ckpt[k] = v 58 | else: 59 | if not ('linear' in k or 'logits' in k): 60 | 61 | if not 'fc'in k: 62 | k = 'ext.'+k 63 | ckpt[k] = v 64 | else: 65 | k = 'head.'+k 66 | ckpt[k]=v 67 | else: 68 | k = k.replace("linear",'head.fc') 69 | k = k.replace("logits",'head.fc') 70 | ckpt[k] = v 71 | ckpt = {k: v for k, v in ckpt.items() if k in model_dict} 72 | 73 | net.load_state_dict(ckpt) 74 | 75 | print('Loaded robust model:', filename) 76 | 77 | 78 | def load_ttt(net, head, ssh, classifier, args, ttt=False): 79 | if ttt: 80 | filename = args.resume + '/{}_both_2_15.pth'.format(args.corruption) 81 | else: 82 | filename = args.resume + '/{}_both_15.pth'.format(args.corruption) 83 | ckpt = torch.load(filename) 84 | net.load_state_dict(ckpt['net']) 85 | head.load_state_dict(ckpt['head']) 86 | print('Loaded updated model from', filename) 87 | 88 | 89 | def corrupt_resnet50(ext, args): 90 | try: 91 | # SSL trained encoder 92 | simclr = torch.load(args.restore + '/simclr.pth') 93 | state_dict = simclr['model'] 94 | 95 | ext_dict = {} 96 | for k, v in state_dict.items(): 97 | if k[:7] == "encoder": 98 | k = k.replace("encoder.", "") 99 | ext_dict[k] = v 100 | ext.load_state_dict(ext_dict) 101 | 102 | print('Corrupted encoder trained by SimCLR') 103 | 104 | except: 105 | # Jointly trained encoder 106 | filename = args.resume + '/ckpt_epoch_{}.pth'.format(args.restore) 107 | 108 | ckpt = torch.load(filename) 109 | state_dict = ckpt['model'] 110 | 111 | ext_dict = {} 112 | for k, v in state_dict.items(): 113 | if k[:7] == "encoder": 114 | k = k.replace("encoder.", "") 115 | ext_dict[k] = v 116 | ext.load_state_dict(ext_dict) 117 | print('Corrupted encoder jontly trained on Classification and SimCLR') 118 | 119 | 120 | def build_resnet50(args): 121 | from models.BigResNet import SupConResNet, LinearClassifier 122 | from models.SSHead import ExtractorHead 123 | 124 | print('Building ResNet50...') 125 | if args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD': 126 | classes = 10 127 | if args.dataset == 'cifar10': 128 | classes = 10 129 | elif args.dataset == 'cifar7': 130 | if not hasattr(args, 'modified') or args.modified: 131 | classes = 7 132 | else: 133 | classes = 10 134 | elif args.dataset == "cifar100" or args.dataset == "cifar100OOD": 135 | classes = 100 136 | 137 | classifier = LinearClassifier(num_classes=classes).cuda() 138 | ssh = SupConResNet().cuda() 139 | head = ssh.head 140 | ext = ssh.encoder 141 | net = ExtractorHead(ext, classifier).cuda() 142 | return net, ext, head, ssh, classifier 143 | 144 | def build_net(args): 145 | from models.BigResNet import SupConResNet, LinearClassifier 146 | from models.SSHead import ExtractorHead 147 | from models.dm import CIFAR10_MEAN, CIFAR10_STD, \ 148 | DMWideResNet, Swish, DMPreActResNet 149 | 150 | print('Building '+args.net) 151 | if args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD': 152 | classes = 10 153 | if args.dataset == 'cifar10': 154 | classes = 10 155 | elif args.dataset == 'cifar7': 156 | if not hasattr(args, 'modified') or args.modified: 157 | classes = 7 158 | else: 159 | classes = 10 160 | elif args.dataset == "cifar100": 161 | classes = 100 162 | 163 | if args.net == "dm": 164 | ext=DMPreActResNet(num_classes=10, 165 | depth=18, 166 | width=0, 167 | activation_fn=Swish, 168 | mean=CIFAR10_MEAN, 169 | std=CIFAR10_STD) 170 | 171 | classifier = LinearClassifier(num_classes=classes,num_dim=ext.num_out).cuda() 172 | ssh = SupConResNet().cuda() 173 | elif args.net == "standard": 174 | from models.wide import WideResNet 175 | ext=WideResNet(depth=28, widen_factor=10) 176 | 177 | classifier = LinearClassifier(num_classes=classes,num_dim=ext.num_out).cuda() 178 | ssh = SupConResNet().cuda() 179 | elif args.net == "resnet18": 180 | classifier = LinearClassifier(num_classes=classes,num_dim=512).cuda() 181 | ssh = SupConResNet(name='resnet18').cuda() 182 | import torchvision.models as models 183 | ext = models.resnet18().cuda() 184 | ext.fc=nn.Sequential().cuda() 185 | else: 186 | classifier = LinearClassifier(num_classes=classes).cuda() 187 | ssh = SupConResNet().cuda() 188 | ext = ssh.encoder 189 | 190 | head = ssh.head 191 | # 192 | net = ExtractorHead(ext, classifier).cuda() 193 | return net, ext, head, ssh, classifier 194 | 195 | 196 | def build_model(args): 197 | from models.ResNet import ResNetCifar as ResNet 198 | from models.SSHead import ExtractorHead 199 | print('Building model...') 200 | if args.dataset == 'cifar10': 201 | classes = 10 202 | elif args.dataset == 'cifar7': 203 | if not hasattr(args, 'modified') or args.modified: 204 | classes = 7 205 | else: 206 | classes = 10 207 | elif args.dataset == "cifar100": 208 | classes = 100 209 | 210 | if args.group_norm == 0: 211 | norm_layer = nn.BatchNorm2d 212 | else: 213 | def gn_helper(planes): 214 | return nn.GroupNorm(args.group_norm, planes) 215 | norm_layer = gn_helper 216 | 217 | if hasattr(args, 'detach') and args.detach: 218 | detach = args.shared 219 | else: 220 | detach = None 221 | net = ResNet(args.depth, args.width, channels=3, classes=classes, norm_layer=norm_layer, detach=detach).cuda() 222 | if args.shared == 'none': 223 | args.shared = None 224 | 225 | if args.shared == 'layer3' or args.shared is None: 226 | from models.SSHead import extractor_from_layer3 227 | ext = extractor_from_layer3(net) 228 | if not hasattr(args, 'ssl') or args.ssl == 'rotation': 229 | head = nn.Linear(64 * args.width, 4) 230 | elif args.ssl == 'contrastive': 231 | head = nn.Sequential( 232 | nn.Linear(64 * args.width, 64 * args.width), 233 | nn.ReLU(inplace=True), 234 | nn.Linear(64 * args.width, 16 * args.width) 235 | ) 236 | else: 237 | raise NotImplementedError 238 | elif args.shared == 'layer2': 239 | from models.SSHead import extractor_from_layer2, head_on_layer2 240 | ext = extractor_from_layer2(net) 241 | head = head_on_layer2(net, args.width, 4) 242 | ssh = ExtractorHead(ext, head).cuda() 243 | 244 | if hasattr(args, 'parallel') and args.parallel: 245 | net = torch.nn.DataParallel(net) 246 | ssh = torch.nn.DataParallel(ssh) 247 | return net, ext, head, ssh 248 | 249 | 250 | def test(dataloader, model, **kwargs): 251 | criterion = nn.CrossEntropyLoss(reduction='none').cuda() 252 | model.eval() 253 | correct = [] 254 | losses = [] 255 | for batch_idx, (inputs, labels) in enumerate(dataloader): 256 | if type(inputs) == list: 257 | inputs = inputs[0] 258 | inputs, labels = inputs.cuda(), labels.cuda() 259 | with torch.no_grad(): 260 | outputs = model(inputs, **kwargs) 261 | _, predicted = outputs.max(1) 262 | correct.append(predicted.eq(labels).cpu()) 263 | correct = torch.cat(correct).numpy() 264 | model.train() 265 | return 1-correct.mean(), correct, losses 266 | 267 | def prototype_test(dataloader, ext,prototype, **kwargs): 268 | # criterion = nn.CrossEntropyLoss(reduction='none').cuda() 269 | ext.eval() 270 | correct = [] 271 | losses = [] 272 | for batch_idx, (inputs, labels) in enumerate(dataloader): 273 | if type(inputs) == list: 274 | inputs = inputs[0] 275 | inputs, labels = inputs.cuda(), labels.cuda() 276 | with torch.no_grad(): 277 | feat = ext(inputs, **kwargs) 278 | outputs = torch.mm(torch.nn.functional.normalize(feat), prototype.t()) 279 | _, predicted = outputs.max(1) 280 | correct.append(predicted.eq(labels).cpu()) 281 | correct = torch.cat(correct).numpy() 282 | ext.train() 283 | return 1-correct.mean(), correct, losses 284 | 285 | 286 | def pair_buckets(o1, o2): 287 | crr = np.logical_and( o1, o2 ) 288 | crw = np.logical_and( o1, np.logical_not(o2) ) 289 | cwr = np.logical_and( np.logical_not(o1), o2 ) 290 | cww = np.logical_and( np.logical_not(o1), np.logical_not(o2) ) 291 | return crr, crw, cwr, cww 292 | 293 | 294 | def count_each(tuple): 295 | return [item.sum() for item in tuple] 296 | 297 | 298 | def plot_epochs(all_err_cls, all_err_ssh, fname, use_agg=True): 299 | import matplotlib.pyplot as plt 300 | if use_agg: 301 | plt.switch_backend('agg') 302 | 303 | plt.plot(np.asarray(all_err_cls)*100, color='r', label='classifier') 304 | plt.plot(np.asarray(all_err_ssh)*100, color='b', label='self-supervised') 305 | plt.xlabel('epoch') 306 | plt.ylabel('test error (%)') 307 | plt.legend() 308 | plt.savefig(fname) 309 | plt.close() 310 | 311 | 312 | @torch.jit.script 313 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 314 | """Entropy of softmax distribution from logits.""" 315 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 316 | 317 | -------------------------------------------------------------------------------- /cifar/models/dm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Deepmind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """WideResNet implementation in PyTorch. From: 16 | https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py 17 | """ 18 | 19 | from typing import Tuple, Type, Union 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | 25 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) 26 | CIFAR10_STD = (0.2471, 0.2435, 0.2616) 27 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) 28 | CIFAR100_STD = (0.2673, 0.2564, 0.2762) 29 | 30 | 31 | class _Swish(torch.autograd.Function): 32 | """Custom implementation of swish.""" 33 | 34 | @staticmethod 35 | def forward(ctx, i): 36 | result = i * torch.sigmoid(i) 37 | ctx.save_for_backward(i) 38 | return result 39 | 40 | @staticmethod 41 | def backward(ctx, grad_output): 42 | i = ctx.saved_variables[0] 43 | sigmoid_i = torch.sigmoid(i) 44 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 45 | 46 | 47 | class Swish(nn.Module): 48 | """Module using custom implementation.""" 49 | 50 | def forward(self, input_tensor): 51 | return _Swish.apply(input_tensor) 52 | 53 | 54 | class _Block(nn.Module): 55 | """WideResNet Block.""" 56 | 57 | def __init__(self, 58 | in_planes, 59 | out_planes, 60 | stride, 61 | activation_fn: Type[nn.Module] = nn.ReLU): 62 | super().__init__() 63 | self.batchnorm_0 = nn.BatchNorm2d(in_planes) 64 | self.relu_0 = activation_fn() 65 | # We manually pad to obtain the same effect as `SAME` (necessary when 66 | # `stride` is different than 1). 67 | self.conv_0 = nn.Conv2d(in_planes, 68 | out_planes, 69 | kernel_size=3, 70 | stride=stride, 71 | padding=0, 72 | bias=False) 73 | self.batchnorm_1 = nn.BatchNorm2d(out_planes) 74 | self.relu_1 = activation_fn() 75 | self.conv_1 = nn.Conv2d(out_planes, 76 | out_planes, 77 | kernel_size=3, 78 | stride=1, 79 | padding=1, 80 | bias=False) 81 | self.has_shortcut = in_planes != out_planes 82 | if self.has_shortcut: 83 | self.shortcut = nn.Conv2d(in_planes, 84 | out_planes, 85 | kernel_size=1, 86 | stride=stride, 87 | padding=0, 88 | bias=False) 89 | else: 90 | self.shortcut = None 91 | self._stride = stride 92 | 93 | def forward(self, x): 94 | if self.has_shortcut: 95 | x = self.relu_0(self.batchnorm_0(x)) 96 | else: 97 | out = self.relu_0(self.batchnorm_0(x)) 98 | v = x if self.has_shortcut else out 99 | if self._stride == 1: 100 | v = F.pad(v, (1, 1, 1, 1)) 101 | elif self._stride == 2: 102 | v = F.pad(v, (0, 1, 0, 1)) 103 | else: 104 | raise ValueError('Unsupported `stride`.') 105 | out = self.conv_0(v) 106 | out = self.relu_1(self.batchnorm_1(out)) 107 | out = self.conv_1(out) 108 | out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) 109 | return out 110 | 111 | 112 | class _BlockGroup(nn.Module): 113 | """WideResNet block group.""" 114 | 115 | def __init__(self, 116 | num_blocks, 117 | in_planes, 118 | out_planes, 119 | stride, 120 | activation_fn: Type[nn.Module] = nn.ReLU): 121 | super().__init__() 122 | block = [] 123 | for i in range(num_blocks): 124 | block.append( 125 | _Block(i == 0 and in_planes or out_planes, 126 | out_planes, 127 | i == 0 and stride or 1, 128 | activation_fn=activation_fn)) 129 | self.block = nn.Sequential(*block) 130 | 131 | def forward(self, x): 132 | return self.block(x) 133 | 134 | 135 | class DMWideResNet(nn.Module): 136 | """WideResNet.""" 137 | 138 | def __init__(self, 139 | num_classes: int = 10, 140 | depth: int = 28, 141 | width: int = 10, 142 | activation_fn: Type[nn.Module] = nn.ReLU, 143 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 144 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 145 | padding: int = 0, 146 | num_input_channels: int = 3): 147 | super().__init__() 148 | # persistent=False to not put these tensors in the module's state_dict and not try to 149 | # load it from the checkpoint 150 | self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), 151 | persistent=False) 152 | self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), 153 | persistent=False) 154 | self.padding = padding 155 | num_channels = [16, 16 * width, 32 * width, 64 * width] 156 | self.num_out=num_channels[3] 157 | assert (depth - 4) % 6 == 0 158 | num_blocks = (depth - 4) // 6 159 | self.init_conv = nn.Conv2d(num_input_channels, 160 | num_channels[0], 161 | kernel_size=3, 162 | stride=1, 163 | padding=1, 164 | bias=False) 165 | self.layer = nn.Sequential( 166 | _BlockGroup(num_blocks, 167 | num_channels[0], 168 | num_channels[1], 169 | 1, 170 | activation_fn=activation_fn), 171 | _BlockGroup(num_blocks, 172 | num_channels[1], 173 | num_channels[2], 174 | 2, 175 | activation_fn=activation_fn), 176 | _BlockGroup(num_blocks, 177 | num_channels[2], 178 | num_channels[3], 179 | 2, 180 | activation_fn=activation_fn)) 181 | self.batchnorm = nn.BatchNorm2d(num_channels[3]) 182 | self.relu = activation_fn() 183 | # self.logits = nn.Linear(num_channels[3], num_classes) 184 | self.num_channels = num_channels[3] 185 | 186 | def forward(self, x): 187 | if self.padding > 0: 188 | x = F.pad(x, (self.padding,) * 4) 189 | out = (x - self.mean) / self.std 190 | out = self.init_conv(out) 191 | out = self.layer(out) 192 | out = self.relu(self.batchnorm(out)) 193 | out = F.avg_pool2d(out, 8) 194 | out = out.view(-1, self.num_channels) 195 | return out 196 | 197 | 198 | class _PreActBlock(nn.Module): 199 | """Pre-activation ResNet Block.""" 200 | 201 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): 202 | super().__init__() 203 | self._stride = stride 204 | self.batchnorm_0 = nn.BatchNorm2d(in_planes) 205 | self.relu_0 = activation_fn() 206 | # We manually pad to obtain the same effect as `SAME` (necessary when 207 | # `stride` is different than 1). 208 | self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 209 | stride=stride, padding=0, bias=False) 210 | self.batchnorm_1 = nn.BatchNorm2d(out_planes) 211 | self.relu_1 = activation_fn() 212 | self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 213 | padding=1, bias=False) 214 | self.has_shortcut = stride != 1 or in_planes != out_planes 215 | if self.has_shortcut: 216 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, 217 | stride=stride, padding=0, bias=False) 218 | 219 | def _pad(self, x): 220 | if self._stride == 1: 221 | x = F.pad(x, (1, 1, 1, 1)) 222 | elif self._stride == 2: 223 | x = F.pad(x, (0, 1, 0, 1)) 224 | else: 225 | raise ValueError('Unsupported `stride`.') 226 | return x 227 | 228 | def forward(self, x): 229 | out = self.relu_0(self.batchnorm_0(x)) 230 | shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x 231 | out = self.conv_2d_1(self._pad(out)) 232 | out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) 233 | return out + shortcut 234 | 235 | 236 | class DMPreActResNet(nn.Module): 237 | """Pre-activation ResNet.""" 238 | 239 | def __init__(self, 240 | num_classes: int = 10, 241 | depth: int = 18, 242 | width: int = 0, # Used to make the constructor consistent. 243 | activation_fn: Type[nn.Module] = nn.ReLU, 244 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, 245 | std: Union[Tuple[float, ...], float] = CIFAR10_STD, 246 | padding: int = 0, 247 | num_input_channels: int = 3, 248 | use_cuda: bool = True): 249 | super().__init__() 250 | if width != 0: 251 | raise ValueError('Unsupported `width`.') 252 | # persistent=False to not put these tensors in the module's state_dict and not try to 253 | # load it from the checkpoint 254 | self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), 255 | persistent=False) 256 | self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), 257 | persistent=False) 258 | self.mean_cuda = None 259 | self.std_cuda = None 260 | self.padding = padding 261 | self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, 262 | padding=1, bias=False) 263 | if depth == 18: 264 | num_blocks = (2, 2, 2, 2) 265 | elif depth == 34: 266 | num_blocks = (3, 4, 6, 3) 267 | else: 268 | raise ValueError('Unsupported `depth`.') 269 | self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) 270 | self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) 271 | self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) 272 | self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) 273 | self.batchnorm = nn.BatchNorm2d(512) 274 | self.relu = activation_fn() 275 | self.num_out = 512 276 | # self.logits = nn.Linear(512, num_classes) 277 | 278 | def _make_layer(self, in_planes, out_planes, num_blocks, stride, 279 | activation_fn): 280 | layers = [] 281 | for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): 282 | layers.append( 283 | _PreActBlock(i == 0 and in_planes or out_planes, 284 | out_planes, 285 | stride, 286 | activation_fn)) 287 | return nn.Sequential(*layers) 288 | 289 | def forward(self, x): 290 | if self.padding > 0: 291 | x = F.pad(x, (self.padding,) * 4) 292 | out = (x - self.mean) / self.std 293 | out = self.conv_2d(out) 294 | out = self.layer_0(out) 295 | out = self.layer_1(out) 296 | out = self.layer_2(out) 297 | out = self.layer_3(out) 298 | out = self.relu(self.batchnorm(out)) 299 | out = F.avg_pool2d(out, 4) 300 | out = out.view(out.size(0), -1) 301 | # return self.logits(out) 302 | return out -------------------------------------------------------------------------------- /imagenet/utils/test_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.resnet import SupCEResNet 3 | import os 4 | 5 | def build_model(): 6 | print("Building ResNet50...") 7 | model = SupCEResNet().cuda() 8 | ext = model.encoder 9 | classifier = model.fc 10 | return model, ext, classifier 11 | 12 | 13 | def test(dataloader, model, **kwargs): 14 | model.eval() 15 | correct = [] 16 | for batch_idx, (inputs, labels) in enumerate(dataloader): 17 | if type(inputs) == list: 18 | inputs = inputs[0] 19 | inputs, labels = inputs.cuda(), labels.cuda() 20 | with torch.no_grad(): 21 | outputs = model(inputs, **kwargs) 22 | _, predicted = outputs.max(1) 23 | correct.append(predicted.eq(labels).cpu()) 24 | correct = torch.cat(correct).numpy() 25 | model.train() 26 | return 1-correct.mean(), correct 27 | 28 | def my_makedir(name): 29 | try: 30 | os.makedirs(name) 31 | except OSError: 32 | pass 33 | 34 | all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] 35 | 36 | imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} -------------------------------------------------------------------------------- /cifar/OURS.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | import torch.utils.data as data 5 | 6 | from utils.misc import * 7 | from utils.test_helpers import * 8 | from utils.prepare_dataset import * 9 | 10 | # ---------------------------------- 11 | import copy 12 | import random 13 | import numpy as np 14 | from utils.contrastive import * 15 | from utils.offline import * 16 | from torch import nn 17 | import torch.nn.functional as F 18 | # ---------------------------------- 19 | 20 | 21 | def compute_os_variance(os, th): 22 | """ 23 | Calculate the area of a rectangle. 24 | 25 | Parameters: 26 | os : OOD score queue. 27 | th : Given threshold to separate weak and strong OOD samples. 28 | 29 | Returns: 30 | float: Weighted variance at the given threshold th. 31 | """ 32 | 33 | thresholded_os = np.zeros(os.shape) 34 | thresholded_os[os >= th] = 1 35 | 36 | # compute weights 37 | nb_pixels = os.size 38 | nb_pixels1 = np.count_nonzero(thresholded_os) 39 | weight1 = nb_pixels1 / nb_pixels 40 | weight0 = 1 - weight1 41 | 42 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered 43 | # in the search for the best threshold 44 | if weight1 == 0 or weight0 == 0: 45 | return np.inf 46 | 47 | # find all pixels belonging to each class 48 | val_pixels1 = os[thresholded_os == 1] 49 | val_pixels0 = os[thresholded_os == 0] 50 | 51 | # compute variance of these classes 52 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0 53 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0 54 | 55 | return weight0 * var0 + weight1 * var1 56 | 57 | 58 | 59 | 60 | class Prototype_Pool(nn.Module): 61 | 62 | """ 63 | Prototype pool containing strong OOD prototypes. 64 | 65 | Methods: 66 | __init__: Constructor method to initialize the prototype pool, storing the values of delta, the number of weak OOD categories, and the maximum count of strong OOD prototypes. 67 | forward: Method to farward pass, return the cosine similarity with strong OOD prototypes. 68 | update_pool: Method to append and delete strong OOD prototypes. 69 | """ 70 | 71 | 72 | def __init__(self, delta=0.1, class_num=10, max=100): 73 | super(Prototype_Pool, self).__init__() 74 | 75 | self.class_num=class_num 76 | self.max_length = max 77 | self.flag = 0 78 | self.delta = delta 79 | 80 | 81 | def forward(self, x, all=False): 82 | 83 | # if the flag is 0, the prototype pool is empty, return None. 84 | if not self.flag: 85 | return None 86 | 87 | # compute the cosine similarity between the features and the strong OOD prototypes. 88 | out = torch.mm(x, self.memory.t()) 89 | 90 | if all==True: 91 | # if all is True, return the cosine similarity with all the strong OOD prototypes. 92 | return out 93 | else: 94 | # if all is False, return the cosine similarity with the nearest strong OOD prototype. 95 | return torch.max(out/(self.delta),dim=1)[0].unsqueeze(1) 96 | 97 | 98 | def update_pool(self, feature): 99 | 100 | if not self.flag: 101 | # if the flag is 0, the prototype pool is empty, use the feature to init the prototype pool. 102 | self.register_buffer('memory', feature.detach()) 103 | self.flag = 1 104 | else: 105 | if self.memory.shape[0] < self.max_length: 106 | # if the number of strong OOD prototypes is less than the maximum count of strong OOD prototypes, append the feature to the prototype pool. 107 | self.memory = torch.cat([self.memory, feature.detach()],dim=0) 108 | else: 109 | # else then delete the earlest appended strong OOD prototype and append the feature to the prototype pool. 110 | self.memory = torch.cat([self.memory[1:], feature.detach()],dim=0) 111 | self.memory = F.normalize(self.memory) 112 | 113 | 114 | def append_prototypes(pool, feat_ext, logit, ts, ts_pro): 115 | """ 116 | Append strong OOD prototypes to the prototype pool. 117 | 118 | Parameters: 119 | pool : Prototype pool. 120 | feat_ext : Normalized features of the input images. 121 | logit : Cosine similarity between the features and the weak OOD prototypes. 122 | ts : Threshold to separate weak and strong OOD samples. 123 | ts_pro : Threshold to append strong OOD prototypes. 124 | 125 | """ 126 | added_list=[] 127 | update = 1 128 | 129 | while update: 130 | feat_mat = pool(F.normalize(feat_ext),all=True) 131 | if not feat_mat==None: 132 | new_logit = torch.cat([logit, feat_mat], 1) 133 | else: 134 | new_logit = logit 135 | 136 | r_i_pro, _ = new_logit.max(dim=-1) 137 | 138 | r_i, _ = logit.max(dim=-1) 139 | 140 | if added_list!=[]: 141 | for add in added_list: 142 | # if added_list is not empty, set the cosine similarity between the added features and the strong OOD prototypes to 1, to avoid the added features to be appended to the prototype pool again. 143 | r_i[add]=1 144 | min_logit , min_index = r_i.min(dim=0) 145 | 146 | 147 | if (1-min_logit) > ts : 148 | # if the cosine similarity between the feature and the weak OOD prototypes is less than the threshold ts, the feature is a strong OOD sample. 149 | added_list.append(min_index) 150 | if (1-r_i_pro[min_index]) > ts_pro: 151 | # if this strong OOD sample is far away from all the strong OOD prototypes, append it to the prototype pool. 152 | pool.update_pool(F.normalize(feat_ext[min_index].unsqueeze(0))) 153 | else: 154 | # all the features are weak OOD samples, stop the loop. 155 | update=0 156 | 157 | 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument('--dataset', default='cifar10OOD') 160 | parser.add_argument('--strong_OOD', default='noise') 161 | parser.add_argument('--strong_ratio', default=1, type=float) 162 | parser.add_argument('--dataroot', default="./data", help='path to dataset') 163 | parser.add_argument('--batch_size', default=256, type=int) 164 | parser.add_argument('--workers', default=4, type=int) 165 | parser.add_argument('--lr', default=0.001, type=float) 166 | parser.add_argument('--delta', default=0.1, type=float) 167 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale') 168 | parser.add_argument('--outf', help='folder to output log') 169 | parser.add_argument('--level', default=5, type=int) 170 | parser.add_argument('--N_m', default=512, type=int, help='queue length') 171 | parser.add_argument('--corruption', default='snow') 172 | parser.add_argument('--resume', default='/cluster/personal/code/TTT/TTAC-master/cifar/results/cifar10_joint_resnet50', help='directory of pretrained model') 173 | parser.add_argument('--da_scale', default=1, type=float, help='distribution alignment loss scale') 174 | parser.add_argument('--model', default='resnet50', help='resnet50') 175 | parser.add_argument('--seed', default=0, type=int) 176 | parser.add_argument('--max_prototypes', default=100, type=int) 177 | parser.add_argument('--save', action='store_true', default=False, help='save the model final checkpoint') 178 | 179 | 180 | # ----------- Args and Dataloader ------------ 181 | args = parser.parse_args() 182 | 183 | print(args) 184 | print('\n') 185 | 186 | 187 | 188 | 189 | class_num = 10 if args.dataset == 'cifar10OOD' else 100 190 | 191 | net, ext, head, ssh, classifier = build_resnet50(args) 192 | 193 | teset, _ = prepare_test_data(args) 194 | teloader = data.DataLoader(teset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=seed_worker, pin_memory=True, drop_last=False) 195 | 196 | pool = Prototype_Pool(args.delta,class_num=class_num,max = args.max_prototypes).cuda() 197 | 198 | # ------------------------------- 199 | print('Resuming from %s...' %(args.resume)) 200 | 201 | load_resnet50(net, head, ssh, classifier, args) 202 | 203 | optimizer = optim.SGD(ext.parameters(), lr=args.lr, momentum=0.9) 204 | 205 | # ----------- Offline Feature Summarization ------------ 206 | args_align = copy.deepcopy(args) 207 | 208 | _, offlineloader = prepare_train_data(args_align) 209 | ext_src_mu, ext_src_cov, ssh_src_mu, ssh_src_cov, mu_src_ext, cov_src_ext, mu_src_ssh, cov_src_ssh = offline(args,offlineloader, ext, classifier, head, class_num) 210 | 211 | ext_src_mu = torch.stack(ext_src_mu) 212 | ext_src_cov = torch.stack(ext_src_cov) 213 | 214 | ema_ext_mu = ext_src_mu.clone() 215 | ema_ext_cov = ext_src_cov.clone() 216 | ema_ext_total_mu = torch.zeros(2048).float() 217 | ema_ext_total_cov = torch.zeros(2048, 2048).float() 218 | 219 | if class_num == 10: 220 | loss_scale = 0.05 221 | ema_length = 128 222 | else: 223 | loss_scale = 0.05 224 | ema_length = 64 225 | 226 | ema_n = torch.zeros(class_num).cuda() 227 | ema_total_n = 0. 228 | weak_prototype = F.normalize(ext_src_mu.clone()).cuda() 229 | args.ts_pro = 0.0 230 | bias = cov_src_ext.max().item() / 30. 231 | template_ext_cov = torch.eye(2048).cuda() * bias 232 | 233 | torch.manual_seed(args.seed) 234 | random.seed(args.seed) 235 | np.random.seed(args.seed) 236 | torch.cuda.manual_seed(args.seed) 237 | torch.cuda.manual_seed_all(args.seed) 238 | 239 | # ----------- Open-World Test-time Training ------------ 240 | 241 | correct = [] 242 | unseen_correct= [] 243 | all_correct=[] 244 | cumulative_error = [] 245 | num_open = 0 246 | predicted_list=[] 247 | label_list=[] 248 | 249 | os_training_queue = [] 250 | os_inference_queue = [] 251 | queue_length = args.N_m 252 | ce_scale = args.ce_scale 253 | 254 | ema_total_n = 0. 255 | 256 | print('\n-----Test-Time Training with OURS-----') 257 | for te_idx, (te_inputs, te_labels) in enumerate(teloader): 258 | classifier.eval() 259 | ext.eval() 260 | 261 | optimizer.zero_grad() 262 | loss = torch.tensor(0.).cuda() 263 | 264 | if isinstance(te_inputs,list): 265 | inputs = te_inputs[0].cuda() 266 | else: 267 | inputs = te_inputs.cuda() 268 | 269 | # features extracted by backbone 270 | feat_ext = ext(inputs) 271 | 272 | # logits of the input images, used to compute the cosine similarity between the features and the weak OOD prototypes. 273 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) / args.delta 274 | 275 | 276 | # compute the cosine similarity between the features and the strong OOD prototypes. 277 | feat_mat = pool(F.normalize(feat_ext)) 278 | if not feat_mat==None: 279 | new_logit = torch.cat([logit, feat_mat], 1) 280 | else: 281 | new_logit = logit 282 | 283 | pro, predicted = new_logit[:,:class_num].max(dim=-1) 284 | 285 | # compute the ood score of the input images. 286 | ood_score = 1-pro*args.delta 287 | os_training_queue.extend(ood_score.detach().cpu().tolist()) 288 | os_training_queue = os_training_queue[-queue_length:] 289 | 290 | 291 | threshold_range = np.arange(0,1,0.01) 292 | criterias = [compute_os_variance(np.array(os_training_queue), th) for th in threshold_range] 293 | 294 | # best threshold is the one minimizing the variance of the two classes 295 | best_threshold = threshold_range[np.argmin(criterias)] 296 | args.ts = best_threshold 297 | seen_mask = (ood_score < args.ts) 298 | unseen_mask = (ood_score >= args.ts) 299 | r_i, pseudo_labels = new_logit.max(dim=-1) 300 | 301 | if unseen_mask.sum().item()!=0: 302 | #compute ts_pro to append new strong OOD prototypes to the prototype pool. 303 | 304 | min_logit , min_index = r_i.min(dim=0) 305 | 306 | in_score = 1-r_i*args.delta 307 | threshold_range = np.arange(0,1,0.01) 308 | criterias = [compute_os_variance(in_score[unseen_mask].detach().cpu().numpy(), th) for th in threshold_range] 309 | 310 | best_threshold = threshold_range[np.argmin(criterias)] 311 | args.ts_pro = best_threshold 312 | 313 | # append new strong OOD prototypes to the prototype pool. 314 | append_prototypes(pool, feat_ext, logit.detach()*args.delta, args.ts, args.ts_pro) 315 | 316 | len_memory = len(new_logit[0]) 317 | 318 | 319 | if len_memory!=class_num: 320 | 321 | if seen_mask.sum().item()!=0: 322 | pseudo_labels[seen_mask] = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1] 323 | if unseen_mask.sum().item()!=0: 324 | pseudo_labels[unseen_mask] = class_num 325 | else: 326 | pseudo_labels = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1] 327 | 328 | 329 | # ------distribuution alignment------ 330 | if seen_mask.sum().item()!=0: 331 | ext.train() 332 | feat_global = ext(inputs[seen_mask]) 333 | # Global Gaussian 334 | b = feat_global.shape[0] 335 | ema_total_n += b 336 | alpha = 1. / 1280 if ema_total_n > 1280 else 1. / ema_total_n 337 | delta_pre = (feat_global - ema_ext_total_mu.cuda()) 338 | delta = alpha * delta_pre.sum(dim=0) 339 | tmp_mu = ema_ext_total_mu.cuda() + delta 340 | tmp_cov = ema_ext_total_cov.cuda() + alpha * (delta_pre.t() @ delta_pre - b * ema_ext_total_cov.cuda()) - delta[:, None] @ delta[None, :] 341 | with torch.no_grad(): 342 | ema_ext_total_mu = tmp_mu.detach().cpu() 343 | ema_ext_total_cov = tmp_cov.detach().cpu() 344 | 345 | source_domain = torch.distributions.MultivariateNormal(mu_src_ext, cov_src_ext + template_ext_cov) 346 | target_domain = torch.distributions.MultivariateNormal(tmp_mu, tmp_cov + template_ext_cov) 347 | loss += args.da_scale*(torch.distributions.kl_divergence(source_domain, target_domain) + torch.distributions.kl_divergence(target_domain, source_domain)) * loss_scale 348 | 349 | 350 | # we only use 50% of samples with ood score far from τ∗ to perform prototype clustering for each batch 351 | if len_memory!=class_num and seen_mask.sum().item()!=0 and unseen_mask.sum().item()!=0: 352 | a, idx1 = torch.sort((ood_score[seen_mask]), descending=True) 353 | filter_down = a[-int(seen_mask.sum().item()*(1/2))] 354 | a, idx1 = torch.sort((ood_score[unseen_mask]), descending=True) 355 | filter_up= a[int(unseen_mask.sum().item()*(1/2))] 356 | for j in range(len(pseudo_labels)): 357 | 358 | if ood_score[j] >=filter_down and seen_mask[j]: 359 | seen_mask[j]=False 360 | if ood_score[j] <=filter_up and unseen_mask[j]: 361 | unseen_mask[j]=False 362 | 363 | 364 | if len_memory!=class_num: 365 | entropy_seen = nn.CrossEntropyLoss()(new_logit[seen_mask,:class_num],pseudo_labels[seen_mask]) 366 | entropy_unseen= nn.CrossEntropyLoss()(new_logit[unseen_mask],pseudo_labels[unseen_mask]) 367 | loss += ce_scale*(entropy_seen+ entropy_unseen)/2 368 | 369 | try: 370 | loss.backward() 371 | optimizer.step() 372 | optimizer.zero_grad() 373 | except: 374 | print('can not backward') 375 | torch.cuda.empty_cache() 376 | 377 | 378 | 379 | ####-------------------------- Test ----------------------------#### 380 | 381 | with torch.no_grad(): 382 | 383 | net.eval() 384 | feat_ext = ext(inputs) #b,2048 385 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())/args.delta 386 | update = 1 387 | 388 | 389 | softmax_logit = logit.softmax(dim=-1) 390 | # _, recall_predicted = softmax_logit.max(1) 391 | pro, predicted = softmax_logit.max(dim=-1) 392 | 393 | ood_score, max_index = logit.max(1) 394 | ood_score = 1-ood_score*args.delta 395 | os_inference_queue.extend(ood_score.detach().cpu().tolist()) 396 | os_inference_queue = os_inference_queue[-queue_length:] 397 | 398 | threshold_range = np.arange(0,1,0.01) 399 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range] 400 | best_threshold = threshold_range[np.argmin(criterias)] 401 | unseen_mask = (ood_score > best_threshold) 402 | args.ts = best_threshold 403 | predicted[unseen_mask] = class_num 404 | 405 | one = torch.ones_like(te_labels)*class_num 406 | false = torch.ones_like(te_labels)*-1 407 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted) 408 | all_labels = torch.where(te_labels>class_num-1, one, te_labels) 409 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels) 410 | unseen_labels = torch.where(te_labels>class_num-1, one, false) 411 | correct.append(predicted.cpu().eq(seen_labels)) 412 | unseen_correct.append(predicted.cpu().eq(unseen_labels)) 413 | all_correct.append(predicted.cpu().eq(all_labels)) 414 | num_open += torch.gt(te_labels, 99).sum() 415 | 416 | predicted_list.append(predicted.long().cpu()) 417 | label_list.append(all_labels.long().cpu()) 418 | 419 | 420 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4) 421 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4) 422 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4) 423 | print('Batch:(', te_idx,'/',len(teloader), ')\tloss:',"%.2f" % loss.item(),\ 424 | '\t Cumulative Results: ACC_S:', seen_acc,\ 425 | '\tACC_N:', unseen_acc,\ 426 | '\tACC_H:',h_score\ 427 | ) 428 | 429 | 430 | print('\nTest time training result:',' ACC_S:', seen_acc,\ 431 | '\tACC_N:', unseen_acc,\ 432 | '\tACC_H:',h_score,'\n\n\n\n'\ 433 | ) 434 | 435 | 436 | if args.outf != None: 437 | my_makedir(args.outf) 438 | with open (args.outf+'/results.txt','a') as f: 439 | f.write(str(args)+'\n') 440 | f.write( 441 | 'ACC_S:'+ str(seen_acc)+\ 442 | '\tACC_N:'+ str(unseen_acc)+\ 443 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\ 444 | ) 445 | if args.save: 446 | torch.save(net.state_dict(), os.path.join(args.outf, 'final.pth')) -------------------------------------------------------------------------------- /imagenet/OURS.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | import torch.utils.data as data 5 | 6 | import torch.nn as nn 7 | from utils.test_helpers import * 8 | from utils.prepare_dataset import * 9 | 10 | # ---------------------------------- 11 | import copy 12 | import random 13 | import numpy as np 14 | 15 | from utils.test_helpers import build_model, test 16 | from utils.prepare_dataset import prepare_transforms, create_dataloader, ImageNetCorruption, ImageNet_, prepare_ood_test_data,prepare_ood_test_data_r 17 | from utils.offline import offline, offline_r 18 | import torch.nn.functional as F 19 | # ---------------------------------- 20 | 21 | 22 | def compute_os_variance(os, th): 23 | """ 24 | Calculate the area of a rectangle. 25 | 26 | Parameters: 27 | os : OOD score queue. 28 | th : Given threshold to separate weak and strong OOD samples. 29 | 30 | Returns: 31 | float: Weighted variance at the given threshold th. 32 | """ 33 | 34 | thresholded_os = np.zeros(os.shape) 35 | thresholded_os[os >= th] = 1 36 | 37 | # compute weights 38 | nb_pixels = os.size 39 | nb_pixels1 = np.count_nonzero(thresholded_os) 40 | weight1 = nb_pixels1 / nb_pixels 41 | weight0 = 1 - weight1 42 | 43 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered 44 | # in the search for the best threshold 45 | if weight1 == 0 or weight0 == 0: 46 | return np.inf 47 | 48 | # find all pixels belonging to each class 49 | val_pixels1 = os[thresholded_os == 1] 50 | val_pixels0 = os[thresholded_os == 0] 51 | 52 | # compute variance of these classes 53 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0 54 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0 55 | 56 | return weight0 * var0 + weight1 * var1 57 | 58 | 59 | 60 | 61 | class Prototype_Pool(nn.Module): 62 | 63 | """ 64 | Prototype pool containing strong OOD prototypes. 65 | 66 | Methods: 67 | __init__: Constructor method to initialize the prototype pool, storing the values of delta, the number of weak OOD categories, and the maximum count of strong OOD prototypes. 68 | forward: Method to farward pass, return the cosine similarity with strong OOD prototypes. 69 | update_pool: Method to append and delete strong OOD prototypes. 70 | """ 71 | 72 | 73 | def __init__(self, delta=0.1, class_num=10, max=100): 74 | super(Prototype_Pool, self).__init__() 75 | 76 | self.class_num=class_num 77 | self.max_length = max 78 | self.flag = 0 79 | self.delta = delta 80 | 81 | 82 | def forward(self, x, all=False): 83 | 84 | # if the flag is 0, the prototype pool is empty, return None. 85 | if not self.flag: 86 | return None 87 | 88 | # compute the cosine similarity between the features and the strong OOD prototypes. 89 | out = torch.mm(x, self.memory.t()) 90 | 91 | if all==True: 92 | # if all is True, return the cosine similarity with all the strong OOD prototypes. 93 | return out 94 | else: 95 | # if all is False, return the cosine similarity with the nearest strong OOD prototype. 96 | return torch.max(out/(self.delta),dim=1)[0].unsqueeze(1) 97 | 98 | 99 | def update_pool(self, feature): 100 | 101 | if not self.flag: 102 | # if the flag is 0, the prototype pool is empty, use the feature to init the prototype pool. 103 | self.register_buffer('memory', feature.detach()) 104 | self.flag = 1 105 | else: 106 | if self.memory.shape[0] < self.max_length: 107 | # if the number of strong OOD prototypes is less than the maximum count of strong OOD prototypes, append the feature to the prototype pool. 108 | self.memory = torch.cat([self.memory, feature.detach()],dim=0) 109 | else: 110 | # else then delete the earlest appended strong OOD prototype and append the feature to the prototype pool. 111 | self.memory = torch.cat([self.memory[1:], feature.detach()],dim=0) 112 | self.memory = F.normalize(self.memory) 113 | 114 | 115 | def append_prototypes(pool, feat_ext, logit, ts, ts_pro): 116 | """ 117 | Append strong OOD prototypes to the prototype pool. 118 | 119 | Parameters: 120 | pool : Prototype pool. 121 | feat_ext : Normalized features of the input images. 122 | logit : Cosine similarity between the features and the weak OOD prototypes. 123 | ts : Threshold to separate weak and strong OOD samples. 124 | ts_pro : Threshold to append strong OOD prototypes. 125 | 126 | """ 127 | added_list=[] 128 | update = 1 129 | 130 | while update: 131 | feat_mat = pool(F.normalize(feat_ext),all=True) 132 | if not feat_mat==None: 133 | new_logit = torch.cat([logit, feat_mat], 1) 134 | else: 135 | new_logit = logit 136 | 137 | r_i_pro, _ = new_logit.max(dim=-1) 138 | 139 | r_i, _ = logit.max(dim=-1) 140 | 141 | if added_list!=[]: 142 | for add in added_list: 143 | # if added_list is not empty, set the cosine similarity between the added features and the strong OOD prototypes to 1, to avoid the added features to be appended to the prototype pool again. 144 | r_i[add]=1 145 | min_logit , min_index = r_i.min(dim=0) 146 | 147 | 148 | if (1-min_logit) > ts : 149 | # if the cosine similarity between the feature and the weak OOD prototypes is less than the threshold ts, the feature is a strong OOD sample. 150 | added_list.append(min_index) 151 | if (1-r_i_pro[min_index]) > ts_pro: 152 | # if this strong OOD sample is far away from all the strong OOD prototypes, append it to the prototype pool. 153 | pool.update_pool(F.normalize(feat_ext[min_index].unsqueeze(0))) 154 | else: 155 | # all the features are weak OOD samples, stop the loop. 156 | update=0 157 | 158 | 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument('--dataset', default='ImageNet-C') 161 | parser.add_argument('--strong_OOD', default='noise') 162 | parser.add_argument('--strong_ratio', default=1, type=float) 163 | parser.add_argument('--dataroot', default='./data') 164 | parser.add_argument('--batch_size', default=128, type=int) 165 | parser.add_argument('--workers', default=8, type=int) 166 | parser.add_argument('--lr', default=0.001, type=float) 167 | parser.add_argument('--delta', default=0.1, type=float) 168 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale') 169 | parser.add_argument('--outf', help='folder to output log') 170 | parser.add_argument('--level', default=5, type=int) 171 | parser.add_argument('--N_m', default=512, type=int, help='queue length') 172 | parser.add_argument('--corruption', default='snow') 173 | parser.add_argument('--offline', default='./results/offline/', help='directory of pretrained model') 174 | parser.add_argument('--da_scale', default=1, type=float, help='distribution alignment loss scale') 175 | parser.add_argument('--model', default='resnet50', help='resnet50') 176 | parser.add_argument('--seed', default=0, type=int) 177 | parser.add_argument('--max_prototypes', default=100, type=int) 178 | parser.add_argument('--save', action='store_true', default=False, help='save the model final checkpoint') 179 | 180 | 181 | # ----------- Args and Dataloader ------------ 182 | args = parser.parse_args() 183 | 184 | print(args) 185 | print('\n') 186 | 187 | my_makedir(args.offline) 188 | 189 | 190 | net, ext, classifier = build_model() 191 | 192 | 193 | train_transform, val_transform, val_corrupt_transform = prepare_transforms() 194 | 195 | source_dataset = ImageNet_(args.dataroot, 'val', transform=val_transform, is_carry_index=True) 196 | 197 | if args.dataset == 'ImageNet-C': 198 | target_dataset_test = prepare_ood_test_data(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform) 199 | class_num = 1000 200 | 201 | elif args.dataset == 'ImageNet-R': 202 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids] 203 | target_dataset_test = prepare_ood_test_data_r(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform) 204 | class_num = 200 205 | else: 206 | raise NotImplementedError 207 | 208 | source_dataloader = create_dataloader(source_dataset, args, True, False) 209 | target_dataloader_test = create_dataloader(target_dataset_test, args, True, False) 210 | 211 | pool = Prototype_Pool(args.delta,class_num=class_num,max = args.max_prototypes).cuda() 212 | 213 | 214 | # ----------- Offline Feature Summarization ------------ 215 | if args.dataset == 'ImageNet-C': 216 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline(args, source_dataloader, ext, classifier) 217 | weak_prototype = F.normalize(ext_mean_categories.clone()).cuda() 218 | else: 219 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline_r(args, source_dataloader, ext, classifier) 220 | weak_prototype = F.normalize(ext_mean_categories[indices_in_1k].clone()).cuda() 221 | 222 | 223 | 224 | sample_predict_ema_logit = torch.zeros(len(target_dataset_test), class_num, dtype=torch.float) 225 | sample_alpha = torch.ones(len(target_dataset_test), dtype=torch.float) 226 | 227 | ema_alpha = 0.9 228 | ema_ext_mu = ext_mean_categories.clone() 229 | ema_ext_cov = ext_cov_categories.clone() 230 | ema_ext_total_mu = torch.zeros(2048).cuda() 231 | ema_ext_total_cov = torch.zeros(2048, 2048).cuda() 232 | 233 | class_ema_length = 64 234 | ema_n = torch.ones(class_num).cuda() * class_ema_length 235 | ema_total_n = 0. 236 | 237 | loss_scale = 0.05 238 | ce_scale = args.ce_scale 239 | 240 | args.ts_pro = 0.0 241 | bias = ext_cov.max().item() / 30. 242 | template_ext_cov = torch.eye(2048).cuda() * bias 243 | 244 | optimizer = optim.SGD(ext.parameters(), lr=args.lr, momentum=0.9) 245 | 246 | torch.manual_seed(args.seed) 247 | random.seed(args.seed) 248 | np.random.seed(args.seed) 249 | torch.cuda.manual_seed(args.seed) 250 | torch.cuda.manual_seed_all(args.seed) 251 | 252 | # ----------- Open-World Test-time Training ------------ 253 | 254 | correct = [] 255 | unseen_correct= [] 256 | all_correct=[] 257 | cumulative_error = [] 258 | num_open = 0 259 | predicted_list=[] 260 | label_list=[] 261 | 262 | os_training_queue = [] 263 | os_inference_queue = [] 264 | queue_length = args.N_m 265 | ce_scale = args.ce_scale 266 | 267 | ema_total_n = 0. 268 | 269 | print('\n-----Test-Time Training with OURS-----') 270 | for te_idx, (te_inputs, te_labels) in enumerate(target_dataloader_test): 271 | classifier.eval() 272 | ext.eval() 273 | 274 | optimizer.zero_grad() 275 | loss = torch.tensor(0.).cuda() 276 | 277 | if isinstance(te_inputs,list): 278 | inputs = te_inputs[0].cuda() 279 | else: 280 | inputs = te_inputs.cuda() 281 | 282 | # features extracted by backbone 283 | feat_ext = ext(inputs) 284 | 285 | # logits of the input images, used to compute the cosine similarity between the features and the weak OOD prototypes. 286 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) / args.delta 287 | 288 | 289 | # compute the cosine similarity between the features and the strong OOD prototypes. 290 | feat_mat = pool(F.normalize(feat_ext)) 291 | if not feat_mat==None: 292 | new_logit = torch.cat([logit, feat_mat], 1) 293 | else: 294 | new_logit = logit 295 | 296 | pro, predicted = new_logit[:,:class_num].max(dim=-1) 297 | 298 | # compute the ood score of the input images. 299 | ood_score = 1-pro*args.delta 300 | os_training_queue.extend(ood_score.detach().cpu().tolist()) 301 | os_training_queue = os_training_queue[-queue_length:] 302 | 303 | 304 | threshold_range = np.arange(0,1,0.01) 305 | criterias = [compute_os_variance(np.array(os_training_queue), th) for th in threshold_range] 306 | 307 | # best threshold is the one minimizing the variance of the two classes 308 | best_threshold = threshold_range[np.argmin(criterias)] 309 | args.ts = best_threshold 310 | seen_mask = (ood_score < args.ts) 311 | unseen_mask = (ood_score >= args.ts) 312 | r_i, pseudo_labels = new_logit.max(dim=-1) 313 | 314 | if unseen_mask.sum().item()!=0: 315 | #compute ts_pro to append new strong OOD prototypes to the prototype pool. 316 | 317 | min_logit , min_index = r_i.min(dim=0) 318 | 319 | in_score = 1-r_i*args.delta 320 | threshold_range = np.arange(0,1,0.01) 321 | criterias = [compute_os_variance(in_score[unseen_mask].detach().cpu().numpy(), th) for th in threshold_range] 322 | 323 | best_threshold = threshold_range[np.argmin(criterias)] 324 | args.ts_pro = best_threshold 325 | 326 | # append new strong OOD prototypes to the prototype pool. 327 | append_prototypes(pool, feat_ext, logit.detach()*args.delta, args.ts, args.ts_pro) 328 | 329 | len_memory = len(new_logit[0]) 330 | 331 | 332 | if len_memory!=class_num: 333 | 334 | if seen_mask.sum().item()!=0: 335 | pseudo_labels[seen_mask] = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1] 336 | if unseen_mask.sum().item()!=0: 337 | pseudo_labels[unseen_mask] = class_num 338 | else: 339 | pseudo_labels = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1] 340 | 341 | 342 | # ------distribuution alignment------ 343 | if seen_mask.sum().item()!=0: 344 | ext.train() 345 | feat_global = ext(inputs[seen_mask]) 346 | # Global Gaussian 347 | b = feat_global.shape[0] 348 | ema_total_n += b 349 | alpha = 1. / 1280 if ema_total_n > 1280 else 1. / ema_total_n 350 | delta_pre = (feat_global - ema_ext_total_mu.cuda()) 351 | delta = alpha * delta_pre.sum(dim=0) 352 | tmp_mu = ema_ext_total_mu.cuda() + delta 353 | tmp_cov = ema_ext_total_cov.cuda() + alpha * (delta_pre.t() @ delta_pre - b * ema_ext_total_cov.cuda()) - delta[:, None] @ delta[None, :] 354 | with torch.no_grad(): 355 | ema_ext_total_mu = tmp_mu.detach().cpu() 356 | ema_ext_total_cov = tmp_cov.detach().cpu() 357 | 358 | source_domain = torch.distributions.MultivariateNormal(ext_mean, ext_cov + template_ext_cov) 359 | target_domain = torch.distributions.MultivariateNormal(tmp_mu, tmp_cov + template_ext_cov) 360 | global_loss=(torch.distributions.kl_divergence(source_domain, target_domain) + torch.distributions.kl_divergence(target_domain, source_domain)) * loss_scale 361 | 362 | loss += args.da_scale*global_loss 363 | 364 | 365 | # we only use 50% of samples with ood score far from τ∗ to perform prototype clustering for each batch 366 | if len_memory!=class_num and seen_mask.sum().item()!=0 and unseen_mask.sum().item()!=0: 367 | a, idx1 = torch.sort((ood_score[seen_mask]), descending=True) 368 | filter_down = a[-int(seen_mask.sum().item()*(1/2))] 369 | a, idx1 = torch.sort((ood_score[unseen_mask]), descending=True) 370 | filter_up= a[int(unseen_mask.sum().item()*(1/2))] 371 | for j in range(len(pseudo_labels)): 372 | 373 | if ood_score[j] >=filter_down and seen_mask[j]: 374 | seen_mask[j]=False 375 | if ood_score[j] <=filter_up and unseen_mask[j]: 376 | unseen_mask[j]=False 377 | 378 | 379 | if len_memory!=class_num: 380 | entropy_seen = nn.CrossEntropyLoss()(new_logit[seen_mask,:class_num],pseudo_labels[seen_mask]) 381 | entropy_unseen= nn.CrossEntropyLoss()(new_logit[unseen_mask],pseudo_labels[unseen_mask]) 382 | loss += ce_scale*(entropy_seen+ entropy_unseen)/2 383 | 384 | try: 385 | loss.backward() 386 | optimizer.step() 387 | optimizer.zero_grad() 388 | except: 389 | print('can not backward') 390 | torch.cuda.empty_cache() 391 | 392 | 393 | 394 | ####-------------------------- Test ----------------------------#### 395 | 396 | with torch.no_grad(): 397 | 398 | net.eval() 399 | feat_ext = ext(inputs) #b,2048 400 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())/args.delta 401 | 402 | 403 | softmax_logit = logit.softmax(dim=-1) 404 | pro, predicted = softmax_logit.max(dim=-1) 405 | 406 | ood_score, max_index = logit.max(1) 407 | ood_score = 1-ood_score*args.delta 408 | os_inference_queue.extend(ood_score.detach().cpu().tolist()) 409 | os_inference_queue = os_inference_queue[-queue_length:] 410 | 411 | threshold_range = np.arange(0,1,0.01) 412 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range] 413 | best_threshold = threshold_range[np.argmin(criterias)] 414 | unseen_mask = (ood_score > best_threshold) 415 | args.ts = best_threshold 416 | predicted[unseen_mask] = class_num 417 | 418 | one = torch.ones_like(te_labels)*class_num 419 | false = torch.ones_like(te_labels)*-1 420 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted) 421 | all_labels = torch.where(te_labels>class_num-1, one, te_labels) 422 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels) 423 | unseen_labels = torch.where(te_labels>class_num-1, one, false) 424 | correct.append(predicted.cpu().eq(seen_labels)) 425 | unseen_correct.append(predicted.cpu().eq(unseen_labels)) 426 | all_correct.append(predicted.cpu().eq(all_labels)) 427 | num_open += torch.gt(te_labels, class_num-1).sum() 428 | 429 | predicted_list.append(predicted.long().cpu()) 430 | label_list.append(all_labels.long().cpu()) 431 | 432 | 433 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4) 434 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4) 435 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4) 436 | print('Batch:(', te_idx,'/',len(target_dataloader_test), ')\tloss:',"%.2f" % loss.item(),\ 437 | '\t Cumulative Results: ACC_S:', seen_acc,\ 438 | '\tACC_N:', unseen_acc,\ 439 | '\tACC_H:',h_score\ 440 | ) 441 | 442 | 443 | print('\nTest time training result:',' ACC_S:', seen_acc,\ 444 | '\tACC_N:', unseen_acc,\ 445 | '\tACC_H:',h_score,'\n\n\n\n'\ 446 | ) 447 | 448 | 449 | if args.outf != None: 450 | my_makedir(args.outf) 451 | with open (args.outf+'/results.txt','a') as f: 452 | f.write(str(args)+'\n') 453 | f.write( 454 | 'ACC_S:'+ str(seen_acc)+\ 455 | '\tACC_N:'+ str(unseen_acc)+\ 456 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\ 457 | ) 458 | if args.save: 459 | torch.save(net.state_dict(), os.path.join(args.outf, 'final.pth')) -------------------------------------------------------------------------------- /imagenet/utils/offline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import statistics 4 | import os 5 | 6 | def covariance(features): 7 | assert len(features.size()) == 2, "TODO: multi-dimensional feature map covariance" 8 | n = features.shape[0] 9 | tmp = torch.ones((1, n), device=features.device) @ features 10 | cov = (features.t() @ features - (tmp.t() @ tmp) / n) / n 11 | return cov 12 | 13 | def coral(cs, ct): 14 | d = cs.shape[0] 15 | loss = (cs - ct).pow(2).sum() / (4. * d ** 2) 16 | return loss 17 | 18 | 19 | def linear_mmd(ms, mt): 20 | loss = (ms - mt).pow(2).mean() 21 | return loss 22 | 23 | def offline(args, trloader, ext, classifier, num_classes=1000): 24 | if os.path.exists(args.offline+'/offline.pth'): 25 | data = torch.load(args.offline+'/offline.pth') 26 | return data 27 | 28 | ext.eval() 29 | 30 | feat_ext_mean = torch.zeros(2048).cuda() 31 | feat_ext_variance = torch.zeros(2048, 2048).cuda() 32 | 33 | feat_ext_mean_categories = torch.zeros(num_classes, 2048).cuda() # K, D 34 | feat_ext_variance_categories = torch.zeros(num_classes, 2048).cuda() 35 | 36 | ema_n = torch.zeros(num_classes).cuda() 37 | ema_total_n = 0 38 | 39 | with torch.no_grad(): 40 | for batch_idx, (inputs, labels) in enumerate(trloader): 41 | feat = ext(inputs[0].cuda()) # N, D 42 | b, d = feat.shape 43 | labels = classifier(feat).argmax(dim=-1) 44 | 45 | feat_ext_categories = torch.zeros(num_classes, b, d).cuda() 46 | feat_ext_categories.scatter_add_(dim=0, index=labels[None, :, None].expand(-1, -1, d), src=feat[None, :, :]) 47 | 48 | num_categories = torch.zeros(num_classes, b, dtype=torch.int).cuda() 49 | num_categories.scatter_add_(dim=0, index=labels[None, :], src=torch.ones_like(labels[None, :], dtype=torch.int)) 50 | ema_n += num_categories.sum(dim=1) 51 | alpha_categories = 1 / (ema_n + 1e-10) # K 52 | delta_pre = (feat_ext_categories - feat_ext_mean_categories[:, None, :]) * num_categories[:, :, None] # K, N, D 53 | delta = alpha_categories[:, None] * delta_pre.sum(dim=1) # K, D 54 | feat_ext_mean_categories += delta 55 | feat_ext_variance_categories += alpha_categories[:, None] * ((delta_pre ** 2).sum(dim=1) - num_categories.sum(dim=1)[:, None] * feat_ext_variance_categories) \ 56 | - delta ** 2 57 | 58 | ema_total_n += b 59 | alpha = 1 / (ema_total_n + 1e-10) 60 | delta_pre = feat - feat_ext_mean[None, :] # b, d 61 | delta = alpha * (delta_pre).sum(dim=0) 62 | feat_ext_mean += delta 63 | feat_ext_variance += alpha * (delta_pre.t() @ delta_pre - b * feat_ext_variance) - delta[:, None] @ delta[None, :] 64 | print('offline process rate: %.2f%%\r' % ((batch_idx + 1) / len(trloader) * 100.), end='') 65 | 66 | 67 | torch.save((feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories), args.offline+'/offline.pth') 68 | return feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories 69 | 70 | 71 | def offline_r(args, trloader, ext, classifier, num_classes=1000): 72 | if os.path.exists(args.offline+'/offline_r.pth'): 73 | data = torch.load(args.offline+'/offline_r.pth') 74 | return data 75 | 76 | ext.eval() 77 | all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] 78 | 79 | imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} 80 | 81 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids] 82 | 83 | feat_ext_mean = torch.zeros(2048).cuda() 84 | feat_ext_variance = torch.zeros(2048, 2048).cuda() 85 | 86 | feat_ext_mean_categories = torch.zeros(num_classes, 2048).cuda() # K, D 87 | feat_ext_variance_categories = torch.zeros(num_classes, 2048).cuda() 88 | 89 | ema_n = torch.zeros(num_classes).cuda() 90 | ema_total_n = 0 91 | 92 | with torch.no_grad(): 93 | for batch_idx, (inputs, labels) in enumerate(trloader): 94 | l=[] 95 | t=[] 96 | for i in range(len(labels)): 97 | if indices_in_1k[labels[i]]==True: 98 | t.append(labels[i]) 99 | l.append(inputs[0][i].unsqueeze(0)) 100 | inputs = torch.cat(l,dim=0) 101 | feat = ext(inputs.cuda()) # N, D 102 | b, d = feat.shape 103 | labels = classifier(feat).argmax(dim=-1) 104 | 105 | feat_ext_categories = torch.zeros(num_classes, b, d).cuda() 106 | feat_ext_categories.scatter_add_(dim=0, index=labels[None, :, None].expand(-1, -1, d), src=feat[None, :, :]) 107 | 108 | num_categories = torch.zeros(num_classes, b, dtype=torch.int).cuda() 109 | num_categories.scatter_add_(dim=0, index=labels[None, :], src=torch.ones_like(labels[None, :], dtype=torch.int)) 110 | ema_n += num_categories.sum(dim=1) 111 | alpha_categories = 1 / (ema_n + 1e-10) # K 112 | delta_pre = (feat_ext_categories - feat_ext_mean_categories[:, None, :]) * num_categories[:, :, None] # K, N, D 113 | delta = alpha_categories[:, None] * delta_pre.sum(dim=1) # K, D 114 | feat_ext_mean_categories += delta 115 | feat_ext_variance_categories += alpha_categories[:, None] * ((delta_pre ** 2).sum(dim=1) - num_categories.sum(dim=1)[:, None] * feat_ext_variance_categories) \ 116 | - delta ** 2 117 | 118 | ema_total_n += b 119 | alpha = 1 / (ema_total_n + 1e-10) 120 | delta_pre = feat - feat_ext_mean[None, :] # b, d 121 | delta = alpha * (delta_pre).sum(dim=0) 122 | feat_ext_mean += delta 123 | feat_ext_variance += alpha * (delta_pre.t() @ delta_pre - b * feat_ext_variance) - delta[:, None] @ delta[None, :] 124 | print('offline process rate: %.2f%%\r' % ((batch_idx + 1) / len(trloader) * 100.), end='') 125 | 126 | 127 | torch.save((feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories), args.offline+'/offline_r.pth') 128 | return feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories 129 | -------------------------------------------------------------------------------- /cifar/utils/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import random 5 | import torchvision 6 | import numpy as np 7 | from PIL import Image 8 | import torch.utils.data 9 | import torchvision.transforms as transforms 10 | 11 | 12 | class CIFAR10(torchvision.datasets.CIFAR10): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | return 16 | 17 | def __getitem__(self, index: int): 18 | image, target = super().__getitem__(index) 19 | if type(image) == list: 20 | image.append(index) 21 | else: 22 | image = [image, index] 23 | return image, target 24 | 25 | class CIFAR100(torchvision.datasets.CIFAR100): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | return 29 | 30 | def __getitem__(self, index: int): 31 | image, target = super().__getitem__(index) 32 | if type(image) == list: 33 | image.append(index) 34 | else: 35 | image = [image, index] 36 | return image, target 37 | 38 | 39 | class CIFAR100_openset(torchvision.datasets.CIFAR100): 40 | def __init__(self,ratio=1, *args, **kwargs): 41 | super().__init__(*args, **kwargs) 42 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)] 43 | return 44 | 45 | def __getitem__(self, index: int): 46 | image, target = super().__getitem__(index) 47 | target = target + 1000 48 | if type(image) == list: 49 | image.append(index) 50 | else: 51 | image = [image, index] 52 | return image, target 53 | 54 | class CIFAR10_openset(torchvision.datasets.CIFAR10): 55 | def __init__(self,ratio=1, *args, **kwargs): 56 | super().__init__(*args, **kwargs) 57 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)] 58 | return 59 | 60 | def __getitem__(self, index: int): 61 | image, target = super().__getitem__(index) 62 | target = target + 1000 63 | if type(image) == list: 64 | image.append(index) 65 | else: 66 | image = [image, index] 67 | return image, target 68 | 69 | class noise_dataset(torch.utils.data.Dataset): 70 | def __init__(self, transform,ratio=1): 71 | self.number = int(10000*ratio) 72 | self.transform = transform 73 | 74 | def __getitem__(self, index:int): 75 | image = torch.randn(3,32,32) 76 | target = 1000 77 | if type(image) == list: 78 | image.append(index) 79 | else: 80 | image = [image, index] 81 | 82 | return image, target 83 | 84 | def __len__(self): 85 | 86 | return self.number 87 | 88 | class MNIST_openset(torchvision.datasets.MNIST): 89 | def __init__(self, *args, ratio = 1 , **kwargs): 90 | super().__init__(*args, **kwargs) 91 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)] 92 | return 93 | 94 | def __getitem__(self, index: int): 95 | image, target = super().__getitem__(index) 96 | target = target + 1000 97 | if type(image) == list: 98 | image.append(index) 99 | else: 100 | image = [image, index] 101 | return image, target 102 | 103 | class SVHN_openset(torchvision.datasets.SVHN): 104 | def __init__(self, *args, ratio = 1 , **kwargs): 105 | super().__init__(*args, **kwargs) 106 | self.data, self.labels = self.data[:int(10000*ratio)], self.labels[:int(10000*ratio)] 107 | return 108 | 109 | def __getitem__(self, index: int): 110 | image, target = super().__getitem__(index) 111 | target = target + 1000 112 | if type(image) == list: 113 | image.append(index) 114 | else: 115 | image = [image, index] 116 | return image, target 117 | 118 | 119 | class TinyImageNet_OOD_nonoverlap(torch.utils.data.Dataset): 120 | def __init__(self, root, train=True, transform=None,list=True,ratio=1): 121 | self.Train = train 122 | self.list=list 123 | self.root_dir = root 124 | self.transform = transform 125 | self.train_dir = os.path.join(self.root_dir, "train") 126 | self.val_dir = os.path.join(self.root_dir, "val") 127 | self.ratio = ratio 128 | 129 | self.class_list = ['n03544143', 'n03255030', 'n04532106', 'n02669723', 'n02321529', 'n02423022', 'n03854065', 'n02509815', 'n04133789', 'n03970156', 'n01882714', 'n04023962', 'n01768244', 'n04596742', 'n03447447', 'n03617480', 'n07720875', 'n02125311', 'n02793495', 'n04532670'] 130 | 131 | if (self.Train): 132 | self._create_class_idx_dict_train() 133 | else: 134 | self._create_class_idx_dict_val() 135 | 136 | self._make_dataset(self.Train) 137 | 138 | words_file = os.path.join(self.root_dir, "words.txt") 139 | wnids_file = os.path.join(self.root_dir, "wnids.txt") 140 | 141 | self.set_nids = set() 142 | 143 | with open(wnids_file, 'r') as fo: 144 | data = fo.readlines() 145 | for entry in data: 146 | if entry.strip("\n") in self.class_list: 147 | self.set_nids.add(entry.strip("\n")) 148 | 149 | self.class_to_label = {} 150 | with open(words_file, 'r') as fo: 151 | data = fo.readlines() 152 | for entry in data: 153 | words = entry.split("\t") 154 | if words[0] in self.set_nids: 155 | self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0] 156 | 157 | def _create_class_idx_dict_train(self): 158 | if sys.version_info >= (3, 5): 159 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] 160 | else: 161 | classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir, d))] 162 | classes = sorted(classes) 163 | num_images = 0 164 | temp=[] 165 | for i in range(20): 166 | temp.append(0) 167 | for root, dirs, files in os.walk(self.train_dir): 168 | for f in files: 169 | if f.endswith(".JPEG") and f.split("_")[0] in self.class_list: 170 | for i in range(len(self.class_list)): 171 | if f.split("_")[0] == self.class_list[i]: 172 | 173 | 174 | if temp[i] < 500: 175 | temp[i]+=1 176 | num_images = num_images + 1 177 | break 178 | self.len_dataset = num_images; 179 | 180 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} 181 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} 182 | 183 | def _create_class_idx_dict_val(self): 184 | val_image_dir = os.path.join(self.val_dir, "images") 185 | if sys.version_info >= (3, 5): 186 | images = [d.name for d in os.scandir(val_image_dir) if d.is_file()] 187 | else: 188 | images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir, d))] 189 | val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt") 190 | self.val_img_to_class = {} 191 | set_of_classes = set() 192 | with open(val_annotations_file, 'r') as fo: 193 | entry = fo.readlines() 194 | for data in entry: 195 | words = data.split("\t") 196 | if words[1] in self.class_list: 197 | self.val_img_to_class[words[0]] = words[1] 198 | set_of_classes.add(words[1]) 199 | 200 | self.len_dataset = len(list(self.val_img_to_class.keys())) 201 | classes = sorted(list(set_of_classes)) 202 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))} 203 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))} 204 | 205 | def _make_dataset(self, Train=True): 206 | self.images = [] 207 | if Train: 208 | img_root_dir = self.train_dir 209 | list_of_dirs = [target for target in self.class_to_tgt_idx.keys()] 210 | else: 211 | img_root_dir = self.val_dir 212 | list_of_dirs = ["images"] 213 | temp=[] 214 | for i in range(20): 215 | temp.append(0) 216 | for tgt in list_of_dirs: 217 | dirs = os.path.join(img_root_dir, tgt) 218 | if not os.path.isdir(dirs): 219 | continue 220 | 221 | for root, _, files in sorted(os.walk(dirs)): 222 | for fname in sorted(files): 223 | if (fname.endswith(".JPEG"))and fname.split("_")[0] in self.class_list: 224 | path = os.path.join(root, fname) 225 | if Train: 226 | item = (path, self.class_to_tgt_idx[tgt]) 227 | else: 228 | item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]]) 229 | for i in range(len(self.class_list)): 230 | if fname.split("_")[0] == self.class_list[i]: 231 | temp[i]+=1 232 | 233 | if temp[i] <= 500: 234 | self.images.append(item) 235 | print('len',len(self.images)) 236 | 237 | def return_label(self, idx): 238 | return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx] 239 | 240 | def __len__(self): 241 | return int(self.len_dataset*self.ratio) 242 | 243 | def __getitem__(self, idx:int): 244 | img_path, tgt = self.images[idx] 245 | tgt+=1000 246 | with open(img_path, 'rb') as f: 247 | sample = Image.open(img_path) 248 | sample = sample.convert('RGB') 249 | if self.transform is not None: 250 | sample = self.transform(sample) 251 | index = idx 252 | if self.list: 253 | if type(sample) == list: 254 | sample.append(index) 255 | else: 256 | sample = [sample, index] 257 | 258 | return sample, tgt 259 | 260 | 261 | def prepare_transforms(dataset): 262 | 263 | if dataset == 'cifar10': 264 | mean = (0.4914, 0.4822, 0.4465) 265 | std = (0.2023, 0.1994, 0.2010) 266 | elif dataset == 'cifar10+100' or dataset == 'cifar10OOD' : 267 | mean = (0.4914, 0.4822, 0.4465) 268 | std = (0.2023, 0.1994, 0.2010) 269 | elif dataset == 'cifar100' or dataset == 'cifar100OOD': 270 | mean = (0.5071, 0.4867, 0.4408) 271 | std = (0.2675, 0.2565, 0.2761) 272 | else: 273 | raise NotImplementedError 274 | 275 | normalize = transforms.Normalize(mean=mean, std=std) 276 | 277 | te_transforms = transforms.Compose([transforms.ToTensor(), normalize]) 278 | 279 | tr_transforms = transforms.Compose([ 280 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 281 | transforms.RandomHorizontalFlip(), 282 | transforms.ToTensor(), 283 | normalize]) 284 | 285 | simclr_transforms = transforms.Compose([ 286 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 287 | transforms.RandomHorizontalFlip(), 288 | transforms.RandomApply([ 289 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 290 | ], p=0.8), 291 | transforms.RandomGrayscale(p=0.2), 292 | transforms.ToTensor(), 293 | normalize 294 | ]) 295 | 296 | return tr_transforms, te_transforms, simclr_transforms 297 | 298 | class TwoCropTransform: 299 | """Create two crops of the same image""" 300 | def __init__(self, transform, te_transform): 301 | self.transform = transform 302 | self.te_transform = te_transform 303 | 304 | def __call__(self, x): 305 | return [self.transform(x), self.transform(x), self.te_transform(x)] 306 | 307 | # ------------------------- 308 | 309 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 310 | 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 311 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 312 | 313 | def seed_worker(worker_id): 314 | worker_seed = torch.initial_seed() % 2**32 315 | np.random.seed(worker_seed) 316 | random.seed(worker_seed) 317 | 318 | 319 | def prepare_test_data(args, ttt=False, num_sample=None, align=False): 320 | 321 | tr_transforms, te_transforms, simclr_transforms = prepare_transforms(args.dataset) 322 | 323 | if args.dataset == 'cifar10OOD': 324 | 325 | tesize = 10000 326 | if args.corruption in common_corruptions: 327 | 328 | print('Test on %s level %d' %(args.corruption, args.level)) 329 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption)) 330 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize] 331 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption)) 332 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize] 333 | teset_10 = CIFAR10(root=args.dataroot, 334 | train=False, download=True, transform=te_transforms) 335 | teset_10.data = teset_raw_10 336 | 337 | if args.strong_OOD == 'MNIST': 338 | te_rize = transforms.Compose([transforms.Resize(size=(32, 32)), transforms.Grayscale(3), te_transforms ]) 339 | noise = MNIST_openset(root=args.dataroot, 340 | train=False, download=True, transform=te_rize, ratio=args.strong_ratio) 341 | 342 | teset = torch.utils.data.ConcatDataset([teset_10,noise]) 343 | 344 | elif args.strong_OOD == 'noise': 345 | noise = noise_dataset(te_transforms, args.strong_ratio) 346 | 347 | teset = torch.utils.data.ConcatDataset([teset_10,noise]) 348 | 349 | elif args.strong_OOD =='cifar100': 350 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/snow.npy') 351 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize] 352 | teset_100 = CIFAR100_openset(root=args.dataroot, 353 | train=False, download=True, transform=te_transforms, ratio=args.strong_ratio) 354 | teset_100.data = teset_raw_100[:int(10000*args.strong_ratio)] 355 | teset = torch.utils.data.ConcatDataset([teset_10,teset_100]) 356 | 357 | elif args.strong_OOD =='SVHN': 358 | te_rize = transforms.Compose([te_transforms ]) 359 | noise = SVHN_openset(root=args.dataroot, 360 | split='test', download=True, transform=te_rize, ratio=args.strong_ratio) 361 | 362 | teset = torch.utils.data.ConcatDataset([teset_10,noise]) 363 | 364 | elif args.strong_OOD =='Tiny': 365 | 366 | transform_test = transforms.Compose([transforms.Resize(32), te_transforms ]) 367 | testset_tiny = TinyImageNet_OOD_nonoverlap(args.dataroot +'/tiny-imagenet-200', transform=transform_test, train=True) 368 | teset = torch.utils.data.ConcatDataset([teset_10,testset_tiny]) 369 | print(len(teset_10),len(testset_tiny),len(teset)) 370 | 371 | else: 372 | raise 373 | 374 | elif args.dataset == 'cifar100OOD': 375 | 376 | tesize = 10000 377 | 378 | if args.corruption in common_corruptions: 379 | print('Test on %s level %d' %(args.corruption, args.level)) 380 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption)) 381 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize] 382 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption)) 383 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize] 384 | teset_100 = CIFAR100(root=args.dataroot, 385 | train=False, download=True, transform=te_transforms) 386 | teset_100.data = teset_raw_100 387 | 388 | if args.strong_OOD == 'MNIST': 389 | te_rize = transforms.Compose([transforms.Resize(size=(32, 32)), transforms.Grayscale(3), te_transforms ]) 390 | noise = MNIST_openset(root=args.dataroot, 391 | train=False, download=True, transform=te_rize, ratio=args.strong_ratio) 392 | 393 | teset = torch.utils.data.ConcatDataset([teset_100,noise]) 394 | 395 | elif args.strong_OOD == 'noise': 396 | noise = noise_dataset(te_transforms, args.strong_ratio) 397 | 398 | teset = torch.utils.data.ConcatDataset([teset_100,noise]) 399 | 400 | elif args.strong_OOD =='cifar10': 401 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/snow.npy') 402 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize] 403 | teset_10 = CIFAR10_openset(root=args.dataroot, 404 | train=False, download=True, transform=te_transforms, ratio=args.strong_ratio) 405 | teset_10.data = teset_raw_10[:int(10000*args.strong_ratio)] 406 | teset = torch.utils.data.ConcatDataset([teset_100,teset_10]) 407 | 408 | elif args.strong_OOD =='SVHN': 409 | te_rize = transforms.Compose([te_transforms ]) 410 | noise = SVHN_openset(root=args.dataroot, 411 | split='test', download=True, transform=te_rize, ratio=args.strong_ratio) 412 | 413 | teset = torch.utils.data.ConcatDataset([teset_100,noise]) 414 | 415 | elif args.strong_OOD =='Tiny': 416 | 417 | transform_test = transforms.Compose([transforms.Resize(32), te_transforms ]) 418 | testset_tiny = TinyImageNet_OOD_nonoverlap(args.dataroot +'/tiny-imagenet-200', transform=transform_test, train=True) 419 | teset = torch.utils.data.ConcatDataset([teset_100,testset_tiny]) 420 | 421 | else: 422 | raise 423 | 424 | if not hasattr(args, 'workers') or args.workers < 2: 425 | pin_memory = False 426 | else: 427 | pin_memory = True 428 | 429 | if ttt: 430 | shuffle = True 431 | drop_last = True 432 | else: 433 | shuffle = True 434 | drop_last = False 435 | 436 | try: 437 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.batch_size, 438 | shuffle=shuffle, num_workers=args.workers, 439 | worker_init_fn=seed_worker, pin_memory=pin_memory, drop_last=drop_last) 440 | except: 441 | teloader = None 442 | 443 | 444 | return teset, teloader 445 | 446 | def prepare_train_data(args, num_sample=None): 447 | print('Preparing data...') 448 | 449 | tr_transforms, te_transforms, simclr_transforms = prepare_transforms(args.dataset) 450 | 451 | if args.dataset == 'cifar10' or args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD': 452 | 453 | if hasattr(args, 'ssl') and args.ssl == 'contrastive': 454 | trset = CIFAR10(root=args.dataroot, 455 | train=False, download=True, 456 | transform=TwoCropTransform(simclr_transforms, te_transforms)) 457 | if hasattr(args, 'corruption') and args.corruption in common_corruptions: 458 | print('Contrastive on %s level %d' %(args.corruption, args.level)) 459 | tesize = 10000 460 | trset_raw = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption)) 461 | trset_raw = trset_raw[(args.level-1)*tesize: args.level*tesize] 462 | trset.data = trset_raw 463 | else: 464 | print('Contrastive on ciar10 training set') 465 | else: 466 | trset = torchvision.datasets.CIFAR10(root=args.dataroot, 467 | train=True, download=True, transform=tr_transforms) 468 | print('Cifar10 training set') 469 | 470 | elif args.dataset == 'cifar100' or args.dataset == 'cifar100OOD': 471 | if hasattr(args, 'ssl') and args.ssl == 'contrastive': 472 | trset = torchvision.datasets.CIFAR100(root=args.dataroot, 473 | train=True, download=True, 474 | transform=TwoCropTransform(simclr_transforms, te_transforms)) 475 | if hasattr(args, 'corruption') and args.corruption in common_corruptions: 476 | print('Contrastive on %s level %d' %(args.corruption, args.level)) 477 | tesize = 10000 478 | trset_raw = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption)) 479 | trset_raw = trset_raw[(args.level-1)*tesize: args.level*tesize] 480 | trset.data = trset_raw 481 | else: 482 | print('Contrastive on ciar10 training set') 483 | else: 484 | trset = torchvision.datasets.CIFAR100(root=args.dataroot, 485 | train=True, download=True, transform=tr_transforms) 486 | print('Cifar100 training set') 487 | else: 488 | raise Exception('Dataset not found!') 489 | 490 | if not hasattr(args, 'workers') or args.workers < 2: 491 | pin_memory = False 492 | else: 493 | pin_memory = True 494 | 495 | if num_sample and num_sample < trset.data.shape[0]: 496 | trset.data = trset.data[:num_sample] 497 | print("Truncate the training set to {:d} samples".format(num_sample)) 498 | 499 | trloader = torch.utils.data.DataLoader(trset, batch_size=args.batch_size, 500 | shuffle=True, num_workers=args.workers, 501 | worker_init_fn=seed_worker, pin_memory=pin_memory, drop_last=False) 502 | return trset, trloader 503 | --------------------------------------------------------------------------------