├── cifar
├── requirements.txt
├── scripts
│ ├── test_cifar10.sh
│ ├── test_cifar100.sh
│ ├── ours_cifar10.sh
│ └── ours_cifar100.sh
├── models
│ ├── SSHead.py
│ ├── ResNet.py
│ ├── WideResNet.py
│ ├── wide.py
│ ├── BigResNet.py
│ └── dm.py
├── utils
│ ├── misc.py
│ ├── offline.py
│ ├── contrastive.py
│ ├── aug.py
│ ├── augmentation.py
│ ├── test_helpers.py
│ └── prepare_dataset.py
├── README.md
├── TEST.py
└── OURS.py
├── imagenet
├── requirements.txt
├── scripts
│ ├── test_r.sh
│ ├── test_c.sh
│ ├── ours_r.sh
│ └── ours_c.sh
├── model
│ └── resnet.py
├── utils
│ ├── create_corruption_dataset.py
│ ├── prepare_dataset.py
│ ├── test_helpers.py
│ └── offline.py
├── README.md
├── TEST.py
└── OURS.py
├── imgs
└── overview.png
├── LICENSE
└── README.md
/cifar/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
--------------------------------------------------------------------------------
/imagenet/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | imagenet_c
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Yushu-Li/OWTTT/HEAD/imgs/overview.png
--------------------------------------------------------------------------------
/imagenet/scripts/test_r.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 | STRONG_OOD=$1
5 |
6 |
7 | python ./TEST.py \
8 | --dataset ImageNet-R \
9 | --dataroot ./data \
10 | --strong_OOD ${STRONG_OOD}
11 |
12 |
--------------------------------------------------------------------------------
/imagenet/scripts/test_c.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 | CORRUPT=$1
5 | STRONG_OOD=$2
6 |
7 |
8 | python ./TEST.py \
9 | --dataset ImageNet-C \
10 | --dataroot ./data \
11 | --strong_OOD ${STRONG_OOD} \
12 | --corruption ${CORRUPT}
13 |
14 |
15 |
--------------------------------------------------------------------------------
/imagenet/scripts/ours_r.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 | STRONG_OOD=$1
6 |
7 |
8 | python ./OURS.py \
9 | --dataset ImageNet-R \
10 | --dataroot ./data \
11 | --strong_OOD ${STRONG_OOD} \
12 | --lr 0.001 \
13 | --delta 0.1 \
14 | --ce_scale 0.05 \
15 | --da_scale 0.1
16 |
17 |
--------------------------------------------------------------------------------
/cifar/scripts/test_cifar10.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 | CORRUPT=$1
6 | STRONG_OOD=$2
7 |
8 |
9 | python TEST.py \
10 | --dataset cifar10OOD \
11 | --dataroot ./data \
12 | --strong_OOD ${STRONG_OOD} \
13 | --resume ./results/cifar10_joint_resnet50 \
14 | --corruption ${CORRUPT}
15 |
16 |
--------------------------------------------------------------------------------
/cifar/scripts/test_cifar100.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 |
6 | CORRUPT=$1
7 | STRONG_OOD=$2
8 |
9 |
10 |
11 | python TEST.py \
12 | --dataset cifar100OOD \
13 | --dataroot ./data \
14 | --strong_OOD ${STRONG_OOD} \
15 | --resume ./results/cifar100_joint_resnet50 \
16 | --corruption ${CORRUPT}
17 |
18 |
--------------------------------------------------------------------------------
/imagenet/scripts/ours_c.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 | CORRUPT=$1
6 | STRONG_OOD=$2
7 |
8 |
9 | python ./OURS.py \
10 | --dataset ImageNet-C \
11 | --dataroot ./data \
12 | --strong_OOD ${STRONG_OOD} \
13 | --corruption ${CORRUPT} \
14 | --lr 0.001 \
15 | --delta 0.1 \
16 | --ce_scale 0.05 \
17 | --da_scale 0.1
--------------------------------------------------------------------------------
/cifar/scripts/ours_cifar10.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 | CORRUPT=$1
6 | STRONG_OOD=$2
7 |
8 | python OURS.py \
9 | --dataset cifar10OOD \
10 | --dataroot ./data \
11 | --strong_OOD ${STRONG_OOD} \
12 | --resume ./results/cifar10_joint_resnet50 \
13 | --corruption ${CORRUPT} \
14 | --lr 0.01 \
15 | --delta 0.1 \
16 | --da_scale 1 \
17 | --ce_scale 0.2
18 |
19 |
--------------------------------------------------------------------------------
/cifar/scripts/ours_cifar100.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | export PYTHONPATH=$PYTHONPATH:$(pwd)
4 |
5 | CORRUPT=$1
6 | STRONG_OOD=$2
7 |
8 | python OURS.py \
9 | --dataset cifar100OOD \
10 | --dataroot ./data \
11 | --strong_OOD ${STRONG_OOD} \
12 | --resume ./results/cifar100_joint_resnet50 \
13 | --corruption ${CORRUPT} \
14 | --lr 0.001 \
15 | --delta 0.1 \
16 | --da_scale 1 \
17 | --ce_scale 0.2
18 |
19 |
--------------------------------------------------------------------------------
/imagenet/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torchvision.models as models
3 |
4 |
5 | class SupCEResNet(nn.Module):
6 | """official Resnet for image classification, e.g., ImageNet"""
7 | def __init__(self, name='resnet50'):
8 | super(SupCEResNet, self).__init__()
9 | self.encoder = models.__dict__[name](pretrained=True)
10 | self.fc = self.encoder.fc
11 | self.encoder.fc = nn.Identity()
12 |
13 | def forward(self, x):
14 | return self.fc(self.encoder(x))
15 |
--------------------------------------------------------------------------------
/cifar/models/SSHead.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import math
3 | import copy
4 |
5 | class ViewFlatten(nn.Module):
6 | def __init__(self):
7 | super(ViewFlatten, self).__init__()
8 |
9 | def forward(self, x):
10 | return x.view(x.size(0), -1)
11 |
12 | class ExtractorHead(nn.Module):
13 | def __init__(self, ext, head):
14 | super(ExtractorHead, self).__init__()
15 | self.ext = ext
16 | self.head = head
17 |
18 | def forward(self, x):
19 | return self.head(self.ext(x))
20 |
21 | def extractor_from_layer3(net):
22 | layers = [net.conv1, net.layer1, net.layer2, net.layer3, net.bn, net.relu, net.avgpool, ViewFlatten()]
23 | return nn.Sequential(*layers)
24 |
25 | def extractor_from_layer2(net):
26 | layers = [net.conv1, net.layer1, net.layer2]
27 | return nn.Sequential(*layers)
28 |
29 | def head_on_layer2(net, width, classes):
30 | head = copy.deepcopy([net.layer3, net.bn, net.relu, net.avgpool])
31 | head.append(ViewFlatten())
32 | head.append(nn.Linear(64 * width, classes))
33 | return nn.Sequential(*head)
34 |
35 | def task_head_on_layer3(net):
36 | layers = [net.fc]
37 | return nn.Sequential(*layers)
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yushu-Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OWTTT
2 |
3 | This repository is an official implementation for our [ICCV 2023 Oral] paper.
4 |
5 | ## On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion
6 |
7 | **[Yushu Li](https://yushu-li.github.io/)1** **[Xun Xu](https://alex-xun-xu.github.io/)2** **[Yongyi Su](https://yysu.site/)1** **[Kui Jia](http://kuijia.site/)1**
8 |
9 | 1South China University of Technology
10 |
2Institute for Infocomm Research (I2R), Agency for Science, Technology and Research (A*STAR)
11 |
12 |
13 | [](https://arxiv.org/abs/2308.09942)
14 | [](https://yushu-li.github.io/owttt-site/)
15 |
16 |
17 | ### Overview
18 |
19 | 
20 |
21 |
22 | ### CIFAR10-C/CIFAR100-C
23 |
24 | The code is released in the [cifar](cifar) folder.
25 |
26 | ### ImageNet-C/ImageNet-R
27 |
28 | The code is released in the [imagenet](imagenet) folder.
29 |
30 | ### Citation
31 |
32 | If you find our work useful in your research, please consider citing:
33 |
34 | ```bibtex
35 | @inproceedings{
36 | li2023robustness,
37 | title={On the Robustness of Open-World Test-Time Training: Self-Training with Dynamic Prototype Expansion},
38 | author={Li, Yushu and Xu, Xun and Su, Yongyi and Jia, Kui},
39 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
40 | month={October},
41 | year={2023}
42 | }
43 | ```
--------------------------------------------------------------------------------
/imagenet/utils/create_corruption_dataset.py:
--------------------------------------------------------------------------------
1 | from imagenet_c import *
2 | from torchvision.datasets import ImageNet
3 | import torchvision.transforms as transforms
4 | from torch.utils.data import DataLoader
5 | import os
6 | import torch
7 | import gorilla
8 |
9 | DATA_ROOT = '/cluster/sc_download/li.yushu/imagenet_ttac'
10 | CORRUPTION_PATH = './corruption'
11 |
12 |
13 | corruption_tuple = (gaussian_noise, shot_noise, impulse_noise, defocus_blur,
14 | glass_blur, motion_blur, zoom_blur, snow, frost, fog,
15 | brightness, contrast, elastic_transform, pixelate, jpeg_compression)
16 |
17 | corruption_dict = {corr_func.__name__: corr_func for corr_func in corruption_tuple}
18 |
19 | class corrupt(object):
20 | def __init__(self, corruption_name, severity=5):
21 | self.corruption_name = corruption_name
22 | self.severity = severity
23 | return
24 |
25 | def __call__(self, x):
26 | # x: PIL.Image
27 | x_corrupted = corruption_dict[self.corruption_name](x, self.severity)
28 | return np.uint8(x_corrupted)
29 |
30 | def __repr__(self):
31 | return "Corruption(name=" + self.corruption_name + ", severity=" + str(self.severity) + ")"
32 |
33 |
34 | print(os.path.join(DATA_ROOT, CORRUPTION_PATH))
35 | if os.path.exists(os.path.join(DATA_ROOT, CORRUPTION_PATH)) is False:
36 | os.mkdir(os.path.join(DATA_ROOT, CORRUPTION_PATH))
37 |
38 |
39 |
40 | for corruption in corruption_dict.keys():
41 | if os.path.exists(os.path.join(DATA_ROOT, CORRUPTION_PATH, corruption + '.pth')):
42 | continue
43 | print(corruption)
44 | val_transform = transforms.Compose([
45 | transforms.Resize(256),
46 | transforms.CenterCrop(224),
47 | corrupt(corruption, 5)
48 | ])
49 |
50 | target_dataset = ImageNet(DATA_ROOT, 'val', transform=val_transform)
51 |
52 | target_dataloader = DataLoader(target_dataset, batch_size=256, shuffle=False, drop_last=False, num_workers=16)
53 |
54 | datas = []
55 | for batch in gorilla.track(target_dataloader):
56 | datas.append(batch[0])
57 | datas = torch.cat(datas)
58 | torch.save(datas, os.path.join(DATA_ROOT, CORRUPTION_PATH, corruption + '.pth'))
59 |
60 |
61 |
--------------------------------------------------------------------------------
/cifar/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | from colorama import Fore
5 |
6 | def get_grad(params):
7 | if isinstance(params, torch.Tensor):
8 | params = [params]
9 | params = list(filter(lambda p: p.grad is not None, params))
10 | grad = [p.grad.data.cpu().view(-1) for p in params]
11 | return torch.cat(grad)
12 |
13 | def write_to_txt(name, content):
14 | with open(name, 'w') as text_file:
15 | text_file.write(content)
16 |
17 | def my_makedir(name):
18 | try:
19 | os.makedirs(name)
20 | except OSError:
21 | pass
22 |
23 | def print_args(opt):
24 | for arg in vars(opt):
25 | print('%s %s' % (arg, getattr(opt, arg)))
26 |
27 | def mean(ls):
28 | return sum(ls) / len(ls)
29 |
30 | def normalize(v):
31 | return (v - v.mean()) / v.std()
32 |
33 | def flat_grad(grad_tuple):
34 | return torch.cat([p.view(-1) for p in grad_tuple])
35 |
36 | def print_nparams(model):
37 | nparams = sum([param.nelement() for param in model.parameters()])
38 | print('number of parameters: %d' % (nparams))
39 |
40 | def print_color(color, string):
41 | print(getattr(Fore, color) + string + Fore.RESET)
42 |
43 | def freeze_params(model):
44 | for name, p in model.named_parameters():
45 | p.requires_grad = False
46 | print("Freeze parameter until", name)
47 |
48 | def print_params(model):
49 | for name, p in model.named_parameters():
50 | print(name)
51 |
52 | class AverageMeter(object):
53 | """Computes and stores the average and current value"""
54 | def __init__(self, name, fmt=':f'):
55 | self.name = name
56 | self.fmt = fmt
57 | self.reset()
58 |
59 | def reset(self):
60 | self.val = 0
61 | self.avg = 0
62 | self.sum = 0
63 | self.count = 0
64 |
65 | def update(self, val, n=1):
66 | self.val = val
67 | self.sum += val * n
68 | self.count += n
69 | self.avg = self.sum / self.count
70 |
71 | def __str__(self):
72 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
73 | return fmtstr.format(**self.__dict__)
74 |
75 | def adjust_learning_rate(args, optimizer, epoch):
76 | lr = args.lr
77 |
78 | eta_min = lr * (args.lr_decay_rate ** 3)
79 | lr = eta_min + (lr - eta_min) * (
80 | 1 + math.cos(math.pi * epoch / args.nepoch)) / 2
81 |
82 | for param_group in optimizer.param_groups:
83 | param_group['lr'] = lr
84 |
--------------------------------------------------------------------------------
/cifar/utils/offline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import statistics
3 | import os
4 |
5 | def covariance(features):
6 | assert len(features.size()) == 2, "TODO: multi-dimensional feature map covariance"
7 | n = features.shape[0]
8 | tmp = torch.ones((1, n), device=features.device) @ features
9 | cov = (features.t() @ features - (tmp.t() @ tmp) / n) / n
10 | return cov
11 |
12 | def coral(cs, ct):
13 | d = cs.shape[0]
14 | loss = (cs - ct).pow(2).sum() / (4. * d ** 2)
15 | return loss
16 |
17 |
18 | def linear_mmd(ms, mt):
19 | loss = (ms - mt).pow(2).mean()
20 | return loss
21 |
22 | def offline(args,trloader, ext, classifier, head, class_num=10):
23 | if class_num == 10:
24 | if os.path.exists(args.resume+'/offline_cifar10.pth'):
25 | data = torch.load(args.resume+'/offline_cifar10.pth')
26 | return data
27 | elif class_num == 100:
28 | if os.path.exists(args.resume+'/offline_cifar100.pth'):
29 | data = torch.load(args.resume+'/offline_cifar100.pth')
30 | return data
31 | else:
32 | raise Exception("This function only handles CIFAR10 and CIFAR100 datasets.")
33 | ext.eval()
34 |
35 | feat_stack = [[] for i in range(class_num)]
36 | ssh_feat_stack = [[] for i in range(class_num)]
37 |
38 | with torch.no_grad():
39 | for batch_idx, (inputs, labels) in enumerate(trloader):
40 |
41 | feat = ext(inputs.cuda())
42 | predict_logit = classifier(feat)
43 | ssh_feat = predict_logit
44 |
45 | pseudo_label = predict_logit.max(dim=1)[1]
46 |
47 | for label in pseudo_label.unique():
48 | label_mask = pseudo_label == label
49 | feat_stack[label].extend(feat[label_mask, :])
50 | ssh_feat_stack[label].extend(ssh_feat[label_mask, :])
51 | ext_mu = []
52 | ext_cov = []
53 | ext_all = []
54 |
55 | ssh_mu = []
56 | ssh_cov = []
57 | ssh_all = []
58 | for feat in feat_stack:
59 | ext_mu.append(torch.stack(feat).mean(dim=0))
60 | ext_cov.append(covariance(torch.stack(feat)))
61 | ext_all.extend(feat)
62 |
63 | for feat in ssh_feat_stack:
64 | ssh_mu.append(torch.stack(feat).mean(dim=0))
65 | ssh_cov.append(covariance(torch.stack(feat)))
66 | ssh_all.extend(feat)
67 |
68 | ext_all = torch.stack(ext_all)
69 | ext_all_mu = ext_all.mean(dim=0)
70 | ext_all_cov = covariance(ext_all)
71 |
72 | ssh_all = torch.stack(ssh_all)
73 | ssh_all_mu = ssh_all.mean(dim=0)
74 | ssh_all_cov = covariance(ssh_all)
75 | if class_num == 10:
76 | torch.save((ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov), args.resume+'/offline_cifar10.pth')
77 | if class_num == 100:
78 | torch.save((ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov), args.resume+'/offline_cifar100.pth')
79 | return ext_mu, ext_cov, ssh_mu, ssh_cov, ext_all_mu, ext_all_cov, ssh_all_mu, ssh_all_cov
80 |
81 |
82 |
--------------------------------------------------------------------------------
/cifar/models/ResNet.py:
--------------------------------------------------------------------------------
1 | # Based on the ResNet implementation in torchvision
2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
3 |
4 | import math
5 | import torch
6 | from torch import nn
7 | from torchvision.models.resnet import conv3x3
8 |
9 | class BasicBlock(nn.Module):
10 | def __init__(self, inplanes, planes, norm_layer, stride=1, downsample=None):
11 | super(BasicBlock, self).__init__()
12 | self.downsample = downsample
13 | self.stride = stride
14 |
15 | self.bn1 = norm_layer(inplanes)
16 | self.relu1 = nn.ReLU(inplace=True)
17 | self.conv1 = conv3x3(inplanes, planes, stride)
18 |
19 | self.bn2 = norm_layer(planes)
20 | self.relu2 = nn.ReLU(inplace=True)
21 | self.conv2 = conv3x3(planes, planes)
22 |
23 | def forward(self, x):
24 | residual = x
25 | residual = self.bn1(residual)
26 | residual = self.relu1(residual)
27 | residual = self.conv1(residual)
28 |
29 | residual = self.bn2(residual)
30 | residual = self.relu2(residual)
31 | residual = self.conv2(residual)
32 |
33 | if self.downsample is not None:
34 | x = self.downsample(x)
35 | return x + residual
36 |
37 | class Downsample(nn.Module):
38 | def __init__(self, nIn, nOut, stride):
39 | super(Downsample, self).__init__()
40 | self.avg = nn.AvgPool2d(stride)
41 | assert nOut % nIn == 0
42 | self.expand_ratio = nOut // nIn
43 |
44 | def forward(self, x):
45 | x = self.avg(x)
46 | return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)
47 |
48 | class ResNetCifar(nn.Module):
49 | def __init__(self, depth, width=1, classes=10, channels=3, norm_layer=nn.BatchNorm2d, detach=None):
50 | assert (depth - 2) % 6 == 0 # depth is 6N+2
51 | self.N = (depth - 2) // 6
52 | super(ResNetCifar, self).__init__()
53 |
54 | # Following the Wide ResNet convention, we fix the very first convolution
55 | self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False)
56 | self.inplanes = 16
57 | self.layer1 = self._make_layer(norm_layer, 16 * width)
58 | self.layer2 = self._make_layer(norm_layer, 32 * width, stride=2)
59 | self.layer3 = self._make_layer(norm_layer, 64 * width, stride=2)
60 | self.bn = norm_layer(64 * width)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.avgpool = nn.AvgPool2d(8)
63 | self.fc = nn.Linear(64 * width, classes)
64 |
65 | # Task-agnostic encoder
66 | self.detach = detach
67 |
68 | # Initialization
69 | for m in self.modules():
70 | if isinstance(m, nn.Conv2d):
71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
72 | m.weight.data.normal_(0, math.sqrt(2. / n))
73 |
74 | def _make_layer(self, norm_layer, planes, stride=1):
75 | downsample = None
76 | if stride != 1 or self.inplanes != planes:
77 | downsample = Downsample(self.inplanes, planes, stride)
78 | layers = [BasicBlock(self.inplanes, planes, norm_layer, stride, downsample)]
79 | self.inplanes = planes
80 | for i in range(self.N - 1):
81 | layers.append(BasicBlock(self.inplanes, planes, norm_layer))
82 | return nn.Sequential(*layers)
83 |
84 | def forward(self, x):
85 | x = self.conv1(x)
86 | x = self.layer1(x)
87 | x = self.layer2(x)
88 | if self.detach == 'layer2': x = x.detach()
89 | x = self.layer3(x)
90 | x = self.bn(x)
91 | x = self.relu(x)
92 | x = self.avgpool(x)
93 | x = x.view(x.size(0), -1)
94 | if self.detach == 'layer3': x = x.detach()
95 | x = self.fc(x)
96 | return x
97 |
--------------------------------------------------------------------------------
/cifar/utils/contrastive.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SupConLoss(nn.Module):
6 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
7 | It also supports the unsupervised contrastive loss in SimCLR"""
8 | def __init__(self, temperature=0.07, contrast_mode='all',
9 | base_temperature=0.07):
10 | super(SupConLoss, self).__init__()
11 | self.temperature = temperature
12 | self.contrast_mode = contrast_mode
13 | self.base_temperature = base_temperature
14 |
15 | def forward(self, features, labels=None, mask=None):
16 | """Compute loss for model. If both `labels` and `mask` are None,
17 | it degenerates to SimCLR unsupervised loss:
18 | https://arxiv.org/pdf/2002.05709.pdf
19 |
20 | Args:
21 | features: hidden vector of shape [bsz, n_views, ...].
22 | labels: ground truth of shape [bsz].
23 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
24 | has the same class as sample i. Can be asymmetric.
25 | Returns:
26 | A loss scalar.
27 | """
28 | device = (torch.device('cuda')
29 | if features.is_cuda
30 | else torch.device('cpu'))
31 |
32 | if len(features.shape) < 3:
33 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
34 | 'at least 3 dimensions are required')
35 | if len(features.shape) > 3:
36 | features = features.view(features.shape[0], features.shape[1], -1)
37 |
38 | batch_size = features.shape[0]
39 | if labels is not None and mask is not None:
40 | raise ValueError('Cannot define both `labels` and `mask`')
41 | elif labels is None and mask is None:
42 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
43 | elif labels is not None:
44 | labels = labels.contiguous().view(-1, 1)
45 | if labels.shape[0] != batch_size:
46 | raise ValueError('Num of labels does not match num of features')
47 | mask = torch.eq(labels, labels.T).float().to(device)
48 | else:
49 | mask = mask.float().to(device)
50 |
51 | contrast_count = features.shape[1]
52 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
53 | if self.contrast_mode == 'one':
54 | anchor_feature = features[:, 0]
55 | anchor_count = 1
56 | elif self.contrast_mode == 'all':
57 | anchor_feature = contrast_feature
58 | anchor_count = contrast_count
59 | else:
60 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
61 |
62 | # compute logits
63 | anchor_dot_contrast = torch.div(
64 | torch.matmul(anchor_feature, contrast_feature.T),
65 | self.temperature)
66 | # for numerical stability
67 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
68 | logits = anchor_dot_contrast - logits_max.detach()
69 |
70 | # tile mask
71 | mask = mask.repeat(anchor_count, contrast_count)
72 | # mask-out self-contrast cases
73 | logits_mask = torch.scatter(
74 | torch.ones_like(mask),
75 | 1,
76 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
77 | 0
78 | )
79 |
80 | mask = mask * logits_mask
81 |
82 | # compute log_prob
83 | exp_logits = torch.exp(logits) * logits_mask
84 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
85 |
86 | # compute mean of log-likelihood over positive
87 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
88 |
89 | # loss
90 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
91 | loss = loss.view(anchor_count, batch_size).mean()
92 |
93 | return loss
94 |
--------------------------------------------------------------------------------
/cifar/README.md:
--------------------------------------------------------------------------------
1 | # OWTTT on CIFAR10-C/100-C
2 |
3 | Ours method and the baseline method TEST (direct test without adaptation) on CIFAR-10-C/100-C under common corruptions or natural shifts. Our implementation is based on [repo](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/cifar) and therefore requires some similar preparation processes.
4 |
5 |
6 | ### Requirements
7 |
8 | To install requirements:
9 |
10 | ```
11 | pip install -r requirements.txt
12 | ```
13 |
14 | To download datasets:
15 |
16 | ```
17 | export DATADIR=/data/cifar
18 | mkdir -p ${DATADIR} && cd ${DATADIR}
19 | wget -O CIFAR-10-C.tar https://zenodo.org/record/2535967/files/CIFAR-10-C.tar?download=1
20 | tar -xvf CIFAR-10-C.tar
21 | wget -O CIFAR-100-C.tar https://zenodo.org/record/3555552/files/CIFAR-100-C.tar?download=1
22 | tar -xvf CIFAR-100-C.tar
23 | wget -O tiny-imagenet-200.zip http://cs231n.stanford.edu/tiny-imagenet-200.zip
24 | unzip tiny-imagenet-200.zip
25 | ```
26 |
27 | ### Pre-trained Models
28 |
29 | The checkpoints of pre-train Resnet-50 can be downloaded (214MB) using the following command:
30 |
31 | ```
32 | mkdir -p results/cifar10_joint_resnet50 && cd results/cifar10_joint_resnet50
33 | gdown https://drive.google.com/uc?id=1QWyI8UrXJ6_H9lBbrq52qXWpjdpq4PUn && cd ../..
34 | mkdir -p results/cifar100_joint_resnet50 && cd results/cifar100_joint_resnet50
35 | gdown https://drive.google.com/uc?id=1cau93HVjl4aWuZlrl7cJIMEKBxXXunR9 && cd ../..
36 | ```
37 |
38 | These models are obtained by training on the clean CIFAR10/100 images using semi-supervised SimCLR.
39 |
40 | ### Open-World Test-Time Training:
41 |
42 | We present our method and the baseline method TEST (direct test without adaptation) on CIFAR10-C/100-C.
43 |
44 | - run OURS method or the baeline method TEST on CIFAR10-C under the OWTTT protocol.
45 |
46 | ```
47 | # OURS:
48 | bash scripts/ours_cifar10.sh "corruption_type" "strong_ood_type"
49 |
50 | # TEST:
51 | bash scripts/test_cifar10.sh "corruption_type" "strong_ood_type"
52 | ```
53 | Where "corruption_type" is the corruption type in CIFAR10-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN, Tiny, cifar100].
54 |
55 | For example, to run OURS or TEST on CIFAR10-C under the snow corruption with MNIST as strong OOD, we can use the following command:
56 |
57 | ```
58 | # OURS:
59 | bash scripts/ours_cifar10.sh snow MNIST
60 |
61 | # TEST:
62 | bash scripts/test_cifar10.sh snow MNIST
63 | ```
64 |
65 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD:
66 |
67 | | Method | ACC_S | ACC_N | ACC_H |
68 | |:------:|:-------:|:-------:|:-------:|
69 | | TEST | 66.36 | 91.56 | 76.95 |
70 | | OURS | 84.05 | 97.46 | 90.26|
71 |
72 | - run OURS method or the baeline method TEST on CIFAR100-C under the OWTTT protocol.
73 |
74 | ```
75 | # OURS:
76 | bash scripts/ours_cifar100.sh "corruption_type" "strong_ood_type"
77 |
78 | # TEST:
79 | bash scripts/test_cifar100.sh "corruption_type" "strong_ood_type"
80 | ```
81 | Where "corruption_type" is the corruption type in CIFAR100-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN, Tiny, cifar10].
82 |
83 | For example, to run OURS or TEST on CIFAR100-C under the snow corruption with MNIST as strong OOD, we can use the following command:
84 |
85 | ```
86 | # OURS:
87 | bash scripts/ours_cifar100.sh snow MNIST
88 |
89 | # TEST:
90 | bash scripts/test_cifar100.sh snow MNIST
91 | ```
92 |
93 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD:
94 |
95 | | Method | ACC_S | ACC_N | ACC_H |
96 | |:------:|:-------:|:-------:|:-------:|
97 | | TEST | 29.2 | 53.27 | 37.72 |
98 | | OURS | 44.78 | 93.56 | 60.57 |
99 |
100 |
101 | ### Acknowledgements
102 |
103 | Our code is built upon the public code of the [TTAC](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/cifar).
104 |
--------------------------------------------------------------------------------
/imagenet/README.md:
--------------------------------------------------------------------------------
1 | # OWTTT on ImageNet-C/R
2 |
3 | Ours method and the baseline method TEST (direct test without adaptation) on ImageNet-C/ImageNet-R under the OWTTT protocol. Our implementation is based on [repo](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/imagenet) and therefore requires some similar preparation processes.
4 |
5 | ### Requirements
6 |
7 | - To install requirements:
8 |
9 | ```
10 | pip install -r requirements.txt
11 | ```
12 |
13 | - To download ImageNet dataset:
14 |
15 | We need to firstly download the validation set and the development kit (Task 1 & 2) of ImageNet-1k on [here](https://image-net.org/challenges/LSVRC/2012/index.php), and put them under `data` folder.
16 |
17 | - To download ImageNet-R dataset:
18 |
19 | To download datasets:
20 |
21 | ```
22 | export DATADIR=/data
23 | cd ${DATADIR}
24 | wget -O imagenet-r.tar https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar
25 | tar -xvf imagenet-r.tar
26 | ```
27 |
28 | - To create the corruption dataset
29 | ```
30 | python utils/create_corruption_dataset.py
31 | ```
32 |
33 | The issue `Frost missing after pip install` can be solved following [here](https://github.com/hendrycks/robustness/issues/4#issuecomment-427226016).
34 |
35 | Finally, the structure of the `data` folder should be like
36 | ```
37 | data
38 | |_ ILSVRC2012_devkit_t12.tar
39 | |_ ILSVRC2012_img_val.tar
40 | |_ val
41 | |_ n01440764
42 | |_ ...
43 | |_ imagenet-r
44 | |_ n01443537
45 | |_ ...
46 | |_ corruption
47 | |_ brightness.pth
48 | |_ contrast.pth
49 | |_ ...
50 | |_ meta.bin
51 | ```
52 |
53 | ### Pre-trained Models
54 |
55 | Here, we use the pretrain model provided by torchvision.
56 |
57 | ### Open-World Test-Time Training:
58 |
59 | We present our method and the baseline method TEST (direct test without adaptation) on ImageNet-C/R.
60 |
61 | - run OURS method or the baseline method TEST on ImageNet-C under the OWTTT protocol.
62 |
63 | ```
64 | # OURS:
65 | bash scripts/ours_c.sh "corruption_type" "strong_ood_type"
66 |
67 | # TEST:
68 | bash scripts/test_c.sh "corruption_type" "strong_ood_type"
69 | ```
70 | Where "corruption_type" is the corruption type in ImageNet-C, and "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN].
71 |
72 | For example, to run OURS or TEST on ImageNet-C under the snow corruption with MNIST as strong OOD, we can use the following command:
73 |
74 | ```
75 | # OURS:
76 | bash scripts/ours_c.sh snow MNIST
77 |
78 | # TEST:
79 | bash scripts/test_c.sh snow MNIST
80 | ```
81 |
82 | The following results are yielded by the above scripts (%) under the snow corruption, and with MNIST as strong OOD:
83 |
84 | | Method | ACC_S | ACC_N | ACC_H |
85 | |:------:|:-------:|:-------:|:-------:|
86 | | TEST | 17.30 | 99.35 | 29.47 |
87 | | OURS | 45.34 | 100.00 | 62.39 |
88 |
89 | - run OURS method or the baseline method TEST on ImageNet-R under the OWTTT protocol.
90 |
91 | ```
92 | # OURS:
93 | bash scripts/ours_cifar100.sh "strong_ood_type"
94 |
95 | # TEST:
96 | bash scripts/test_cifar100.sh "strong_ood_type"
97 | ```
98 | Where "strong_ood_type" is the strong OOD type in [noise, MNIST, SVHN].
99 |
100 | For example, to run OURS or TEST on ImageNet-R with MNIST as strong OOD, we can use the following command:
101 |
102 | ```
103 | # OURS:
104 | bash scripts/ours_r.sh MNIST
105 |
106 | # TEST:
107 | bash scripts/test_r.sh MNIST
108 | ```
109 |
110 | The following results are yielded by the above scripts (%) with MNIST as strong OOD:
111 |
112 | | Method | ACC_S | ACC_N | ACC_H |
113 | |:------:|:-------:|:-------:|:-------:|
114 | | TEST | 35.50 | 99.96 | 52.39 |
115 | | OURS | 41.40 | 100.00 | 58.56 |
116 |
117 |
118 | ### Acknowledgements
119 |
120 | Our code is built upon the public code of the [TTAC](https://github.com/Gorilla-Lab-SCUT/TTAC/tree/master/imagenet).
121 |
--------------------------------------------------------------------------------
/cifar/utils/aug.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from PIL import ImageOps, Image
6 | from torchvision import transforms
7 |
8 |
9 | ## https://github.com/google-research/augmix
10 |
11 | def _augmix_aug(x_orig):
12 | x_orig = preaugment(x_orig)
13 | x_processed = preprocess(x_orig)
14 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
15 | m = np.float32(np.random.beta(1.0, 1.0))
16 |
17 | mix = torch.zeros_like(x_processed)
18 | for i in range(3):
19 | x_aug = x_orig.copy()
20 | for _ in range(np.random.randint(1, 4)):
21 | x_aug = np.random.choice(augmentations)(x_aug)
22 | mix += w[i] * preprocess(x_aug)
23 | mix = m * x_processed + (1 - m) * mix
24 | return mix
25 |
26 | aug = _augmix_aug
27 |
28 |
29 | def autocontrast(pil_img, level=None):
30 | return ImageOps.autocontrast(pil_img)
31 |
32 | def equalize(pil_img, level=None):
33 | return ImageOps.equalize(pil_img)
34 |
35 | def rotate(pil_img, level):
36 | degrees = int_parameter(rand_lvl(level), 30)
37 | if np.random.uniform() > 0.5:
38 | degrees = -degrees
39 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)
40 |
41 | def solarize(pil_img, level):
42 | level = int_parameter(rand_lvl(level), 256)
43 | return ImageOps.solarize(pil_img, 256 - level)
44 |
45 | def shear_x(pil_img, level):
46 | level = float_parameter(rand_lvl(level), 0.3)
47 | if np.random.uniform() > 0.5:
48 | level = -level
49 | return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128)
50 |
51 | def shear_y(pil_img, level):
52 | level = float_parameter(rand_lvl(level), 0.3)
53 | if np.random.uniform() > 0.5:
54 | level = -level
55 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR, fillcolor=128)
56 |
57 | def translate_x(pil_img, level):
58 | level = int_parameter(rand_lvl(level), 32 / 3)
59 | if np.random.random() > 0.5:
60 | level = -level
61 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128)
62 |
63 | def translate_y(pil_img, level):
64 | level = int_parameter(rand_lvl(level), 32 / 3)
65 | if np.random.random() > 0.5:
66 | level = -level
67 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR, fillcolor=128)
68 |
69 | def posterize(pil_img, level):
70 | level = int_parameter(rand_lvl(level), 4)
71 | return ImageOps.posterize(pil_img, 4 - level)
72 |
73 |
74 | def int_parameter(level, maxval):
75 | """Helper function to scale `val` between 0 and maxval .
76 | Args:
77 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
78 | maxval: Maximum value that the operation can have. This will be scaled
79 | to level/PARAMETER_MAX.
80 | Returns:
81 | An int that results from scaling `maxval` according to `level`.
82 | """
83 | return int(level * maxval / 10)
84 |
85 | def float_parameter(level, maxval):
86 | """Helper function to scale `val` between 0 and maxval .
87 | Args:
88 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
89 | maxval: Maximum value that the operation can have. This will be scaled
90 | to level/PARAMETER_MAX.
91 | Returns:
92 | A float that results from scaling `maxval` according to `level`.
93 | """
94 | return float(level) * maxval / 10.
95 |
96 | def rand_lvl(n):
97 | return np.random.uniform(low=0.1, high=n)
98 |
99 |
100 | augmentations = [
101 | autocontrast,
102 | equalize,
103 | lambda x: rotate(x, 1),
104 | lambda x: solarize(x, 1),
105 | lambda x: shear_x(x, 1),
106 | lambda x: shear_y(x, 1),
107 | lambda x: translate_x(x, 1),
108 | lambda x: translate_y(x, 1),
109 | lambda x: posterize(x, 1),
110 | ]
111 |
112 | mean = [0.5, 0.5, 0.5]
113 | std = [0.5, 0.5, 0.5]
114 | preprocess = transforms.Compose([
115 | transforms.ToTensor(),
116 | transforms.Normalize(mean, std)
117 | ])
118 | preaugment = transforms.Compose([
119 | transforms.RandomCrop(32, padding=4),
120 | transforms.RandomHorizontalFlip(),
121 | ])
--------------------------------------------------------------------------------
/cifar/models/WideResNet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class WideResNet(nn.Module):
51 | """ Based on code from https://github.com/yaodongyu/TRADES """
52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
53 | super(WideResNet, self).__init__()
54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55 | assert ((depth - 4) % 6 == 0)
56 | n = (depth - 4) / 6
57 | block = BasicBlock
58 | # 1st conv before any network block
59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60 | padding=1, bias=False)
61 | # 1st block
62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
63 | if sub_block1:
64 | # 1st sub-block
65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(nChannels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
74 | self.nChannels = nChannels[3]
75 |
76 | for m in self.modules():
77 | if isinstance(m, nn.Conv2d):
78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
79 | m.weight.data.normal_(0, math.sqrt(2. / n))
80 | elif isinstance(m, nn.BatchNorm2d):
81 | m.weight.data.fill_(1)
82 | m.bias.data.zero_()
83 | elif isinstance(m, nn.Linear) and not m.bias is None:
84 | m.bias.data.zero_()
85 |
86 | def forward(self, x):
87 | out = self.conv1(x)
88 | out = self.block1(out)
89 | out = self.block2(out)
90 | out = self.block3(out)
91 | out = self.relu(self.bn1(out))
92 | out = F.avg_pool2d(out, 8)
93 | out = out.view(-1, self.nChannels)
94 | return self.fc(out)
95 |
--------------------------------------------------------------------------------
/cifar/models/wide.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class WideResNet(nn.Module):
51 | """ Based on code from https://github.com/yaodongyu/TRADES """
52 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True):
53 | super(WideResNet, self).__init__()
54 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
55 | assert ((depth - 4) % 6 == 0)
56 | n = (depth - 4) / 6
57 | block = BasicBlock
58 | # 1st conv before any network block
59 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
60 | padding=1, bias=False)
61 | # 1st block
62 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
63 | if sub_block1:
64 | # 1st sub-block
65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
66 | # 2nd block
67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
68 | # 3rd block
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
70 | # global average pooling and classifier
71 | self.bn1 = nn.BatchNorm2d(nChannels[3])
72 | self.relu = nn.ReLU(inplace=True)
73 | # self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last)
74 | self.nChannels = nChannels[3]
75 | self.num_out = self.nChannels
76 |
77 | for m in self.modules():
78 | if isinstance(m, nn.Conv2d):
79 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
80 | m.weight.data.normal_(0, math.sqrt(2. / n))
81 | elif isinstance(m, nn.BatchNorm2d):
82 | m.weight.data.fill_(1)
83 | m.bias.data.zero_()
84 | elif isinstance(m, nn.Linear) and not m.bias is None:
85 | m.bias.data.zero_()
86 |
87 | def forward(self, x):
88 | out = self.conv1(x)
89 | out = self.block1(out)
90 | out = self.block2(out)
91 | out = self.block3(out)
92 | out = self.relu(self.bn1(out))
93 | out = F.avg_pool2d(out, 8)
94 | out = out.view(-1, self.nChannels)
95 | # return self.fc(out)
96 | return out
97 |
--------------------------------------------------------------------------------
/cifar/utils/augmentation.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from
2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
5 | import logging
6 | import random
7 |
8 | import numpy as np
9 | import PIL
10 | import PIL.ImageOps
11 | import PIL.ImageEnhance
12 | import PIL.ImageDraw
13 | from PIL import Image
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | PARAMETER_MAX = 10
18 |
19 |
20 | def AutoContrast(img, **kwarg):
21 | return PIL.ImageOps.autocontrast(img)
22 |
23 |
24 | def Brightness(img, v, max_v, bias=0):
25 | v = _float_parameter(v, max_v) + bias
26 | return PIL.ImageEnhance.Brightness(img).enhance(v)
27 |
28 |
29 | def Color(img, v, max_v, bias=0):
30 | v = _float_parameter(v, max_v) + bias
31 | return PIL.ImageEnhance.Color(img).enhance(v)
32 |
33 |
34 | def Contrast(img, v, max_v, bias=0):
35 | v = _float_parameter(v, max_v) + bias
36 | return PIL.ImageEnhance.Contrast(img).enhance(v)
37 |
38 |
39 | def Cutout(img, v, max_v, bias=0):
40 | if v == 0:
41 | return img
42 | v = _float_parameter(v, max_v) + bias
43 | v = int(v * min(img.size))
44 | return CutoutAbs(img, v)
45 |
46 |
47 | def CutoutAbs(img, v, **kwarg):
48 | w, h = img.size
49 | x0 = np.random.uniform(0, w)
50 | y0 = np.random.uniform(0, h)
51 | x0 = int(max(0, x0 - v / 2.))
52 | y0 = int(max(0, y0 - v / 2.))
53 | x1 = int(min(w, x0 + v))
54 | y1 = int(min(h, y0 + v))
55 | xy = (x0, y0, x1, y1)
56 | # gray
57 | color = (127, 127, 127)
58 | img = img.copy()
59 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
60 | return img
61 |
62 |
63 | def Equalize(img, **kwarg):
64 | return PIL.ImageOps.equalize(img)
65 |
66 |
67 | def Identity(img, **kwarg):
68 | return img
69 |
70 |
71 | def Invert(img, **kwarg):
72 | return PIL.ImageOps.invert(img)
73 |
74 |
75 | def Posterize(img, v, max_v, bias=0):
76 | v = _int_parameter(v, max_v) + bias
77 | return PIL.ImageOps.posterize(img, v)
78 |
79 |
80 | def Rotate(img, v, max_v, bias=0):
81 | v = _int_parameter(v, max_v) + bias
82 | if random.random() < 0.5:
83 | v = -v
84 | return img.rotate(v)
85 |
86 |
87 | def Sharpness(img, v, max_v, bias=0):
88 | v = _float_parameter(v, max_v) + bias
89 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
90 |
91 |
92 | def ShearX(img, v, max_v, bias=0):
93 | v = _float_parameter(v, max_v) + bias
94 | if random.random() < 0.5:
95 | v = -v
96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
97 |
98 |
99 | def ShearY(img, v, max_v, bias=0):
100 | v = _float_parameter(v, max_v) + bias
101 | if random.random() < 0.5:
102 | v = -v
103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
104 |
105 |
106 | def Solarize(img, v, max_v, bias=0):
107 | v = _int_parameter(v, max_v) + bias
108 | return PIL.ImageOps.solarize(img, 256 - v)
109 |
110 |
111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
112 | v = _int_parameter(v, max_v) + bias
113 | if random.random() < 0.5:
114 | v = -v
115 | img_np = np.array(img).astype(np.int)
116 | img_np = img_np + v
117 | img_np = np.clip(img_np, 0, 255)
118 | img_np = img_np.astype(np.uint8)
119 | img = Image.fromarray(img_np)
120 | return PIL.ImageOps.solarize(img, threshold)
121 |
122 |
123 | def TranslateX(img, v, max_v, bias=0):
124 | v = _float_parameter(v, max_v) + bias
125 | if random.random() < 0.5:
126 | v = -v
127 | v = int(v * img.size[0])
128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
129 |
130 |
131 | def TranslateY(img, v, max_v, bias=0):
132 | v = _float_parameter(v, max_v) + bias
133 | if random.random() < 0.5:
134 | v = -v
135 | v = int(v * img.size[1])
136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
137 |
138 |
139 | def _float_parameter(v, max_v):
140 | return float(v) * max_v / PARAMETER_MAX
141 |
142 |
143 | def _int_parameter(v, max_v):
144 | return int(v * max_v / PARAMETER_MAX)
145 |
146 |
147 | def fixmatch_augment_pool():
148 | # FixMatch paper
149 | augs = [(AutoContrast, None, None),
150 | (Brightness, 0.9, 0.05),
151 | (Color, 0.9, 0.05),
152 | (Contrast, 0.9, 0.05),
153 | (Equalize, None, None),
154 | (Identity, None, None),
155 | (Posterize, 4, 4),
156 | (Rotate, 30, 0),
157 | (Sharpness, 0.9, 0.05),
158 | (ShearX, 0.3, 0),
159 | (ShearY, 0.3, 0),
160 | (Solarize, 256, 0),
161 | (TranslateX, 0.3, 0),
162 | (TranslateY, 0.3, 0)]
163 | return augs
164 |
165 |
166 | def my_augment_pool():
167 | # Test
168 | augs = [(AutoContrast, None, None),
169 | (Brightness, 1.8, 0.1),
170 | (Color, 1.8, 0.1),
171 | (Contrast, 1.8, 0.1),
172 | (Cutout, 0.2, 0),
173 | (Equalize, None, None),
174 | (Invert, None, None),
175 | (Posterize, 4, 4),
176 | (Rotate, 30, 0),
177 | (Sharpness, 1.8, 0.1),
178 | (ShearX, 0.3, 0),
179 | (ShearY, 0.3, 0),
180 | (Solarize, 256, 0),
181 | (SolarizeAdd, 110, 0),
182 | (TranslateX, 0.45, 0),
183 | (TranslateY, 0.45, 0)]
184 | return augs
185 |
186 |
187 | class RandAugmentPC(object):
188 | def __init__(self, n, m):
189 | assert n >= 1
190 | assert 1 <= m <= 10
191 | self.n = n
192 | self.m = m
193 | self.augment_pool = my_augment_pool()
194 |
195 | def __call__(self, img):
196 | ops = random.choices(self.augment_pool, k=self.n)
197 | for op, max_v, bias in ops:
198 | prob = np.random.uniform(0.2, 0.8)
199 | if random.random() + prob >= 1:
200 | img = op(img, v=self.m, max_v=max_v, bias=bias)
201 | img = CutoutAbs(img, int(32*0.5))
202 | return img
203 |
204 |
205 | class RandAugmentMC(object):
206 | def __init__(self, n, m):
207 | assert n >= 1
208 | assert 1 <= m <= 10
209 | self.n = n
210 | self.m = m
211 | self.augment_pool = fixmatch_augment_pool()
212 |
213 | def __call__(self, img):
214 | ops = random.choices(self.augment_pool, k=self.n)
215 | for op, max_v, bias in ops:
216 | v = np.random.randint(1, self.m)
217 | if random.random() < 0.5:
218 | img = op(img, v=v, max_v=max_v, bias=bias)
219 | img = CutoutAbs(img, int(32*0.5))
220 | return img
--------------------------------------------------------------------------------
/cifar/TEST.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | import torch.utils.data as data
5 |
6 | from utils.misc import *
7 | from utils.test_helpers import *
8 | from utils.prepare_dataset import *
9 |
10 | # ----------------------------------
11 | import copy
12 | import random
13 | import numpy as np
14 | from utils.contrastive import *
15 | from utils.offline import *
16 | from torch import nn
17 | import torch.nn.functional as F
18 | # ----------------------------------
19 |
20 |
21 | def compute_os_variance(os, th):
22 | """
23 | Calculate the area of a rectangle.
24 |
25 | Parameters:
26 | os : OOD score queue.
27 | th : Given threshold to separate weak and strong OOD samples.
28 |
29 | Returns:
30 | float: Weighted variance at the given threshold th.
31 | """
32 |
33 | thresholded_os = np.zeros(os.shape)
34 | thresholded_os[os >= th] = 1
35 |
36 | # compute weights
37 | nb_pixels = os.size
38 | nb_pixels1 = np.count_nonzero(thresholded_os)
39 | weight1 = nb_pixels1 / nb_pixels
40 | weight0 = 1 - weight1
41 |
42 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered
43 | # in the search for the best threshold
44 | if weight1 == 0 or weight0 == 0:
45 | return np.inf
46 |
47 | # find all pixels belonging to each class
48 | val_pixels1 = os[thresholded_os == 1]
49 | val_pixels0 = os[thresholded_os == 0]
50 |
51 | # compute variance of these classes
52 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0
53 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0
54 |
55 | return weight0 * var0 + weight1 * var1
56 |
57 |
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('--dataset', default='cifar10OOD')
60 | parser.add_argument('--strong_OOD', default='noise')
61 | parser.add_argument('--strong_ratio', default=1, type=float)
62 | parser.add_argument('--dataroot', default="./data", help='path to dataset')
63 | parser.add_argument('--batch_size', default=256, type=int)
64 | parser.add_argument('--workers', default=4, type=int)
65 | parser.add_argument('--outf', help='folder to output log')
66 | parser.add_argument('--level', default=5, type=int)
67 | parser.add_argument('--N_m', default=512, type=int, help='queue length')
68 | parser.add_argument('--corruption', default='snow')
69 | parser.add_argument('--resume', default='/cluster/personal/code/TTT/TTAC-master/cifar/results/cifar10_joint_resnet50', help='directory of pretrained model')
70 | parser.add_argument('--model', default='resnet50', help='resnet50')
71 | parser.add_argument('--seed', default=0, type=int)
72 |
73 |
74 | # ----------- Args and Dataloader ------------
75 | args = parser.parse_args()
76 |
77 | print(args)
78 | print('\n')
79 |
80 |
81 |
82 |
83 | class_num = 10 if args.dataset == 'cifar10OOD' else 100
84 |
85 | net, ext, head, ssh, classifier = build_resnet50(args)
86 |
87 | teset, _ = prepare_test_data(args)
88 | teloader = data.DataLoader(teset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=seed_worker, pin_memory=True, drop_last=False)
89 |
90 | # -------------------------------
91 | print('Resuming from %s...' %(args.resume))
92 |
93 | load_resnet50(net, head, ssh, classifier, args)
94 |
95 | # ----------- Offline Feature Summarization ------------
96 | args_align = copy.deepcopy(args)
97 |
98 | _, offlineloader = prepare_train_data(args_align)
99 | ext_src_mu, ext_src_cov, ssh_src_mu, ssh_src_cov, mu_src_ext, cov_src_ext, mu_src_ssh, cov_src_ssh = offline(args,offlineloader, ext, classifier, head, class_num)
100 |
101 | ext_src_mu = torch.stack(ext_src_mu)
102 | weak_prototype = F.normalize(ext_src_mu.clone()).cuda()
103 |
104 | torch.manual_seed(args.seed)
105 | random.seed(args.seed)
106 | np.random.seed(args.seed)
107 | torch.cuda.manual_seed(args.seed)
108 | torch.cuda.manual_seed_all(args.seed)
109 |
110 | # ----------- Open-World Test-time Training ------------
111 |
112 | correct = []
113 | unseen_correct= []
114 | all_correct=[]
115 | cumulative_error = []
116 | num_open = 0
117 | predicted_list=[]
118 | label_list=[]
119 |
120 | os_inference_queue = []
121 | queue_length = args.N_m
122 |
123 | ema_total_n = 0.
124 |
125 | print('\n-----Test-Time Training with TEST-----')
126 | for te_idx, (te_inputs, te_labels) in enumerate(teloader):
127 |
128 |
129 | ####-------------------------- Test ----------------------------####
130 |
131 | with torch.no_grad():
132 | if isinstance(te_inputs,list):
133 | inputs = te_inputs[0].cuda()
134 | else:
135 | inputs = te_inputs.cuda()
136 | net.eval()
137 | feat_ext = ext(inputs) #b,2048
138 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())
139 | update = 1
140 | softmax_logit = logit.softmax(dim=-1)
141 | pro, predicted = softmax_logit.max(dim=-1)
142 |
143 | ood_score, max_index = logit.max(1)
144 | ood_score = 1-ood_score
145 | os_inference_queue.extend(ood_score.detach().cpu().tolist())
146 | os_inference_queue = os_inference_queue[-queue_length:]
147 |
148 | threshold_range = np.arange(0,1,0.01)
149 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range]
150 | best_threshold = threshold_range[np.argmin(criterias)]
151 | unseen_mask = (ood_score > best_threshold)
152 | args.ts = best_threshold
153 | predicted[unseen_mask] = class_num
154 |
155 | one = torch.ones_like(te_labels)*class_num
156 | false = torch.ones_like(te_labels)*-1
157 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted)
158 | all_labels = torch.where(te_labels>class_num-1, one, te_labels)
159 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels)
160 | unseen_labels = torch.where(te_labels>class_num-1, one, false)
161 | correct.append(predicted.cpu().eq(seen_labels))
162 | unseen_correct.append(predicted.cpu().eq(unseen_labels))
163 | all_correct.append(predicted.cpu().eq(all_labels))
164 | num_open += torch.gt(te_labels, 99).sum()
165 |
166 | predicted_list.append(predicted.long().cpu())
167 | label_list.append(all_labels.long().cpu())
168 |
169 |
170 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4)
171 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4)
172 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4)
173 | print('Batch:(', te_idx,'/',len(teloader),\
174 | '\t Cumulative Results: ACC_S:', seen_acc,\
175 | '\tACC_N:', unseen_acc,\
176 | '\tACC_H:',h_score\
177 | )
178 |
179 |
180 | print('\nTest time training result:',' ACC_S:', seen_acc,\
181 | '\tACC_N:', unseen_acc,\
182 | '\tACC_H:',h_score,'\n\n\n\n'\
183 | )
184 |
185 |
186 | if args.outf != None:
187 | my_makedir(args.outf)
188 | with open (args.outf+'/results.txt','a') as f:
189 | f.write(str(args)+'\n')
190 | f.write(
191 | 'ACC_S:'+ str(seen_acc)+\
192 | '\tACC_N:'+ str(unseen_acc)+\
193 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\
194 | )
--------------------------------------------------------------------------------
/imagenet/TEST.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | import torch.utils.data as data
5 |
6 | import torch.nn as nn
7 | from utils.test_helpers import *
8 | from utils.prepare_dataset import *
9 |
10 | # ----------------------------------
11 | import copy
12 | import random
13 | import numpy as np
14 |
15 | from utils.test_helpers import build_model, test
16 | from utils.prepare_dataset import prepare_transforms, create_dataloader, ImageNetCorruption, ImageNet_, prepare_ood_test_data,prepare_ood_test_data_r
17 | from utils.offline import offline, offline_r
18 | import torch.nn.functional as F
19 | # ----------------------------------
20 |
21 |
22 | def compute_os_variance(os, th):
23 | """
24 | Calculate the area of a rectangle.
25 |
26 | Parameters:
27 | os : OOD score queue.
28 | th : Given threshold to separate weak and strong OOD samples.
29 |
30 | Returns:
31 | float: Weighted variance at the given threshold th.
32 | """
33 |
34 | thresholded_os = np.zeros(os.shape)
35 | thresholded_os[os >= th] = 1
36 |
37 | # compute weights
38 | nb_pixels = os.size
39 | nb_pixels1 = np.count_nonzero(thresholded_os)
40 | weight1 = nb_pixels1 / nb_pixels
41 | weight0 = 1 - weight1
42 |
43 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered
44 | # in the search for the best threshold
45 | if weight1 == 0 or weight0 == 0:
46 | return np.inf
47 |
48 | # find all pixels belonging to each class
49 | val_pixels1 = os[thresholded_os == 1]
50 | val_pixels0 = os[thresholded_os == 0]
51 |
52 | # compute variance of these classes
53 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0
54 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0
55 |
56 | return weight0 * var0 + weight1 * var1
57 |
58 |
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument('--dataset', default='ImageNet-C')
61 | parser.add_argument('--strong_OOD', default='noise')
62 | parser.add_argument('--strong_ratio', default=1, type=float)
63 | parser.add_argument('--dataroot', default='./data')
64 | parser.add_argument('--batch_size', default=128, type=int)
65 | parser.add_argument('--workers', default=8, type=int)
66 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale')
67 | parser.add_argument('--outf', help='folder to output log')
68 | parser.add_argument('--level', default=5, type=int)
69 | parser.add_argument('--N_m', default=512, type=int, help='queue length')
70 | parser.add_argument('--corruption', default='snow')
71 | parser.add_argument('--offline', default='./results/offline/', help='directory of pretrained model')
72 | parser.add_argument('--model', default='resnet50', help='resnet50')
73 | parser.add_argument('--seed', default=0, type=int)
74 |
75 |
76 | # ----------- Args and Dataloader ------------
77 | args = parser.parse_args()
78 |
79 | print(args)
80 | print('\n')
81 |
82 |
83 |
84 | net, ext, classifier = build_model()
85 |
86 |
87 | train_transform, val_transform, val_corrupt_transform = prepare_transforms()
88 |
89 | source_dataset = ImageNet_(args.dataroot, 'val', transform=val_transform, is_carry_index=True)
90 |
91 | if args.dataset == 'ImageNet-C':
92 | target_dataset_test = prepare_ood_test_data(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform)
93 | class_num = 1000
94 |
95 | elif args.dataset == 'ImageNet-R':
96 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids]
97 | target_dataset_test = prepare_ood_test_data_r(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform)
98 | class_num = 200
99 | else:
100 | raise NotImplementedError
101 |
102 | source_dataloader = create_dataloader(source_dataset, args, True, False)
103 | target_dataloader_test = create_dataloader(target_dataset_test, args, True, False)
104 |
105 | # ----------- Offline Feature Summarization ------------
106 | if args.dataset == 'ImageNet-C':
107 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline(args, source_dataloader, ext, classifier)
108 | weak_prototype = F.normalize(ext_mean_categories.clone()).cuda()
109 | else:
110 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline_r(args, source_dataloader, ext, classifier)
111 | weak_prototype = F.normalize(ext_mean_categories[indices_in_1k].clone()).cuda()
112 |
113 | torch.manual_seed(args.seed)
114 | random.seed(args.seed)
115 | np.random.seed(args.seed)
116 | torch.cuda.manual_seed(args.seed)
117 | torch.cuda.manual_seed_all(args.seed)
118 |
119 | # ----------- Open-World Test-time Training ------------
120 |
121 | correct = []
122 | unseen_correct= []
123 | all_correct=[]
124 | cumulative_error = []
125 | num_open = 0
126 | predicted_list=[]
127 | label_list=[]
128 |
129 | os_inference_queue = []
130 | queue_length = args.N_m
131 |
132 | ema_total_n = 0.
133 |
134 | print('\n-----Test-Time Training with TEST-----')
135 | for te_idx, (te_inputs, te_labels) in enumerate(target_dataloader_test):
136 |
137 | if isinstance(te_inputs,list):
138 | inputs = te_inputs[0].cuda()
139 | else:
140 | inputs = te_inputs.cuda()
141 |
142 | ####-------------------------- Test ----------------------------####
143 |
144 | with torch.no_grad():
145 |
146 | net.eval()
147 | feat_ext = ext(inputs) #b,2048
148 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())
149 |
150 |
151 | softmax_logit = logit.softmax(dim=-1)
152 | pro, predicted = softmax_logit.max(dim=-1)
153 |
154 | ood_score, max_index = logit.max(1)
155 | ood_score = 1-ood_score
156 | os_inference_queue.extend(ood_score.detach().cpu().tolist())
157 | os_inference_queue = os_inference_queue[-queue_length:]
158 |
159 | threshold_range = np.arange(0,1,0.01)
160 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range]
161 | best_threshold = threshold_range[np.argmin(criterias)]
162 | unseen_mask = (ood_score > best_threshold)
163 | args.ts = best_threshold
164 | predicted[unseen_mask] = class_num
165 |
166 | one = torch.ones_like(te_labels)*class_num
167 | false = torch.ones_like(te_labels)*-1
168 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted)
169 | all_labels = torch.where(te_labels>class_num-1, one, te_labels)
170 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels)
171 | unseen_labels = torch.where(te_labels>class_num-1, one, false)
172 | correct.append(predicted.cpu().eq(seen_labels))
173 | unseen_correct.append(predicted.cpu().eq(unseen_labels))
174 | all_correct.append(predicted.cpu().eq(all_labels))
175 | num_open += torch.gt(te_labels, class_num-1).sum()
176 |
177 | predicted_list.append(predicted.long().cpu())
178 | label_list.append(all_labels.long().cpu())
179 |
180 |
181 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4)
182 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4)
183 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4)
184 | print('Batch:(', te_idx,'/',len(target_dataloader_test),\
185 | '\t Cumulative Results: ACC_S:', seen_acc,\
186 | '\tACC_N:', unseen_acc,\
187 | '\tACC_H:',h_score\
188 | )
189 |
190 |
191 | print('\nTest time training result:',' ACC_S:', seen_acc,\
192 | '\tACC_N:', unseen_acc,\
193 | '\tACC_H:',h_score,'\n\n\n\n'\
194 | )
195 |
196 |
197 | if args.outf != None:
198 | my_makedir(args.outf)
199 | with open (args.outf+'/results.txt','a') as f:
200 | f.write(str(args)+'\n')
201 | f.write(
202 | 'ACC_S:'+ str(seen_acc)+\
203 | '\tACC_N:'+ str(unseen_acc)+\
204 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\
205 | )
--------------------------------------------------------------------------------
/cifar/models/BigResNet.py:
--------------------------------------------------------------------------------
1 | """ResNet in PyTorch.
2 | ImageNet-Style ResNet
3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
5 | ResNet adapted from: https://github.com/bearpaw/pytorch-classification
6 | SupConResNet adpated from https://github.com/HobbitLong/SupContrast
7 | """
8 | from functools import partial
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | class BasicBlock(nn.Module):
14 | expansion = 1
15 |
16 | def __init__(self, in_planes, planes, stride=1, is_last=False):
17 | super(BasicBlock, self).__init__()
18 | self.is_last = is_last
19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 |
24 | self.shortcut = nn.Sequential()
25 | if stride != 1 or in_planes != self.expansion * planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
28 | nn.BatchNorm2d(self.expansion * planes)
29 | )
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 | out += self.shortcut(x)
35 | preact = out
36 | out = F.relu(out)
37 | if self.is_last:
38 | return out, preact
39 | else:
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, in_planes, planes, stride=1, is_last=False):
47 | super(Bottleneck, self).__init__()
48 | self.is_last = is_last
49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
50 | self.bn1 = nn.BatchNorm2d(planes)
51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
55 |
56 | self.shortcut = nn.Sequential()
57 | if stride != 1 or in_planes != self.expansion * planes:
58 | self.shortcut = nn.Sequential(
59 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
60 | nn.BatchNorm2d(self.expansion * planes)
61 | )
62 |
63 | def forward(self, x):
64 | out = F.relu(self.bn1(self.conv1(x)))
65 | out = F.relu(self.bn2(self.conv2(out)))
66 | out = self.bn3(self.conv3(out))
67 | out += self.shortcut(x)
68 | preact = out
69 | out = F.relu(out)
70 | if self.is_last:
71 | return out, preact
72 | else:
73 | return out
74 |
75 |
76 | class ResNet(nn.Module):
77 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
78 | super(ResNet, self).__init__()
79 | self.in_planes = 64
80 |
81 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
82 | bias=False)
83 | self.bn1 = nn.BatchNorm2d(64)
84 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
85 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
86 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
87 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, is_last=True)
88 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
89 |
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
93 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
94 | nn.init.constant_(m.weight, 1)
95 | nn.init.constant_(m.bias, 0)
96 |
97 | # Zero-initialize the last BN in each residual branch,
98 | # so that the residual branch starts with zeros, and each residual block behaves
99 | # like an identity. This improves the model by 0.2~0.3% according to:
100 | # https://arxiv.org/abs/1706.02677
101 | if zero_init_residual:
102 | for m in self.modules():
103 | if isinstance(m, Bottleneck):
104 | nn.init.constant_(m.bn3.weight, 0)
105 | elif isinstance(m, BasicBlock):
106 | nn.init.constant_(m.bn2.weight, 0)
107 |
108 | def _make_layer(self, block, planes, num_blocks, stride, is_last=False):
109 | strides = [stride] + [1] * (num_blocks - 1)
110 | layers = []
111 | for i in range(num_blocks):
112 | stride = strides[i]
113 | layers.append(block(self.in_planes, planes, stride))
114 | self.in_planes = planes * block.expansion
115 | return nn.Sequential(*layers)
116 |
117 | def forward(self, x, layer=100):
118 | out = F.relu(self.bn1(self.conv1(x)))
119 | out = self.layer1(out)
120 | out = self.layer2(out)
121 | out = self.layer3(out)
122 | out = self.layer4(out)
123 | out = self.avgpool(out)
124 | out = torch.flatten(out, 1)
125 | return out
126 |
127 |
128 | def resnet18(**kwargs):
129 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
130 |
131 |
132 | def resnet34(**kwargs):
133 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
134 |
135 |
136 | def resnet50(**kwargs):
137 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
138 |
139 |
140 | def resnet101(**kwargs):
141 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
142 |
143 |
144 | model_dict = {
145 | 'resnet18': [resnet18, 512],
146 | 'resnet34': [resnet34, 512],
147 | 'resnet50': [resnet50, 2048],
148 | 'resnet101': [resnet101, 2048],
149 | }
150 |
151 |
152 | class LinearBatchNorm(nn.Module):
153 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
154 | def __init__(self, dim, affine=True):
155 | super(LinearBatchNorm, self).__init__()
156 | self.dim = dim
157 | self.bn = nn.BatchNorm2d(dim, affine=affine)
158 |
159 | def forward(self, x):
160 | x = x.view(-1, self.dim, 1, 1)
161 | x = self.bn(x)
162 | x = x.view(-1, self.dim)
163 | return x
164 |
165 |
166 | class SupConResNet(nn.Module):
167 | """backbone + projection head"""
168 | def __init__(self, name='resnet50', head='mlp', feat_dim=128):
169 | super(SupConResNet, self).__init__()
170 | model_fun, dim_in = model_dict[name]
171 | self.encoder = model_fun()
172 | if head == 'linear':
173 | self.head = nn.Linear(dim_in, feat_dim)
174 | elif head == 'mlp':
175 | self.head = nn.Sequential(
176 | nn.Linear(dim_in, dim_in),
177 | nn.ReLU(inplace=True),
178 | nn.Linear(dim_in, feat_dim)
179 | )
180 | else:
181 | raise NotImplementedError(
182 | 'head not supported: {}'.format(head))
183 |
184 | def forward(self, x):
185 | feat = self.encoder(x)
186 | feat = F.normalize(self.head(feat), dim=1)
187 | return feat
188 |
189 |
190 | class LinearClassifier(nn.Module):
191 | """Linear classifier"""
192 | def __init__(self, name='resnet50', num_classes=10,num_dim=None):
193 | super(LinearClassifier, self).__init__()
194 | if num_dim==None:
195 | _, feat_dim = model_dict[name]
196 | else:
197 | feat_dim=num_dim
198 | self.fc = nn.Linear(feat_dim, num_classes)
199 |
200 | def forward(self, features,norm=False):
201 | if not norm:
202 | return self.fc(features)
203 | self.weight_norm()
204 |
205 | return self.fc(F.normalize(features))- self.fc.bias
206 |
207 | def weight_norm(self):
208 | # print(self.fc.bias.data)-+
209 | w = self.fc.weight.data
210 | norm = w.norm(p=2, dim=1, keepdim=True)
211 | self.fc.weight.data = w.div(norm.expand_as(w))
212 |
--------------------------------------------------------------------------------
/imagenet/utils/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import torchvision.transforms as transforms
5 | from torchvision.datasets import ImageNet
6 | import os
7 | import torchvision
8 |
9 | def prepare_transforms():
10 | train_transform = transforms.Compose([
11 | transforms.RandomResizedCrop(224),
12 | transforms.RandomHorizontalFlip(),
13 | transforms.ToTensor(),
14 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
15 | ])
16 | val_transform = transforms.Compose([
17 | transforms.Resize(256),
18 | transforms.CenterCrop(224),
19 | transforms.ToTensor(),
20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
21 | ])
22 | val_corrupt_transform = transforms.Compose([
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
25 | ])
26 | return train_transform, val_transform, val_corrupt_transform
27 |
28 | def seed_worker(worker_id):
29 | worker_seed = torch.initial_seed() % 2**32
30 | np.random.seed(worker_seed)
31 | random.seed(worker_seed)
32 |
33 | def create_dataloader(dataset, args, shuffle=False, drop_last=False):
34 | return torch.utils.data.DataLoader(dataset,
35 | batch_size=args.batch_size,
36 | shuffle=shuffle,
37 | num_workers=args.workers,
38 | worker_init_fn=seed_worker,
39 | pin_memory=True,
40 | drop_last=drop_last)
41 |
42 |
43 | class ImageNetCorruption(ImageNet):
44 | def __init__(self, root, corruption_name="gaussian_noise", transform=None, is_carry_index=False):
45 | super().__init__(root, 'val', transform=transform)
46 | self.root = root
47 | self.corruption_name = corruption_name
48 | self.transform = transform
49 | self.is_carry_index = is_carry_index
50 | self.load_data()
51 |
52 | def load_data(self):
53 | self.data = torch.load(os.path.join(self.root, 'corruption', self.corruption_name + '.pth')).numpy()
54 | self.target = [i[1] for i in self.imgs]
55 | return
56 |
57 | def __getitem__(self, index):
58 | img = self.data[index, :, :, :]
59 | target = self.target[index]
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.is_carry_index:
63 | img = [img, index]
64 | return img, target
65 |
66 | def __len__(self):
67 | return self.data.shape[0]
68 |
69 | class ImageNet_(ImageNet):
70 | def __init__(self, *args, is_carry_index=False, **kwargs):
71 | super().__init__(*args, **kwargs)
72 | self.is_carry_index = is_carry_index
73 |
74 | def __getitem__(self, index: int):
75 | img, target = super().__getitem__(index)
76 | if self.is_carry_index:
77 | if type(img) == list:
78 | img.append(index)
79 | else:
80 | img = [img, index]
81 | return img, target
82 |
83 |
84 | class noise_dataset(torch.utils.data.Dataset):#需要继承data.Dataset
85 | def __init__(self, transform,ratio=1):
86 | #定义好 image 的路径
87 | self.number = int(50000*ratio)
88 | self.transform = transform
89 |
90 | def __getitem__(self, index:int):
91 | image = torch.randn(3,224,224)
92 | target = 1000
93 | # if self.transform is not None:
94 | # image = self.transform(image)
95 | if type(image) == list:
96 | image.append(index)
97 | else:
98 | image = [image, index]
99 |
100 | return image, target
101 |
102 | def __len__(self):
103 |
104 | return self.number
105 |
106 | class imageneta(torchvision.datasets.ImageFolder):#需要继承data.Dataset
107 | def __init__(self, *args, **kwargs):
108 | super().__init__(*args, **kwargs)
109 | # self.is_carry_index = is_carry_index
110 |
111 | def __getitem__(self, index: int):
112 | img, target = super().__getitem__(index)
113 |
114 | if type(img) == list:
115 | img.append(index)
116 | else:
117 | img = [img, index]
118 | return img, target
119 |
120 |
121 | class MNIST_openset(torchvision.datasets.MNIST):
122 | def __init__(self, *args, ratio = 1 , **kwargs):
123 | super().__init__(*args, **kwargs)
124 | self.data, self.targets = self.data[:int(50000*ratio)], self.targets[:int(50000*ratio)]
125 | print(ratio)
126 | print(len(self.data))
127 | return
128 |
129 | def __getitem__(self, index: int):
130 | image, target = super().__getitem__(index)
131 | target = target + 1000
132 | if type(image) == list:
133 | image.append(index)
134 | else:
135 | image = [image, index]
136 | return image, target
137 |
138 |
139 | class SVHN_openset(torchvision.datasets.SVHN):
140 | def __init__(self, *args, ratio = 1 , **kwargs):
141 | super().__init__(*args, **kwargs)
142 | self.data, self.labels = self.data[:int(50000*ratio)], self.labels[:int(50000*ratio)]
143 | print(ratio)
144 | print(len(self.data))
145 | return
146 |
147 | def __getitem__(self, index: int):
148 | image, target = super().__getitem__(index)
149 | target = target + 1000
150 | if type(image) == list:
151 | image.append(index)
152 | else:
153 | image = [image, index]
154 | return image, target
155 |
156 |
157 | def prepare_ood_test_data_a(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None):
158 | teset_seen = imageneta(root='/cluster/personal/dataset/imagenet-a', transform=OOD_transform)
159 | print(len(teset_seen))
160 | if OOD =='noise':
161 | teset_unseen = noise_dataset(transform,ratio=0.15)
162 | elif OOD=='SVHN':
163 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C",
164 | split='train', download=True, transform=OOD_transform, ratio=0.15)
165 | elif OOD=='MNIST':
166 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ])
167 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C",
168 | train=True, download=True, transform=te_rize, ratio=0.15)
169 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen])
170 | return teset
171 |
172 | def prepare_ood_test_data_r(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None):
173 | teset_seen = imageneta(root='/cluster/personal/dataset/imagenet-r', transform=OOD_transform)
174 | print(len(teset_seen))
175 | if OOD =='noise':
176 | teset_unseen = noise_dataset(transform,ratio=0.6)
177 | elif OOD=='SVHN':
178 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C",
179 | split='train', download=True, transform=OOD_transform, ratio=0.6)
180 | elif OOD=='MNIST':
181 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ])
182 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C",
183 | train=True, download=True, transform=te_rize, ratio=0.6)
184 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen])
185 | return teset
186 |
187 | def prepare_test_data_r(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None):
188 | teset_seen = imageneta(root=root+'/imagenet-r', transform=OOD_transform)
189 | return teset_seen
190 |
191 | def prepare_ood_test_data(root, corruption_name="gaussian_noise", transform=None, is_carry_index=False, OOD = 'noise', OOD_transform=None):
192 | teset_seen = ImageNetCorruption(root, corruption_name, transform=transform, is_carry_index=is_carry_index)
193 | if OOD =='noise':
194 | teset_unseen = noise_dataset(transform)
195 | elif OOD=='SVHN':
196 | teset_unseen = SVHN_openset(root="/cluster/personal/dataset/CIFAR-C",
197 | split='train', download=True, transform=OOD_transform, ratio=1)
198 | elif OOD=='MNIST':
199 | te_rize = transforms.Compose([transforms.Grayscale(3), OOD_transform ])
200 | teset_unseen = MNIST_openset(root="/cluster/personal/dataset/CIFAR-C",
201 | train=True, download=True, transform=te_rize, ratio=1)
202 | teset = torch.utils.data.ConcatDataset([teset_seen,teset_unseen])
203 | return teset
204 |
--------------------------------------------------------------------------------
/cifar/utils/test_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from utils.misc import *
5 |
6 |
7 | def load_resnet50(net, head, ssh, classifier, args):
8 |
9 | filename = args.resume + '/ckpt.pth'
10 |
11 | ckpt = torch.load(filename)
12 | state_dict = ckpt['model']
13 |
14 | net_dict = {}
15 | head_dict = {}
16 | for k, v in state_dict.items():
17 | if k[:4] == "head":
18 | k = k.replace("head.", "")
19 | head_dict[k] = v
20 | else:
21 | k = k.replace("encoder.", "ext.")
22 | k = k.replace("fc.", "head.fc.")
23 | net_dict[k] = v
24 |
25 | net.load_state_dict(net_dict)
26 | head.load_state_dict(head_dict)
27 |
28 | print('Loaded model trained jointly on Classification and SimCLR:', filename)
29 |
30 | def load_resnet50_sne(net, head, ssh, classifier, args):
31 |
32 | filename = args.resume
33 | ckpt = torch.load(filename)
34 | net.load_state_dict(ckpt)
35 |
36 | print('Loaded model trained jointly on Classification and SimCLR:', filename)
37 |
38 | def load_robust_resnet50(net, head, ssh, classifier, args):
39 |
40 | filename = args.resume
41 | try:
42 | sd = torch.load(filename,map_location='cuda:0')['state_dict']
43 | except:
44 | sd = torch.load(filename)
45 | model_dict=net.state_dict()
46 |
47 | ckpt={}
48 | for k, v in sd.items():
49 | print(k)
50 | k = k.replace("backbone.",'')
51 | if k[:12] == "module.model":
52 | if not 'linear' in k:
53 | k = k.replace("module.model.", "ext.")
54 | ckpt[k] = v
55 | else:
56 | k = k.replace("module.model.linear",'head.fc')
57 | ckpt[k] = v
58 | else:
59 | if not ('linear' in k or 'logits' in k):
60 |
61 | if not 'fc'in k:
62 | k = 'ext.'+k
63 | ckpt[k] = v
64 | else:
65 | k = 'head.'+k
66 | ckpt[k]=v
67 | else:
68 | k = k.replace("linear",'head.fc')
69 | k = k.replace("logits",'head.fc')
70 | ckpt[k] = v
71 | ckpt = {k: v for k, v in ckpt.items() if k in model_dict}
72 |
73 | net.load_state_dict(ckpt)
74 |
75 | print('Loaded robust model:', filename)
76 |
77 |
78 | def load_ttt(net, head, ssh, classifier, args, ttt=False):
79 | if ttt:
80 | filename = args.resume + '/{}_both_2_15.pth'.format(args.corruption)
81 | else:
82 | filename = args.resume + '/{}_both_15.pth'.format(args.corruption)
83 | ckpt = torch.load(filename)
84 | net.load_state_dict(ckpt['net'])
85 | head.load_state_dict(ckpt['head'])
86 | print('Loaded updated model from', filename)
87 |
88 |
89 | def corrupt_resnet50(ext, args):
90 | try:
91 | # SSL trained encoder
92 | simclr = torch.load(args.restore + '/simclr.pth')
93 | state_dict = simclr['model']
94 |
95 | ext_dict = {}
96 | for k, v in state_dict.items():
97 | if k[:7] == "encoder":
98 | k = k.replace("encoder.", "")
99 | ext_dict[k] = v
100 | ext.load_state_dict(ext_dict)
101 |
102 | print('Corrupted encoder trained by SimCLR')
103 |
104 | except:
105 | # Jointly trained encoder
106 | filename = args.resume + '/ckpt_epoch_{}.pth'.format(args.restore)
107 |
108 | ckpt = torch.load(filename)
109 | state_dict = ckpt['model']
110 |
111 | ext_dict = {}
112 | for k, v in state_dict.items():
113 | if k[:7] == "encoder":
114 | k = k.replace("encoder.", "")
115 | ext_dict[k] = v
116 | ext.load_state_dict(ext_dict)
117 | print('Corrupted encoder jontly trained on Classification and SimCLR')
118 |
119 |
120 | def build_resnet50(args):
121 | from models.BigResNet import SupConResNet, LinearClassifier
122 | from models.SSHead import ExtractorHead
123 |
124 | print('Building ResNet50...')
125 | if args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD':
126 | classes = 10
127 | if args.dataset == 'cifar10':
128 | classes = 10
129 | elif args.dataset == 'cifar7':
130 | if not hasattr(args, 'modified') or args.modified:
131 | classes = 7
132 | else:
133 | classes = 10
134 | elif args.dataset == "cifar100" or args.dataset == "cifar100OOD":
135 | classes = 100
136 |
137 | classifier = LinearClassifier(num_classes=classes).cuda()
138 | ssh = SupConResNet().cuda()
139 | head = ssh.head
140 | ext = ssh.encoder
141 | net = ExtractorHead(ext, classifier).cuda()
142 | return net, ext, head, ssh, classifier
143 |
144 | def build_net(args):
145 | from models.BigResNet import SupConResNet, LinearClassifier
146 | from models.SSHead import ExtractorHead
147 | from models.dm import CIFAR10_MEAN, CIFAR10_STD, \
148 | DMWideResNet, Swish, DMPreActResNet
149 |
150 | print('Building '+args.net)
151 | if args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD':
152 | classes = 10
153 | if args.dataset == 'cifar10':
154 | classes = 10
155 | elif args.dataset == 'cifar7':
156 | if not hasattr(args, 'modified') or args.modified:
157 | classes = 7
158 | else:
159 | classes = 10
160 | elif args.dataset == "cifar100":
161 | classes = 100
162 |
163 | if args.net == "dm":
164 | ext=DMPreActResNet(num_classes=10,
165 | depth=18,
166 | width=0,
167 | activation_fn=Swish,
168 | mean=CIFAR10_MEAN,
169 | std=CIFAR10_STD)
170 |
171 | classifier = LinearClassifier(num_classes=classes,num_dim=ext.num_out).cuda()
172 | ssh = SupConResNet().cuda()
173 | elif args.net == "standard":
174 | from models.wide import WideResNet
175 | ext=WideResNet(depth=28, widen_factor=10)
176 |
177 | classifier = LinearClassifier(num_classes=classes,num_dim=ext.num_out).cuda()
178 | ssh = SupConResNet().cuda()
179 | elif args.net == "resnet18":
180 | classifier = LinearClassifier(num_classes=classes,num_dim=512).cuda()
181 | ssh = SupConResNet(name='resnet18').cuda()
182 | import torchvision.models as models
183 | ext = models.resnet18().cuda()
184 | ext.fc=nn.Sequential().cuda()
185 | else:
186 | classifier = LinearClassifier(num_classes=classes).cuda()
187 | ssh = SupConResNet().cuda()
188 | ext = ssh.encoder
189 |
190 | head = ssh.head
191 | #
192 | net = ExtractorHead(ext, classifier).cuda()
193 | return net, ext, head, ssh, classifier
194 |
195 |
196 | def build_model(args):
197 | from models.ResNet import ResNetCifar as ResNet
198 | from models.SSHead import ExtractorHead
199 | print('Building model...')
200 | if args.dataset == 'cifar10':
201 | classes = 10
202 | elif args.dataset == 'cifar7':
203 | if not hasattr(args, 'modified') or args.modified:
204 | classes = 7
205 | else:
206 | classes = 10
207 | elif args.dataset == "cifar100":
208 | classes = 100
209 |
210 | if args.group_norm == 0:
211 | norm_layer = nn.BatchNorm2d
212 | else:
213 | def gn_helper(planes):
214 | return nn.GroupNorm(args.group_norm, planes)
215 | norm_layer = gn_helper
216 |
217 | if hasattr(args, 'detach') and args.detach:
218 | detach = args.shared
219 | else:
220 | detach = None
221 | net = ResNet(args.depth, args.width, channels=3, classes=classes, norm_layer=norm_layer, detach=detach).cuda()
222 | if args.shared == 'none':
223 | args.shared = None
224 |
225 | if args.shared == 'layer3' or args.shared is None:
226 | from models.SSHead import extractor_from_layer3
227 | ext = extractor_from_layer3(net)
228 | if not hasattr(args, 'ssl') or args.ssl == 'rotation':
229 | head = nn.Linear(64 * args.width, 4)
230 | elif args.ssl == 'contrastive':
231 | head = nn.Sequential(
232 | nn.Linear(64 * args.width, 64 * args.width),
233 | nn.ReLU(inplace=True),
234 | nn.Linear(64 * args.width, 16 * args.width)
235 | )
236 | else:
237 | raise NotImplementedError
238 | elif args.shared == 'layer2':
239 | from models.SSHead import extractor_from_layer2, head_on_layer2
240 | ext = extractor_from_layer2(net)
241 | head = head_on_layer2(net, args.width, 4)
242 | ssh = ExtractorHead(ext, head).cuda()
243 |
244 | if hasattr(args, 'parallel') and args.parallel:
245 | net = torch.nn.DataParallel(net)
246 | ssh = torch.nn.DataParallel(ssh)
247 | return net, ext, head, ssh
248 |
249 |
250 | def test(dataloader, model, **kwargs):
251 | criterion = nn.CrossEntropyLoss(reduction='none').cuda()
252 | model.eval()
253 | correct = []
254 | losses = []
255 | for batch_idx, (inputs, labels) in enumerate(dataloader):
256 | if type(inputs) == list:
257 | inputs = inputs[0]
258 | inputs, labels = inputs.cuda(), labels.cuda()
259 | with torch.no_grad():
260 | outputs = model(inputs, **kwargs)
261 | _, predicted = outputs.max(1)
262 | correct.append(predicted.eq(labels).cpu())
263 | correct = torch.cat(correct).numpy()
264 | model.train()
265 | return 1-correct.mean(), correct, losses
266 |
267 | def prototype_test(dataloader, ext,prototype, **kwargs):
268 | # criterion = nn.CrossEntropyLoss(reduction='none').cuda()
269 | ext.eval()
270 | correct = []
271 | losses = []
272 | for batch_idx, (inputs, labels) in enumerate(dataloader):
273 | if type(inputs) == list:
274 | inputs = inputs[0]
275 | inputs, labels = inputs.cuda(), labels.cuda()
276 | with torch.no_grad():
277 | feat = ext(inputs, **kwargs)
278 | outputs = torch.mm(torch.nn.functional.normalize(feat), prototype.t())
279 | _, predicted = outputs.max(1)
280 | correct.append(predicted.eq(labels).cpu())
281 | correct = torch.cat(correct).numpy()
282 | ext.train()
283 | return 1-correct.mean(), correct, losses
284 |
285 |
286 | def pair_buckets(o1, o2):
287 | crr = np.logical_and( o1, o2 )
288 | crw = np.logical_and( o1, np.logical_not(o2) )
289 | cwr = np.logical_and( np.logical_not(o1), o2 )
290 | cww = np.logical_and( np.logical_not(o1), np.logical_not(o2) )
291 | return crr, crw, cwr, cww
292 |
293 |
294 | def count_each(tuple):
295 | return [item.sum() for item in tuple]
296 |
297 |
298 | def plot_epochs(all_err_cls, all_err_ssh, fname, use_agg=True):
299 | import matplotlib.pyplot as plt
300 | if use_agg:
301 | plt.switch_backend('agg')
302 |
303 | plt.plot(np.asarray(all_err_cls)*100, color='r', label='classifier')
304 | plt.plot(np.asarray(all_err_ssh)*100, color='b', label='self-supervised')
305 | plt.xlabel('epoch')
306 | plt.ylabel('test error (%)')
307 | plt.legend()
308 | plt.savefig(fname)
309 | plt.close()
310 |
311 |
312 | @torch.jit.script
313 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor:
314 | """Entropy of softmax distribution from logits."""
315 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
316 |
317 |
--------------------------------------------------------------------------------
/cifar/models/dm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Deepmind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """WideResNet implementation in PyTorch. From:
16 | https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py
17 | """
18 |
19 | from typing import Tuple, Type, Union
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 |
25 | CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
26 | CIFAR10_STD = (0.2471, 0.2435, 0.2616)
27 | CIFAR100_MEAN = (0.5071, 0.4865, 0.4409)
28 | CIFAR100_STD = (0.2673, 0.2564, 0.2762)
29 |
30 |
31 | class _Swish(torch.autograd.Function):
32 | """Custom implementation of swish."""
33 |
34 | @staticmethod
35 | def forward(ctx, i):
36 | result = i * torch.sigmoid(i)
37 | ctx.save_for_backward(i)
38 | return result
39 |
40 | @staticmethod
41 | def backward(ctx, grad_output):
42 | i = ctx.saved_variables[0]
43 | sigmoid_i = torch.sigmoid(i)
44 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
45 |
46 |
47 | class Swish(nn.Module):
48 | """Module using custom implementation."""
49 |
50 | def forward(self, input_tensor):
51 | return _Swish.apply(input_tensor)
52 |
53 |
54 | class _Block(nn.Module):
55 | """WideResNet Block."""
56 |
57 | def __init__(self,
58 | in_planes,
59 | out_planes,
60 | stride,
61 | activation_fn: Type[nn.Module] = nn.ReLU):
62 | super().__init__()
63 | self.batchnorm_0 = nn.BatchNorm2d(in_planes)
64 | self.relu_0 = activation_fn()
65 | # We manually pad to obtain the same effect as `SAME` (necessary when
66 | # `stride` is different than 1).
67 | self.conv_0 = nn.Conv2d(in_planes,
68 | out_planes,
69 | kernel_size=3,
70 | stride=stride,
71 | padding=0,
72 | bias=False)
73 | self.batchnorm_1 = nn.BatchNorm2d(out_planes)
74 | self.relu_1 = activation_fn()
75 | self.conv_1 = nn.Conv2d(out_planes,
76 | out_planes,
77 | kernel_size=3,
78 | stride=1,
79 | padding=1,
80 | bias=False)
81 | self.has_shortcut = in_planes != out_planes
82 | if self.has_shortcut:
83 | self.shortcut = nn.Conv2d(in_planes,
84 | out_planes,
85 | kernel_size=1,
86 | stride=stride,
87 | padding=0,
88 | bias=False)
89 | else:
90 | self.shortcut = None
91 | self._stride = stride
92 |
93 | def forward(self, x):
94 | if self.has_shortcut:
95 | x = self.relu_0(self.batchnorm_0(x))
96 | else:
97 | out = self.relu_0(self.batchnorm_0(x))
98 | v = x if self.has_shortcut else out
99 | if self._stride == 1:
100 | v = F.pad(v, (1, 1, 1, 1))
101 | elif self._stride == 2:
102 | v = F.pad(v, (0, 1, 0, 1))
103 | else:
104 | raise ValueError('Unsupported `stride`.')
105 | out = self.conv_0(v)
106 | out = self.relu_1(self.batchnorm_1(out))
107 | out = self.conv_1(out)
108 | out = torch.add(self.shortcut(x) if self.has_shortcut else x, out)
109 | return out
110 |
111 |
112 | class _BlockGroup(nn.Module):
113 | """WideResNet block group."""
114 |
115 | def __init__(self,
116 | num_blocks,
117 | in_planes,
118 | out_planes,
119 | stride,
120 | activation_fn: Type[nn.Module] = nn.ReLU):
121 | super().__init__()
122 | block = []
123 | for i in range(num_blocks):
124 | block.append(
125 | _Block(i == 0 and in_planes or out_planes,
126 | out_planes,
127 | i == 0 and stride or 1,
128 | activation_fn=activation_fn))
129 | self.block = nn.Sequential(*block)
130 |
131 | def forward(self, x):
132 | return self.block(x)
133 |
134 |
135 | class DMWideResNet(nn.Module):
136 | """WideResNet."""
137 |
138 | def __init__(self,
139 | num_classes: int = 10,
140 | depth: int = 28,
141 | width: int = 10,
142 | activation_fn: Type[nn.Module] = nn.ReLU,
143 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
144 | std: Union[Tuple[float, ...], float] = CIFAR10_STD,
145 | padding: int = 0,
146 | num_input_channels: int = 3):
147 | super().__init__()
148 | # persistent=False to not put these tensors in the module's state_dict and not try to
149 | # load it from the checkpoint
150 | self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1),
151 | persistent=False)
152 | self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1),
153 | persistent=False)
154 | self.padding = padding
155 | num_channels = [16, 16 * width, 32 * width, 64 * width]
156 | self.num_out=num_channels[3]
157 | assert (depth - 4) % 6 == 0
158 | num_blocks = (depth - 4) // 6
159 | self.init_conv = nn.Conv2d(num_input_channels,
160 | num_channels[0],
161 | kernel_size=3,
162 | stride=1,
163 | padding=1,
164 | bias=False)
165 | self.layer = nn.Sequential(
166 | _BlockGroup(num_blocks,
167 | num_channels[0],
168 | num_channels[1],
169 | 1,
170 | activation_fn=activation_fn),
171 | _BlockGroup(num_blocks,
172 | num_channels[1],
173 | num_channels[2],
174 | 2,
175 | activation_fn=activation_fn),
176 | _BlockGroup(num_blocks,
177 | num_channels[2],
178 | num_channels[3],
179 | 2,
180 | activation_fn=activation_fn))
181 | self.batchnorm = nn.BatchNorm2d(num_channels[3])
182 | self.relu = activation_fn()
183 | # self.logits = nn.Linear(num_channels[3], num_classes)
184 | self.num_channels = num_channels[3]
185 |
186 | def forward(self, x):
187 | if self.padding > 0:
188 | x = F.pad(x, (self.padding,) * 4)
189 | out = (x - self.mean) / self.std
190 | out = self.init_conv(out)
191 | out = self.layer(out)
192 | out = self.relu(self.batchnorm(out))
193 | out = F.avg_pool2d(out, 8)
194 | out = out.view(-1, self.num_channels)
195 | return out
196 |
197 |
198 | class _PreActBlock(nn.Module):
199 | """Pre-activation ResNet Block."""
200 |
201 | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU):
202 | super().__init__()
203 | self._stride = stride
204 | self.batchnorm_0 = nn.BatchNorm2d(in_planes)
205 | self.relu_0 = activation_fn()
206 | # We manually pad to obtain the same effect as `SAME` (necessary when
207 | # `stride` is different than 1).
208 | self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3,
209 | stride=stride, padding=0, bias=False)
210 | self.batchnorm_1 = nn.BatchNorm2d(out_planes)
211 | self.relu_1 = activation_fn()
212 | self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
213 | padding=1, bias=False)
214 | self.has_shortcut = stride != 1 or in_planes != out_planes
215 | if self.has_shortcut:
216 | self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3,
217 | stride=stride, padding=0, bias=False)
218 |
219 | def _pad(self, x):
220 | if self._stride == 1:
221 | x = F.pad(x, (1, 1, 1, 1))
222 | elif self._stride == 2:
223 | x = F.pad(x, (0, 1, 0, 1))
224 | else:
225 | raise ValueError('Unsupported `stride`.')
226 | return x
227 |
228 | def forward(self, x):
229 | out = self.relu_0(self.batchnorm_0(x))
230 | shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x
231 | out = self.conv_2d_1(self._pad(out))
232 | out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out)))
233 | return out + shortcut
234 |
235 |
236 | class DMPreActResNet(nn.Module):
237 | """Pre-activation ResNet."""
238 |
239 | def __init__(self,
240 | num_classes: int = 10,
241 | depth: int = 18,
242 | width: int = 0, # Used to make the constructor consistent.
243 | activation_fn: Type[nn.Module] = nn.ReLU,
244 | mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN,
245 | std: Union[Tuple[float, ...], float] = CIFAR10_STD,
246 | padding: int = 0,
247 | num_input_channels: int = 3,
248 | use_cuda: bool = True):
249 | super().__init__()
250 | if width != 0:
251 | raise ValueError('Unsupported `width`.')
252 | # persistent=False to not put these tensors in the module's state_dict and not try to
253 | # load it from the checkpoint
254 | self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1),
255 | persistent=False)
256 | self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1),
257 | persistent=False)
258 | self.mean_cuda = None
259 | self.std_cuda = None
260 | self.padding = padding
261 | self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1,
262 | padding=1, bias=False)
263 | if depth == 18:
264 | num_blocks = (2, 2, 2, 2)
265 | elif depth == 34:
266 | num_blocks = (3, 4, 6, 3)
267 | else:
268 | raise ValueError('Unsupported `depth`.')
269 | self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn)
270 | self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn)
271 | self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn)
272 | self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn)
273 | self.batchnorm = nn.BatchNorm2d(512)
274 | self.relu = activation_fn()
275 | self.num_out = 512
276 | # self.logits = nn.Linear(512, num_classes)
277 |
278 | def _make_layer(self, in_planes, out_planes, num_blocks, stride,
279 | activation_fn):
280 | layers = []
281 | for i, stride in enumerate([stride] + [1] * (num_blocks - 1)):
282 | layers.append(
283 | _PreActBlock(i == 0 and in_planes or out_planes,
284 | out_planes,
285 | stride,
286 | activation_fn))
287 | return nn.Sequential(*layers)
288 |
289 | def forward(self, x):
290 | if self.padding > 0:
291 | x = F.pad(x, (self.padding,) * 4)
292 | out = (x - self.mean) / self.std
293 | out = self.conv_2d(out)
294 | out = self.layer_0(out)
295 | out = self.layer_1(out)
296 | out = self.layer_2(out)
297 | out = self.layer_3(out)
298 | out = self.relu(self.batchnorm(out))
299 | out = F.avg_pool2d(out, 4)
300 | out = out.view(out.size(0), -1)
301 | # return self.logits(out)
302 | return out
--------------------------------------------------------------------------------
/imagenet/utils/test_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model.resnet import SupCEResNet
3 | import os
4 |
5 | def build_model():
6 | print("Building ResNet50...")
7 | model = SupCEResNet().cuda()
8 | ext = model.encoder
9 | classifier = model.fc
10 | return model, ext, classifier
11 |
12 |
13 | def test(dataloader, model, **kwargs):
14 | model.eval()
15 | correct = []
16 | for batch_idx, (inputs, labels) in enumerate(dataloader):
17 | if type(inputs) == list:
18 | inputs = inputs[0]
19 | inputs, labels = inputs.cuda(), labels.cuda()
20 | with torch.no_grad():
21 | outputs = model(inputs, **kwargs)
22 | _, predicted = outputs.max(1)
23 | correct.append(predicted.eq(labels).cpu())
24 | correct = torch.cat(correct).numpy()
25 | model.train()
26 | return 1-correct.mean(), correct
27 |
28 | def my_makedir(name):
29 | try:
30 | os.makedirs(name)
31 | except OSError:
32 | pass
33 |
34 | all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']
35 |
36 | imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'}
--------------------------------------------------------------------------------
/cifar/OURS.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | import torch.utils.data as data
5 |
6 | from utils.misc import *
7 | from utils.test_helpers import *
8 | from utils.prepare_dataset import *
9 |
10 | # ----------------------------------
11 | import copy
12 | import random
13 | import numpy as np
14 | from utils.contrastive import *
15 | from utils.offline import *
16 | from torch import nn
17 | import torch.nn.functional as F
18 | # ----------------------------------
19 |
20 |
21 | def compute_os_variance(os, th):
22 | """
23 | Calculate the area of a rectangle.
24 |
25 | Parameters:
26 | os : OOD score queue.
27 | th : Given threshold to separate weak and strong OOD samples.
28 |
29 | Returns:
30 | float: Weighted variance at the given threshold th.
31 | """
32 |
33 | thresholded_os = np.zeros(os.shape)
34 | thresholded_os[os >= th] = 1
35 |
36 | # compute weights
37 | nb_pixels = os.size
38 | nb_pixels1 = np.count_nonzero(thresholded_os)
39 | weight1 = nb_pixels1 / nb_pixels
40 | weight0 = 1 - weight1
41 |
42 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered
43 | # in the search for the best threshold
44 | if weight1 == 0 or weight0 == 0:
45 | return np.inf
46 |
47 | # find all pixels belonging to each class
48 | val_pixels1 = os[thresholded_os == 1]
49 | val_pixels0 = os[thresholded_os == 0]
50 |
51 | # compute variance of these classes
52 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0
53 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0
54 |
55 | return weight0 * var0 + weight1 * var1
56 |
57 |
58 |
59 |
60 | class Prototype_Pool(nn.Module):
61 |
62 | """
63 | Prototype pool containing strong OOD prototypes.
64 |
65 | Methods:
66 | __init__: Constructor method to initialize the prototype pool, storing the values of delta, the number of weak OOD categories, and the maximum count of strong OOD prototypes.
67 | forward: Method to farward pass, return the cosine similarity with strong OOD prototypes.
68 | update_pool: Method to append and delete strong OOD prototypes.
69 | """
70 |
71 |
72 | def __init__(self, delta=0.1, class_num=10, max=100):
73 | super(Prototype_Pool, self).__init__()
74 |
75 | self.class_num=class_num
76 | self.max_length = max
77 | self.flag = 0
78 | self.delta = delta
79 |
80 |
81 | def forward(self, x, all=False):
82 |
83 | # if the flag is 0, the prototype pool is empty, return None.
84 | if not self.flag:
85 | return None
86 |
87 | # compute the cosine similarity between the features and the strong OOD prototypes.
88 | out = torch.mm(x, self.memory.t())
89 |
90 | if all==True:
91 | # if all is True, return the cosine similarity with all the strong OOD prototypes.
92 | return out
93 | else:
94 | # if all is False, return the cosine similarity with the nearest strong OOD prototype.
95 | return torch.max(out/(self.delta),dim=1)[0].unsqueeze(1)
96 |
97 |
98 | def update_pool(self, feature):
99 |
100 | if not self.flag:
101 | # if the flag is 0, the prototype pool is empty, use the feature to init the prototype pool.
102 | self.register_buffer('memory', feature.detach())
103 | self.flag = 1
104 | else:
105 | if self.memory.shape[0] < self.max_length:
106 | # if the number of strong OOD prototypes is less than the maximum count of strong OOD prototypes, append the feature to the prototype pool.
107 | self.memory = torch.cat([self.memory, feature.detach()],dim=0)
108 | else:
109 | # else then delete the earlest appended strong OOD prototype and append the feature to the prototype pool.
110 | self.memory = torch.cat([self.memory[1:], feature.detach()],dim=0)
111 | self.memory = F.normalize(self.memory)
112 |
113 |
114 | def append_prototypes(pool, feat_ext, logit, ts, ts_pro):
115 | """
116 | Append strong OOD prototypes to the prototype pool.
117 |
118 | Parameters:
119 | pool : Prototype pool.
120 | feat_ext : Normalized features of the input images.
121 | logit : Cosine similarity between the features and the weak OOD prototypes.
122 | ts : Threshold to separate weak and strong OOD samples.
123 | ts_pro : Threshold to append strong OOD prototypes.
124 |
125 | """
126 | added_list=[]
127 | update = 1
128 |
129 | while update:
130 | feat_mat = pool(F.normalize(feat_ext),all=True)
131 | if not feat_mat==None:
132 | new_logit = torch.cat([logit, feat_mat], 1)
133 | else:
134 | new_logit = logit
135 |
136 | r_i_pro, _ = new_logit.max(dim=-1)
137 |
138 | r_i, _ = logit.max(dim=-1)
139 |
140 | if added_list!=[]:
141 | for add in added_list:
142 | # if added_list is not empty, set the cosine similarity between the added features and the strong OOD prototypes to 1, to avoid the added features to be appended to the prototype pool again.
143 | r_i[add]=1
144 | min_logit , min_index = r_i.min(dim=0)
145 |
146 |
147 | if (1-min_logit) > ts :
148 | # if the cosine similarity between the feature and the weak OOD prototypes is less than the threshold ts, the feature is a strong OOD sample.
149 | added_list.append(min_index)
150 | if (1-r_i_pro[min_index]) > ts_pro:
151 | # if this strong OOD sample is far away from all the strong OOD prototypes, append it to the prototype pool.
152 | pool.update_pool(F.normalize(feat_ext[min_index].unsqueeze(0)))
153 | else:
154 | # all the features are weak OOD samples, stop the loop.
155 | update=0
156 |
157 |
158 | parser = argparse.ArgumentParser()
159 | parser.add_argument('--dataset', default='cifar10OOD')
160 | parser.add_argument('--strong_OOD', default='noise')
161 | parser.add_argument('--strong_ratio', default=1, type=float)
162 | parser.add_argument('--dataroot', default="./data", help='path to dataset')
163 | parser.add_argument('--batch_size', default=256, type=int)
164 | parser.add_argument('--workers', default=4, type=int)
165 | parser.add_argument('--lr', default=0.001, type=float)
166 | parser.add_argument('--delta', default=0.1, type=float)
167 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale')
168 | parser.add_argument('--outf', help='folder to output log')
169 | parser.add_argument('--level', default=5, type=int)
170 | parser.add_argument('--N_m', default=512, type=int, help='queue length')
171 | parser.add_argument('--corruption', default='snow')
172 | parser.add_argument('--resume', default='/cluster/personal/code/TTT/TTAC-master/cifar/results/cifar10_joint_resnet50', help='directory of pretrained model')
173 | parser.add_argument('--da_scale', default=1, type=float, help='distribution alignment loss scale')
174 | parser.add_argument('--model', default='resnet50', help='resnet50')
175 | parser.add_argument('--seed', default=0, type=int)
176 | parser.add_argument('--max_prototypes', default=100, type=int)
177 | parser.add_argument('--save', action='store_true', default=False, help='save the model final checkpoint')
178 |
179 |
180 | # ----------- Args and Dataloader ------------
181 | args = parser.parse_args()
182 |
183 | print(args)
184 | print('\n')
185 |
186 |
187 |
188 |
189 | class_num = 10 if args.dataset == 'cifar10OOD' else 100
190 |
191 | net, ext, head, ssh, classifier = build_resnet50(args)
192 |
193 | teset, _ = prepare_test_data(args)
194 | teloader = data.DataLoader(teset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, worker_init_fn=seed_worker, pin_memory=True, drop_last=False)
195 |
196 | pool = Prototype_Pool(args.delta,class_num=class_num,max = args.max_prototypes).cuda()
197 |
198 | # -------------------------------
199 | print('Resuming from %s...' %(args.resume))
200 |
201 | load_resnet50(net, head, ssh, classifier, args)
202 |
203 | optimizer = optim.SGD(ext.parameters(), lr=args.lr, momentum=0.9)
204 |
205 | # ----------- Offline Feature Summarization ------------
206 | args_align = copy.deepcopy(args)
207 |
208 | _, offlineloader = prepare_train_data(args_align)
209 | ext_src_mu, ext_src_cov, ssh_src_mu, ssh_src_cov, mu_src_ext, cov_src_ext, mu_src_ssh, cov_src_ssh = offline(args,offlineloader, ext, classifier, head, class_num)
210 |
211 | ext_src_mu = torch.stack(ext_src_mu)
212 | ext_src_cov = torch.stack(ext_src_cov)
213 |
214 | ema_ext_mu = ext_src_mu.clone()
215 | ema_ext_cov = ext_src_cov.clone()
216 | ema_ext_total_mu = torch.zeros(2048).float()
217 | ema_ext_total_cov = torch.zeros(2048, 2048).float()
218 |
219 | if class_num == 10:
220 | loss_scale = 0.05
221 | ema_length = 128
222 | else:
223 | loss_scale = 0.05
224 | ema_length = 64
225 |
226 | ema_n = torch.zeros(class_num).cuda()
227 | ema_total_n = 0.
228 | weak_prototype = F.normalize(ext_src_mu.clone()).cuda()
229 | args.ts_pro = 0.0
230 | bias = cov_src_ext.max().item() / 30.
231 | template_ext_cov = torch.eye(2048).cuda() * bias
232 |
233 | torch.manual_seed(args.seed)
234 | random.seed(args.seed)
235 | np.random.seed(args.seed)
236 | torch.cuda.manual_seed(args.seed)
237 | torch.cuda.manual_seed_all(args.seed)
238 |
239 | # ----------- Open-World Test-time Training ------------
240 |
241 | correct = []
242 | unseen_correct= []
243 | all_correct=[]
244 | cumulative_error = []
245 | num_open = 0
246 | predicted_list=[]
247 | label_list=[]
248 |
249 | os_training_queue = []
250 | os_inference_queue = []
251 | queue_length = args.N_m
252 | ce_scale = args.ce_scale
253 |
254 | ema_total_n = 0.
255 |
256 | print('\n-----Test-Time Training with OURS-----')
257 | for te_idx, (te_inputs, te_labels) in enumerate(teloader):
258 | classifier.eval()
259 | ext.eval()
260 |
261 | optimizer.zero_grad()
262 | loss = torch.tensor(0.).cuda()
263 |
264 | if isinstance(te_inputs,list):
265 | inputs = te_inputs[0].cuda()
266 | else:
267 | inputs = te_inputs.cuda()
268 |
269 | # features extracted by backbone
270 | feat_ext = ext(inputs)
271 |
272 | # logits of the input images, used to compute the cosine similarity between the features and the weak OOD prototypes.
273 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) / args.delta
274 |
275 |
276 | # compute the cosine similarity between the features and the strong OOD prototypes.
277 | feat_mat = pool(F.normalize(feat_ext))
278 | if not feat_mat==None:
279 | new_logit = torch.cat([logit, feat_mat], 1)
280 | else:
281 | new_logit = logit
282 |
283 | pro, predicted = new_logit[:,:class_num].max(dim=-1)
284 |
285 | # compute the ood score of the input images.
286 | ood_score = 1-pro*args.delta
287 | os_training_queue.extend(ood_score.detach().cpu().tolist())
288 | os_training_queue = os_training_queue[-queue_length:]
289 |
290 |
291 | threshold_range = np.arange(0,1,0.01)
292 | criterias = [compute_os_variance(np.array(os_training_queue), th) for th in threshold_range]
293 |
294 | # best threshold is the one minimizing the variance of the two classes
295 | best_threshold = threshold_range[np.argmin(criterias)]
296 | args.ts = best_threshold
297 | seen_mask = (ood_score < args.ts)
298 | unseen_mask = (ood_score >= args.ts)
299 | r_i, pseudo_labels = new_logit.max(dim=-1)
300 |
301 | if unseen_mask.sum().item()!=0:
302 | #compute ts_pro to append new strong OOD prototypes to the prototype pool.
303 |
304 | min_logit , min_index = r_i.min(dim=0)
305 |
306 | in_score = 1-r_i*args.delta
307 | threshold_range = np.arange(0,1,0.01)
308 | criterias = [compute_os_variance(in_score[unseen_mask].detach().cpu().numpy(), th) for th in threshold_range]
309 |
310 | best_threshold = threshold_range[np.argmin(criterias)]
311 | args.ts_pro = best_threshold
312 |
313 | # append new strong OOD prototypes to the prototype pool.
314 | append_prototypes(pool, feat_ext, logit.detach()*args.delta, args.ts, args.ts_pro)
315 |
316 | len_memory = len(new_logit[0])
317 |
318 |
319 | if len_memory!=class_num:
320 |
321 | if seen_mask.sum().item()!=0:
322 | pseudo_labels[seen_mask] = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1]
323 | if unseen_mask.sum().item()!=0:
324 | pseudo_labels[unseen_mask] = class_num
325 | else:
326 | pseudo_labels = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1]
327 |
328 |
329 | # ------distribuution alignment------
330 | if seen_mask.sum().item()!=0:
331 | ext.train()
332 | feat_global = ext(inputs[seen_mask])
333 | # Global Gaussian
334 | b = feat_global.shape[0]
335 | ema_total_n += b
336 | alpha = 1. / 1280 if ema_total_n > 1280 else 1. / ema_total_n
337 | delta_pre = (feat_global - ema_ext_total_mu.cuda())
338 | delta = alpha * delta_pre.sum(dim=0)
339 | tmp_mu = ema_ext_total_mu.cuda() + delta
340 | tmp_cov = ema_ext_total_cov.cuda() + alpha * (delta_pre.t() @ delta_pre - b * ema_ext_total_cov.cuda()) - delta[:, None] @ delta[None, :]
341 | with torch.no_grad():
342 | ema_ext_total_mu = tmp_mu.detach().cpu()
343 | ema_ext_total_cov = tmp_cov.detach().cpu()
344 |
345 | source_domain = torch.distributions.MultivariateNormal(mu_src_ext, cov_src_ext + template_ext_cov)
346 | target_domain = torch.distributions.MultivariateNormal(tmp_mu, tmp_cov + template_ext_cov)
347 | loss += args.da_scale*(torch.distributions.kl_divergence(source_domain, target_domain) + torch.distributions.kl_divergence(target_domain, source_domain)) * loss_scale
348 |
349 |
350 | # we only use 50% of samples with ood score far from τ∗ to perform prototype clustering for each batch
351 | if len_memory!=class_num and seen_mask.sum().item()!=0 and unseen_mask.sum().item()!=0:
352 | a, idx1 = torch.sort((ood_score[seen_mask]), descending=True)
353 | filter_down = a[-int(seen_mask.sum().item()*(1/2))]
354 | a, idx1 = torch.sort((ood_score[unseen_mask]), descending=True)
355 | filter_up= a[int(unseen_mask.sum().item()*(1/2))]
356 | for j in range(len(pseudo_labels)):
357 |
358 | if ood_score[j] >=filter_down and seen_mask[j]:
359 | seen_mask[j]=False
360 | if ood_score[j] <=filter_up and unseen_mask[j]:
361 | unseen_mask[j]=False
362 |
363 |
364 | if len_memory!=class_num:
365 | entropy_seen = nn.CrossEntropyLoss()(new_logit[seen_mask,:class_num],pseudo_labels[seen_mask])
366 | entropy_unseen= nn.CrossEntropyLoss()(new_logit[unseen_mask],pseudo_labels[unseen_mask])
367 | loss += ce_scale*(entropy_seen+ entropy_unseen)/2
368 |
369 | try:
370 | loss.backward()
371 | optimizer.step()
372 | optimizer.zero_grad()
373 | except:
374 | print('can not backward')
375 | torch.cuda.empty_cache()
376 |
377 |
378 |
379 | ####-------------------------- Test ----------------------------####
380 |
381 | with torch.no_grad():
382 |
383 | net.eval()
384 | feat_ext = ext(inputs) #b,2048
385 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())/args.delta
386 | update = 1
387 |
388 |
389 | softmax_logit = logit.softmax(dim=-1)
390 | # _, recall_predicted = softmax_logit.max(1)
391 | pro, predicted = softmax_logit.max(dim=-1)
392 |
393 | ood_score, max_index = logit.max(1)
394 | ood_score = 1-ood_score*args.delta
395 | os_inference_queue.extend(ood_score.detach().cpu().tolist())
396 | os_inference_queue = os_inference_queue[-queue_length:]
397 |
398 | threshold_range = np.arange(0,1,0.01)
399 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range]
400 | best_threshold = threshold_range[np.argmin(criterias)]
401 | unseen_mask = (ood_score > best_threshold)
402 | args.ts = best_threshold
403 | predicted[unseen_mask] = class_num
404 |
405 | one = torch.ones_like(te_labels)*class_num
406 | false = torch.ones_like(te_labels)*-1
407 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted)
408 | all_labels = torch.where(te_labels>class_num-1, one, te_labels)
409 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels)
410 | unseen_labels = torch.where(te_labels>class_num-1, one, false)
411 | correct.append(predicted.cpu().eq(seen_labels))
412 | unseen_correct.append(predicted.cpu().eq(unseen_labels))
413 | all_correct.append(predicted.cpu().eq(all_labels))
414 | num_open += torch.gt(te_labels, 99).sum()
415 |
416 | predicted_list.append(predicted.long().cpu())
417 | label_list.append(all_labels.long().cpu())
418 |
419 |
420 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4)
421 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4)
422 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4)
423 | print('Batch:(', te_idx,'/',len(teloader), ')\tloss:',"%.2f" % loss.item(),\
424 | '\t Cumulative Results: ACC_S:', seen_acc,\
425 | '\tACC_N:', unseen_acc,\
426 | '\tACC_H:',h_score\
427 | )
428 |
429 |
430 | print('\nTest time training result:',' ACC_S:', seen_acc,\
431 | '\tACC_N:', unseen_acc,\
432 | '\tACC_H:',h_score,'\n\n\n\n'\
433 | )
434 |
435 |
436 | if args.outf != None:
437 | my_makedir(args.outf)
438 | with open (args.outf+'/results.txt','a') as f:
439 | f.write(str(args)+'\n')
440 | f.write(
441 | 'ACC_S:'+ str(seen_acc)+\
442 | '\tACC_N:'+ str(unseen_acc)+\
443 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\
444 | )
445 | if args.save:
446 | torch.save(net.state_dict(), os.path.join(args.outf, 'final.pth'))
--------------------------------------------------------------------------------
/imagenet/OURS.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | import torch.utils.data as data
5 |
6 | import torch.nn as nn
7 | from utils.test_helpers import *
8 | from utils.prepare_dataset import *
9 |
10 | # ----------------------------------
11 | import copy
12 | import random
13 | import numpy as np
14 |
15 | from utils.test_helpers import build_model, test
16 | from utils.prepare_dataset import prepare_transforms, create_dataloader, ImageNetCorruption, ImageNet_, prepare_ood_test_data,prepare_ood_test_data_r
17 | from utils.offline import offline, offline_r
18 | import torch.nn.functional as F
19 | # ----------------------------------
20 |
21 |
22 | def compute_os_variance(os, th):
23 | """
24 | Calculate the area of a rectangle.
25 |
26 | Parameters:
27 | os : OOD score queue.
28 | th : Given threshold to separate weak and strong OOD samples.
29 |
30 | Returns:
31 | float: Weighted variance at the given threshold th.
32 | """
33 |
34 | thresholded_os = np.zeros(os.shape)
35 | thresholded_os[os >= th] = 1
36 |
37 | # compute weights
38 | nb_pixels = os.size
39 | nb_pixels1 = np.count_nonzero(thresholded_os)
40 | weight1 = nb_pixels1 / nb_pixels
41 | weight0 = 1 - weight1
42 |
43 | # if one the classes is empty, eg all pixels are below or above the threshold, that threshold will not be considered
44 | # in the search for the best threshold
45 | if weight1 == 0 or weight0 == 0:
46 | return np.inf
47 |
48 | # find all pixels belonging to each class
49 | val_pixels1 = os[thresholded_os == 1]
50 | val_pixels0 = os[thresholded_os == 0]
51 |
52 | # compute variance of these classes
53 | var0 = np.var(val_pixels0) if len(val_pixels0) > 0 else 0
54 | var1 = np.var(val_pixels1) if len(val_pixels1) > 0 else 0
55 |
56 | return weight0 * var0 + weight1 * var1
57 |
58 |
59 |
60 |
61 | class Prototype_Pool(nn.Module):
62 |
63 | """
64 | Prototype pool containing strong OOD prototypes.
65 |
66 | Methods:
67 | __init__: Constructor method to initialize the prototype pool, storing the values of delta, the number of weak OOD categories, and the maximum count of strong OOD prototypes.
68 | forward: Method to farward pass, return the cosine similarity with strong OOD prototypes.
69 | update_pool: Method to append and delete strong OOD prototypes.
70 | """
71 |
72 |
73 | def __init__(self, delta=0.1, class_num=10, max=100):
74 | super(Prototype_Pool, self).__init__()
75 |
76 | self.class_num=class_num
77 | self.max_length = max
78 | self.flag = 0
79 | self.delta = delta
80 |
81 |
82 | def forward(self, x, all=False):
83 |
84 | # if the flag is 0, the prototype pool is empty, return None.
85 | if not self.flag:
86 | return None
87 |
88 | # compute the cosine similarity between the features and the strong OOD prototypes.
89 | out = torch.mm(x, self.memory.t())
90 |
91 | if all==True:
92 | # if all is True, return the cosine similarity with all the strong OOD prototypes.
93 | return out
94 | else:
95 | # if all is False, return the cosine similarity with the nearest strong OOD prototype.
96 | return torch.max(out/(self.delta),dim=1)[0].unsqueeze(1)
97 |
98 |
99 | def update_pool(self, feature):
100 |
101 | if not self.flag:
102 | # if the flag is 0, the prototype pool is empty, use the feature to init the prototype pool.
103 | self.register_buffer('memory', feature.detach())
104 | self.flag = 1
105 | else:
106 | if self.memory.shape[0] < self.max_length:
107 | # if the number of strong OOD prototypes is less than the maximum count of strong OOD prototypes, append the feature to the prototype pool.
108 | self.memory = torch.cat([self.memory, feature.detach()],dim=0)
109 | else:
110 | # else then delete the earlest appended strong OOD prototype and append the feature to the prototype pool.
111 | self.memory = torch.cat([self.memory[1:], feature.detach()],dim=0)
112 | self.memory = F.normalize(self.memory)
113 |
114 |
115 | def append_prototypes(pool, feat_ext, logit, ts, ts_pro):
116 | """
117 | Append strong OOD prototypes to the prototype pool.
118 |
119 | Parameters:
120 | pool : Prototype pool.
121 | feat_ext : Normalized features of the input images.
122 | logit : Cosine similarity between the features and the weak OOD prototypes.
123 | ts : Threshold to separate weak and strong OOD samples.
124 | ts_pro : Threshold to append strong OOD prototypes.
125 |
126 | """
127 | added_list=[]
128 | update = 1
129 |
130 | while update:
131 | feat_mat = pool(F.normalize(feat_ext),all=True)
132 | if not feat_mat==None:
133 | new_logit = torch.cat([logit, feat_mat], 1)
134 | else:
135 | new_logit = logit
136 |
137 | r_i_pro, _ = new_logit.max(dim=-1)
138 |
139 | r_i, _ = logit.max(dim=-1)
140 |
141 | if added_list!=[]:
142 | for add in added_list:
143 | # if added_list is not empty, set the cosine similarity between the added features and the strong OOD prototypes to 1, to avoid the added features to be appended to the prototype pool again.
144 | r_i[add]=1
145 | min_logit , min_index = r_i.min(dim=0)
146 |
147 |
148 | if (1-min_logit) > ts :
149 | # if the cosine similarity between the feature and the weak OOD prototypes is less than the threshold ts, the feature is a strong OOD sample.
150 | added_list.append(min_index)
151 | if (1-r_i_pro[min_index]) > ts_pro:
152 | # if this strong OOD sample is far away from all the strong OOD prototypes, append it to the prototype pool.
153 | pool.update_pool(F.normalize(feat_ext[min_index].unsqueeze(0)))
154 | else:
155 | # all the features are weak OOD samples, stop the loop.
156 | update=0
157 |
158 |
159 | parser = argparse.ArgumentParser()
160 | parser.add_argument('--dataset', default='ImageNet-C')
161 | parser.add_argument('--strong_OOD', default='noise')
162 | parser.add_argument('--strong_ratio', default=1, type=float)
163 | parser.add_argument('--dataroot', default='./data')
164 | parser.add_argument('--batch_size', default=128, type=int)
165 | parser.add_argument('--workers', default=8, type=int)
166 | parser.add_argument('--lr', default=0.001, type=float)
167 | parser.add_argument('--delta', default=0.1, type=float)
168 | parser.add_argument('--ce_scale', default=0, type=float, help='cross entropy loss scale')
169 | parser.add_argument('--outf', help='folder to output log')
170 | parser.add_argument('--level', default=5, type=int)
171 | parser.add_argument('--N_m', default=512, type=int, help='queue length')
172 | parser.add_argument('--corruption', default='snow')
173 | parser.add_argument('--offline', default='./results/offline/', help='directory of pretrained model')
174 | parser.add_argument('--da_scale', default=1, type=float, help='distribution alignment loss scale')
175 | parser.add_argument('--model', default='resnet50', help='resnet50')
176 | parser.add_argument('--seed', default=0, type=int)
177 | parser.add_argument('--max_prototypes', default=100, type=int)
178 | parser.add_argument('--save', action='store_true', default=False, help='save the model final checkpoint')
179 |
180 |
181 | # ----------- Args and Dataloader ------------
182 | args = parser.parse_args()
183 |
184 | print(args)
185 | print('\n')
186 |
187 | my_makedir(args.offline)
188 |
189 |
190 | net, ext, classifier = build_model()
191 |
192 |
193 | train_transform, val_transform, val_corrupt_transform = prepare_transforms()
194 |
195 | source_dataset = ImageNet_(args.dataroot, 'val', transform=val_transform, is_carry_index=True)
196 |
197 | if args.dataset == 'ImageNet-C':
198 | target_dataset_test = prepare_ood_test_data(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform)
199 | class_num = 1000
200 |
201 | elif args.dataset == 'ImageNet-R':
202 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids]
203 | target_dataset_test = prepare_ood_test_data_r(args.dataroot, args.corruption, transform=val_corrupt_transform, is_carry_index=True, OOD=args.strong_OOD,OOD_transform=val_transform)
204 | class_num = 200
205 | else:
206 | raise NotImplementedError
207 |
208 | source_dataloader = create_dataloader(source_dataset, args, True, False)
209 | target_dataloader_test = create_dataloader(target_dataset_test, args, True, False)
210 |
211 | pool = Prototype_Pool(args.delta,class_num=class_num,max = args.max_prototypes).cuda()
212 |
213 |
214 | # ----------- Offline Feature Summarization ------------
215 | if args.dataset == 'ImageNet-C':
216 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline(args, source_dataloader, ext, classifier)
217 | weak_prototype = F.normalize(ext_mean_categories.clone()).cuda()
218 | else:
219 | ext_mean, ext_cov, ext_mean_categories, ext_cov_categories = offline_r(args, source_dataloader, ext, classifier)
220 | weak_prototype = F.normalize(ext_mean_categories[indices_in_1k].clone()).cuda()
221 |
222 |
223 |
224 | sample_predict_ema_logit = torch.zeros(len(target_dataset_test), class_num, dtype=torch.float)
225 | sample_alpha = torch.ones(len(target_dataset_test), dtype=torch.float)
226 |
227 | ema_alpha = 0.9
228 | ema_ext_mu = ext_mean_categories.clone()
229 | ema_ext_cov = ext_cov_categories.clone()
230 | ema_ext_total_mu = torch.zeros(2048).cuda()
231 | ema_ext_total_cov = torch.zeros(2048, 2048).cuda()
232 |
233 | class_ema_length = 64
234 | ema_n = torch.ones(class_num).cuda() * class_ema_length
235 | ema_total_n = 0.
236 |
237 | loss_scale = 0.05
238 | ce_scale = args.ce_scale
239 |
240 | args.ts_pro = 0.0
241 | bias = ext_cov.max().item() / 30.
242 | template_ext_cov = torch.eye(2048).cuda() * bias
243 |
244 | optimizer = optim.SGD(ext.parameters(), lr=args.lr, momentum=0.9)
245 |
246 | torch.manual_seed(args.seed)
247 | random.seed(args.seed)
248 | np.random.seed(args.seed)
249 | torch.cuda.manual_seed(args.seed)
250 | torch.cuda.manual_seed_all(args.seed)
251 |
252 | # ----------- Open-World Test-time Training ------------
253 |
254 | correct = []
255 | unseen_correct= []
256 | all_correct=[]
257 | cumulative_error = []
258 | num_open = 0
259 | predicted_list=[]
260 | label_list=[]
261 |
262 | os_training_queue = []
263 | os_inference_queue = []
264 | queue_length = args.N_m
265 | ce_scale = args.ce_scale
266 |
267 | ema_total_n = 0.
268 |
269 | print('\n-----Test-Time Training with OURS-----')
270 | for te_idx, (te_inputs, te_labels) in enumerate(target_dataloader_test):
271 | classifier.eval()
272 | ext.eval()
273 |
274 | optimizer.zero_grad()
275 | loss = torch.tensor(0.).cuda()
276 |
277 | if isinstance(te_inputs,list):
278 | inputs = te_inputs[0].cuda()
279 | else:
280 | inputs = te_inputs.cuda()
281 |
282 | # features extracted by backbone
283 | feat_ext = ext(inputs)
284 |
285 | # logits of the input images, used to compute the cosine similarity between the features and the weak OOD prototypes.
286 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t()) / args.delta
287 |
288 |
289 | # compute the cosine similarity between the features and the strong OOD prototypes.
290 | feat_mat = pool(F.normalize(feat_ext))
291 | if not feat_mat==None:
292 | new_logit = torch.cat([logit, feat_mat], 1)
293 | else:
294 | new_logit = logit
295 |
296 | pro, predicted = new_logit[:,:class_num].max(dim=-1)
297 |
298 | # compute the ood score of the input images.
299 | ood_score = 1-pro*args.delta
300 | os_training_queue.extend(ood_score.detach().cpu().tolist())
301 | os_training_queue = os_training_queue[-queue_length:]
302 |
303 |
304 | threshold_range = np.arange(0,1,0.01)
305 | criterias = [compute_os_variance(np.array(os_training_queue), th) for th in threshold_range]
306 |
307 | # best threshold is the one minimizing the variance of the two classes
308 | best_threshold = threshold_range[np.argmin(criterias)]
309 | args.ts = best_threshold
310 | seen_mask = (ood_score < args.ts)
311 | unseen_mask = (ood_score >= args.ts)
312 | r_i, pseudo_labels = new_logit.max(dim=-1)
313 |
314 | if unseen_mask.sum().item()!=0:
315 | #compute ts_pro to append new strong OOD prototypes to the prototype pool.
316 |
317 | min_logit , min_index = r_i.min(dim=0)
318 |
319 | in_score = 1-r_i*args.delta
320 | threshold_range = np.arange(0,1,0.01)
321 | criterias = [compute_os_variance(in_score[unseen_mask].detach().cpu().numpy(), th) for th in threshold_range]
322 |
323 | best_threshold = threshold_range[np.argmin(criterias)]
324 | args.ts_pro = best_threshold
325 |
326 | # append new strong OOD prototypes to the prototype pool.
327 | append_prototypes(pool, feat_ext, logit.detach()*args.delta, args.ts, args.ts_pro)
328 |
329 | len_memory = len(new_logit[0])
330 |
331 |
332 | if len_memory!=class_num:
333 |
334 | if seen_mask.sum().item()!=0:
335 | pseudo_labels[seen_mask] = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1]
336 | if unseen_mask.sum().item()!=0:
337 | pseudo_labels[unseen_mask] = class_num
338 | else:
339 | pseudo_labels = new_logit[seen_mask,:class_num].softmax(dim=-1).max(dim=-1)[1]
340 |
341 |
342 | # ------distribuution alignment------
343 | if seen_mask.sum().item()!=0:
344 | ext.train()
345 | feat_global = ext(inputs[seen_mask])
346 | # Global Gaussian
347 | b = feat_global.shape[0]
348 | ema_total_n += b
349 | alpha = 1. / 1280 if ema_total_n > 1280 else 1. / ema_total_n
350 | delta_pre = (feat_global - ema_ext_total_mu.cuda())
351 | delta = alpha * delta_pre.sum(dim=0)
352 | tmp_mu = ema_ext_total_mu.cuda() + delta
353 | tmp_cov = ema_ext_total_cov.cuda() + alpha * (delta_pre.t() @ delta_pre - b * ema_ext_total_cov.cuda()) - delta[:, None] @ delta[None, :]
354 | with torch.no_grad():
355 | ema_ext_total_mu = tmp_mu.detach().cpu()
356 | ema_ext_total_cov = tmp_cov.detach().cpu()
357 |
358 | source_domain = torch.distributions.MultivariateNormal(ext_mean, ext_cov + template_ext_cov)
359 | target_domain = torch.distributions.MultivariateNormal(tmp_mu, tmp_cov + template_ext_cov)
360 | global_loss=(torch.distributions.kl_divergence(source_domain, target_domain) + torch.distributions.kl_divergence(target_domain, source_domain)) * loss_scale
361 |
362 | loss += args.da_scale*global_loss
363 |
364 |
365 | # we only use 50% of samples with ood score far from τ∗ to perform prototype clustering for each batch
366 | if len_memory!=class_num and seen_mask.sum().item()!=0 and unseen_mask.sum().item()!=0:
367 | a, idx1 = torch.sort((ood_score[seen_mask]), descending=True)
368 | filter_down = a[-int(seen_mask.sum().item()*(1/2))]
369 | a, idx1 = torch.sort((ood_score[unseen_mask]), descending=True)
370 | filter_up= a[int(unseen_mask.sum().item()*(1/2))]
371 | for j in range(len(pseudo_labels)):
372 |
373 | if ood_score[j] >=filter_down and seen_mask[j]:
374 | seen_mask[j]=False
375 | if ood_score[j] <=filter_up and unseen_mask[j]:
376 | unseen_mask[j]=False
377 |
378 |
379 | if len_memory!=class_num:
380 | entropy_seen = nn.CrossEntropyLoss()(new_logit[seen_mask,:class_num],pseudo_labels[seen_mask])
381 | entropy_unseen= nn.CrossEntropyLoss()(new_logit[unseen_mask],pseudo_labels[unseen_mask])
382 | loss += ce_scale*(entropy_seen+ entropy_unseen)/2
383 |
384 | try:
385 | loss.backward()
386 | optimizer.step()
387 | optimizer.zero_grad()
388 | except:
389 | print('can not backward')
390 | torch.cuda.empty_cache()
391 |
392 |
393 |
394 | ####-------------------------- Test ----------------------------####
395 |
396 | with torch.no_grad():
397 |
398 | net.eval()
399 | feat_ext = ext(inputs) #b,2048
400 | logit = torch.mm(F.normalize(feat_ext), weak_prototype.t())/args.delta
401 |
402 |
403 | softmax_logit = logit.softmax(dim=-1)
404 | pro, predicted = softmax_logit.max(dim=-1)
405 |
406 | ood_score, max_index = logit.max(1)
407 | ood_score = 1-ood_score*args.delta
408 | os_inference_queue.extend(ood_score.detach().cpu().tolist())
409 | os_inference_queue = os_inference_queue[-queue_length:]
410 |
411 | threshold_range = np.arange(0,1,0.01)
412 | criterias = [compute_os_variance(np.array(os_inference_queue), th) for th in threshold_range]
413 | best_threshold = threshold_range[np.argmin(criterias)]
414 | unseen_mask = (ood_score > best_threshold)
415 | args.ts = best_threshold
416 | predicted[unseen_mask] = class_num
417 |
418 | one = torch.ones_like(te_labels)*class_num
419 | false = torch.ones_like(te_labels)*-1
420 | predicted = torch.where(predicted>class_num-1, one.cuda(), predicted)
421 | all_labels = torch.where(te_labels>class_num-1, one, te_labels)
422 | seen_labels = torch.where(te_labels>class_num-1, false, te_labels)
423 | unseen_labels = torch.where(te_labels>class_num-1, one, false)
424 | correct.append(predicted.cpu().eq(seen_labels))
425 | unseen_correct.append(predicted.cpu().eq(unseen_labels))
426 | all_correct.append(predicted.cpu().eq(all_labels))
427 | num_open += torch.gt(te_labels, class_num-1).sum()
428 |
429 | predicted_list.append(predicted.long().cpu())
430 | label_list.append(all_labels.long().cpu())
431 |
432 |
433 | seen_acc = round(torch.cat(correct).numpy().sum() / (len(torch.cat(correct).numpy())-num_open.numpy()),4)
434 | unseen_acc = round(torch.cat(unseen_correct).numpy().sum() / num_open.numpy(),4)
435 | h_score = round((2*seen_acc*unseen_acc) / (seen_acc + unseen_acc),4)
436 | print('Batch:(', te_idx,'/',len(target_dataloader_test), ')\tloss:',"%.2f" % loss.item(),\
437 | '\t Cumulative Results: ACC_S:', seen_acc,\
438 | '\tACC_N:', unseen_acc,\
439 | '\tACC_H:',h_score\
440 | )
441 |
442 |
443 | print('\nTest time training result:',' ACC_S:', seen_acc,\
444 | '\tACC_N:', unseen_acc,\
445 | '\tACC_H:',h_score,'\n\n\n\n'\
446 | )
447 |
448 |
449 | if args.outf != None:
450 | my_makedir(args.outf)
451 | with open (args.outf+'/results.txt','a') as f:
452 | f.write(str(args)+'\n')
453 | f.write(
454 | 'ACC_S:'+ str(seen_acc)+\
455 | '\tACC_N:'+ str(unseen_acc)+\
456 | '\tACC_H:'+str(h_score)+'\n\n\n\n'\
457 | )
458 | if args.save:
459 | torch.save(net.state_dict(), os.path.join(args.outf, 'final.pth'))
--------------------------------------------------------------------------------
/imagenet/utils/offline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import copy
3 | import statistics
4 | import os
5 |
6 | def covariance(features):
7 | assert len(features.size()) == 2, "TODO: multi-dimensional feature map covariance"
8 | n = features.shape[0]
9 | tmp = torch.ones((1, n), device=features.device) @ features
10 | cov = (features.t() @ features - (tmp.t() @ tmp) / n) / n
11 | return cov
12 |
13 | def coral(cs, ct):
14 | d = cs.shape[0]
15 | loss = (cs - ct).pow(2).sum() / (4. * d ** 2)
16 | return loss
17 |
18 |
19 | def linear_mmd(ms, mt):
20 | loss = (ms - mt).pow(2).mean()
21 | return loss
22 |
23 | def offline(args, trloader, ext, classifier, num_classes=1000):
24 | if os.path.exists(args.offline+'/offline.pth'):
25 | data = torch.load(args.offline+'/offline.pth')
26 | return data
27 |
28 | ext.eval()
29 |
30 | feat_ext_mean = torch.zeros(2048).cuda()
31 | feat_ext_variance = torch.zeros(2048, 2048).cuda()
32 |
33 | feat_ext_mean_categories = torch.zeros(num_classes, 2048).cuda() # K, D
34 | feat_ext_variance_categories = torch.zeros(num_classes, 2048).cuda()
35 |
36 | ema_n = torch.zeros(num_classes).cuda()
37 | ema_total_n = 0
38 |
39 | with torch.no_grad():
40 | for batch_idx, (inputs, labels) in enumerate(trloader):
41 | feat = ext(inputs[0].cuda()) # N, D
42 | b, d = feat.shape
43 | labels = classifier(feat).argmax(dim=-1)
44 |
45 | feat_ext_categories = torch.zeros(num_classes, b, d).cuda()
46 | feat_ext_categories.scatter_add_(dim=0, index=labels[None, :, None].expand(-1, -1, d), src=feat[None, :, :])
47 |
48 | num_categories = torch.zeros(num_classes, b, dtype=torch.int).cuda()
49 | num_categories.scatter_add_(dim=0, index=labels[None, :], src=torch.ones_like(labels[None, :], dtype=torch.int))
50 | ema_n += num_categories.sum(dim=1)
51 | alpha_categories = 1 / (ema_n + 1e-10) # K
52 | delta_pre = (feat_ext_categories - feat_ext_mean_categories[:, None, :]) * num_categories[:, :, None] # K, N, D
53 | delta = alpha_categories[:, None] * delta_pre.sum(dim=1) # K, D
54 | feat_ext_mean_categories += delta
55 | feat_ext_variance_categories += alpha_categories[:, None] * ((delta_pre ** 2).sum(dim=1) - num_categories.sum(dim=1)[:, None] * feat_ext_variance_categories) \
56 | - delta ** 2
57 |
58 | ema_total_n += b
59 | alpha = 1 / (ema_total_n + 1e-10)
60 | delta_pre = feat - feat_ext_mean[None, :] # b, d
61 | delta = alpha * (delta_pre).sum(dim=0)
62 | feat_ext_mean += delta
63 | feat_ext_variance += alpha * (delta_pre.t() @ delta_pre - b * feat_ext_variance) - delta[:, None] @ delta[None, :]
64 | print('offline process rate: %.2f%%\r' % ((batch_idx + 1) / len(trloader) * 100.), end='')
65 |
66 |
67 | torch.save((feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories), args.offline+'/offline.pth')
68 | return feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories
69 |
70 |
71 | def offline_r(args, trloader, ext, classifier, num_classes=1000):
72 | if os.path.exists(args.offline+'/offline_r.pth'):
73 | data = torch.load(args.offline+'/offline_r.pth')
74 | return data
75 |
76 | ext.eval()
77 | all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141']
78 |
79 | imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'}
80 |
81 | indices_in_1k = [wnid in imagenet_r_wnids for wnid in all_wnids]
82 |
83 | feat_ext_mean = torch.zeros(2048).cuda()
84 | feat_ext_variance = torch.zeros(2048, 2048).cuda()
85 |
86 | feat_ext_mean_categories = torch.zeros(num_classes, 2048).cuda() # K, D
87 | feat_ext_variance_categories = torch.zeros(num_classes, 2048).cuda()
88 |
89 | ema_n = torch.zeros(num_classes).cuda()
90 | ema_total_n = 0
91 |
92 | with torch.no_grad():
93 | for batch_idx, (inputs, labels) in enumerate(trloader):
94 | l=[]
95 | t=[]
96 | for i in range(len(labels)):
97 | if indices_in_1k[labels[i]]==True:
98 | t.append(labels[i])
99 | l.append(inputs[0][i].unsqueeze(0))
100 | inputs = torch.cat(l,dim=0)
101 | feat = ext(inputs.cuda()) # N, D
102 | b, d = feat.shape
103 | labels = classifier(feat).argmax(dim=-1)
104 |
105 | feat_ext_categories = torch.zeros(num_classes, b, d).cuda()
106 | feat_ext_categories.scatter_add_(dim=0, index=labels[None, :, None].expand(-1, -1, d), src=feat[None, :, :])
107 |
108 | num_categories = torch.zeros(num_classes, b, dtype=torch.int).cuda()
109 | num_categories.scatter_add_(dim=0, index=labels[None, :], src=torch.ones_like(labels[None, :], dtype=torch.int))
110 | ema_n += num_categories.sum(dim=1)
111 | alpha_categories = 1 / (ema_n + 1e-10) # K
112 | delta_pre = (feat_ext_categories - feat_ext_mean_categories[:, None, :]) * num_categories[:, :, None] # K, N, D
113 | delta = alpha_categories[:, None] * delta_pre.sum(dim=1) # K, D
114 | feat_ext_mean_categories += delta
115 | feat_ext_variance_categories += alpha_categories[:, None] * ((delta_pre ** 2).sum(dim=1) - num_categories.sum(dim=1)[:, None] * feat_ext_variance_categories) \
116 | - delta ** 2
117 |
118 | ema_total_n += b
119 | alpha = 1 / (ema_total_n + 1e-10)
120 | delta_pre = feat - feat_ext_mean[None, :] # b, d
121 | delta = alpha * (delta_pre).sum(dim=0)
122 | feat_ext_mean += delta
123 | feat_ext_variance += alpha * (delta_pre.t() @ delta_pre - b * feat_ext_variance) - delta[:, None] @ delta[None, :]
124 | print('offline process rate: %.2f%%\r' % ((batch_idx + 1) / len(trloader) * 100.), end='')
125 |
126 |
127 | torch.save((feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories), args.offline+'/offline_r.pth')
128 | return feat_ext_mean, feat_ext_variance, feat_ext_mean_categories, feat_ext_variance_categories
129 |
--------------------------------------------------------------------------------
/cifar/utils/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import random
5 | import torchvision
6 | import numpy as np
7 | from PIL import Image
8 | import torch.utils.data
9 | import torchvision.transforms as transforms
10 |
11 |
12 | class CIFAR10(torchvision.datasets.CIFAR10):
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | return
16 |
17 | def __getitem__(self, index: int):
18 | image, target = super().__getitem__(index)
19 | if type(image) == list:
20 | image.append(index)
21 | else:
22 | image = [image, index]
23 | return image, target
24 |
25 | class CIFAR100(torchvision.datasets.CIFAR100):
26 | def __init__(self, *args, **kwargs):
27 | super().__init__(*args, **kwargs)
28 | return
29 |
30 | def __getitem__(self, index: int):
31 | image, target = super().__getitem__(index)
32 | if type(image) == list:
33 | image.append(index)
34 | else:
35 | image = [image, index]
36 | return image, target
37 |
38 |
39 | class CIFAR100_openset(torchvision.datasets.CIFAR100):
40 | def __init__(self,ratio=1, *args, **kwargs):
41 | super().__init__(*args, **kwargs)
42 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)]
43 | return
44 |
45 | def __getitem__(self, index: int):
46 | image, target = super().__getitem__(index)
47 | target = target + 1000
48 | if type(image) == list:
49 | image.append(index)
50 | else:
51 | image = [image, index]
52 | return image, target
53 |
54 | class CIFAR10_openset(torchvision.datasets.CIFAR10):
55 | def __init__(self,ratio=1, *args, **kwargs):
56 | super().__init__(*args, **kwargs)
57 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)]
58 | return
59 |
60 | def __getitem__(self, index: int):
61 | image, target = super().__getitem__(index)
62 | target = target + 1000
63 | if type(image) == list:
64 | image.append(index)
65 | else:
66 | image = [image, index]
67 | return image, target
68 |
69 | class noise_dataset(torch.utils.data.Dataset):
70 | def __init__(self, transform,ratio=1):
71 | self.number = int(10000*ratio)
72 | self.transform = transform
73 |
74 | def __getitem__(self, index:int):
75 | image = torch.randn(3,32,32)
76 | target = 1000
77 | if type(image) == list:
78 | image.append(index)
79 | else:
80 | image = [image, index]
81 |
82 | return image, target
83 |
84 | def __len__(self):
85 |
86 | return self.number
87 |
88 | class MNIST_openset(torchvision.datasets.MNIST):
89 | def __init__(self, *args, ratio = 1 , **kwargs):
90 | super().__init__(*args, **kwargs)
91 | self.data, self.targets = self.data[:int(10000*ratio)], self.targets[:int(10000*ratio)]
92 | return
93 |
94 | def __getitem__(self, index: int):
95 | image, target = super().__getitem__(index)
96 | target = target + 1000
97 | if type(image) == list:
98 | image.append(index)
99 | else:
100 | image = [image, index]
101 | return image, target
102 |
103 | class SVHN_openset(torchvision.datasets.SVHN):
104 | def __init__(self, *args, ratio = 1 , **kwargs):
105 | super().__init__(*args, **kwargs)
106 | self.data, self.labels = self.data[:int(10000*ratio)], self.labels[:int(10000*ratio)]
107 | return
108 |
109 | def __getitem__(self, index: int):
110 | image, target = super().__getitem__(index)
111 | target = target + 1000
112 | if type(image) == list:
113 | image.append(index)
114 | else:
115 | image = [image, index]
116 | return image, target
117 |
118 |
119 | class TinyImageNet_OOD_nonoverlap(torch.utils.data.Dataset):
120 | def __init__(self, root, train=True, transform=None,list=True,ratio=1):
121 | self.Train = train
122 | self.list=list
123 | self.root_dir = root
124 | self.transform = transform
125 | self.train_dir = os.path.join(self.root_dir, "train")
126 | self.val_dir = os.path.join(self.root_dir, "val")
127 | self.ratio = ratio
128 |
129 | self.class_list = ['n03544143', 'n03255030', 'n04532106', 'n02669723', 'n02321529', 'n02423022', 'n03854065', 'n02509815', 'n04133789', 'n03970156', 'n01882714', 'n04023962', 'n01768244', 'n04596742', 'n03447447', 'n03617480', 'n07720875', 'n02125311', 'n02793495', 'n04532670']
130 |
131 | if (self.Train):
132 | self._create_class_idx_dict_train()
133 | else:
134 | self._create_class_idx_dict_val()
135 |
136 | self._make_dataset(self.Train)
137 |
138 | words_file = os.path.join(self.root_dir, "words.txt")
139 | wnids_file = os.path.join(self.root_dir, "wnids.txt")
140 |
141 | self.set_nids = set()
142 |
143 | with open(wnids_file, 'r') as fo:
144 | data = fo.readlines()
145 | for entry in data:
146 | if entry.strip("\n") in self.class_list:
147 | self.set_nids.add(entry.strip("\n"))
148 |
149 | self.class_to_label = {}
150 | with open(words_file, 'r') as fo:
151 | data = fo.readlines()
152 | for entry in data:
153 | words = entry.split("\t")
154 | if words[0] in self.set_nids:
155 | self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0]
156 |
157 | def _create_class_idx_dict_train(self):
158 | if sys.version_info >= (3, 5):
159 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()]
160 | else:
161 | classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir, d))]
162 | classes = sorted(classes)
163 | num_images = 0
164 | temp=[]
165 | for i in range(20):
166 | temp.append(0)
167 | for root, dirs, files in os.walk(self.train_dir):
168 | for f in files:
169 | if f.endswith(".JPEG") and f.split("_")[0] in self.class_list:
170 | for i in range(len(self.class_list)):
171 | if f.split("_")[0] == self.class_list[i]:
172 |
173 |
174 | if temp[i] < 500:
175 | temp[i]+=1
176 | num_images = num_images + 1
177 | break
178 | self.len_dataset = num_images;
179 |
180 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
181 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
182 |
183 | def _create_class_idx_dict_val(self):
184 | val_image_dir = os.path.join(self.val_dir, "images")
185 | if sys.version_info >= (3, 5):
186 | images = [d.name for d in os.scandir(val_image_dir) if d.is_file()]
187 | else:
188 | images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir, d))]
189 | val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt")
190 | self.val_img_to_class = {}
191 | set_of_classes = set()
192 | with open(val_annotations_file, 'r') as fo:
193 | entry = fo.readlines()
194 | for data in entry:
195 | words = data.split("\t")
196 | if words[1] in self.class_list:
197 | self.val_img_to_class[words[0]] = words[1]
198 | set_of_classes.add(words[1])
199 |
200 | self.len_dataset = len(list(self.val_img_to_class.keys()))
201 | classes = sorted(list(set_of_classes))
202 | self.class_to_tgt_idx = {classes[i]: i for i in range(len(classes))}
203 | self.tgt_idx_to_class = {i: classes[i] for i in range(len(classes))}
204 |
205 | def _make_dataset(self, Train=True):
206 | self.images = []
207 | if Train:
208 | img_root_dir = self.train_dir
209 | list_of_dirs = [target for target in self.class_to_tgt_idx.keys()]
210 | else:
211 | img_root_dir = self.val_dir
212 | list_of_dirs = ["images"]
213 | temp=[]
214 | for i in range(20):
215 | temp.append(0)
216 | for tgt in list_of_dirs:
217 | dirs = os.path.join(img_root_dir, tgt)
218 | if not os.path.isdir(dirs):
219 | continue
220 |
221 | for root, _, files in sorted(os.walk(dirs)):
222 | for fname in sorted(files):
223 | if (fname.endswith(".JPEG"))and fname.split("_")[0] in self.class_list:
224 | path = os.path.join(root, fname)
225 | if Train:
226 | item = (path, self.class_to_tgt_idx[tgt])
227 | else:
228 | item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]])
229 | for i in range(len(self.class_list)):
230 | if fname.split("_")[0] == self.class_list[i]:
231 | temp[i]+=1
232 |
233 | if temp[i] <= 500:
234 | self.images.append(item)
235 | print('len',len(self.images))
236 |
237 | def return_label(self, idx):
238 | return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx]
239 |
240 | def __len__(self):
241 | return int(self.len_dataset*self.ratio)
242 |
243 | def __getitem__(self, idx:int):
244 | img_path, tgt = self.images[idx]
245 | tgt+=1000
246 | with open(img_path, 'rb') as f:
247 | sample = Image.open(img_path)
248 | sample = sample.convert('RGB')
249 | if self.transform is not None:
250 | sample = self.transform(sample)
251 | index = idx
252 | if self.list:
253 | if type(sample) == list:
254 | sample.append(index)
255 | else:
256 | sample = [sample, index]
257 |
258 | return sample, tgt
259 |
260 |
261 | def prepare_transforms(dataset):
262 |
263 | if dataset == 'cifar10':
264 | mean = (0.4914, 0.4822, 0.4465)
265 | std = (0.2023, 0.1994, 0.2010)
266 | elif dataset == 'cifar10+100' or dataset == 'cifar10OOD' :
267 | mean = (0.4914, 0.4822, 0.4465)
268 | std = (0.2023, 0.1994, 0.2010)
269 | elif dataset == 'cifar100' or dataset == 'cifar100OOD':
270 | mean = (0.5071, 0.4867, 0.4408)
271 | std = (0.2675, 0.2565, 0.2761)
272 | else:
273 | raise NotImplementedError
274 |
275 | normalize = transforms.Normalize(mean=mean, std=std)
276 |
277 | te_transforms = transforms.Compose([transforms.ToTensor(), normalize])
278 |
279 | tr_transforms = transforms.Compose([
280 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
281 | transforms.RandomHorizontalFlip(),
282 | transforms.ToTensor(),
283 | normalize])
284 |
285 | simclr_transforms = transforms.Compose([
286 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
287 | transforms.RandomHorizontalFlip(),
288 | transforms.RandomApply([
289 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
290 | ], p=0.8),
291 | transforms.RandomGrayscale(p=0.2),
292 | transforms.ToTensor(),
293 | normalize
294 | ])
295 |
296 | return tr_transforms, te_transforms, simclr_transforms
297 |
298 | class TwoCropTransform:
299 | """Create two crops of the same image"""
300 | def __init__(self, transform, te_transform):
301 | self.transform = transform
302 | self.te_transform = te_transform
303 |
304 | def __call__(self, x):
305 | return [self.transform(x), self.transform(x), self.te_transform(x)]
306 |
307 | # -------------------------
308 |
309 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
310 | 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
311 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
312 |
313 | def seed_worker(worker_id):
314 | worker_seed = torch.initial_seed() % 2**32
315 | np.random.seed(worker_seed)
316 | random.seed(worker_seed)
317 |
318 |
319 | def prepare_test_data(args, ttt=False, num_sample=None, align=False):
320 |
321 | tr_transforms, te_transforms, simclr_transforms = prepare_transforms(args.dataset)
322 |
323 | if args.dataset == 'cifar10OOD':
324 |
325 | tesize = 10000
326 | if args.corruption in common_corruptions:
327 |
328 | print('Test on %s level %d' %(args.corruption, args.level))
329 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption))
330 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize]
331 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption))
332 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize]
333 | teset_10 = CIFAR10(root=args.dataroot,
334 | train=False, download=True, transform=te_transforms)
335 | teset_10.data = teset_raw_10
336 |
337 | if args.strong_OOD == 'MNIST':
338 | te_rize = transforms.Compose([transforms.Resize(size=(32, 32)), transforms.Grayscale(3), te_transforms ])
339 | noise = MNIST_openset(root=args.dataroot,
340 | train=False, download=True, transform=te_rize, ratio=args.strong_ratio)
341 |
342 | teset = torch.utils.data.ConcatDataset([teset_10,noise])
343 |
344 | elif args.strong_OOD == 'noise':
345 | noise = noise_dataset(te_transforms, args.strong_ratio)
346 |
347 | teset = torch.utils.data.ConcatDataset([teset_10,noise])
348 |
349 | elif args.strong_OOD =='cifar100':
350 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/snow.npy')
351 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize]
352 | teset_100 = CIFAR100_openset(root=args.dataroot,
353 | train=False, download=True, transform=te_transforms, ratio=args.strong_ratio)
354 | teset_100.data = teset_raw_100[:int(10000*args.strong_ratio)]
355 | teset = torch.utils.data.ConcatDataset([teset_10,teset_100])
356 |
357 | elif args.strong_OOD =='SVHN':
358 | te_rize = transforms.Compose([te_transforms ])
359 | noise = SVHN_openset(root=args.dataroot,
360 | split='test', download=True, transform=te_rize, ratio=args.strong_ratio)
361 |
362 | teset = torch.utils.data.ConcatDataset([teset_10,noise])
363 |
364 | elif args.strong_OOD =='Tiny':
365 |
366 | transform_test = transforms.Compose([transforms.Resize(32), te_transforms ])
367 | testset_tiny = TinyImageNet_OOD_nonoverlap(args.dataroot +'/tiny-imagenet-200', transform=transform_test, train=True)
368 | teset = torch.utils.data.ConcatDataset([teset_10,testset_tiny])
369 | print(len(teset_10),len(testset_tiny),len(teset))
370 |
371 | else:
372 | raise
373 |
374 | elif args.dataset == 'cifar100OOD':
375 |
376 | tesize = 10000
377 |
378 | if args.corruption in common_corruptions:
379 | print('Test on %s level %d' %(args.corruption, args.level))
380 | teset_raw_100 = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption))
381 | teset_raw_100 = teset_raw_100[(args.level-1)*tesize: args.level*tesize]
382 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption))
383 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize]
384 | teset_100 = CIFAR100(root=args.dataroot,
385 | train=False, download=True, transform=te_transforms)
386 | teset_100.data = teset_raw_100
387 |
388 | if args.strong_OOD == 'MNIST':
389 | te_rize = transforms.Compose([transforms.Resize(size=(32, 32)), transforms.Grayscale(3), te_transforms ])
390 | noise = MNIST_openset(root=args.dataroot,
391 | train=False, download=True, transform=te_rize, ratio=args.strong_ratio)
392 |
393 | teset = torch.utils.data.ConcatDataset([teset_100,noise])
394 |
395 | elif args.strong_OOD == 'noise':
396 | noise = noise_dataset(te_transforms, args.strong_ratio)
397 |
398 | teset = torch.utils.data.ConcatDataset([teset_100,noise])
399 |
400 | elif args.strong_OOD =='cifar10':
401 | teset_raw_10 = np.load(args.dataroot + '/CIFAR-10-C/snow.npy')
402 | teset_raw_10 = teset_raw_10[(args.level-1)*tesize: args.level*tesize]
403 | teset_10 = CIFAR10_openset(root=args.dataroot,
404 | train=False, download=True, transform=te_transforms, ratio=args.strong_ratio)
405 | teset_10.data = teset_raw_10[:int(10000*args.strong_ratio)]
406 | teset = torch.utils.data.ConcatDataset([teset_100,teset_10])
407 |
408 | elif args.strong_OOD =='SVHN':
409 | te_rize = transforms.Compose([te_transforms ])
410 | noise = SVHN_openset(root=args.dataroot,
411 | split='test', download=True, transform=te_rize, ratio=args.strong_ratio)
412 |
413 | teset = torch.utils.data.ConcatDataset([teset_100,noise])
414 |
415 | elif args.strong_OOD =='Tiny':
416 |
417 | transform_test = transforms.Compose([transforms.Resize(32), te_transforms ])
418 | testset_tiny = TinyImageNet_OOD_nonoverlap(args.dataroot +'/tiny-imagenet-200', transform=transform_test, train=True)
419 | teset = torch.utils.data.ConcatDataset([teset_100,testset_tiny])
420 |
421 | else:
422 | raise
423 |
424 | if not hasattr(args, 'workers') or args.workers < 2:
425 | pin_memory = False
426 | else:
427 | pin_memory = True
428 |
429 | if ttt:
430 | shuffle = True
431 | drop_last = True
432 | else:
433 | shuffle = True
434 | drop_last = False
435 |
436 | try:
437 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.batch_size,
438 | shuffle=shuffle, num_workers=args.workers,
439 | worker_init_fn=seed_worker, pin_memory=pin_memory, drop_last=drop_last)
440 | except:
441 | teloader = None
442 |
443 |
444 | return teset, teloader
445 |
446 | def prepare_train_data(args, num_sample=None):
447 | print('Preparing data...')
448 |
449 | tr_transforms, te_transforms, simclr_transforms = prepare_transforms(args.dataset)
450 |
451 | if args.dataset == 'cifar10' or args.dataset == 'cifar10+100' or args.dataset == 'cifar10OOD':
452 |
453 | if hasattr(args, 'ssl') and args.ssl == 'contrastive':
454 | trset = CIFAR10(root=args.dataroot,
455 | train=False, download=True,
456 | transform=TwoCropTransform(simclr_transforms, te_transforms))
457 | if hasattr(args, 'corruption') and args.corruption in common_corruptions:
458 | print('Contrastive on %s level %d' %(args.corruption, args.level))
459 | tesize = 10000
460 | trset_raw = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' %(args.corruption))
461 | trset_raw = trset_raw[(args.level-1)*tesize: args.level*tesize]
462 | trset.data = trset_raw
463 | else:
464 | print('Contrastive on ciar10 training set')
465 | else:
466 | trset = torchvision.datasets.CIFAR10(root=args.dataroot,
467 | train=True, download=True, transform=tr_transforms)
468 | print('Cifar10 training set')
469 |
470 | elif args.dataset == 'cifar100' or args.dataset == 'cifar100OOD':
471 | if hasattr(args, 'ssl') and args.ssl == 'contrastive':
472 | trset = torchvision.datasets.CIFAR100(root=args.dataroot,
473 | train=True, download=True,
474 | transform=TwoCropTransform(simclr_transforms, te_transforms))
475 | if hasattr(args, 'corruption') and args.corruption in common_corruptions:
476 | print('Contrastive on %s level %d' %(args.corruption, args.level))
477 | tesize = 10000
478 | trset_raw = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' %(args.corruption))
479 | trset_raw = trset_raw[(args.level-1)*tesize: args.level*tesize]
480 | trset.data = trset_raw
481 | else:
482 | print('Contrastive on ciar10 training set')
483 | else:
484 | trset = torchvision.datasets.CIFAR100(root=args.dataroot,
485 | train=True, download=True, transform=tr_transforms)
486 | print('Cifar100 training set')
487 | else:
488 | raise Exception('Dataset not found!')
489 |
490 | if not hasattr(args, 'workers') or args.workers < 2:
491 | pin_memory = False
492 | else:
493 | pin_memory = True
494 |
495 | if num_sample and num_sample < trset.data.shape[0]:
496 | trset.data = trset.data[:num_sample]
497 | print("Truncate the training set to {:d} samples".format(num_sample))
498 |
499 | trloader = torch.utils.data.DataLoader(trset, batch_size=args.batch_size,
500 | shuffle=True, num_workers=args.workers,
501 | worker_init_fn=seed_worker, pin_memory=pin_memory, drop_last=False)
502 | return trset, trloader
503 |
--------------------------------------------------------------------------------