├── NC-TTT.png ├── utils ├── visdatest.py ├── create_model.py ├── utils.py └── prepare_dataset.py ├── README.md ├── models ├── ResNet.py ├── simplenet_cifar.py └── simplenet.py ├── configuration.py ├── joint_training.py └── adapt.py /NC-TTT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GustavoVargasHakim/NCTTT/HEAD/NC-TTT.png -------------------------------------------------------------------------------- /utils/visdatest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from typing import Callable, Optional 5 | 6 | 7 | class VisdaTest(Dataset): 8 | def __init__(self, root: str, transforms: Optional[Callable] = None): 9 | self.root = root 10 | self.transforms = transforms 11 | self.img_list = np.loadtxt(root + 'image_list.txt', dtype=str) 12 | 13 | def __len__(self): 14 | return self.img_list.shape[0] 15 | 16 | def __getitem__(self, idx): 17 | name = self.img_list[idx][0] 18 | label = int(self.img_list[idx][1]) 19 | 20 | img = Image.open(self.root + 'test/' + name) 21 | if self.transforms is not None: 22 | img = self.transforms(img) 23 | 24 | return img, label -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NC-TTT 2 | 3 | Official repository of the CVPR 2024 paper "NC-TTT: A Noise Constrastive Approach for Test-Time Training", by David Osowiechi, Gustavo A. Vargas Hakim, Mehrdad Noori, Milad Cheraghalikhani, Ali Bahri, Moslem Yazdanpanah, Ismail Ben Ayed, and Christian Desrosiers. 4 | The whole article can be found [here](https://openaccess.thecvf.com/content/ICCV2023/html/****html). 5 | This work was greatly inspired by the code in [ClusT3]([(https://github.com/dosowiechi/ClusT3.git)]). 6 | 7 | We propose a novel unsupervised TTT technique based on the discrimination of noisy feature maps. By learning to classify noisy views of projected feature maps, and then adapting the model accordingly on new domains, classification performance can be recovered by an important margin. 8 | 9 | ![Diagram](https://github.com/GustavoVargasHakim/NCTTT/blob/master/NC-TTT.png) 10 | 11 | ## Datasets 12 | 13 | The experiments utilize the CIFAR-10 training split as the source dataset. It can be downloaded from 14 | [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz), or can also be done using torchvision 15 | datasets: `train_data = torchvision.datasets.CIFAR10(root='Your/Path/To/Data', train=True, download=True)`. 16 | The same line of code can be used to load the data if it is already downloaded, just by changing the 17 | argument `download` to `False`. 18 | 19 | At test-time, we use CIFAR-10-C and CIFAR-10-new. The first one can be downloaded from [CIFAR-10-C]( 20 | https://zenodo.org/record/2535967#.YzHFMXbMJPY). For the second one, please download the files 21 | `cifar10.1_v6_data.npy` and `cifar10.1_v6_labels.npy` from [CIFAR-10-new](https://github.com/modestyachts/CIFAR-10.1/tree/master/datasets). 22 | All the data should be placed in a common folder from which they can be loaded, e.g., `/datasets/`. 23 | 24 | The training works the same way on CIFAR-100 dataset and it can be downloaded from [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz). 25 | At test-time, we use CIFAR-100-C which can be downloaded from [CIFAR-100-C](https://zenodo.org/record/3555552/files/CIFAR-100-C.tar?download=1). 26 | 27 | ## Citation 28 | 29 | If you found this repository, or its related paper useful for your research, you can cite this work as: 30 | 31 | ``` 32 | @inproceedings{NCTTT2024, 33 | title={NC-TTT: A Noise Constrastive Approach for Test-Time Training}, 34 | author={David Osowiechi and Gustavo A. Vargas Hakim and Mehrdad Noori and Milad Cheraghalikhani and Ali Bahri and Moslem Yazdanpanah and Ismail Ben Ayed and Christian Desrosiers}, 35 | booktitle={***}, 36 | pages={}, 37 | month={June}, 38 | year={2024} 39 | } 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /utils/create_model.py: -------------------------------------------------------------------------------- 1 | from models import ResNet, simplenet, simplenet_cifar 2 | import torch.nn as nn 3 | import torch 4 | import types 5 | import timm 6 | import os 7 | from utils import utils 8 | 9 | def model_sizes(args, layer): 10 | if args.dataset in ['imagenet', 'visdaC']: 11 | if layer == 0: 12 | channels, resolution = 64, 112 13 | if layer == 1: 14 | channels, resolution = 256, 56 15 | if layer == 2: 16 | channels, resolution = 512, 28 17 | if layer == 3: 18 | channels, resolution = 1024, 14 19 | if layer == 4: 20 | channels, resolution = 2048, 7 21 | 22 | elif args.dataset == 'cifar10' or args.dataset == 'cifar100': 23 | if layer == 0: 24 | channels, resolution = 64, 32 25 | if layer == 1: 26 | channels, resolution = 256, 32 27 | if layer == 2: 28 | channels, resolution = 512, 16 29 | if layer == 3: 30 | channels, resolution = 1024, 8 31 | if layer == 4: 32 | channels, resolution = 2048, 4 33 | 34 | return channels, resolution 35 | 36 | 37 | def get_part(model,layer): 38 | if layer == 1: 39 | extractor = [model.net.conv1, model.net.bn1, nn.ReLU(inplace=True), model.net.layer1] 40 | elif layer == 2: 41 | extractor = [model.net.conv1, model.net.bn1, nn.ReLU(inplace=True), model.net.layer1, model.net.layer2] 42 | elif layer == 3: 43 | extractor = [model.net.conv1, model.net.bn1, nn.ReLU(inplace=True), model.net.layer1, model.net.layer2, model.net.layer3] 44 | elif layer == 4: 45 | extractor = [model.net.conv1, model.net.bn1, nn.ReLU(inplace=True), model.net.layer1, model.net.layer2, model.net.layer3, model.net.layer4] 46 | return nn.Sequential(*extractor) 47 | 48 | 49 | class ExtractorHead(nn.Module): 50 | def __init__(self, net, head): 51 | super(ExtractorHead, self).__init__() 52 | self.net = net 53 | self.head = head 54 | 55 | def forward(self, x, features=False, train=True, use_entropy=False, **kwargs): 56 | entropy = utils.Entropy() 57 | loss_ent = 0.0 58 | out, feature = self.net(x, feature=True) 59 | ssh_loss = self.head(feature, train=train) 60 | if use_entropy: 61 | loss_ent = entropy(out) 62 | if features: 63 | return out, feature, ssh_loss + loss_ent 64 | else: 65 | return out, ssh_loss + loss_ent 66 | 67 | def visda_forward(self, x, feature=True): 68 | features = [] 69 | x = self.conv1(x) 70 | x = self.bn1(x) 71 | x = self.act1(x) 72 | x = self.maxpool(x) 73 | features.append(x) 74 | 75 | x = self.layer1(x) 76 | features.append(x) 77 | x = self.layer2(x) 78 | features.append(x) 79 | x = self.layer3(x) 80 | features.append(x) 81 | x = self.layer4(x) 82 | features.append(x) 83 | x = self.global_pool(x) 84 | x = x.view(x.size(0), -1) 85 | x = self.fc(x) 86 | if feature: 87 | return x, features 88 | else: 89 | return x 90 | 91 | # This is the modified forward_features method from the timm model (special case for CIFAR-10/100) 92 | # Ignore the error, as the function checkpoint_seq is out of context, but is correct inside timm model 93 | def create_model(args, device='cpu', weights=None): 94 | func_type = types.MethodType 95 | # Creating model based on dataset 96 | if args.dataset == 'visdaC': 97 | num_classes = 12 98 | net = timm.create_model('resnet50', features_only=True, pretrained=False) 99 | net.global_pool = nn.AdaptiveAvgPool2d((1, 1)) 100 | net.fc = nn.Linear(2048, num_classes) 101 | net.forward = func_type(visda_forward, net) 102 | weights_path = os.path.join(args.root, 'weights', 'resnet50.pth') 103 | pretraining = torch.load(weights_path) 104 | del pretraining['fc.weight'] 105 | del pretraining['fc.bias'] 106 | net.load_state_dict(pretraining, strict=False) 107 | ssh = simplenet.SimpleNet(args.layers, args.embed_size, std1=args.std, std2=args.std2, dataset=args.dataset, 108 | device=device).to(device) 109 | elif args.dataset in ['cifar10','cifar100']: 110 | num_classes = 10 if args.dataset == 'cifar10' else 100 111 | net = ResNet.resnet50(num_classes) 112 | ssh = simplenet_cifar.SimpleNet(args.layers, args.embed_size, std1=args.std, std2=args.std2, hidden=args.hidden, 113 | dataset=args.dataset, device=device).to(device) 114 | model = ExtractorHead(net, ssh) 115 | 116 | # Loading weights 117 | if weights is not None: 118 | model.load_state_dict(weights, strict=False) 119 | 120 | return model 121 | -------------------------------------------------------------------------------- /models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def get_part(model,layer): 7 | if layer ==1: 8 | extractor = [model.conv1, model.bn1, nn.ReLU(inplace=True), model.layer1] 9 | elif layer ==2: 10 | extractor = [model.conv1, model.bn1, nn.ReLU(inplace=True), model.layer1, model.layer2] 11 | elif layer ==3: 12 | extractor = [model.conv1, model.bn1, nn.ReLU(inplace=True), model.layer1, model.layer2, model.layer3] 13 | elif layer ==4: 14 | extractor = [model.conv1, model.bn1, nn.ReLU(inplace=True), model.layer1, model.layer2, model.layer3, model.layer4] 15 | return nn.Sequential(*extractor) 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | 27 | self.shortcut = nn.Sequential() 28 | if stride != 1 or in_planes != self.expansion*planes: 29 | self.shortcut = nn.Sequential( 30 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 52 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 53 | 54 | self.shortcut = nn.Sequential() 55 | if stride != 1 or in_planes != self.expansion*planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(self.expansion*planes) 59 | ) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = self.bn3(self.conv3(out)) 65 | out += self.shortcut(x) 66 | out = F.relu(out) 67 | return out 68 | 69 | 70 | class ResNet(nn.Module): 71 | def __init__(self, block, num_blocks, num_classes=10): 72 | super(ResNet, self).__init__() 73 | self.in_planes = 64 74 | 75 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) #64 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) #128 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) #256 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) #512 81 | self.fc = nn.Linear(512*block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, planes, num_blocks, stride): 84 | strides = [stride] + [1]*(num_blocks-1) 85 | layers = [] 86 | for stride in strides: 87 | layers.append(block(self.in_planes, planes, stride)) 88 | self.in_planes = planes * block.expansion 89 | return nn.Sequential(*layers) 90 | 91 | def forward(self, x, feature=False): 92 | features = [] 93 | out = F.relu(self.bn1(self.conv1(x))) 94 | features.append(out) 95 | out = self.layer1(out) 96 | features.append(out) 97 | out = self.layer2(out) 98 | features.append(out) 99 | out = self.layer3(out) 100 | features.append(out) 101 | out = self.layer4(out) 102 | out = F.avg_pool2d(out, 4) 103 | features.append(out) 104 | out = out.view(out.size(0), -1) 105 | out = self.fc(out) 106 | if feature: 107 | return out, features 108 | else: 109 | return out 110 | 111 | def resnet50(num_classes = 10, **kwargs): 112 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs) 113 | return model 114 | -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def argparser(): 4 | parser = argparse.ArgumentParser() 5 | 6 | #Directories 7 | parser.add_argument('--root', type=str, default='/home/vhakim/scratch/Projects/NCTTT/', help='Base path') 8 | parser.add_argument('--dataroot', type=str, default='/home/davidoso/Documents/Data/') 9 | parser.add_argument('--save', type=str, default='work/', help='Path for base training weights') 10 | parser.add_argument('--save-iter', type=str, default='work/', help='Path for base training weights') 11 | 12 | #General settings 13 | parser.add_argument('--seed', type=int, default=42, help='Random seed') 14 | parser.add_argument('--print-freq', type=int, default=10, help='Number of epochs to print progress') 15 | 16 | #Model 17 | parser.add_argument('--model', type=str, default='resnet50') 18 | parser.add_argument('--eval', action='store_true', help='Using Eval at training') 19 | parser.add_argument('--load', type=str) 20 | 21 | #SimpleNet 22 | parser.add_argument('--patchsize', type=int, default=3, help='Patching size') 23 | parser.add_argument('--patchstride', type=int, default=1, help='Patching stride') 24 | parser.add_argument('--layers', type=int, nargs='+', default=[1], help='Layer blocks to put additional modules on (very common in TTT methods)') 25 | parser.add_argument('--embed-size', type=int, default=1536, help='Embedding size') 26 | parser.add_argument('--th-dsc', type=float, default=0.5, help='Threshold for discriminator') 27 | parser.add_argument('--std', type=float, default=0.1, help='Noise standard deviation') 28 | parser.add_argument('--std2', type=float, default=0.015, help='Noise standard deviation') 29 | parser.add_argument('--hidden', type=int, default=4, help='Noise standard deviation') 30 | 31 | #Dataset 32 | parser.add_argument('--dataset', type=str, default='cifar10', choices=('cifar10', 'cifar100', 'imagenet', 'visdaC')) 33 | parser.add_argument('--target', type=str, default='cifar10') 34 | parser.add_argument('--workers', type=int, default=6, help='Number of workers for dataloader') 35 | 36 | #Source training 37 | parser.add_argument('--method', type=str, default='original', help='Type of task', choices=('original', 'margin', 'margin2', 'margin3', 'margin4','margin5','margin6','margin7')) 38 | parser.add_argument('--epochs', type=int, default=100, help='Number of base training epochs') 39 | parser.add_argument('--start-epoch', type=int, default=0, help='Manual epoch number for restarts') 40 | parser.add_argument('--batch-size', type=int, default=128, help='Batch size for base training') 41 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 42 | parser.add_argument('--weight', type=float, default=1.0, help='Weight for loss function') 43 | parser.add_argument('--full', action='store_true', help='To use all the features') 44 | parser.add_argument('--separate', action='store_true', help='To use all the features') 45 | parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer to use', choices=('sgd', 'adam')) 46 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for optimizer') 47 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay for optimizer') 48 | parser.add_argument('--evaluate', action='store_true', help='Evaluating on evaluation set') 49 | parser.add_argument('--resume', default='', type=str, help='Path to latest checkpoint') 50 | parser.add_argument('--lbd', type=float, default=1.0, help='Lambda') 51 | 52 | #Test-Time Adaptation 53 | parser.add_argument('--adapt', action='store_true', help='To adapt or not') 54 | parser.add_argument('--source', action='store_true', help='To use source training') 55 | parser.add_argument('--use-entropy', action='store_true', help='To use entropy loss at test-time') 56 | parser.add_argument('--level', default=5, type=int) 57 | parser.add_argument('--corruption', default='gaussian_noise') 58 | parser.add_argument('--val-times', default=1, type=int) 59 | parser.add_argument('--split', default='eval', type=str, help='To use the evaluation set or the test set from VisdaC') 60 | parser.add_argument('--adapt-lr', default=0.00001, type=float) 61 | parser.add_argument('--optim', default='adam', type=str, help='Optimizer to use', choices=('sgd', 'adam')) 62 | parser.add_argument('--niter', default=50, type=int) 63 | parser.add_argument('--best', action='store_true', help='Using best pretraining weights or not') 64 | parser.add_argument('--K', type=int, default=10, help='Num of classes') 65 | parser.add_argument('--use-mean', action='store_true', help='Use mean to stop iterate') 66 | parser.add_argument('--two-std', action='store_true', help='Using two noises to create true and fake samples') 67 | 68 | #Distributed 69 | parser.add_argument('--distributed', action='store_true', help='Activate distributed training') 70 | parser.add_argument('--init-method', type=str, default='tcp://127.0.0.1:3456', help='url for distributed training') 71 | parser.add_argument('--dist-backend', default='gloo', type=str, help='distributed backend') 72 | parser.add_argument('--world-size', type=int, default=1, help='Number of nodes for training') 73 | 74 | return parser.parse_args() 75 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import random 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | def save_checkpoint(state, is_best, args): 9 | string = '_' 10 | if args.separate: 11 | separate = '_separate' 12 | else: 13 | separate = '' 14 | for layer in args.layers: 15 | string += str(layer) 16 | if is_best: 17 | torch.save(state, args.save + args.dataset + '_' + args.model + '_' + args.optimizer + '_std1' + str(args.std) + '_std2' + str(args.std2) + '_lr' + str(args.lr) + string + separate + '_best.pth') 18 | else: 19 | torch.save(state, args.save + args.dataset + '_' + args.model + '_' + args.optimizer + '_std1' + str(args.std) + '_std2' + str(args.std2) + '_lr' + str(args.lr) + string + separate + '.pth') 20 | 21 | def get_path(args, is_best=False): 22 | if args.separate: 23 | separate = '_separate' 24 | else: 25 | separate = '' 26 | if args.source: 27 | path = args.dataset + '_source.pth' 28 | else: 29 | string = '_' 30 | for layer in args.layers: 31 | string += str(layer) 32 | path = args.dataset + '_' + args.model + '_' + args.optimizer+ '_std1' + str(args.std) + '_std2' + str(args.std2) + '_lr' + str(args.lr) + string + separate 33 | if is_best: 34 | path += '_best.pth' 35 | else: 36 | path += '.pth' 37 | return path 38 | 39 | class AverageMeter(object): 40 | """Computes and stores the average and current value""" 41 | def __init__(self, name, fmt=':f'): 42 | self.name = name 43 | self.fmt = fmt 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | def __str__(self): 59 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 60 | return fmtstr.format(**self.__dict__) 61 | 62 | def accuracy(output, target, topk=(1,)): 63 | """Computes the accuracy over the k top predictions for the specified values of k""" 64 | with torch.no_grad(): 65 | maxk = max(topk) 66 | batch_size = target.size(0) 67 | 68 | _, pred = output.topk(maxk, 1, True, True) 69 | pred = pred.t() 70 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 71 | 72 | res = [] 73 | for k in topk: 74 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 75 | res.append(correct_k.mul_(100.0 / batch_size)) 76 | return res 77 | 78 | '''--------------------Adaptation Function-----------------------------''' 79 | def adapt_batch(net, niter, inputs, opt, iterations, save_iter, use_mean = False, train=False, **kwargs): 80 | entropy = kwargs['entropy'] 81 | net.inference = False 82 | mean_global = 0.023 83 | net.train() 84 | for iteration in range(niter): 85 | _, loss = net(inputs, train=train, use_entropy=entropy) 86 | loss.backward() 87 | opt.step() 88 | opt.zero_grad(set_to_none=True) 89 | if iteration + 1 in iterations and not use_mean: 90 | weights = {'weights': net.state_dict()} 91 | torch.save(weights, save_iter + 'weights_iter_' + str(iteration+1) +'.pkl') 92 | if use_mean: 93 | net.eval() 94 | _, scores = net(inputs, train=True) 95 | mean_scores = np.mean(scores) 96 | if mean_scores <= mean_global: 97 | return iteration 98 | break 99 | net.train() 100 | net.eval() 101 | 102 | '''--------------------Testing Function-----------------------------''' 103 | def test_batch(net, inputs, labels, source=False, full=False, **kwargs): 104 | net.eval() 105 | net.inference = False 106 | with torch.no_grad(): 107 | if source: 108 | outputs = net(inputs) 109 | else: 110 | outputs, _ = net(inputs, train=False) 111 | acc = accuracy(outputs, labels) 112 | predicted = torch.argmax(outputs, dim=1) 113 | correctness = predicted.eq(labels).cpu() 114 | return correctness, acc 115 | 116 | '''-------------------Loss Functions----------------------------------''' 117 | class Entropy(torch.nn.Module): 118 | def __init__(self): 119 | super(Entropy, self).__init__() 120 | 121 | def forward(self, x): 122 | return -(x.softmax(0)*x.log_softmax(0)).sum(0).mean() 123 | 124 | '''-------------------Getting Adapters Parameters---------------------''' 125 | def get_parameters(layers, model): 126 | parameters = [] 127 | if layers[0] is not None: 128 | parameters += list(model.mask1.parameters()) 129 | if layers[1] is not None: 130 | parameters += list(model.mask2.parameters()) 131 | if layers[2] is not None: 132 | parameters += list(model.mask3.parameters()) 133 | if layers[3] is not None: 134 | parameters += list(model.mask4.parameters()) 135 | return parameters 136 | 137 | def extractor_from_layer2(net): 138 | layers = [net.conv1, net.bn1, nn.ReLU(inplace=True), net.layer1, net.layer2] 139 | return nn.Sequential(*layers) 140 | 141 | def neg_log_likelihood_2d(target, z, log_det): 142 | log_likelihood_per_dim = target.log_prob(z) + log_det 143 | return -log_likelihood_per_dim.mean() 144 | 145 | def entropy_energy(Y, unary, pairwise, bound_lambda): 146 | E = (unary * Y - bound_lambda * pairwise * Y + Y * torch.log(Y.clip(1e-20))).sum() 147 | return E 148 | -------------------------------------------------------------------------------- /models/simplenet_cifar.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | import scipy.ndimage as ndimage 7 | from scipy.stats import multivariate_normal 8 | import numpy as np 9 | 10 | def get_shapes(layer, dataset='cifar10'): 11 | if dataset == 'visdaC': 12 | if layer == 1: 13 | return 256, 56 14 | if layer == 2: 15 | return 512, 28 16 | if layer == 3: 17 | return 1024, 14 18 | if layer == 4: 19 | return 2048, 7 20 | elif dataset in ['cifar10','cifar100']: 21 | if layer == 1: 22 | return 256, 32 23 | if layer == 2: 24 | return 512, 16 25 | if layer == 3: 26 | return 1024, 8 27 | if layer == 4: 28 | return 2048, 1 29 | 30 | def conv_out_size(c, k): 31 | return (c + k) + 1 32 | 33 | def score(x): 34 | was_numpy = False 35 | if isinstance(x, np.ndarray): 36 | was_numpy = True 37 | x = torch.from_numpy(x) 38 | while x.ndim > 2: 39 | x = torch.max(x, dim=-1).values 40 | if x.ndim == 2: 41 | x = torch.max(x, dim=1).values 42 | if was_numpy: 43 | return x.numpy() 44 | return x 45 | 46 | def init_weight(m): 47 | if isinstance(m, torch.nn.Linear): 48 | torch.nn.init.xavier_normal_(m.weight) 49 | elif isinstance(m, torch.nn.Conv2d): 50 | torch.nn.init.xavier_normal_(m.weight) 51 | 52 | class Projection(torch.nn.Module): 53 | def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0, stride=1): 54 | super(Projection, self).__init__() 55 | 56 | if out_planes is None: 57 | out_planes = in_planes 58 | self.layers = torch.nn.Sequential() 59 | _in = None 60 | _out = None 61 | for i in range(n_layers): 62 | _in = in_planes if i == 0 else _out 63 | _out = out_planes 64 | if stride == 2: 65 | self.layers.add_module(f"{i}fc", 66 | torch.nn.Conv2d(_in, _out, kernel_size=1, stride=2)) 67 | else: 68 | self.layers.add_module(f"{i}fc", 69 | torch.nn.Conv2d(_in, _out, kernel_size=1)) 70 | if i < n_layers - 1: 71 | # if layer_type > 0: 72 | # self.layers.add_module(f"{i}bn", 73 | # torch.nn.BatchNorm1d(_out)) 74 | if layer_type > 1: 75 | self.layers.add_module(f"{i}relu", 76 | torch.nn.LeakyReLU(.2)) 77 | self.apply(init_weight) 78 | 79 | def forward(self, x): 80 | 81 | # x = .1 * self.layers(x) + x 82 | x = self.layers(x) 83 | return x 84 | 85 | class Discriminator(torch.nn.Module): 86 | def __init__(self, in_planes, n_layers=1, hidden=None): 87 | super(Discriminator, self).__init__() 88 | 89 | _hidden = in_planes if hidden is None else hidden 90 | self.body = torch.nn.Sequential() 91 | for i in range(n_layers - 1): 92 | _in = in_planes if i == 0 else _hidden 93 | _hidden = int(_hidden // 1.5) if hidden is None else hidden 94 | self.body.add_module('block%d' % (i + 1), 95 | torch.nn.Sequential( 96 | torch.nn.Linear(_in, _hidden), 97 | torch.nn.BatchNorm1d(_hidden), 98 | torch.nn.LeakyReLU(0.2) 99 | )) 100 | self.tail = torch.nn.Linear(_hidden, 1, bias=False) 101 | self.apply(init_weight) 102 | 103 | def forward(self,x): 104 | x = self.body(x) 105 | x = self.tail(x) 106 | return x 107 | 108 | class SimpleNet(nn.Module): 109 | def __init__(self, layer, embed_dimension, std1=0.015, std2=0.05, hidden = 4, dataset='cifar10', device='cpu'): 110 | super(SimpleNet, self).__init__() 111 | self.device = device 112 | self.layer = layer[0] 113 | self.std1 = std1 114 | self.std2 = std2 115 | self.projector = [] 116 | in_planes = get_shapes(self.layer, dataset)[0] 117 | self.projector = Projection(in_planes, embed_dimension, 1, 0).to(self.device) 118 | self.discriminator = Discriminator(embed_dimension*16*16, n_layers=2, hidden=hidden) 119 | 120 | def forward(self, features, train=False): 121 | # Getting features 122 | features = features[self.layer] 123 | 124 | # Projection (adapter) 125 | true_feats = self.projector(features).reshape(len(features), -1) 126 | crossentropy = nn.BCELoss().to(self.device) 127 | 128 | if train: 129 | # Noise addition 130 | D = true_feats.shape[1] 131 | noise_idxs = torch.randint(0, 1, torch.Size([true_feats.shape[0]])) 132 | noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=1).to(self.device) 133 | if self.std1 != 0: 134 | N1 = torch.normal(0, self.std1, true_feats.shape).to(self.device) 135 | N1 = (N1 * noise_one_hot.unsqueeze(-1)).sum(1) 136 | N2 = torch.normal(0, self.std2, true_feats.shape).to(self.device) 137 | N2 = (N2 * noise_one_hot.unsqueeze(-1)).sum(1) 138 | if self.std1 != 0: 139 | N = torch.cat([N1, N2], dim=0) 140 | z = 0.5 * ((1 / self.std2) ** 2 - (1 / (self.std1 + 1e-6)) ** 2) * (N * N).sum(dim=1) - D * np.log( 141 | self.std1 / self.std2) 142 | Y = 1 / (1 + torch.exp(-z)) 143 | Y = torch.where(Y < 1e-6, 0, Y) 144 | X1 = true_feats + N1.to(self.device) 145 | else: 146 | X1 = true_feats 147 | Y1 = torch.zeros(X1.shape[0]).to(self.device) 148 | Y2 = torch.ones(N2.shape[0]).to(self.device) 149 | Y = torch.cat([Y1, Y2]).float() 150 | X2 = true_feats + N2.to(self.device) 151 | X = torch.cat([X1, X2], dim=0) 152 | scores = self.discriminator(X) 153 | scores = F.sigmoid(scores) 154 | loss = crossentropy(scores.squeeze(), Y.to(self.device)) 155 | return loss 156 | 157 | else: 158 | scores = -self.discriminator(true_feats) 159 | scores = F.sigmoid(scores.squeeze(1)) 160 | Y = torch.ones_like(scores).to(self.device) 161 | loss = crossentropy(scores, Y) 162 | 163 | return loss -------------------------------------------------------------------------------- /joint_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.distributed as dist 6 | import torch.backends.cudnn as cudnn 7 | import math 8 | 9 | import configuration 10 | from utils import create_model, prepare_dataset, utils 11 | 12 | best_uns = math.inf 13 | lbd = 1 14 | 15 | def main(args): 16 | global best_uns 17 | global lbd 18 | 19 | #-----------------DISTRIBUTED TRAINING------------------------------------------------------------------------------ 20 | ngpus_per_node = torch.cuda.device_count() 21 | local_rank = int(os.environ.get("SLURM_LOCALID")) 22 | rank = int(os.environ.get("SLURM_NODEID")) * ngpus_per_node + local_rank 23 | current_device = local_rank 24 | torch.cuda.set_device(current_device) 25 | if rank == 0: 26 | print('From Rank: {}, ==> Initializing Process Group...'.format(rank)) 27 | dist.init_process_group(backend=args.dist_backend, init_method=args.init_method, world_size=args.world_size, 28 | rank=rank) 29 | 30 | args.batch_size = int(args.batch_size / ngpus_per_node) 31 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 32 | 33 | if rank == 0: 34 | print('From Rank: {}, ==> Making model..'.format(rank)) 35 | print('Test on original data') 36 | print('Optimizer: ', args.optimizer) 37 | print('Layers: ', args.layers) 38 | print('Noise std1: ', args.std) 39 | print('Noise std2: ', args.std2) 40 | 41 | #------------------DATALOADERS-------------------------------------------------------------------------------------- 42 | cudnn.benchmark = True 43 | trloader, trsampler, teloader, tesampler = prepare_dataset.prepare_train_data(args) 44 | if args.dataset in ['cifar10', 'cifar100']: 45 | args.corruption = 'original' 46 | teloader, tesampler = prepare_dataset.prepare_test_data(args) 47 | input_size, _ = next(enumerate(teloader))[1] 48 | args.input_size = input_size.size(-1) 49 | 50 | #------------------CREATE MODEL------------------------------------------------------------------------------------- 51 | model = create_model.create_model(args, device = 'cuda').cuda() 52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[current_device]) 53 | 54 | if args.resume: 55 | checkpoint = torch.load(args.resume) 56 | args.start_epoch = checkpoint['epoch'] 57 | best_acc1 = checkpoint['best_acc1'] 58 | 59 | if args.optimizer == 'sgd': 60 | optimizer = torch.optim.SGD(model.module.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 61 | else: 62 | optimizer = torch.optim.Adam(model.module.parameters(), args.lr) 63 | #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 64 | 65 | if args.resume: 66 | optimizer.load_state_dict(checkpoint['optimizer']) 67 | 68 | if rank == 0: 69 | print('\t\tTrain Loss \t\t Sup Loss \t\t SSH Loss \t\t Train Accuracy \t\t Val Loss \t\t Val Accuracy') 70 | 71 | # ------------------JOINT TRAINING---------------------------------------------------------------------------------- 72 | for epoch in range(args.start_epoch, args.epochs): 73 | trsampler.set_epoch(epoch) 74 | tesampler.set_epoch(epoch) 75 | acc_train, loss_train, loss_sup, loss_ssh = train(model, optimizer, trloader, args) 76 | acc_val, loss_val = 0.0, 0.0 77 | if args.evaluate: 78 | acc_val, loss_val = validate(model, teloader, args) 79 | 80 | if rank == 0: 81 | print(('Epoch %d/%d:' % (epoch, args.epochs)).ljust(24) + 82 | '%.2f\t\t%.2f\t\t%.2f\t\t%.2f\t\t%.2f\t\t%.2f' % (loss_train, loss_sup, loss_ssh, acc_train, loss_val, acc_val)) 83 | 84 | if args.evaluate: 85 | is_best = loss_val < best_uns 86 | else: 87 | is_best = False 88 | best_uns = max(loss_val, best_uns) 89 | 90 | if rank == 0: 91 | dict = {'epoch': epoch + 1, 92 | 'arch': args.model, 93 | 'state_dict': model.module.state_dict(), 94 | 'best_uns': best_uns, 95 | 'optimizer': optimizer.state_dict(), 96 | } 97 | utils.save_checkpoint(dict, is_best, args) 98 | 99 | def train(model, optimizer, train_loader, args): 100 | batch_time = utils.AverageMeter('Time', ':6.3f') 101 | data_time = utils.AverageMeter('Data', ':6.3f') 102 | sup_losses = utils.AverageMeter('Sup Loss', ':.4e') 103 | ssh_losses = utils.AverageMeter('SSH Loss', ':.4e') 104 | losses = utils.AverageMeter('SSH Loss', ':.4e') 105 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 106 | 107 | model.train() 108 | end = time.time() 109 | cross_entropy = nn.CrossEntropyLoss() 110 | for i, (images, target) in enumerate(train_loader): 111 | data_time.update(time.time() - end) 112 | 113 | images = images.cuda(non_blocking=True) 114 | target = target.cuda(non_blocking=True) 115 | 116 | #Compute output and loss 117 | output, ssh_loss = model(images, train=True) 118 | sup_loss = cross_entropy(output, target) 119 | loss = sup_loss + args.weight*ssh_loss 120 | 121 | #Compute accuracy 122 | acc1 = utils.accuracy(output, target, topk=(1,)) 123 | losses.update(loss.item(), images.size(0)) 124 | sup_losses.update(sup_loss.item(), images.size(0)) 125 | ssh_losses.update(ssh_loss.item(), images.size(0)) 126 | top1.update(acc1[0], images.size(0)) 127 | 128 | # Backward pass 129 | optimizer.zero_grad(set_to_none=True) 130 | loss.backward() 131 | optimizer.step() 132 | 133 | # Measure elapsed time 134 | batch_time.update(time.time() - end) 135 | 136 | return top1.avg, losses.avg, sup_losses.avg, ssh_losses.avg 137 | 138 | 139 | def validate(model, val_loader, args): 140 | batch_time = utils.AverageMeter('Time', ':6.3f') 141 | losses = utils.AverageMeter('Loss', ':.4e') 142 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 143 | 144 | # switch to evaluate mode 145 | model.eval() 146 | cross_entropy = nn.CrossEntropyLoss() 147 | with torch.no_grad(): 148 | end = time.time() 149 | for i, (images, target) in enumerate(val_loader): 150 | images = images.cuda(non_blocking=True) 151 | target = target.cuda(non_blocking=True) 152 | 153 | # compute output 154 | output, ssh_loss = model(images) 155 | loss = cross_entropy(output, target) + ssh_loss 156 | 157 | # measure accuracy and record loss 158 | acc1 = utils.accuracy(output, target, topk=(1,)) 159 | losses.update(loss.item(), images.size(0)) 160 | top1.update(acc1[0], images.size(0)) 161 | 162 | # measure elapsed time 163 | batch_time.update(time.time() - end) 164 | end = time.time() 165 | 166 | return top1.avg, losses.avg 167 | 168 | 169 | if __name__=='__main__': 170 | args = configuration.argparser() 171 | main(args) 172 | -------------------------------------------------------------------------------- /models/simplenet.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions.multivariate_normal import MultivariateNormal 6 | import scipy.ndimage as ndimage 7 | from scipy.stats import multivariate_normal 8 | import numpy as np 9 | 10 | def get_shapes(layer, dataset='cifar10'): 11 | if dataset == 'visdaC': 12 | if layer == 1: 13 | return 256, 56 14 | if layer == 2: 15 | return 512, 28 16 | if layer == 3: 17 | return 1024, 14 18 | if layer == 4: 19 | return 2048, 7 20 | elif dataset in ['cifar10','cifar100']: 21 | if layer == 1: 22 | return 256, 32 23 | if layer == 2: 24 | return 512, 16 25 | if layer == 3: 26 | return 1024, 8 27 | if layer == 4: 28 | return 2048, 1 29 | 30 | def conv_out_size(c, k): 31 | return (c + k) + 1 32 | 33 | def score(x): 34 | was_numpy = False 35 | if isinstance(x, np.ndarray): 36 | was_numpy = True 37 | x = torch.from_numpy(x) 38 | while x.ndim > 2: 39 | x = torch.max(x, dim=-1).values 40 | if x.ndim == 2: 41 | x = torch.max(x, dim=1).values 42 | if was_numpy: 43 | return x.numpy() 44 | return x 45 | 46 | def init_weight(m): 47 | if isinstance(m, torch.nn.Linear): 48 | torch.nn.init.xavier_normal_(m.weight) 49 | elif isinstance(m, torch.nn.Conv2d): 50 | torch.nn.init.xavier_normal_(m.weight) 51 | 52 | class Projection(torch.nn.Module): 53 | def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0, stride=1): 54 | super(Projection, self).__init__() 55 | 56 | if out_planes is None: 57 | out_planes = in_planes 58 | self.layers = torch.nn.Sequential() 59 | _in = None 60 | _out = None 61 | for i in range(n_layers): 62 | _in = in_planes if i == 0 else _out 63 | _out = out_planes 64 | if stride == 2: 65 | self.layers.add_module(f"{i}fc", 66 | torch.nn.Conv2d(_in, _out, kernel_size=1, stride=2)) 67 | else: 68 | self.layers.add_module(f"{i}fc", 69 | torch.nn.Conv2d(_in, _out, kernel_size=1)) 70 | if i < n_layers - 1: 71 | # if layer_type > 0: 72 | # self.layers.add_module(f"{i}bn", 73 | # torch.nn.BatchNorm1d(_out)) 74 | if layer_type > 1: 75 | self.layers.add_module(f"{i}relu", 76 | torch.nn.LeakyReLU(.2)) 77 | self.apply(init_weight) 78 | 79 | def forward(self, x): 80 | 81 | # x = .1 * self.layers(x) + x 82 | x = self.layers(x) 83 | return x 84 | 85 | class Discriminator(torch.nn.Module): 86 | def __init__(self, in_planes, n_layers=1, hidden=None): 87 | super(Discriminator, self).__init__() 88 | 89 | _hidden = in_planes if hidden is None else hidden 90 | self.body = torch.nn.Sequential() 91 | for i in range(n_layers-1): 92 | _in = in_planes if i == 0 else _hidden 93 | _hidden = int(_hidden // 1.5) if hidden is None else hidden 94 | self.body.add_module('block%d'%(i+1), 95 | torch.nn.Sequential( 96 | torch.nn.Conv2d(_in, _hidden, kernel_size=1), 97 | torch.nn.BatchNorm2d(_hidden), 98 | torch.nn.LeakyReLU(0.2) 99 | )) 100 | self.tail = torch.nn.Conv2d(_hidden, 1, kernel_size=1) 101 | self.apply(init_weight) 102 | 103 | def forward(self,x): 104 | x = self.body(x) 105 | x = self.tail(x) 106 | return x 107 | 108 | class SimpleNet(nn.Module): 109 | def __init__(self, layer, embed_dimension, std1=0.015, std2=0.05, dataset='cifar10', device='cpu'): 110 | super(SimpleNet, self).__init__() 111 | self.device = device 112 | self.layer = layer[0] 113 | self.std1 = std1 114 | self.std2 = std2 115 | self.projector = [] 116 | in_planes = get_shapes(self.layer, dataset)[0] 117 | self.projector = Projection(in_planes, embed_dimension, 1, 0).to(self.device) 118 | self.discriminator = Discriminator(embed_dimension, n_layers=2, hidden=4) 119 | 120 | def forward(self, features, train=False): 121 | # Getting features 122 | features = features[self.layer] 123 | 124 | # Projection (adapter) 125 | true_feats = self.projector(features) 126 | crossentropy = nn.BCELoss().to(self.device) 127 | 128 | if train: 129 | # Noise addition 130 | B = true_feats.shape[0] 131 | D = true_feats.shape[1] 132 | H = W = true_feats.shape[2] 133 | mean = torch.zeros(D).to(self.device) 134 | I = torch.eye(D).to(self.device) 135 | if self.std1 != 0: 136 | gaussian1 = torch.distributions.MultivariateNormal(mean, (self.std1 ** 2) * I) 137 | gaussian2 = torch.distributions.MultivariateNormal(mean, (self.std2 ** 2) * I) 138 | if self.std1 != 0: 139 | N1 = gaussian1.sample((B * W * H,)).reshape(B, D, H * W).transpose(2, 1).reshape(B, D, W, H) 140 | N2 = gaussian2.sample((B * W * H,)).reshape(B, D, H * W).transpose(2, 1).reshape(B, D, W, H) 141 | noise_idxs = torch.randint(0, 1, torch.Size([true_feats.shape[0]])) 142 | noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=1).to(self.device) 143 | N2 = (N2 * noise_one_hot.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).sum(1) 144 | if self.std1 != 0: 145 | N1 = (N1 * noise_one_hot.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)).sum(1) 146 | N = torch.cat([N1, N2], dim=0) 147 | z = 0.5 * ((1 / self.std2) ** 2 - (1 / (self.std1 + 1e-6)) ** 2) * (N * N).sum(dim=1) - D * np.log( 148 | self.std1 / self.std2) 149 | Y = 1 / (1 + torch.exp(-z)) 150 | Y = torch.where(Y < 1e-6, 0, Y) 151 | X1 = true_feats + N1.to(self.device) 152 | else: 153 | X1 = true_feats 154 | Y1 = torch.zeros((X1.shape[0], X1.shape[2], X1.shape[3])).to(self.device) 155 | Y2 = torch.ones((N2.shape[0], N2.shape[2], N2.shape[3])).to(self.device) 156 | Y = torch.cat([Y1, Y2]).float() 157 | X2 = true_feats + N2.to(self.device) 158 | X = torch.cat([X1, X2], dim=0) 159 | scores = self.discriminator(X) 160 | scores = F.sigmoid(scores.squeeze(1)) 161 | loss = crossentropy(scores, Y.to(self.device)) 162 | 163 | return loss 164 | 165 | else: 166 | scores = -self.discriminator(true_feats) 167 | scores = F.sigmoid(scores.squeeze(1)) 168 | Y = torch.ones_like(scores).to(self.device) 169 | loss = crossentropy(scores, Y) 170 | 171 | return loss -------------------------------------------------------------------------------- /utils/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import random_split 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from torchvision.datasets import ImageFolder 8 | from utils.visdatest import * 9 | 10 | NORM = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 11 | te_transforms = transforms.Compose([transforms.ToTensor(), 12 | transforms.Normalize(*NORM)]) 13 | tr_transforms = transforms.Compose([transforms.RandomCrop(32, padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | transforms.ToTensor(), 16 | transforms.Normalize(*NORM)]) 17 | 18 | visda_train = transforms.Compose([transforms.Resize((256,256)), 19 | transforms.RandomCrop((224,224)), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 23 | 24 | visda_val = transforms.Compose([transforms.Resize((256,256)), 25 | transforms.CenterCrop((224,224)), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 28 | 29 | #augment_transforms = transforms.Compose([transforms.RandomRotation(180), 30 | # transforms.ColorJitter()]) 31 | 32 | augment_transforms = transforms.Compose([ 33 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 36 | transforms.RandomGrayscale(), 37 | transforms.ToTensor(), 38 | transforms.Normalize( 39 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 40 | ]) 41 | 42 | office_test = transforms.Compose([ 43 | transforms.Resize((224,224)), 44 | transforms.ToTensor(), 45 | transforms.Normalize( 46 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 47 | ]) 48 | 49 | 50 | 51 | 52 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 53 | 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 54 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 55 | 56 | def prepare_test_data(args): 57 | if args.dataset == 'cifar10': 58 | tesize = 10000 59 | if not hasattr(args, 'corruption') or args.corruption == 'original': 60 | teset = torchvision.datasets.CIFAR10(root=args.dataroot, 61 | train=False, download=False, transform=te_transforms) 62 | elif args.corruption in common_corruptions: 63 | teset_raw = np.load(args.dataroot + '/CIFAR-10-C/%s.npy' % (args.corruption)) 64 | teset_raw = teset_raw[(args.level - 1) * tesize: args.level * tesize] 65 | teset = torchvision.datasets.CIFAR10(root=args.dataroot, 66 | train=False, download=False, transform=te_transforms) 67 | teset.data = teset_raw 68 | 69 | elif args.corruption == 'cifar_new': 70 | from utils.cifar_new import CIFAR_New 71 | teset = CIFAR_New(root=args.dataroot + '/CIFAR10.1', transform=te_transforms) 72 | permute = False 73 | else: 74 | raise Exception('Corruption not found!') 75 | 76 | elif args.dataset == 'cifar100': 77 | tesize = 10000 78 | if not hasattr(args, 'corruption') or args.corruption == 'original': 79 | teset = torchvision.datasets.CIFAR100(root=args.dataroot, 80 | train=False, download=False, transform=te_transforms) 81 | elif args.corruption in common_corruptions: 82 | teset_raw = np.load(args.dataroot + '/CIFAR-100-C/%s.npy' % (args.corruption)) 83 | teset_raw = teset_raw[(args.level - 1) * tesize: args.level * tesize] 84 | teset = torchvision.datasets.CIFAR100(root=args.dataroot, 85 | train=False, download=True, transform=te_transforms) 86 | 87 | teset.data = teset_raw 88 | elif args.dataset == 'visdaC': 89 | teset = VisdaTest(args.dataroot, transforms=visda_val) 90 | elif args.dataset == 'office': 91 | teset = ImageFolder(root=args.dataroot + 'OfficeHomeDataset_10072016/' + args.category, transform=office_test) 92 | else: 93 | raise Exception('Dataset not found!') 94 | 95 | if args.distributed: 96 | te_sampler = torch.utils.data.distributed.DistributedSampler(teset) 97 | else: 98 | te_sampler = None 99 | 100 | if not hasattr(args, 'workers'): 101 | args.workers = 1 102 | if args.distributed: 103 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.batch_size, 104 | shuffle=(te_sampler is None), num_workers=args.workers, pin_memory=True, sampler=te_sampler) 105 | else: 106 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.batch_size, 107 | shuffle=True, num_workers=args.workers, pin_memory=True) 108 | 109 | return teloader, te_sampler 110 | 111 | 112 | def prepare_val_data(args): 113 | if args.dataset == 'visdaC': 114 | vset = ImageFolder(root=args.dataroot + 'validation/', transform=visda_val) 115 | else: 116 | raise Exception('Dataset not found!') 117 | 118 | if args.distributed: 119 | v_sampler = torch.utils.data.distributed.DistributedSampler(vset) 120 | else: 121 | v_sampler = None 122 | if not hasattr(args, 'workers'): 123 | args.workers = 1 124 | vloader = torch.utils.data.DataLoader(vset, batch_size=args.batch_size, 125 | shuffle=(v_sampler is None), num_workers=args.workers, pin_memory=True, sampler=v_sampler) 126 | return vloader, v_sampler 127 | 128 | def prepare_train_data(args): 129 | if args.dataset == 'cifar10': 130 | trset = torchvision.datasets.CIFAR10(root=args.dataroot, 131 | train=True, download=False, transform=tr_transforms) 132 | vset = None 133 | elif args.dataset == 'cifar100': 134 | trset = torchvision.datasets.CIFAR100(root=args.dataroot, 135 | train=True, download=False, transform=tr_transforms) 136 | vset = None 137 | elif args.dataset == 'visdaC': 138 | dataset = ImageFolder(root=args.dataroot + 'train/', transform=visda_train) 139 | trset, vset = random_split(dataset, [106678, 45719], generator=torch.Generator().manual_seed(args.seed)) 140 | elif args.dataset == 'office': 141 | dataset = ImageFolder(root=args.dataroot + 'OfficeHomeDataset_10072016/' + args.category, transform=augment_transforms) 142 | long = len(dataset) 143 | trset, vset = random_split(dataset, [math.floor(long*0.8), math.floor(long*0.2)], generator=torch.Generator().manual_seed(args.seed)) 144 | else: 145 | raise Exception('Dataset not found!') 146 | 147 | if args.distributed: 148 | tr_sampler = torch.utils.data.distributed.DistributedSampler(trset) 149 | if args.dataset == 'visdaC': 150 | v_sampler = torch.utils.data.distributed.DistributedSampler(vset) 151 | else: 152 | v_sampler = None 153 | else: 154 | tr_sampler = None 155 | v_sampler = None 156 | 157 | if not hasattr(args, 'workers'): 158 | args.workers = 1 159 | trloader = torch.utils.data.DataLoader(trset, batch_size=args.batch_size, 160 | shuffle=(tr_sampler is None), num_workers=args.workers, pin_memory=True, sampler=tr_sampler, drop_last=True) 161 | if args.dataset == 'visdaC': 162 | vloader = torch.utils.data.DataLoader(vset, batch_size=args.batch_size, 163 | shuffle=(v_sampler is None), num_workers=args.workers, pin_memory=True, sampler=v_sampler, drop_last=True) 164 | else: 165 | vloader = None 166 | return trloader, tr_sampler, vloader, v_sampler 167 | -------------------------------------------------------------------------------- /adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import timm 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | from tqdm import tqdm 6 | 7 | import configuration 8 | import numpy as np 9 | from utils import utils, create_model, prepare_dataset 10 | import copy 11 | 12 | 13 | def experiment(args): 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | cudnn.benchmark = True 16 | 17 | '''-------------------Loading Dataset----------------------------''' 18 | if args.split == 'test' or args.dataset in ['cifar10','cifar100']: 19 | teloader, _ = prepare_dataset.prepare_test_data(args) 20 | else: 21 | teloader, _ = prepare_dataset.prepare_val_data(args) 22 | input_size, _ = next(enumerate(teloader))[1] 23 | args.input_size = input_size.size(-1) 24 | 25 | '''--------------------Loading Model-----------------------------''' 26 | print('Loading model') 27 | print('Dataset: ', args.dataset) 28 | print('Corruption: ', args.corruption if args.dataset in ['cifar10', 'cifar100'] else 'N/A') 29 | print('Training optimizer: ', args.optimizer) 30 | print('Adaptation optimizer: ', args.optim) 31 | print('Layers: ', args.layers) 32 | print('Std 1: ', args.std) 33 | print('Std 2: ', args.std2) 34 | 35 | if args.source: 36 | model = timm.create_model('resnet50', num_classes=12).cuda() 37 | else: 38 | model = create_model.create_model(args, device=device).to(device) 39 | path = utils.get_path(args, is_best=args.best) 40 | checkpoint = torch.load(os.path.join(args.root, 'weights', path)) 41 | model.load_state_dict(checkpoint['state_dict']) 42 | 43 | state = copy.deepcopy(model.state_dict()) 44 | print('Number of iterations:', args.niter) 45 | 46 | '''-------------------Optimizer----------------------------------''' 47 | if args.source: 48 | optimizer = torch.optim.Adam(model.parameters(), lr=args.adapt_lr) 49 | else: 50 | extractor = create_model.get_part(model, args.layers[-1]) 51 | if args.optim == 'adam': 52 | optimizer = torch.optim.Adam(extractor.parameters(), lr=args.adapt_lr) 53 | else: 54 | optimizer = torch.optim.SGD(extractor.parameters(), lr=args.adapt_lr) 55 | 56 | '''--------------------Test-Time Adaptation----------------------''' 57 | print('Test-Time Adaptation') 58 | iteration = [1, 3, 5, 10, 15, 20, 50, 100] 59 | scores_before = [] 60 | scores_after = [] 61 | if args.niter in iteration and not args.use_mean: 62 | validation = args.val_times 63 | indice = iteration.index(args.niter) 64 | good_good_V = np.zeros([indice + 1, validation]) 65 | good_bad_V = np.zeros([indice + 1, validation]) 66 | bad_good_V = np.zeros([indice + 1, validation]) 67 | bad_bad_V = np.zeros([indice + 1, validation]) 68 | accuracy_V = np.zeros([indice + 1, validation]) 69 | for val in range(validation): 70 | good_good = np.zeros([indice + 1, len(teloader.dataset)]) 71 | good_bad = np.zeros([indice + 1, len(teloader.dataset)]) 72 | bad_good = np.zeros([indice + 1, len(teloader.dataset)]) 73 | bad_bad = np.zeros([indice + 1, len(teloader.dataset)]) 74 | correct = np.zeros(indice + 1) 75 | for batch_idx, (inputs, labels) in tqdm(enumerate(teloader)): 76 | inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True) 77 | model.load_state_dict(state) 78 | if args.source: 79 | optimizer = torch.optim.Adam(model.parameters(), lr=args.adapt_lr) 80 | else: 81 | if args.optimizer == 'adam': 82 | optimizer = torch.optim.Adam(extractor.parameters(), lr=args.adapt_lr) 83 | else: 84 | optimizer = torch.optim.SGD(extractor.parameters(), lr=args.adapt_lr) 85 | correctness, _ = utils.test_batch(model, inputs, labels, source=args.source) 86 | 87 | if args.adapt: 88 | utils.adapt_batch(model, args.niter, inputs, optimizer, iteration, args.save_iter, train=False, entropy=args.use_entropy) 89 | 90 | for k in range(len(iteration[:indice + 1])): 91 | ckpt = torch.load(args.save_iter + 'weights_iter_' + str(iteration[k]) + '.pkl') 92 | model.load_state_dict(ckpt['weights']) 93 | correctness_new, _ = utils.test_batch(model, inputs, labels, source=args.source) 94 | for i in range(len(correctness_new.tolist())): 95 | if correctness[i] == True and correctness_new[i] == True: 96 | good_good[k, i + batch_idx * args.batch_size] = 1 97 | elif correctness[i] == True and correctness_new[i] == False: 98 | good_bad[k, i + batch_idx * args.batch_size] = 1 99 | elif correctness[i] == False and correctness_new[i] == True: 100 | bad_good[k, i + batch_idx * args.batch_size] = 1 101 | elif correctness[i] == False and correctness_new[i] == False: 102 | bad_bad[k, i + batch_idx * args.batch_size] = 1 103 | else: 104 | correct += correctness.sum().item() 105 | 106 | for k in range(len(iteration[:indice + 1])): 107 | correct[k] += np.sum(good_good[k,]) + np.sum(bad_good[k,]) 108 | accuracy = correct[k] / len(teloader.dataset) 109 | good_good_V[k, val] = np.sum(good_good[k,]) 110 | good_bad_V[k, val] = np.sum(good_bad[k,]) 111 | bad_good_V[k, val] = np.sum(bad_good[k,]) 112 | bad_bad_V[k, val] = np.sum(bad_bad[k,]) 113 | accuracy_V[k, val] = accuracy 114 | 115 | for k in range(len(iteration[:indice + 1])): 116 | print('--------------------RESULTS----------------------') 117 | print('Perturbation: ', args.corruption) 118 | print('Number of iterations: ', iteration[k]) 119 | print('Good first, good after: ', str(good_good_V[k,].mean()) + '+/-' + str(good_good_V[k,].std())) 120 | print('Good first, bad after: ', str(good_bad_V[k,].mean()) + '+/-' + str(good_bad_V[k,].std())) 121 | print('Bad first, good after: ', str(bad_good_V[k,].mean()) + '+/-' + str(bad_good_V[k,].std())) 122 | print('Bad first, bad after: ', str(bad_bad_V[k,].mean()) + '+/-' + str(bad_bad_V[k,].std())) 123 | print('Accuracy: ', str(np.round(accuracy_V[k,].mean()*100,2)) + '+/-' + str(np.round(accuracy_V[k,].std()*100,2))) 124 | 125 | else: 126 | validation = 1 127 | good_good_V = np.zeros([1, validation]) 128 | good_bad_V = np.zeros([1, validation]) 129 | bad_good_V = np.zeros([1, validation]) 130 | bad_bad_V = np.zeros([1, validation]) 131 | accuracy_V = np.zeros([1, validation]) 132 | nb_iteration_V = np.zeros([1, validation]) 133 | for val in range(validation): 134 | good_good = [] 135 | good_bad = [] 136 | bad_good = [] 137 | bad_bad = [] 138 | correct = 0 139 | nb_iteration = [] 140 | for batch_idx, (inputs, labels) in tqdm(enumerate(teloader)): 141 | inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True) 142 | model.load_state_dict(state) 143 | correctness, _, _ = utils.test_batch(model, inputs, labels, source=args.source, q1=args.q, q2=1 - args.q, method=args.method) 144 | 145 | if args.adapt: 146 | nb_iteration.append(utils.adapt_batch(model, args.niter, inputs, optimizer, iteration, args.save_iter, train=False, use_mean= args.use_mean, two_std=args.two_std, entropy=args.use_entropy)) 147 | correctness_new, _ = utils.test_batch(model, inputs, labels, source=args.source, q1=args.q, q2=1 - args.q, method=args.method) 148 | for i in range(len(correctness_new.tolist())): 149 | if correctness[i] == True and correctness_new[i] == True: 150 | good_good.append(1) 151 | elif correctness[i] == True and correctness_new[i] == False: 152 | good_bad.append(1) 153 | elif correctness[i] == False and correctness_new[i] == True: 154 | bad_good.append(1) 155 | elif correctness[i] == False and correctness_new[i] == False: 156 | bad_bad.append(1) 157 | else: 158 | correct += correctness.sum().item() 159 | 160 | correct += np.sum(good_good) + np.sum(bad_good) 161 | accuracy = correct / len(teloader.dataset) 162 | good_good_V[0, val] = np.sum(good_good) 163 | good_bad_V[0, val] = np.sum(good_bad) 164 | bad_good_V[0, val] = np.sum(bad_good) 165 | bad_bad_V[0, val] = np.sum(bad_bad) 166 | accuracy_V[0, val] = accuracy 167 | nb_iteration_V[0, val] = np.mean(nb_iteration) 168 | 169 | print('--------------------RESULTS----------------------') 170 | print('Perturbation: ', args.corruption) 171 | if args.adapt: 172 | print('Iteration: ', str(nb_iteration_V[0,].mean()) + '+/-' + str(nb_iteration_V[0,].std())) 173 | print('Good first, good after: ', str(good_good_V[0,].mean()) + '+/-' + str(good_good_V[0,].std())) 174 | print('Good first, bad after: ', str(good_bad_V[0,].mean()) + '+/-' + str(good_bad_V[0,].std())) 175 | print('Bad first, good after: ', str(bad_good_V[0,].mean()) + '+/-' + str(bad_good_V[0,].std())) 176 | print('Bad first, bad after: ', str(bad_bad_V[0,].mean()) + '+/-' + str(bad_bad_V[0,].std())) 177 | print('Accuracy: ', str(np.round(accuracy_V[0,].mean()*100,2)) + '+/-' + str(np.round(accuracy_V[0,].std()*100,2))) 178 | 179 | 180 | if __name__ == '__main__': 181 | args = configuration.argparser() 182 | experiment(args) 183 | --------------------------------------------------------------------------------