├── models ├── __init__.py └── Res.py ├── utils ├── __init__.py ├── cli_utils.py ├── third_party.py ├── utils.py └── metrics.py ├── dataset ├── __init__.py ├── generate_shifted_sample_indices.py └── selectedRotateImageFolder.py ├── requirements.txt ├── .ipynb_checkpoints ├── requirements-checkpoint.txt └── README-checkpoint.md ├── start.sh ├── start-open.sh ├── sam.py ├── tent.py ├── tent_come.py ├── README.md ├── sar.py ├── sar_come.py ├── eata.py ├── eata_come.py └── main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub==0.16.4 2 | numpy==1.23.2 3 | timm==0.9.12 4 | torch==1.13.1 5 | torchvision==0.14.1 6 | pycm==4.0 7 | foolbox==3.3.4 8 | scikit-learn==1.0.2 -------------------------------------------------------------------------------- /.ipynb_checkpoints/requirements-checkpoint.txt: -------------------------------------------------------------------------------- 1 | huggingface-hub==0.16.4 2 | numpy==1.23.2 3 | timm==0.9.12 4 | torch==1.13.1 5 | torchvision==0.14.1 6 | pycm==4.0 7 | foolbox==3.3.4 8 | scikit-learn==1.0.2 -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ImageNet1K Path (if EATA method is included) 3 | data='/path/to/dataset/Imagenet1K' 4 | # ImageNet_C Path (necessary) 5 | data_corruption='/path/to/dataset/ImageNet_C' 6 | # Log File Output Path 7 | output='/path/to/output/result' 8 | # corrupt level of ImageNet_C 9 | level=5 10 | exp_type="normal" 11 | step=1 12 | # backbone model 13 | model="vitbase_timm" 14 | seed=2024 15 | ood_rate=0.0 16 | test_batch_size=64 17 | export CUDA_VISIBLE_DEVICES=1 18 | 19 | run_experiment () { 20 | local method=$1 21 | local scoring_function=$2 22 | local name="experiment_${method}_${model}_ood${ood_rate}_level${level}_seed${seed}_${exp_type}" 23 | echo "Running $name with seed: $seed" 24 | python3 main.py --data $data --data_corruption $data_corruption --output $output \ 25 | --method $method --level $level --exp_type $exp_type --step $step\ 26 | --ood_rate $ood_rate --scoring_function $scoring_function --model $model --seed $seed --test_batch_size $test_batch_size 27 | } 28 | 29 | methods=("no_adapt" "Tent" "EATA" "SAR" "Tent_COME" "EATA_COME" "SAR_COME") 30 | scoring_functions=( "msp" "msp" "msp" "msp" "dirichlet" "dirichlet" "dirichlet" ) 31 | 32 | for i in ${!methods[@]}; do 33 | run_experiment "${methods[$i]}" "${scoring_functions[$i]}" 34 | done 35 | 36 | 37 | -------------------------------------------------------------------------------- /start-open.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ImageNet1K Path (if EATA method is included) 3 | data='/path/to/dataset/Imagenet1K' 4 | # ImageNet_C Path (necessary) 5 | data_corruption='/path/to/dataset/ImageNet_C' 6 | # Log File Output Path 7 | output='/path/to/output/result' 8 | # open-world data path 9 | ood_root='/data1/kongxinke/datasets/' 10 | # corrupt level of ImageNet_C 11 | level=3 12 | exp_type="open-world" 13 | step=1 14 | # backbone model 15 | model="vitbase_timm" 16 | seed=2024 17 | ood_rate=0.5 18 | test_batch_size=64 19 | export CUDA_VISIBLE_DEVICES=1 20 | 21 | run_experiment () { 22 | local method=$1 23 | local scoring_function=$2 24 | local name="experiment_${method}_${model}_ood${ood_rate}_level${level}_seed${seed}_${exp_type}" 25 | echo "Running $name with seed: $seed" 26 | python3 main.py --data $data --data_corruption $data_corruption --output $output --ood_root $ood_root\ 27 | --method $method --level $level --exp_type $exp_type --step $step\ 28 | --ood_rate $ood_rate --scoring_function $scoring_function --model $model --seed $seed --test_batch_size $test_batch_size 29 | } 30 | methods=("no_adapt" "Tent" "EATA" "SAR" "Tent_COME" "EATA_COME" "SAR_COME") 31 | scoring_functions=( "msp" "msp" "msp" "msp" "dirichlet" "dirichlet" "dirichlet" ) 32 | 33 | for i in ${!methods[@]}; do 34 | run_experiment "${methods[$i]}" "${scoring_functions[$i]}" 35 | done 36 | -------------------------------------------------------------------------------- /sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM(torch.optim.Optimizer): 5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 7 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 8 | super(SAM, self).__init__(params, defaults) 9 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 10 | self.param_groups = self.base_optimizer.param_groups 11 | self.defaults.update(self.base_optimizer.defaults) 12 | 13 | @torch.no_grad() 14 | def first_step(self, zero_grad=False): 15 | grad_norm = self._grad_norm() 16 | for group in self.param_groups: 17 | scale = group["rho"] / (grad_norm + 1e-12) 18 | for p in group["params"]: 19 | if p.grad is None: continue 20 | self.state[p]["old_p"] = p.data.clone() 21 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 22 | p.add_(e_w) 23 | if zero_grad: self.zero_grad() 24 | 25 | @torch.no_grad() 26 | def second_step(self, zero_grad=False): 27 | for group in self.param_groups: 28 | for p in group["params"]: 29 | if p.grad is None: continue 30 | p.data = self.state[p]["old_p"] 31 | self.base_optimizer.step() 32 | if zero_grad: self.zero_grad() 33 | 34 | @torch.no_grad() 35 | def step(self, closure=None): 36 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 37 | closure = torch.enable_grad()(closure) 38 | self.first_step(zero_grad=True) 39 | closure() 40 | self.second_step() 41 | 42 | def _grad_norm(self): 43 | shared_device = self.param_groups[0]["params"][0].device 44 | norm = torch.norm( 45 | torch.stack([ 46 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 47 | for group in self.param_groups for p in group["params"] 48 | if p.grad is not None 49 | ]), 50 | p=2 51 | ) 52 | return norm 53 | 54 | def load_state_dict(self, state_dict): 55 | super().load_state_dict(state_dict) 56 | self.base_optimizer.param_groups = self.param_groups -------------------------------------------------------------------------------- /dataset/generate_shifted_sample_indices.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | seed = 2023 5 | 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | 9 | 10 | def monotone_shift_constructor(q1, q2): 11 | def monotone_shift(T): 12 | lamb = 1.0 / (T-1) 13 | return np.concatenate([np.expand_dims(q1 * (1 - lamb * t) + q2 * lamb * t, axis=0) for t in range(T)], axis=0) 14 | return monotone_shift 15 | 16 | 17 | def generate_sample_indices_and_ys(q_all, dataset_name='imagenet1k'): 18 | np.random.seed(2022) 19 | if dataset_name == 'imagenet1k': 20 | num_classes = 1000 21 | tset_length = 50000 22 | num_each_class = 50 23 | elif dataset_name == 'cifar10': 24 | num_classes = 10 25 | tset_length = 10000 26 | num_each_class = 1000 27 | else: 28 | assert False, "not supported, now only support imagenet1k" 29 | ys = np.squeeze(np.asarray([np.random.choice(num_classes, 1, p=q) for q in q_all])) 30 | 31 | print((ys == 3).sum()) 32 | print((ys == 5).sum()) 33 | print((ys == 6).sum()) 34 | 35 | print(q_all[:3,:]) 36 | print(ys[:100]) 37 | 38 | 39 | num_tests = len(ys) 40 | tset_indices = np.array([i for i in range(tset_length)]) 41 | tset_ys = np.array([i // num_each_class for i in range(tset_length)]) 42 | 43 | generated_indices = np.zeros([num_tests]) 44 | 45 | for i in range(num_classes): 46 | num_i = (ys == i).sum() 47 | if num_i == 0: 48 | continue 49 | num_test_i = (tset_ys == i).sum() 50 | sampled_indices = np.random.randint(0, num_test_i, num_i) 51 | sampled_indices = tset_indices[tset_ys == i][sampled_indices] 52 | 53 | generated_indices[ys == i] = sampled_indices 54 | return generated_indices 55 | 56 | for myir in [10]: 57 | shift_proccess_name = "per_class_shift" 58 | T = 100000 59 | dataset_name = 'imagenet1k' 60 | 61 | if dataset_name == 'imagenet1k': 62 | num_classes = 1000 63 | elif dataset_name == 'cifar10': 64 | num_classes = 10 65 | 66 | 67 | 68 | if shift_proccess_name == "per_class_shift" and dataset_name == "imagenet1k": 69 | imbalance_ratio = myir 70 | shuffle_class_order = "yes" 71 | minor_class_prob = 1 / (imbalance_ratio + num_classes - 1) 72 | major_class_prob = minor_class_prob * imbalance_ratio 73 | q_for_all_classes = np.ones([num_classes, num_classes]) * minor_class_prob 74 | print(q_for_all_classes.shape) 75 | for i in range(num_classes): 76 | q_for_all_classes[i, i] = major_class_prob 77 | if shuffle_class_order == "yes": 78 | indices = list(range(num_classes)) 79 | random.shuffle(indices) 80 | q_for_all_classes = q_for_all_classes[indices,:] 81 | def shift_proccess(T): 82 | num_for_repeat_each_q = T // num_classes 83 | assert num_for_repeat_each_q > 0, "T should greater than number of classes" 84 | return np.concatenate([np.expand_dims(q_for_all_classes[i,:], axis=0) for i in range(num_classes) for _ in range(num_for_repeat_each_q)], axis=0) 85 | else: 86 | assert False, NotImplementedError 87 | 88 | q_all = shift_proccess(T) 89 | 90 | print(q_all.shape) 91 | 92 | simulated_indices = generate_sample_indices_and_ys(q_all, dataset_name=dataset_name) 93 | 94 | print(simulated_indices.shape) 95 | print(simulated_indices[:100]) 96 | 97 | print(list(simulated_indices[:10])) 98 | 99 | np.save('seed{}_total_{}_ir_{}_class_order_shuffle_{}'.format(seed, T, imbalance_ratio, shuffle_class_order), simulated_indices) 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def __str__(self): 28 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 29 | return fmtstr.format(**self.__dict__) 30 | 31 | 32 | class ProgressMeter(object): 33 | def __init__(self, num_batches, meters, prefix=""): 34 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 35 | self.meters = meters 36 | self.prefix = prefix 37 | 38 | def display(self, batch): 39 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 40 | entries += [str(meter) for meter in self.meters] 41 | print('\t'.join(entries)) 42 | 43 | def _get_batch_fmtstr(self, num_batches): 44 | num_digits = len(str(num_batches // 1)) 45 | fmt = '{:' + str(num_digits) + 'd}' 46 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 47 | 48 | 49 | def save_checkpoint(state, is_best, save_dir=None): 50 | checkpoint_path = os.path.join(save_dir, 'checkpoint.pth.tar') 51 | torch.save(state, checkpoint_path) 52 | if is_best: 53 | best_checkpoint_path = os.path.join(save_dir, 'model_best.pth.tar') 54 | shutil.copyfile(checkpoint_path, best_checkpoint_path) 55 | 56 | 57 | def adjust_learning_rate(optimizer, epoch, args): 58 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 59 | lr = args.lr * (0.1 ** (epoch // 5)) 60 | for param_group in optimizer.param_groups: 61 | param_group['lr'] = lr 62 | 63 | 64 | def accuracy(output, target, topk=(1,)): 65 | """Computes the accuracy over the k top predictions for the specified values of k""" 66 | with torch.no_grad(): 67 | maxk = max(topk) 68 | batch_size = target.size(0) 69 | 70 | _, pred = output.topk(maxk, 1, True, True) 71 | pred = pred.t() 72 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 73 | 74 | res = [] 75 | for k in topk: 76 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 77 | res.append(correct_k.mul_(100.0 / batch_size)) 78 | return res 79 | 80 | 81 | class entropy_loss(nn.Module): 82 | def __init__(self): 83 | super(entropy_loss, self).__init__() 84 | 85 | def forward(self, x): 86 | 87 | softmax_x = x 88 | return -1 * torch.mean(torch.sum(softmax_x * torch.log(softmax_x), dim=1)) 89 | 90 | 91 | class LabelSmoothingCrossEntropy(nn.Module): 92 | """ 93 | NLL loss with label smoothing. 94 | """ 95 | def __init__(self, smoothing=0.1): 96 | """ 97 | Constructor for the LabelSmoothing module. 98 | :param smoothing: label smoothing factor 99 | """ 100 | super(LabelSmoothingCrossEntropy, self).__init__() 101 | assert smoothing < 1.0 102 | self.smoothing = smoothing 103 | self.confidence = 1. - smoothing 104 | 105 | def forward(self, x, target): 106 | logprobs = F.log_softmax(x, dim=-1) 107 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 108 | nll_loss = nll_loss.squeeze(1) 109 | smooth_loss = -logprobs.mean(dim=-1) 110 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 111 | return loss.mean() -------------------------------------------------------------------------------- /utils/third_party.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from functools import partial 9 | 10 | from PIL import ImageOps, Image 11 | from torchvision import transforms 12 | 13 | mean = [0.485, 0.456, 0.406] 14 | std = [0.229, 0.224, 0.225] 15 | preprocess = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean, std) 18 | ]) 19 | preaugment = transforms.Compose([ 20 | transforms.RandomResizedCrop(224), 21 | transforms.RandomHorizontalFlip(), 22 | ]) 23 | def _augmix_aug(x_orig): 24 | x_orig = preaugment(x_orig) 25 | x_processed = preprocess(x_orig) 26 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 27 | m = np.float32(np.random.beta(1.0, 1.0)) 28 | 29 | mix = torch.zeros_like(x_processed) 30 | for i in range(3): 31 | x_aug = x_orig.copy() 32 | for _ in range(np.random.randint(1, 4)): 33 | x_aug = np.random.choice(augmentations)(x_aug) 34 | mix += w[i] * preprocess(x_aug) 35 | mix = m * x_processed + (1 - m) * mix 36 | return mix 37 | 38 | aug = _augmix_aug 39 | 40 | 41 | def autocontrast(pil_img, level=None): 42 | return ImageOps.autocontrast(pil_img) 43 | 44 | def equalize(pil_img, level=None): 45 | return ImageOps.equalize(pil_img) 46 | 47 | def rotate(pil_img, level): 48 | degrees = int_parameter(rand_lvl(level), 30) 49 | if np.random.uniform() > 0.5: 50 | degrees = -degrees 51 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 52 | 53 | def solarize(pil_img, level): 54 | level = int_parameter(rand_lvl(level), 256) 55 | return ImageOps.solarize(pil_img, 256 - level) 56 | 57 | def shear_x(pil_img, level): 58 | level = float_parameter(rand_lvl(level), 0.3) 59 | if np.random.uniform() > 0.5: 60 | level = -level 61 | return pil_img.transform((224, 224), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128) 62 | 63 | def shear_y(pil_img, level): 64 | level = float_parameter(rand_lvl(level), 0.3) 65 | if np.random.uniform() > 0.5: 66 | level = -level 67 | return pil_img.transform((224, 224), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR, fillcolor=128) 68 | 69 | def translate_x(pil_img, level): 70 | level = int_parameter(rand_lvl(level), 224 / 3) 71 | if np.random.random() > 0.5: 72 | level = -level 73 | return pil_img.transform((224, 224), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128) 74 | 75 | def translate_y(pil_img, level): 76 | level = int_parameter(rand_lvl(level), 224 / 3) 77 | if np.random.random() > 0.5: 78 | level = -level 79 | return pil_img.transform((224, 224), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR, fillcolor=128) 80 | 81 | def posterize(pil_img, level): 82 | level = int_parameter(rand_lvl(level), 4) 83 | return ImageOps.posterize(pil_img, 4 - level) 84 | 85 | 86 | def int_parameter(level, maxval): 87 | """Helper function to scale `val` between 0 and maxval . 88 | Args: 89 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 90 | maxval: Maximum value that the operation can have. This will be scaled 91 | to level/PARAMETER_MAX. 92 | Returns: 93 | An int that results from scaling `maxval` according to `level`. 94 | """ 95 | return int(level * maxval / 10) 96 | 97 | def float_parameter(level, maxval): 98 | """Helper function to scale `val` between 0 and maxval . 99 | Args: 100 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 101 | maxval: Maximum value that the operation can have. This will be scaled 102 | to level/PARAMETER_MAX. 103 | Returns: 104 | A float that results from scaling `maxval` according to `level`. 105 | """ 106 | return float(level) * maxval / 10. 107 | 108 | def rand_lvl(n): 109 | return np.random.uniform(low=0.1, high=n) 110 | 111 | 112 | augmentations = [ 113 | autocontrast, 114 | equalize, 115 | lambda x: rotate(x, 1), 116 | lambda x: solarize(x, 1), 117 | lambda x: shear_x(x, 1), 118 | lambda x: shear_y(x, 1), 119 | lambda x: translate_x(x, 1), 120 | lambda x: translate_y(x, 1), 121 | lambda x: posterize(x, 1), 122 | ] 123 | 124 | -------------------------------------------------------------------------------- /tent.py: -------------------------------------------------------------------------------- 1 | # https://github.com/mr-eggplant/SAR/blob/main/tent.py 2 | from copy import deepcopy 3 | import torch 4 | import torch.nn as nn 5 | import torch.jit 6 | 7 | class Tent(nn.Module): 8 | """Tent adapts a model by entropy minimization during testing. 9 | Once tented, a model adapts itself by updating on every forward. 10 | """ 11 | def __init__(self, model, optimizer, steps=1, episodic=False): 12 | super().__init__() 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.steps = steps 16 | assert steps > 0, "tent requires >= 1 step(s) to forward and update" 17 | self.episodic = episodic 18 | self.model_state, self.optimizer_state = \ 19 | copy_model_and_optimizer(self.model, self.optimizer) 20 | 21 | def forward(self, x): 22 | if self.episodic: 23 | self.reset() 24 | for _ in range(self.steps): 25 | outputs = forward_and_adapt(x, self.model, self.optimizer) 26 | return outputs 27 | 28 | def reset(self): 29 | if self.model_state is None or self.optimizer_state is None: 30 | raise Exception("cannot reset without saved model/optimizer state") 31 | load_model_and_optimizer(self.model, self.optimizer, 32 | self.model_state, self.optimizer_state) 33 | 34 | @torch.jit.script 35 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 36 | """Entropy of softmax distribution from logits.""" 37 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 38 | 39 | @torch.enable_grad() 40 | def forward_and_adapt(x, model, optimizer): 41 | """Forward and adapt model on batch of data. 42 | Measure entropy of the model prediction, take gradients, and update params. 43 | """ 44 | outputs = model(x) 45 | loss = softmax_entropy(outputs).mean(0) 46 | loss.backward() 47 | optimizer.step() 48 | optimizer.zero_grad() 49 | return outputs 50 | 51 | def collect_params(model): 52 | """Collect the affine scale + shift parameters from batch norms. 53 | Walk the model's modules and collect all batch normalization parameters. 54 | Return the parameters and their names. 55 | """ 56 | params = [] 57 | names = [] 58 | for nm, m in model.named_modules(): 59 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 60 | for np, p in m.named_parameters(): 61 | if np in ['weight', 'bias']: 62 | params.append(p) 63 | names.append(f"{nm}.{np}") 64 | return params, names 65 | 66 | def copy_model_and_optimizer(model, optimizer): 67 | """Copy the model and optimizer states for resetting after adaptation.""" 68 | model_state = deepcopy(model.state_dict()) 69 | optimizer_state = deepcopy(optimizer.state_dict()) 70 | return model_state, optimizer_state 71 | 72 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 73 | """Restore the model and optimizer states from copies.""" 74 | model.load_state_dict(model_state, strict=True) 75 | optimizer.load_state_dict(optimizer_state) 76 | 77 | def configure_model(model): 78 | """Configure model for use with tent.""" 79 | model.train() 80 | model.requires_grad_(False) 81 | for m in model.modules(): 82 | if isinstance(m, nn.BatchNorm2d): 83 | m.requires_grad_(True) 84 | m.track_running_stats = False 85 | m.running_mean = None 86 | m.running_var = None 87 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm)): 88 | m.requires_grad_(True) 89 | return model 90 | 91 | def check_model(model): 92 | """Check model for compatability with tent.""" 93 | is_training = model.training 94 | assert is_training, "tent needs train mode: call model.train()" 95 | param_grads = [p.requires_grad for p in model.parameters()] 96 | has_any_params = any(param_grads) 97 | has_all_params = all(param_grads) 98 | assert has_any_params, "tent needs params to update: " \ 99 | "check which require grad" 100 | assert not has_all_params, "tent should not update all params: " \ 101 | "check which require grad" 102 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 103 | assert has_bn, "tent needs normalization for its optimization" -------------------------------------------------------------------------------- /tent_come.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to COME Authors, ICLR 2025 3 | built upon on Tent code. 4 | # https://github.com/mr-eggplant/SAR/blob/main/tent.py 5 | """ 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.jit 11 | import torch.nn.functional as F 12 | 13 | class Tent_COME(nn.Module): 14 | """Tent_COME adapts a model by entropy minimization during testing. 15 | """ 16 | def __init__(self, model, optimizer,args, steps=1, episodic=False): 17 | super().__init__() 18 | self.model = model 19 | self.optimizer = optimizer 20 | self.steps = steps 21 | assert steps > 0, "tent_come requires >= 1 step(s) to forward and update" 22 | self.episodic = episodic 23 | self.args = args 24 | self.model_state, self.optimizer_state = \ 25 | copy_model_and_optimizer(self.model, self.optimizer) 26 | 27 | def forward(self, x): 28 | if self.episodic: 29 | self.reset() 30 | for _ in range(self.steps): 31 | outputs = forward_and_adapt(x, self.model, self.optimizer,self.args) 32 | return outputs 33 | 34 | def reset(self): 35 | if self.model_state is None or self.optimizer_state is None: 36 | raise Exception("cannot reset without saved model/optimizer state") 37 | load_model_and_optimizer(self.model, self.optimizer, 38 | self.model_state, self.optimizer_state) 39 | 40 | @torch.jit.script 41 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 42 | """Entropy of softmax distribution from logits.""" 43 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 44 | 45 | @torch.jit.script 46 | def dirichlet_entropy(x: torch.Tensor):#key component of COME 47 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 48 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 49 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 50 | probability = torch.cat([brief, uncertainty], dim=1) + 1e-7 51 | entropy = -(probability * torch.log(probability)).sum(1) 52 | return entropy 53 | 54 | @torch.enable_grad() 55 | def forward_and_adapt(x, model, optimizer,args): 56 | """Forward and adapt model on batch of data. 57 | Measure entropy of the model prediction, take gradients, and update params. 58 | """ 59 | outputs = model(x) 60 | # COME: replace softmax_entropy with dirichlet_entropy 61 | entropy = dirichlet_entropy(outputs) 62 | loss = entropy 63 | loss = loss.mean(0) 64 | loss.backward() 65 | optimizer.step() 66 | optimizer.zero_grad() 67 | return outputs 68 | 69 | def collect_params(model): 70 | """Collect the affine scale + shift parameters from batch norms. 71 | Walk the model's modules and collect all batch normalization parameters. 72 | Return the parameters and their names. 73 | """ 74 | params = [] 75 | names = [] 76 | for nm, m in model.named_modules(): 77 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 78 | for np, p in m.named_parameters(): 79 | if np in ['weight', 'bias']: 80 | params.append(p) 81 | names.append(f"{nm}.{np}") 82 | return params, names 83 | 84 | def copy_model_and_optimizer(model, optimizer): 85 | """Copy the model and optimizer states for resetting after adaptation.""" 86 | model_state = deepcopy(model.state_dict()) 87 | optimizer_state = deepcopy(optimizer.state_dict()) 88 | return model_state, optimizer_state 89 | 90 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 91 | """Restore the model and optimizer states from copies.""" 92 | model.load_state_dict(model_state, strict=True) 93 | optimizer.load_state_dict(optimizer_state) 94 | 95 | def configure_model(model): 96 | """Configure model for use with tent_come.""" 97 | model.train() 98 | model.requires_grad_(False) 99 | for m in model.modules(): 100 | if isinstance(m, nn.BatchNorm2d): 101 | m.requires_grad_(True) 102 | m.track_running_stats = False 103 | m.running_mean = None 104 | m.running_var = None 105 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm)): 106 | m.requires_grad_(True) 107 | return model 108 | 109 | def check_model(model): 110 | """Check model for compatability with tent_come.""" 111 | is_training = model.training 112 | assert is_training, "tent_come needs train mode: call model.train()" 113 | param_grads = [p.requires_grad for p in model.parameters()] 114 | has_any_params = any(param_grads) 115 | has_all_params = all(param_grads) 116 | assert has_any_params, "tent_come needs params to update: " \ 117 | "check which require grad" 118 | assert not has_all_params, "tent_come should not update all params: " \ 119 | "check which require grad" 120 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 121 | assert has_bn, "tent_come needs normalization for its optimization" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # COME: Test-time Adaption by Conservatively Minimizing Entropy (ICLR'25) 2 | 3 | ![visitor count](https://komarev.com/ghpvc/?username=BlueWhaleLab&repo=COME) 4 | 5 | This is the official implementation of [COME: Test-time Adaption by Conservatively Minimizing Entropy](https://arxiv.org/abs/2410.10894) on ICLR 2025. We propose Conservatively Minimizating Entropy (COME) as a simple drop-in refinement of Entropy Minimization for test-time adaption. 6 | 7 | ## Installation Requirements 8 | 9 | To get started with this repository, you need to follow these installation. 10 | 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Data preparation 16 | 17 | We follow [Robustness bench](https://github.com/hendrycks/robustness) and [OpenOOD](https://github.com/Jingkang50/OpenOOD) to prepare the datasets. We provide the links to download each dataset: 18 | 19 | - ImageNet-C: download it from [here 🔗](https://zenodo.org/record/2235448#.YpCSLxNBxAc). 20 | 21 | The following datasets are only used in open-world TTA setting: 22 | 23 | - iNaturalist: download it from [this link](https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz) 24 | - NINCO, SSB_Hard, Texture, Open-ImageNet: download them from [this link](https://drive.google.com/drive/folders/1IFb4pPWTHsvWV6ezzbmGkIR64_VnOdSh?usp=drive_link) 25 | 26 | ## Usage Example 27 | 28 | COME can be implemtented by simply replacing the loss function of previous TTA algorithms i.e., Tent, EATA, and SAR from **softmax entropy** to **entropy of opinion**. 29 | 30 | ```python 31 | def entropy_of_opinion(x: torch.Tensor): #key component of COME 32 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 33 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 34 | uncertainty = K / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 35 | probability = torch.cat([brief, uncertainty], dim=1) + 1e-7 36 | entropy = -(probability * torch.log(probability)).sum(1) 37 | return entropy 38 | 39 | def forward_and_adapt(x, model, optimizer, args): 40 | """Forward and adapt model on batch of data. 41 | Measure entropy of the model prediction, take gradients, and update params. 42 | """ 43 | outputs = model(x) 44 | # COME: replace softmax_entropy with entropy_of_opinion 45 | loss = entropy_of_opinion(outputs) 46 | loss = loss.mean(0) 47 | loss.backward() 48 | optimizer.step() 49 | optimizer.zero_grad() 50 | return outputs 51 | ``` 52 | 53 | ## Reproduce the results 54 | 55 | ### Baselines 56 | 57 | 1. **no_adapt (source)**: The original model without any adaptation. 58 | 2. **Tent** & **EATA** & **SAR**: Previous TTA methods using Entropy Minimization. 59 | 3. **Tent_COME** & **EATA_COME** & **SAR_COME**: The enhanced COME version. 60 | 61 | ### Running the Code 62 | 63 | To run the experiments, execute the script [start.sh](./start.sh). 64 | 65 | ### Results 66 | 67 | Classification Accuracy Comparison on ImageNet-C (Level 5). Substantial (≥ 0.5) improvements compared to the baseline are marked with +. 68 | 69 | | Methods | COME | Gauss. | Shot | Impul. | Defoc | Glass | Motion | Zoom | Snow | Frost | Fog | Brit. | Contr. | Elast. | Pixel | JPEG | Avg. Acc↑ | 70 | | -------- | ------- | ------ | ---- | ------ | ----- | ----- | ------ | ---- | ----- | ----- | ---- | ----- | ------ | ------ | ----- | ---- | --------- | 71 | | no_adapt | ✗ | 35.1 | 32.2 | 35.9 | 31.4 | 25.3 | 39.4 | 31.6 | 24.5 | 30.1 | 54.7 | 64.5 | 49.0 | 34.2 | 53.2 | 56.5 | 39.8 | 72 | | **Tent** | ✗ | 52.4 | 51.8 | 53.3 | 53.0 | 47.6 | 56.8 | 47.6 | 10.6 | 28.0 | 67.5 | 74.2 | 67.4 | 50.2 | 66.7 | 64.6 | 52.8 | 73 | | | ✓ | 53.8 | 53.7 | 55.3 | 55.7 | 51.7 | 59.7 | 52.7 | 59.0 | 61.7 | 71.3 | 78.2 | 68.7 | 57.7 | 70.5 | 68.2 | 61.2 | 74 | | | Improve | +1.4 | +1.9 | +1.9 | +2.7 | +4.1 | +2.9 | +5.0 | +48.4 | +33.6 | +3.9 | +4.0 | +1.3 | +7.5 | +3.8 | +3.6 | +8.4 | 75 | | **EATA** | ✗ | 55.9 | 56.5 | 57.1 | 54.1 | 53.3 | 61.9 | 58.7 | 62.1 | 60.2 | 71.3 | 75.4 | 68.5 | 62.8 | 69.3 | 66.6 | 62.2 | 76 | | | ✓ | 56.2 | 56.6 | 57.2 | 58.1 | 57.6 | 62.5 | 59.5 | 65.5 | 63.9 | 72.5 | 78.1 | 69.7 | 66.5 | 72.4 | 70.7 | 64.5 | 77 | | | Improve | +0.3 | +0.2 | +0.1 | +4.1 | +4.3 | +0.6 | +0.7 | +3.5 | +3.7 | +1.2 | +2.7 | +1.2 | +3.7 | +3.1 | +4.0 | +2.2 | 78 | | **SAR** | ✗ | 52.7 | 52.1 | 53.6 | 53.5 | 48.9 | 56.7 | 48.8 | 22.5 | 51.9 | 67.5 | 73.4 | 66.8 | 52.7 | 66.3 | 64.5 | 55.5 | 79 | | | ✓ | 56.2 | 56.5 | 57.5 | 58.3 | 56.7 | 62.9 | 58.2 | 65.3 | 64.8 | 72.6 | 78.5 | 69.3 | 64.4 | 71.9 | 69.5 | 64.2 | 80 | | | Improve | +3.5 | +4.4 | +3.8 | +4.8 | +7.7 | +6.2 | +9.5 | +42.9 | +12.8 | +5.0 | +5.1 | +2.5 | +11.6 | +5.6 | +5.0 | +8.7 | 81 | 82 | 83 | 84 | ## Citation 85 | 86 | If you find COME helpful in your research, please consider citing our paper: 87 | 88 | ``` 89 | @inproceedings{zhang2025come, 90 | title={COME: TEST-TIME ADAPTION BY CONSERVATIVELY 91 | MINIMIZING ENTROPY}, 92 | author={Qingyang Zhang, Yatao Bian, Xinke Kong, Peilin Zhao, Changqing Zhang}, 93 | booktitle = {Internetional Conference on Learning Representations}, 94 | year = {2025} 95 | } 96 | ``` 97 | 98 | 99 | 100 | 101 | 102 | ## Acknowledgment 103 | 104 | This repo is developed upon [SAR 🔗](https://github.com/mr-eggplant/SAR). 105 | 106 | 107 | 108 | For any additional questions, feel free to email [qingyangzhang@tju.edu.cn](mailto:qingyangzhang@tju.edu.cn). 109 | -------------------------------------------------------------------------------- /sar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to SAR Authors, ICLR 2023 Oral (notable-top-5%) 3 | built upon on Tent code. 4 | # https://github.com/mr-eggplant/SAR/blob/main/sar.py 5 | """ 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.jit 11 | import math 12 | import numpy as np 13 | 14 | def update_ema(ema, new_data): 15 | if ema is None: 16 | return new_data 17 | else: 18 | with torch.no_grad(): 19 | return 0.9 * ema + (1 - 0.9) * new_data 20 | 21 | class SAR(nn.Module): 22 | """SAR online adapts a model by Sharpness-Aware and Reliable entropy minimization during testing. 23 | Once SARed, a model adapts itself by updating on every forward. 24 | """ 25 | def __init__(self, model, optimizer, steps=1, episodic=False, margin_e0=0.4*math.log(1000), reset_constant_em=0.2): 26 | super().__init__() 27 | self.model = model 28 | self.optimizer = optimizer 29 | self.steps = steps 30 | assert steps > 0, "SAR requires >= 1 step(s) to forward and update" 31 | self.episodic = episodic 32 | self.margin_e0 = margin_e0 33 | self.reset_constant_em = reset_constant_em 34 | self.ema = None 35 | self.model_state, self.optimizer_state = \ 36 | copy_model_and_optimizer(self.model, self.optimizer) 37 | 38 | def forward(self, x): 39 | if self.episodic: 40 | self.reset() 41 | for _ in range(self.steps): 42 | outputs, ema, reset_flag = forward_and_adapt_sar(x, self.model, self.optimizer, self.margin_e0, self.reset_constant_em, self.ema) 43 | if reset_flag: 44 | self.reset() 45 | self.ema = ema 46 | return outputs 47 | 48 | def reset(self): 49 | if self.model_state is None or self.optimizer_state is None: 50 | raise Exception("cannot reset without saved model/optimizer state") 51 | load_model_and_optimizer(self.model, self.optimizer, 52 | self.model_state, self.optimizer_state) 53 | self.ema = None 54 | 55 | @torch.jit.script 56 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 57 | """Entropy of softmax distribution from logits.""" 58 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 59 | 60 | @torch.enable_grad() 61 | def forward_and_adapt_sar(x, model, optimizer, margin, reset_constant, ema): 62 | """Forward and adapt model input data. 63 | Measure entropy of the model prediction, take gradients, and update params. 64 | """ 65 | optimizer.zero_grad() 66 | outputs = model(x) 67 | entropys = softmax_entropy(outputs) 68 | filter_ids_1 = torch.where(entropys < margin) 69 | entropys = entropys[filter_ids_1] 70 | loss = entropys.mean(0) 71 | loss.backward() 72 | optimizer.first_step(zero_grad=True) 73 | entropys2 = softmax_entropy(model(x)) 74 | entropys2 = entropys2[filter_ids_1] 75 | loss_second_value = entropys2.clone().detach().mean(0) 76 | filter_ids_2 = torch.where(entropys2 < margin) 77 | loss_second = entropys2[filter_ids_2].mean(0) 78 | if not np.isnan(loss_second.item()): 79 | ema = update_ema(ema, loss_second.item()) 80 | loss_second.backward() 81 | optimizer.second_step(zero_grad=True) 82 | reset_flag = False 83 | if ema is not None: 84 | if ema < 0.2: 85 | print("ema < 0.2, now reset the model") 86 | reset_flag = True 87 | return outputs, ema, reset_flag 88 | 89 | def collect_params(model): 90 | """Collect the affine scale + shift parameters from batch norms. 91 | Walk the model's modules and collect all batch normalization parameters. 92 | Return the parameters and their names. 93 | """ 94 | params = [] 95 | names = [] 96 | for nm, m in model.named_modules(): 97 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 98 | for np, p in m.named_parameters(): 99 | if np in ['weight', 'bias']: 100 | params.append(p) 101 | names.append(f"{nm}.{np}") 102 | return params, names 103 | 104 | def copy_model_and_optimizer(model, optimizer): 105 | """Copy the model and optimizer states for resetting after adaptation.""" 106 | model_state = deepcopy(model.state_dict()) 107 | optimizer_state = deepcopy(optimizer.state_dict()) 108 | return model_state, optimizer_state 109 | 110 | 111 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 112 | """Restore the model and optimizer states from copies.""" 113 | model.load_state_dict(model_state, strict=True) 114 | optimizer.load_state_dict(optimizer_state) 115 | 116 | def configure_model(model): 117 | """Configure model for use with SAR.""" 118 | model.train() 119 | model.requires_grad_(False) 120 | for m in model.modules(): 121 | if isinstance(m, nn.BatchNorm2d): 122 | m.requires_grad_(True) 123 | m.track_running_stats = False 124 | m.running_mean = None 125 | m.running_var = None 126 | if isinstance(m, (nn.LayerNorm, nn.GroupNorm)): 127 | m.requires_grad_(True) 128 | return model 129 | 130 | 131 | def check_model(model): 132 | """Check model for compatability with SAR.""" 133 | is_training = model.training 134 | assert is_training, "SAR needs train mode: call model.train()" 135 | param_grads = [p.requires_grad for p in model.parameters()] 136 | has_any_params = any(param_grads) 137 | has_all_params = all(param_grads) 138 | assert has_any_params, "SAR needs params to update: " \ 139 | "check which require grad" 140 | assert not has_all_params, "SAR should not update all params: " \ 141 | "check which require grad" 142 | has_norm = any([isinstance(m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)) for m in model.modules()]) 143 | assert has_norm, "SAR needs normalization layer parameters for its optimization" 144 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | # COME: Test-time Adaption by Conservatively Minimizing Entropy 2 | 3 | This is the project repository for [COME: Test-time Adaption by Conservatively Minimizing Entropy](https://arxiv.org/abs/2410.10894) by Qingyang Zhang, Yatao Bian, Xinke Kong, Peilin Zhao, Changqing Zhang(ICLR 2025). 4 | 5 | Machine learning models must continuously self-adjust themselves for novel data distribution in the open world. As the predominant principle, entropy minimization (EM) has been proven to be a simple yet effective cornerstone in existing test-time adaption (TTA) methods. While unfortunately its fatal limitation (i.e., overconfidence) tends to result in model collapse. For this issue, we propose to conservatively minimize the entropy (COME), which is a simple drop-in replacement of traditional EM to elegantly address the limitation. In essence, COME explicitly models the uncertainty by characterizing a Dirichlet prior distribution over model predictions during TTA. By doing so, COME naturally regularizes the model to favor conservative confidence on unreliable samples. Theoretically, we provide a preliminary analysis to reveal the ability of COME in enhancing the optimization stability by introducing a data-adaptive lower bound on the entropy. Empirically, our method achieves state-of-the-art performance on commonly used benchmarks, showing significant improvements in terms of classification accuracy and uncertainty estimation under various settings including standard, life-long and open-world TTA. 6 | 7 | We provide **[example code](#1)** in PyTorch to illustrate the **COME** method and fully test-time adaptation setting. 8 | 9 | We provide implementations of the classic EM algorithms—Tent, EATA, and SAR—along with the enhanced COME version. Additionally, we offer two scripts designed for running both fully test-time adaptation setting and open-world test-time adaptation settings. You're welcome to experiment with your own datasets and models as well! 10 | **Installation**: 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | **Data preparation**: 17 | 18 | This repository contains code for evaluation on [ImageNet-C 🔗](https://arxiv.org/abs/1903.12261) with VitBase and ResNet. 19 | 20 | - Step 1: Download [ImageNet-C 🔗](https://github.com/hendrycks/robustness) dataset from [here 🔗](https://zenodo.org/record/2235448#.YpCSLxNBxAc). 21 | 22 | - Step 2: Put IamgeNet-C path at "--data_corruption" in main.py or "data_corruption" in start.sh. 23 | 24 | - Step 3: Put output path at "--output" in main.py or "output" in start.sh. 25 | 26 | - Step 4 [optional, for EATA]: Put ImageNet **test/val set** at "--data" in main.py or "data" in start.sh. 27 | 28 | - Step 5 [optional, for open-world setting]: Put NINCO, iNaturalist, SSB_Hard, Texture, Openimage_O at "--ood_root" in main.py or "ood_root" in start-open.sh. 29 | 30 | 31 | **COME** depends on 32 | 33 | - Python 3 34 | - [PyTorch](https://pytorch.org/) >= 1.0 35 | 36 | 37 | **Usage**: 38 | 39 | ``` 40 | import tent_come 41 | 42 | net = backbone_net() 43 | net = tent_come.configure_model(net) 44 | params, param_names = tent_come.collect_params(net) 45 | optimizer = torch.optim.SGD(params, args.lr, momentum=args.momentum) 46 | adapt_model = tent_come.Tent_COME(net, optimizer) 47 | 48 | outputs = adapt_model(inputs) # now it infers and adapts! 49 | ``` 50 | 51 | ## Example: TTA setting 52 | 53 | This example demonstrates how to adapt an ImageNet1K classifier to handle image corruptions on the ImageNet_C dataset. 54 | 55 | ### Methods Compared 56 | 57 | 1. **no_adapt (source)**: The original model without any adaptation. 58 | 2. **Tent** & **EATA** & **SAR**: Classic EM algorithms adapt the model at test time using entropy minimization. 59 | 3. **Tent_COME** & **EATA_COME** & **SAR_COME**: Classic EM algorithms with the enhanced COME version adapt the model at test time using entropy minimization. 60 | 61 | ### Dataset 62 | 63 | The dataset used is [ImageNet_C](https://github.com/hendrycks/robustness/), containing 15 corruption types, each with 5 levels of severity. 64 | 65 | ### Running the Code 66 | 67 | To run the experiments, execute the script [start.sh](./start.sh). 68 | 69 | ### Result: Classification Accuracy Comparison on ImageNet-C (Level 5) 70 | 71 | Substantial (≥ 0.5) improvements compared to the baseline are marked with +. We only report average FPR↓ in the appendix. 72 | 73 | | Methods | COME | Gauss. | Shot | Impul. | Defoc | Glass | Motion | Zoom | Snow | Frost | Fog | Brit. | Contr. | Elast. | Pixel | JPEG | Avg. Acc↑ | 74 | |----------|------|--------|-------|--------|-------|-------|--------|-------|-------|-------|-------|-------|--------|--------|-------|-------|-----------| 75 | | no_adapt | ✗ | 35.1 | 32.2 | 35.9 | 31.4 | 25.3 | 39.4 | 31.6 | 24.5 | 30.1 | 54.7 | 64.5 | 49.0 | 34.2 | 53.2 | 56.5 | 39.8 | 76 | | **Tent** | ✗ | 52.4 | 51.8 | 53.3 | 53.0 | 47.6 | 56.8 | 47.6 | 10.6 | 28.0 | 67.5 | 74.2 | 67.4 | 50.2 | 66.7 | 64.6 | 52.8 | 77 | | | ✓ | 53.8 | 53.7 | 55.3 | 55.7 | 51.7 | 59.7 | 52.7 | 59.0 | 61.7 | 71.3 | 78.2 | 68.7 | 57.7 | 70.5 | 68.2 | 61.2 | 78 | | |Improve| +1.4 | +1.9 | +1.9 | +2.7 | +4.1 | +2.9 | +5.0 | +48.4 | +33.6 | +3.9 | +4.0 | +1.3 | +7.5 | +3.8 | +3.6 | +8.4 | 79 | | **EATA** | ✗ | 55.9 | 56.5 | 57.1 | 54.1 | 53.3 | 61.9 | 58.7 | 62.1 | 60.2 | 71.3 | 75.4 | 68.5 | 62.8 | 69.3 | 66.6 | 62.2 | 80 | | | ✓ | 56.2 | 56.6 | 57.2 | 58.1 | 57.6 | 62.5 | 59.5 | 65.5 | 63.9 | 72.5 | 78.1 | 69.7 | 66.5 | 72.4 | 70.7 | 64.5 | 81 | | |Improve| +0.3 | +0.2 | +0.1 | +4.1 | +4.3 | +0.6 | +0.7 | +3.5 | +3.7 | +1.2 | +2.7 | +1.2 | +3.7 | +3.1 | +4.0 | +2.2 | 82 | | **SAR** | ✗ | 52.7 | 52.1 | 53.6 | 53.5 | 48.9 | 56.7 | 48.8 | 22.5 | 51.9 | 67.5 | 73.4 | 66.8 | 52.7 | 66.3 | 64.5 | 55.5 | 83 | | | ✓ | 56.2 | 56.5 | 57.5 | 58.3 | 56.7 | 62.9 | 58.2 | 65.3 | 64.8 | 72.6 | 78.5 | 69.3 | 64.4 | 71.9 | 69.5 | 64.2 | 84 | | |Improve| +3.5 | +4.4 | +3.8 | +4.8 | +7.7 | +6.2 | +9.5 | +42.9 | +12.8 | +5.0 | +5.1 | +2.5 | +11.6 | +5.6 | +5.0 | +8.7 | 85 | 86 | ## Citation 87 | If our COME method is helpful in your research, please consider citing our paper: 88 | ``` 89 | @inproceedings{zhang2025come, 90 | title={COME: TEST-TIME ADAPTION BY CONSERVATIVELY 91 | MINIMIZING ENTROPY}, 92 | author={Qingyang Zhang, Yatao Bian, Xinke Kong, Peilin Zhao, Changqing Zhang}, 93 | booktitle = {Internetional Conference on Learning Representations}, 94 | year = {2025} 95 | } 96 | ``` 97 | ## Acknowledgment 98 | The code is inspired by the [SAR 🔗](https://github.com/mr-eggplant/SAR). 99 | 100 | -------------------------------------------------------------------------------- /sar_come.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to COME Authors, ICLR 2025 3 | built upon on SAR code. 4 | """ 5 | from copy import deepcopy 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.jit 10 | import math 11 | import numpy as np 12 | 13 | def update_ema(ema, new_data): 14 | if ema is None: 15 | return new_data 16 | else: 17 | with torch.no_grad(): 18 | return 0.9 * ema + (1 - 0.9) * new_data 19 | 20 | class SAR_COME(nn.Module): 21 | """SAR online adapts a model by Sharpness-Aware and Reliable entropy minimization during testing. 22 | Once SARed, a model adapts itself by updating on every forward. 23 | """ 24 | def __init__(self, model, optimizer, steps=1, episodic=False, margin_e0=0.4*math.log(1000), reset_constant_em=0.2): 25 | super().__init__() 26 | self.model = model 27 | self.optimizer = optimizer 28 | self.steps = steps 29 | assert steps > 0, "SAR requires >= 1 step(s) to forward and update" 30 | self.episodic = episodic 31 | self.margin_e0 = margin_e0 # margin E_0 for reliable entropy minimization, Eqn. (2) 32 | self.reset_constant_em = reset_constant_em # threshold e_m for model recovery scheme 33 | self.ema = None # to record the moving average of model output entropy, as model recovery criteria 34 | 35 | # note: if the model is never reset, like for continual adaptation, 36 | # then skipping the state copy would save memory 37 | self.model_state, self.optimizer_state = \ 38 | copy_model_and_optimizer(self.model, self.optimizer) 39 | 40 | def forward(self, x): 41 | if self.episodic: 42 | self.reset() 43 | for _ in range(self.steps): 44 | outputs, ema, reset_flag = forward_and_adapt_sar_come(x, self.model, self.optimizer, self.margin_e0, self.reset_constant_em, self.ema) 45 | if reset_flag: 46 | self.reset() 47 | self.ema = ema # update moving average value of loss 48 | return outputs 49 | 50 | def reset(self): 51 | if self.model_state is None or self.optimizer_state is None: 52 | raise Exception("cannot reset without saved model/optimizer state") 53 | load_model_and_optimizer(self.model, self.optimizer, 54 | self.model_state, self.optimizer_state) 55 | self.ema = None 56 | 57 | @torch.jit.script 58 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 59 | """Entropy of softmax distribution from logits.""" 60 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 61 | 62 | @torch.jit.script 63 | def dirichlet_entropy(x: torch.Tensor):#key component of COME 64 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 65 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 66 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 67 | probability = torch.cat([brief, uncertainty], dim=1) + 1e-7 68 | entropy = -(probability * torch.log(probability)).sum(1) 69 | return entropy 70 | 71 | @torch.enable_grad() # ensure grads in possible no grad context for testing 72 | def forward_and_adapt_sar_come(x, model, optimizer, margin, reset_constant, ema): 73 | """Forward and adapt model input data. 74 | Measure entropy of the model prediction, take gradients, and update params. 75 | """ 76 | optimizer.zero_grad() 77 | # forward 78 | outputs = model(x) 79 | # adapt 80 | # filtering reliable samples/gradients for further adaptation; first time forward 81 | # COME: replace softmax_entropy with dirichlet_entropy 82 | entropys = dirichlet_entropy(outputs) 83 | filter_ids_1 = torch.where(entropys < margin) 84 | entropys = entropys[filter_ids_1] 85 | loss = entropys.mean(0) 86 | loss.backward() 87 | optimizer.first_step(zero_grad=True) 88 | entropys2 = dirichlet_entropy(model(x)) 89 | entropys2 = entropys2[filter_ids_1] 90 | loss_second_value = entropys2.clone().detach().mean(0) 91 | filter_ids_2 = torch.where(entropys2 < margin) 92 | loss_second = entropys2[filter_ids_2].mean(0) 93 | if not np.isnan(loss_second.item()): 94 | ema = update_ema(ema, loss_second.item()) 95 | loss_second.backward() 96 | optimizer.second_step(zero_grad=True) 97 | reset_flag = False 98 | if ema is not None: 99 | if ema < 0.2: 100 | print("ema < 0.2, now reset the model") 101 | reset_flag = True 102 | return outputs, ema, reset_flag 103 | 104 | def collect_params(model): 105 | """Collect the affine scale + shift parameters from norm layers. 106 | Walk the model's modules and collect all normalization parameters. 107 | Return the parameters and their names. 108 | Note: other choices of parameterization are possible! 109 | """ 110 | params = [] 111 | names = [] 112 | for nm, m in model.named_modules(): 113 | # skip top layers for adaptation: layer4 for ResNets and blocks9-11 for Vit-Base 114 | if 'layer4' in nm: 115 | continue 116 | if 'blocks.9' in nm: 117 | continue 118 | if 'blocks.10' in nm: 119 | continue 120 | if 'blocks.11' in nm: 121 | continue 122 | if 'norm.' in nm: 123 | continue 124 | if nm in ['norm']: 125 | continue 126 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 127 | for np, p in m.named_parameters(): 128 | if np in ['weight', 'bias']: 129 | params.append(p) 130 | names.append(f"{nm}.{np}") 131 | return params, names 132 | 133 | def copy_model_and_optimizer(model, optimizer): 134 | """Copy the model and optimizer states for resetting after adaptation.""" 135 | model_state = deepcopy(model.state_dict()) 136 | optimizer_state = deepcopy(optimizer.state_dict()) 137 | return model_state, optimizer_state 138 | 139 | 140 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 141 | """Restore the model and optimizer states from copies.""" 142 | model.load_state_dict(model_state, strict=True) 143 | optimizer.load_state_dict(optimizer_state) 144 | 145 | def configure_model(model): 146 | """Configure model for use with SAR_COME.""" 147 | model.train() 148 | model.requires_grad_(False) 149 | for m in model.modules(): 150 | if isinstance(m, nn.BatchNorm2d): 151 | m.requires_grad_(True) 152 | m.track_running_stats = False 153 | m.running_mean = None 154 | m.running_var = None 155 | if isinstance(m, (nn.LayerNorm, nn.GroupNorm)): 156 | m.requires_grad_(True) 157 | return model 158 | 159 | 160 | def check_model(model): 161 | """Check model for compatability with SAR_COME.""" 162 | is_training = model.training 163 | assert is_training, "SAR_COME needs train mode: call model.train()" 164 | param_grads = [p.requires_grad for p in model.parameters()] 165 | has_any_params = any(param_grads) 166 | has_all_params = all(param_grads) 167 | assert has_any_params, "SAR_COME needs params to update: " \ 168 | "check which require grad" 169 | assert not has_all_params, "SAR_COME should not update all params: " \ 170 | "check which require grad" 171 | has_norm = any([isinstance(m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)) for m in model.modules()]) 172 | assert has_norm, "SAR_COME needs normalization layer parameters for its optimization" 173 | -------------------------------------------------------------------------------- /eata.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to EATA ICML 2022 Authors, 2022.03.20 3 | Based on Tent ICLR 2021 Spotlight. 4 | """ 5 | from argparse import ArgumentDefaultsHelpFormatter 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.jit 11 | 12 | import math 13 | import torch.nn.functional as F 14 | 15 | 16 | class EATA(nn.Module): 17 | """EATA adapts a model by entropy minimization during testing. 18 | Once EATAed, a model adapts itself by updating on every forward. 19 | """ 20 | def __init__(self, model, optimizer, fishers=None, fisher_beta=2000.0, steps=1, episodic=False, e_margin=math.log(1000)/2-1, d_margin=0.05): 21 | super().__init__() 22 | self.model = model 23 | self.optimizer = optimizer 24 | self.steps = steps 25 | assert steps > 0, "EATA requires >= 1 step(s) to forward and update" 26 | self.episodic = episodic 27 | 28 | self.num_samples_update_1 = 0 # number of samples after First filtering, exclude unreliable samples 29 | self.num_samples_update_2 = 0 # number of samples after Second filtering, exclude both unreliable and redundant samples 30 | self.e_margin = e_margin # hyper-parameter E_0 (Eqn. 3) 31 | self.d_margin = d_margin # hyper-parameter \epsilon for consine simlarity thresholding (Eqn. 5) 32 | 33 | self.current_model_probs = None # the moving average of probability vector (Eqn. 4) 34 | 35 | self.fishers = fishers # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9) 36 | self.fisher_beta = fisher_beta # trade-off \beta for two losses (Eqn. 8) 37 | 38 | # note: if the model is never reset, like for continual adaptation, 39 | # then skipping the state copy would save memory 40 | self.model_state, self.optimizer_state = \ 41 | copy_model_and_optimizer(self.model, self.optimizer) 42 | 43 | def forward(self, x): 44 | if self.episodic: 45 | self.reset() 46 | if self.steps > 0: 47 | for _ in range(self.steps): 48 | outputs, num_counts_2, num_counts_1, updated_probs = forward_and_adapt_eata(x, self.model, self.optimizer, self.fishers, self.e_margin, self.current_model_probs, fisher_beta=self.fisher_beta, num_samples_update=self.num_samples_update_2, d_margin=self.d_margin) 49 | self.num_samples_update_2 += num_counts_2 50 | self.num_samples_update_1 += num_counts_1 51 | self.reset_model_probs(updated_probs) 52 | else: 53 | self.model.eval() 54 | with torch.no_grad(): 55 | outputs = self.model(x) 56 | return outputs 57 | 58 | def reset(self): 59 | if self.model_state is None or self.optimizer_state is None: 60 | raise Exception("cannot reset without saved model/optimizer state") 61 | load_model_and_optimizer(self.model, self.optimizer, 62 | self.model_state, self.optimizer_state) 63 | 64 | def reset_steps(self, new_steps): 65 | self.steps = new_steps 66 | 67 | def reset_model_probs(self, probs): 68 | self.current_model_probs = probs 69 | def test(self,x): 70 | outputs = self.model(x) 71 | return outputs 72 | 73 | @torch.jit.script 74 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 75 | """Entropy of softmax distribution from logits.""" 76 | temprature = 1 77 | x = x/ temprature 78 | x = -(x.softmax(1) * x.log_softmax(1)).sum(1) 79 | return x 80 | 81 | 82 | @torch.enable_grad() # ensure grads in possible no grad context for testing 83 | def forward_and_adapt_eata(x, model, optimizer, fishers, e_margin, current_model_probs, fisher_beta=50.0, d_margin=0.05, scale_factor=2, num_samples_update=0): 84 | """Forward and adapt model on batch of data. 85 | Measure entropy of the model prediction, take gradients, and update params. 86 | Return: 87 | 1. model outputs; 88 | 2. the number of reliable and non-redundant samples; 89 | 3. the number of reliable samples; 90 | 4. the moving average probability vector over all previous samples 91 | """ 92 | # forward 93 | outputs = model(x) 94 | # adapt 95 | entropys = softmax_entropy(outputs) 96 | # filter unreliable samples 97 | filter_ids_1 = torch.where(entropys < e_margin) 98 | ids1 = filter_ids_1 99 | ids2 = torch.where(ids1[0]>-0.1) 100 | entropys = entropys[filter_ids_1] 101 | # filter redundant samples 102 | if current_model_probs is not None: 103 | cosine_similarities = F.cosine_similarity(current_model_probs.unsqueeze(dim=0), outputs[filter_ids_1].softmax(1), dim=1) 104 | filter_ids_2 = torch.where(torch.abs(cosine_similarities) < d_margin) 105 | entropys = entropys[filter_ids_2] 106 | ids2 = filter_ids_2 107 | updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1][filter_ids_2].softmax(1)) 108 | else: 109 | updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1].softmax(1)) 110 | coeff = 1 / (torch.exp(entropys.clone().detach() - e_margin)) 111 | # implementation version 1, compute loss, all samples backward (some unselected are masked) 112 | entropys = entropys.mul(coeff) # reweight entropy losses for diff. samples 113 | loss = entropys.mean(0) 114 | """ 115 | # implementation version 2, compute loss, forward all batch, forward and backward selected samples again. 116 | # loss = 0 117 | # if x[ids1][ids2].size(0) != 0: 118 | # loss = softmax_entropy(model(x[ids1][ids2])).mul(coeff).mean(0) # reweight entropy losses for diff. samples 119 | """ 120 | if fishers is not None: 121 | ewc_loss = 0 122 | for name, param in model.named_parameters(): 123 | if name in fishers: 124 | ewc_loss += fisher_beta * (fishers[name][0] * (param - fishers[name][1])**2).sum() 125 | loss += ewc_loss 126 | if x[ids1][ids2].size(0) != 0: 127 | loss.backward() 128 | optimizer.step() 129 | optimizer.zero_grad() 130 | return outputs, entropys.size(0), filter_ids_1[0].size(0), updated_probs 131 | 132 | 133 | def update_model_probs(current_model_probs, new_probs): 134 | if current_model_probs is None: 135 | if new_probs.size(0) == 0: 136 | return None 137 | else: 138 | with torch.no_grad(): 139 | return new_probs.mean(0) 140 | else: 141 | if new_probs.size(0) == 0: 142 | with torch.no_grad(): 143 | return current_model_probs 144 | else: 145 | with torch.no_grad(): 146 | return 0.9 * current_model_probs + (1 - 0.9) * new_probs.mean(0) 147 | 148 | 149 | def collect_params(model): 150 | """Collect the affine scale + shift parameters from batch norms. 151 | Walk the model's modules and collect all batch normalization parameters. 152 | Return the parameters and their names. 153 | Note: other choices of parameterization are possible! 154 | """ 155 | params = [] 156 | names = [] 157 | for nm, m in model.named_modules(): 158 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 159 | for np, p in m.named_parameters(): 160 | if np in ['weight', 'bias']: # weight is scale, bias is shift 161 | params.append(p) 162 | names.append(f"{nm}.{np}") 163 | return params, names 164 | 165 | 166 | def copy_model_and_optimizer(model, optimizer): 167 | """Copy the model and optimizer states for resetting after adaptation.""" 168 | model_state = deepcopy(model.state_dict()) 169 | optimizer_state = deepcopy(optimizer.state_dict()) 170 | return model_state, optimizer_state 171 | 172 | 173 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 174 | """Restore the model and optimizer states from copies.""" 175 | model.load_state_dict(model_state, strict=True) 176 | optimizer.load_state_dict(optimizer_state) 177 | 178 | 179 | def configure_model(model): 180 | """Configure model for use with eata.""" 181 | # train mode, because eata optimizes the model to minimize entropy 182 | model.train() 183 | # disable grad, to (re-)enable only what eata updates 184 | model.requires_grad_(False) 185 | # configure norm for eata updates: enable grad + force batch statisics 186 | for m in model.modules(): 187 | if isinstance(m, nn.BatchNorm2d): 188 | m.requires_grad_(True) 189 | # force use of batch stats in train and eval modes 190 | m.track_running_stats = False 191 | m.running_mean = None 192 | m.running_var = None 193 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm)): 194 | m.requires_grad_(True) 195 | return model 196 | 197 | 198 | def check_model(model): 199 | """Check model for compatability with eata.""" 200 | is_training = model.training 201 | assert is_training, "eata needs train mode: call model.train()" 202 | param_grads = [p.requires_grad for p in model.parameters()] 203 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 204 | assert has_bn, "eata needs normalization for its optimization" -------------------------------------------------------------------------------- /eata_come.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to COME Authors, ICLR 2025 3 | built upon on EATA code. 4 | # https://github.com/mr-eggplant/EATA 5 | """ 6 | from argparse import ArgumentDefaultsHelpFormatter 7 | from copy import deepcopy 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.jit 12 | 13 | import math 14 | import torch.nn.functional as F 15 | 16 | 17 | class EATA_COME(nn.Module): 18 | """EATA adapts a model by entropy minimization during testing. 19 | Once EATAed, a model adapts itself by updating on every forward. 20 | """ 21 | def __init__(self, model, optimizer, fishers=None, fisher_beta=2000.0, steps=1, episodic=False, e_margin=math.log(1000)/2-1, d_margin=0.05): 22 | super().__init__() 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.steps = steps 26 | assert steps > 0, "EATA requires >= 1 step(s) to forward and update" 27 | self.episodic = episodic 28 | 29 | self.num_samples_update_1 = 0 # number of samples after First filtering, exclude unreliable samples 30 | self.num_samples_update_2 = 0 # number of samples after Second filtering, exclude both unreliable and redundant samples 31 | self.e_margin = e_margin # hyper-parameter E_0 (Eqn. 3) 32 | self.d_margin = d_margin # hyper-parameter \epsilon for consine simlarity thresholding (Eqn. 5) 33 | 34 | self.current_model_probs = None # the moving average of probability vector (Eqn. 4) 35 | 36 | self.fishers = fishers # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9) 37 | self.fisher_beta = fisher_beta # trade-off \beta for two losses (Eqn. 8) 38 | 39 | # note: if the model is never reset, like for continual adaptation, 40 | # then skipping the state copy would save memory 41 | self.model_state, self.optimizer_state = \ 42 | copy_model_and_optimizer(self.model, self.optimizer) 43 | 44 | def forward(self, x): 45 | if self.episodic: 46 | self.reset() 47 | if self.steps > 0: 48 | for _ in range(self.steps): 49 | outputs, num_counts_2, num_counts_1, updated_probs = forward_and_adapt_eata(x, self.model, self.optimizer, self.fishers, self.e_margin, self.current_model_probs, fisher_beta=self.fisher_beta, num_samples_update=self.num_samples_update_2, d_margin=self.d_margin) 50 | self.num_samples_update_2 += num_counts_2 51 | self.num_samples_update_1 += num_counts_1 52 | self.reset_model_probs(updated_probs) 53 | else: 54 | self.model.eval() 55 | with torch.no_grad(): 56 | outputs = self.model(x) 57 | return outputs 58 | 59 | def reset(self): 60 | if self.model_state is None or self.optimizer_state is None: 61 | raise Exception("cannot reset without saved model/optimizer state") 62 | load_model_and_optimizer(self.model, self.optimizer, 63 | self.model_state, self.optimizer_state) 64 | 65 | def reset_steps(self, new_steps): 66 | self.steps = new_steps 67 | 68 | def reset_model_probs(self, probs): 69 | self.current_model_probs = probs 70 | def test(self,x): 71 | outputs = self.model(x) 72 | return outputs 73 | 74 | @torch.jit.script 75 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 76 | """Entropy of softmax distribution from logits.""" 77 | temprature = 1 78 | x = x/ temprature 79 | norms = torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-7 80 | x = torch.div(x, norms) * norms.clone().detach() 81 | x = -(x.softmax(1) * x.log_softmax(1)).sum(1) 82 | return x 83 | 84 | @torch.jit.script 85 | def dirichlet_entropy(x):#key component of COME 86 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 87 | """Entropy of softmax distribution from logits.""" 88 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 89 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 90 | probability = torch.cat([brief, uncertainty], dim=1) + 1e-7 91 | 92 | return -(probability * torch.log(probability)).sum(1) 93 | 94 | 95 | @torch.enable_grad() # ensure grads in possible no grad context for testing 96 | def forward_and_adapt_eata(x, model, optimizer, fishers, e_margin, current_model_probs, fisher_beta=50.0, d_margin=0.05, scale_factor=2, num_samples_update=0): 97 | """Forward and adapt model on batch of data. 98 | Measure entropy of the model prediction, take gradients, and update params. 99 | Return: 100 | 1. model outputs; 101 | 2. the number of reliable and non-redundant samples; 102 | 3. the number of reliable samples; 103 | 4. the moving average probability vector over all previous samples 104 | """ 105 | # forward 106 | outputs = model(x) 107 | # adapt 108 | # COME: replace softmax_entropy with dirichlet_entropy 109 | entropys = dirichlet_entropy(outputs) 110 | # filter unreliable samples 111 | filter_ids_1 = torch.where(entropys < e_margin) 112 | ids1 = filter_ids_1 113 | ids2 = torch.where(ids1[0]>-0.1) 114 | entropys = entropys[filter_ids_1] 115 | # filter redundant samples 116 | if current_model_probs is not None: 117 | cosine_similarities = F.cosine_similarity(current_model_probs.unsqueeze(dim=0), outputs[filter_ids_1].softmax(1), dim=1) 118 | filter_ids_2 = torch.where(torch.abs(cosine_similarities) < d_margin) 119 | entropys = entropys[filter_ids_2] 120 | ids2 = filter_ids_2 121 | updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1][filter_ids_2].softmax(1)) 122 | else: 123 | updated_probs = update_model_probs(current_model_probs, outputs[filter_ids_1].softmax(1)) 124 | coeff = 1 / (torch.exp(entropys.clone().detach() - e_margin)) 125 | # implementation version 1, compute loss, all samples backward (some unselected are masked) 126 | entropys = entropys.mul(coeff) # reweight entropy losses for diff. samples 127 | loss = entropys.mean(0) 128 | """ 129 | # implementation version 2, compute loss, forward all batch, forward and backward selected samples again. 130 | # loss = 0 131 | # if x[ids1][ids2].size(0) != 0: 132 | # loss = softmax_entropy(model(x[ids1][ids2])).mul(coeff).mean(0) # reweight entropy losses for diff. samples 133 | """ 134 | if fishers is not None: 135 | ewc_loss = 0 136 | for name, param in model.named_parameters(): 137 | if name in fishers: 138 | ewc_loss += fisher_beta * (fishers[name][0] * (param - fishers[name][1])**2).sum() 139 | loss += ewc_loss 140 | if x[ids1][ids2].size(0) != 0: 141 | loss.backward() 142 | optimizer.step() 143 | optimizer.zero_grad() 144 | return outputs, entropys.size(0), filter_ids_1[0].size(0), updated_probs 145 | 146 | 147 | def update_model_probs(current_model_probs, new_probs): 148 | if current_model_probs is None: 149 | if new_probs.size(0) == 0: 150 | return None 151 | else: 152 | with torch.no_grad(): 153 | return new_probs.mean(0) 154 | else: 155 | if new_probs.size(0) == 0: 156 | with torch.no_grad(): 157 | return current_model_probs 158 | else: 159 | with torch.no_grad(): 160 | return 0.9 * current_model_probs + (1 - 0.9) * new_probs.mean(0) 161 | 162 | 163 | def collect_params(model): 164 | """Collect the affine scale + shift parameters from batch norms. 165 | Walk the model's modules and collect all batch normalization parameters. 166 | Return the parameters and their names. 167 | Note: other choices of parameterization are possible! 168 | """ 169 | params = [] 170 | names = [] 171 | for nm, m in model.named_modules(): 172 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm)): 173 | for np, p in m.named_parameters(): 174 | if np in ['weight', 'bias']: # weight is scale, bias is shift 175 | params.append(p) 176 | names.append(f"{nm}.{np}") 177 | return params, names 178 | 179 | 180 | def copy_model_and_optimizer(model, optimizer): 181 | """Copy the model and optimizer states for resetting after adaptation.""" 182 | model_state = deepcopy(model.state_dict()) 183 | optimizer_state = deepcopy(optimizer.state_dict()) 184 | return model_state, optimizer_state 185 | 186 | 187 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 188 | """Restore the model and optimizer states from copies.""" 189 | model.load_state_dict(model_state, strict=True) 190 | optimizer.load_state_dict(optimizer_state) 191 | 192 | 193 | def configure_model(model): 194 | """Configure model for use with eata.""" 195 | # train mode, because eata optimizes the model to minimize entropy 196 | model.train() 197 | # disable grad, to (re-)enable only what eata updates 198 | model.requires_grad_(False) 199 | # configure norm for eata updates: enable grad + force batch statisics 200 | for m in model.modules(): 201 | if isinstance(m, nn.BatchNorm2d): 202 | m.requires_grad_(True) 203 | # force use of batch stats in train and eval modes 204 | m.track_running_stats = False 205 | m.running_mean = None 206 | m.running_var = None 207 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm)): 208 | m.requires_grad_(True) 209 | return model 210 | 211 | 212 | def check_model(model): 213 | """Check model for compatability with eata.""" 214 | is_training = model.training 215 | assert is_training, "eata needs train mode: call model.train()" 216 | param_grads = [p.requires_grad for p in model.parameters()] 217 | has_any_params = any(param_grads) 218 | has_all_params = all(param_grads) 219 | assert has_any_params, "eata needs params to update: " \ 220 | "check which require grad" 221 | assert not has_all_params, "eata should not update all params: " \ 222 | "check which require grad" 223 | has_bn = any([isinstance(m, nn.BatchNorm2d) for m in model.modules()]) 224 | assert has_bn, "eata needs normalization for its optimization" -------------------------------------------------------------------------------- /dataset/selectedRotateImageFolder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | import torchvision.models as models 12 | import torch.utils.data 13 | from utils.third_party import _augmix_aug as tr_transforms_memo 14 | 15 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 16 | tr_transforms = transforms.Compose([transforms.RandomResizedCrop(224), 17 | transforms.RandomHorizontalFlip(), 18 | 19 | transforms.ToTensor(), 20 | normalize]) 21 | te_transforms = transforms.Compose([transforms.Resize(256), 22 | transforms.CenterCrop(224), 23 | transforms.ToTensor(), 24 | normalize]) 25 | 26 | 27 | te_transforms_imageC = transforms.Compose([transforms.CenterCrop(224), 28 | transforms.ToTensor(), 29 | normalize]) 30 | ad_transforms_imageC = transforms.Compose([transforms.CenterCrop(224), 31 | transforms.ToTensor(),]) 32 | 33 | rotation_tr_transforms = tr_transforms 34 | rotation_te_transforms = te_transforms 35 | 36 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 37 | 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 38 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 39 | 40 | 41 | class ImagePathFolder(datasets.ImageFolder): 42 | def __init__(self, traindir, train_transform): 43 | super(ImagePathFolder, self).__init__(traindir, train_transform) 44 | 45 | def __getitem__(self, index): 46 | path, _ = self.imgs[index] 47 | img = self.loader(path) 48 | if self.transform is not None: 49 | img = self.transform(img) 50 | path, pa = os.path.split(path) 51 | path, pb = os.path.split(path) 52 | return img, 'val/%s/%s' %(pb, pa) 53 | 54 | 55 | 56 | 57 | def tensor_rot_90(x): 58 | return x.flip(2).transpose(1, 2) 59 | 60 | def tensor_rot_180(x): 61 | return x.flip(2).flip(1) 62 | 63 | def tensor_rot_270(x): 64 | return x.transpose(1, 2).flip(2) 65 | 66 | def rotate_single_with_label(img, label): 67 | if label == 1: 68 | img = tensor_rot_90(img) 69 | elif label == 2: 70 | img = tensor_rot_180(img) 71 | elif label == 3: 72 | img = tensor_rot_270(img) 73 | return img 74 | 75 | def rotate_batch_with_labels(batch, labels): 76 | images = [] 77 | for img, label in zip(batch, labels): 78 | img = rotate_single_with_label(img, label) 79 | images.append(img.unsqueeze(0)) 80 | return torch.cat(images) 81 | 82 | def rotate_batch(batch, label='rand'): 83 | if label == 'rand': 84 | labels = torch.randint(4, (len(batch),), dtype=torch.long) 85 | else: 86 | assert isinstance(label, int) 87 | labels = torch.zeros((len(batch),), dtype=torch.long) + label 88 | return rotate_batch_with_labels(batch, labels), labels 89 | 90 | 91 | 92 | 93 | class SelectedRotateImageFolder(datasets.ImageFolder): 94 | def __init__(self, root, train_transform, original=True, rotation=True, rotation_transform=None): 95 | super(SelectedRotateImageFolder, self).__init__(root, train_transform) 96 | self.original = original 97 | self.rotation = rotation 98 | self.rotation_transform = rotation_transform 99 | 100 | self.original_samples = self.samples 101 | 102 | def __getitem__(self, index): 103 | path, target = self.samples[index] 104 | img_input = self.loader(path) 105 | 106 | if self.transform is not None: 107 | if isinstance(self.transform, list): 108 | img = self.transform[1](img_input) 109 | img_aug = self.transform[0](img_input) 110 | else: 111 | img = self.transform(img_input) 112 | else: 113 | img = img_input 114 | 115 | results = [] 116 | if self.original: 117 | results.append(img) 118 | results.append(target) 119 | if isinstance(self.transform, list): 120 | results.append(img_aug) 121 | if self.rotation: 122 | if self.rotation_transform is not None: 123 | img = self.rotation_transform(img_input) 124 | target_ssh = np.random.randint(0, 4, 1)[0] 125 | img_ssh = rotate_single_with_label(img, target_ssh) 126 | results.append(img_ssh) 127 | results.append(target_ssh) 128 | return results 129 | 130 | def switch_mode(self, original, rotation): 131 | self.original = original 132 | self.rotation = rotation 133 | 134 | def set_target_class_dataset(self, target_class_index, logger=None): 135 | self.target_class_index = target_class_index 136 | self.samples = [(path, idx) for (path, idx) in self.original_samples if idx in self.target_class_index] 137 | self.targets = [s[1] for s in self.samples] 138 | 139 | def set_dataset_size(self, subset_size): 140 | num_train = len(self.targets) 141 | indices = list(range(num_train)) 142 | random.shuffle(indices) 143 | self.samples = [self.samples[i] for i in indices[:subset_size]] 144 | self.targets = [self.targets[i] for i in indices[:subset_size]] 145 | return len(self.targets) 146 | 147 | def set_specific_subset(self, indices): 148 | self.samples = [self.original_samples[i] for i in indices] 149 | self.targets = [s[1] for s in self.samples] 150 | 151 | 152 | def reset_data_sampler(sampler, dset_length, dset): 153 | sampler.dataset = dset 154 | if dset_length % sampler.num_replicas != 0 and False: 155 | sampler.num_samples = math.ceil((dset_length - sampler.num_replicas) / sampler.num_replicas) 156 | else: 157 | sampler.num_samples = math.ceil(dset_length / sampler.num_replicas) 158 | sampler.total_size = sampler.num_samples * sampler.num_replicas 159 | 160 | 161 | def prepare_train_dataset(args, use_transforms=True): 162 | print('Preparing training data (ori imagenet train)...') 163 | tr_transforms_local = tr_transforms if use_transforms else None 164 | traindir = os.path.join(args.data, 'train') 165 | trset = SelectedRotateImageFolder(traindir, tr_transforms_local, original=True, rotation=args.rotation, 166 | rotation_transform=rotation_tr_transforms) 167 | return trset 168 | 169 | 170 | def prepare_train_dataloader(args, trset=None, sampler=None): 171 | if sampler is None: 172 | trloader = torch.utils.data.DataLoader(trset, batch_size=args.train_batch_size, shuffle=True, 173 | num_workers=args.workers, pin_memory=False) 174 | train_sampler = None 175 | else: 176 | train_sampler = torch.utils.data.distributed.DistributedSampler(trset) 177 | trloader = torch.utils.data.DataLoader( 178 | trset, batch_size=args.batch_size, 179 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) 180 | return trloader, train_sampler 181 | def obtain_train_loader(args): 182 | args.corruption = 'original' 183 | train_dataset, train_loader = prepare_test_data(args) 184 | train_dataset.switch_mode(True, False) 185 | return train_dataset, train_loader 186 | 187 | def prepare_test_data(args, use_transforms=True): 188 | if args.corruption == 'original': 189 | te_transforms_local = te_transforms if use_transforms else None 190 | elif args.corruption in common_corruptions: 191 | te_transforms_local = te_transforms_imageC if use_transforms else None 192 | else: 193 | assert False, NotImplementedError 194 | 195 | if args.exp_type == 'adversial_attack': 196 | te_transforms_local = ad_transforms_imageC 197 | 198 | if not hasattr(args, 'corruption') or args.corruption == 'original': 199 | print('Test on the original test set') 200 | validdir = os.path.join(args.data, 'val') 201 | teset = SelectedRotateImageFolder(validdir, te_transforms_local, original=False, rotation=False, 202 | rotation_transform=rotation_te_transforms) 203 | elif args.corruption in common_corruptions: 204 | print('Test on %s level %d' %(args.corruption, args.level)) 205 | validdir = os.path.join(args.data_corruption, args.corruption, str(args.level)) 206 | teset = SelectedRotateImageFolder(validdir, te_transforms_local, original=False, rotation=False, 207 | rotation_transform=rotation_te_transforms) 208 | else: 209 | raise Exception('Corruption not found!') 210 | 211 | 212 | if not hasattr(args, 'workers'): 213 | args.workers = 1 214 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.test_batch_size, shuffle=args.if_shuffle, 215 | num_workers=args.workers, pin_memory=True) 216 | if args.method=="MEMO" or args.method=="MEMO_dirichlet" : 217 | te_transforms_local = None 218 | if not hasattr(args, 'corruption') or args.corruption == 'original': 219 | print('Test on the original test set') 220 | validdir = os.path.join(args.data, 'val') 221 | elif args.corruption in common_corruptions: 222 | print('Test on %s level %d' %(args.corruption, args.level)) 223 | validdir = os.path.join(args.data_corruption, args.corruption, str(args.level)) 224 | else: 225 | raise Exception('Corruption not found!') 226 | 227 | teset = datasets.ImageFolder(validdir, te_transforms_local) 228 | if not hasattr(args, 'workers'): 229 | args.workers = 8 230 | collate_fn = None if use_transforms else lambda x: x 231 | teloader = torch.utils.data.DataLoader(teset, batch_size=args.test_batch_size, shuffle=args.if_shuffle, 232 | num_workers=args.workers, pin_memory=True, collate_fn=collate_fn) 233 | return teset, teloader 234 | te_transforms_inc = transforms.Compose([transforms.CenterCrop(224), 235 | transforms.ToTensor(), 236 | normalize]) 237 | def custom_collate_fn(batch): 238 | 239 | images, labels = zip(*batch) 240 | 241 | 242 | transformed_images_inc = [te_transforms_inc(img) for img in images] 243 | 244 | 245 | transformed_images_tr = [tr_transforms_memo(img) for img in images] 246 | 247 | 248 | return ( 249 | torch.utils.data.dataloader.default_collate(list(zip(transformed_images_inc, labels))), 250 | torch.utils.data.dataloader.default_collate(list(zip(transformed_images_tr, labels))) 251 | ) 252 | 253 | def prepare_test_data_for_train(args, use_transforms=True): 254 | te_transforms_local = tr_transforms if use_transforms else None 255 | if args.corruption in common_corruptions: 256 | print('Test on %s level %d' %(args.corruption, args.level)) 257 | validdir = os.path.join(args.data_corruption, args.corruption, str(args.level)) 258 | teset = SelectedRotateImageFolder(validdir, te_transforms_local, original=False, rotation=False, 259 | rotation_transform=rotation_te_transforms) 260 | else: 261 | raise Exception('Corruption not found!') 262 | 263 | if not hasattr(args, 'workers'): 264 | args.workers = 1 265 | teloader = torch.utils.data.DataLoader(teset, batch_size=64, shuffle=True, 266 | num_workers=args.workers, pin_memory=True) 267 | return teset, teloader -------------------------------------------------------------------------------- /models/Res.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | try: 3 | from torch.hub import load_state_dict_from_url 4 | except ImportError: 5 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 19 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | expansion = 4 76 | 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 78 | base_width=64, dilation=1, norm_layer=None): 79 | super(Bottleneck, self).__init__() 80 | if norm_layer is None: 81 | norm_layer = nn.BatchNorm2d 82 | width = int(planes * (base_width / 64.)) * groups 83 | 84 | self.conv1 = conv1x1(inplanes, width) 85 | self.bn1 = norm_layer(width) 86 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 87 | self.bn2 = norm_layer(width) 88 | self.conv3 = conv1x1(width, planes * self.expansion) 89 | self.bn3 = norm_layer(planes * self.expansion) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | identity = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.downsample is not None: 109 | identity = self.downsample(x) 110 | 111 | out += identity 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class ResNet(nn.Module): 118 | 119 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 120 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 121 | norm_layer=None): 122 | super(ResNet, self).__init__() 123 | if norm_layer is None: 124 | norm_layer = nn.BatchNorm2d 125 | self._norm_layer = norm_layer 126 | 127 | self.inplanes = 64 128 | self.dilation = 1 129 | if replace_stride_with_dilation is None: 130 | 131 | 132 | replace_stride_with_dilation = [False, False, False] 133 | if len(replace_stride_with_dilation) != 3: 134 | raise ValueError("replace_stride_with_dilation should be None " 135 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 136 | self.groups = groups 137 | self.base_width = width_per_group 138 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 139 | bias=False) 140 | self.bn1 = norm_layer(self.inplanes) 141 | self.relu = nn.ReLU(inplace=True) 142 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 143 | self.layer1 = self._make_layer(block, 64, layers[0]) 144 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 145 | dilate=replace_stride_with_dilation[0]) 146 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 147 | dilate=replace_stride_with_dilation[1]) 148 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 149 | dilate=replace_stride_with_dilation[2]) 150 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 151 | self.fc = nn.Linear(512 * block.expansion, num_classes) 152 | 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 156 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 157 | nn.init.constant_(m.weight, 1) 158 | nn.init.constant_(m.bias, 0) 159 | 160 | 161 | 162 | 163 | if zero_init_residual: 164 | for m in self.modules(): 165 | if isinstance(m, Bottleneck): 166 | nn.init.constant_(m.bn3.weight, 0) 167 | elif isinstance(m, BasicBlock): 168 | nn.init.constant_(m.bn2.weight, 0) 169 | 170 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 171 | norm_layer = self._norm_layer 172 | downsample = None 173 | previous_dilation = self.dilation 174 | if dilate: 175 | self.dilation *= stride 176 | stride = 1 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | conv1x1(self.inplanes, planes * block.expansion, stride), 180 | norm_layer(planes * block.expansion), 181 | ) 182 | 183 | layers = [] 184 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 185 | self.base_width, previous_dilation, norm_layer)) 186 | self.inplanes = planes * block.expansion 187 | for _ in range(1, blocks): 188 | layers.append(block(self.inplanes, planes, groups=self.groups, 189 | base_width=self.base_width, dilation=self.dilation, 190 | norm_layer=norm_layer)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x, return_feature=False, return_feature_only=False): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | 204 | 205 | x = self.layer4(x) 206 | x = self.avgpool(x) 207 | x = x.reshape(x.size(0), -1) 208 | if return_feature: 209 | feature = x 210 | x = self.fc(x) 211 | 212 | if return_feature: 213 | if return_feature_only: 214 | return feature 215 | else: 216 | return x, feature 217 | 218 | else: 219 | return x 220 | 221 | 222 | def _resnet(arch, block, layers, pretrained, progress, norm_layer, **kwargs): 223 | model = ResNet(block, layers, norm_layer=norm_layer, **kwargs) 224 | if pretrained: 225 | state_dict = load_state_dict_from_url(model_urls[arch], 226 | progress=progress) 227 | model.load_state_dict(state_dict) 228 | return model 229 | 230 | 231 | def resnet18(pretrained=False, progress=True, **kwargs): 232 | """Constructs a ResNet-18 model. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | """Constructs a ResNet-34 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet50(pretrained=False, progress=True, norm_layer=None, **kwargs): 252 | """Constructs a ResNet-50 model. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, norm_layer, 258 | **kwargs) 259 | 260 | 261 | def resnet101(pretrained=False, progress=True, **kwargs): 262 | """Constructs a ResNet-101 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | """Constructs a ResNet-152 model. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | """Constructs a ResNeXt-50 32x4d model. 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | kwargs['groups'] = 32 288 | kwargs['width_per_group'] = 4 289 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 290 | pretrained, progress, **kwargs) 291 | 292 | 293 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 294 | """Constructs a ResNeXt-101 32x8d model. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | kwargs['groups'] = 32 300 | kwargs['width_per_group'] = 8 301 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 302 | pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset 11 | 12 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 13 | 14 | 15 | def mean(items): 16 | return sum(items)/len(items) 17 | 18 | 19 | def max_with_index(values): 20 | best_v = values[0] 21 | best_i = 0 22 | for i, v in enumerate(values): 23 | if v > best_v: 24 | best_v = v 25 | best_i = i 26 | return best_v, best_i 27 | 28 | 29 | def shuffle(*items): 30 | example, *_ = items 31 | batch_size, *_ = example.size() 32 | index = torch.randperm(batch_size, device=example.device) 33 | 34 | return [item[index] for item in items] 35 | 36 | 37 | def to_device(*items): 38 | return [item.to(device=device) for item in items] 39 | 40 | 41 | def set_reproducible(seed=0): 42 | ''' 43 | To ensure the reproducibility, refer to https://pytorch.org/docs/stable/notes/randomness.html. 44 | Note that completely reproducible results are not guaranteed. 45 | ''' 46 | random.seed(seed) 47 | np.random.seed(seed) 48 | torch.manual_seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | 52 | 53 | def get_logger(name: str, output_directory: str, log_name: str, debug: str) -> logging.Logger: 54 | logger = logging.getLogger(name) 55 | 56 | formatter = logging.Formatter( 57 | "%(asctime)s %(levelname)-8s: %(message)s" 58 | ) 59 | 60 | console_handler = logging.StreamHandler(sys.stdout) 61 | console_handler.setFormatter(formatter) 62 | logger.addHandler(console_handler) 63 | 64 | if output_directory is not None: 65 | file_handler = logging.FileHandler(os.path.join(output_directory, log_name)) 66 | file_handler.setFormatter(formatter) 67 | logger.addHandler(file_handler) 68 | 69 | if debug: 70 | logger.setLevel(logging.DEBUG) 71 | else: 72 | logger.setLevel(logging.INFO) 73 | 74 | logger.propagate = False 75 | return logger 76 | 77 | 78 | def _sign(number): 79 | if isinstance(number, (list, tuple)): 80 | return [_sign(v) for v in number] 81 | if number >= 0.0: 82 | return 1 83 | elif number < 0.0: 84 | return -1 85 | 86 | 87 | def compute_kendall_tau(a, b): 88 | assert len(a) == len(b), "Sequence a and b should have the same length while computing kendall tau." 89 | length = len(a) 90 | count = 0 91 | total = 0 92 | for i in range(length-1): 93 | for j in range(i+1, length): 94 | count += _sign(a[i] - a[j]) * _sign(b[i] - b[j]) 95 | total += 1 96 | Ktau = count / total 97 | return Ktau 98 | 99 | 100 | def infer(buffer, ranker): 101 | matrix0 = [] 102 | ops0 = [] 103 | matrix1 = [] 104 | ops1 = [] 105 | for ((matrix0_, ops0_), (matrix1_, ops1_)) in buffer: 106 | matrix0.append(matrix0_) 107 | ops0.append(ops0_) 108 | 109 | matrix1.append(matrix1_) 110 | ops1.append(ops1_) 111 | 112 | matrix0 = torch.stack(matrix0, dim=0) 113 | ops0 = torch.stack(ops0, dim=0) 114 | 115 | matrix1 = torch.stack(matrix1, dim=0) 116 | ops1 = torch.stack(ops1, dim=0) 117 | 118 | with torch.no_grad(): 119 | outputs = ranker((matrix0, ops0), (matrix1, ops1)) 120 | 121 | return _sign((outputs-0.5).cpu().tolist()) 122 | 123 | 124 | def select(items, index): 125 | return [items[i] for i in index] 126 | 127 | 128 | def list_select(items, index): 129 | return [item[index] for item in items] 130 | 131 | 132 | def transpose_l(items): 133 | return list(map(list, zip(*items))) 134 | 135 | 136 | def index_generate(m, n, up_triangular=False, max_batch_size=1024): 137 | if up_triangular: 138 | indexs = [] 139 | for i in range(m-1): 140 | for j in range(i+1, n): 141 | indexs.append((i, j)) 142 | if len(indexs) == max_batch_size: 143 | yield indexs 144 | indexs = [] 145 | if indexs: 146 | yield indexs 147 | else: 148 | indexs = [] 149 | for i in range(m): 150 | for j in range(n): 151 | indexs.append((i, j)) 152 | if len(indexs) == max_batch_size: 153 | yield indexs 154 | indexs = [] 155 | if indexs: 156 | yield indexs 157 | 158 | 159 | def batchify(items): 160 | if isinstance(items[0], (list, tuple)): 161 | transposed_items = transpose_l(items) 162 | return [torch.stack(item, dim=0) for item in transposed_items] 163 | else: 164 | return torch.stack(items, dim=0) 165 | 166 | 167 | 168 | def cartesian_traverse(arch0, arch1, ranker, up_triangular=False): 169 | m, n = len(arch0), len(arch1) 170 | outputs = [] 171 | with torch.no_grad(): 172 | for index in index_generate(m, n, up_triangular): 173 | i, j = transpose_l(index) 174 | a = batchify(select(arch0, i)) 175 | b = batchify(select(arch1, j)) 176 | output = ranker(a, b) 177 | outputs.append(output) 178 | outputs = torch.cat(outputs, dim=0) 179 | if up_triangular: 180 | return outputs 181 | else: 182 | return outputs.view(m, n) 183 | 184 | 185 | def compute_kendall_tau_AR(ranker, archs, performances): 186 | 187 | length = len(performances) 188 | count = 0 189 | total = 0 190 | 191 | archs = transpose_l([torch.unbind(item) for item in archs]) 192 | outputs = cartesian_traverse(archs, archs, ranker, up_triangular=True) 193 | 194 | p_combination = _sign((outputs-0.5).cpu().tolist()) 195 | 196 | for i in range(length-1): 197 | for j in range(i+1, length): 198 | count += p_combination[total] * _sign(performances[i]-performances[j]) 199 | total += 1 200 | 201 | assert len(p_combination) == total 202 | Ktau = count / total 203 | return Ktau 204 | 205 | 206 | def concat(a, b): 207 | return [torch.cat([item0, item1]) for item0, item1 in zip(a, b)] 208 | 209 | 210 | def list_concat(a, b): 211 | if len(a) == 3: 212 | a0, a1, a2 = a 213 | b0, b1, b2 = b 214 | rev0 = concat(a0, b0) 215 | rev1 = concat(a1, b1) 216 | rev2 = torch.cat([a2, b2], dim=0) 217 | return rev0, rev1, rev2 218 | else: 219 | a0, a1 = a 220 | b0, b1 = b 221 | rev0 = concat(a0, b0) 222 | rev1 = concat(a1, b1) 223 | return rev0, rev1 224 | 225 | def compute_flops(module: nn.Module, size, skip_pattern, device): 226 | 227 | def size_hook(module: nn.Module, input: torch.Tensor, output: torch.Tensor): 228 | *_, h, w = output.shape 229 | module.output_size = (h, w) 230 | hooks = [] 231 | for name, m in module.named_modules(): 232 | if isinstance(m, nn.Conv2d): 233 | 234 | hooks.append(m.register_forward_hook(size_hook)) 235 | with torch.no_grad(): 236 | training = module.training 237 | module.eval() 238 | module(torch.rand(size).to(device)) 239 | module.train(mode=training) 240 | 241 | for hook in hooks: 242 | hook.remove() 243 | 244 | flops = 0 245 | for name, m in module.named_modules(): 246 | if skip_pattern in name: 247 | continue 248 | if isinstance(m, nn.Conv2d): 249 | 250 | h, w = m.output_size 251 | kh, kw = m.kernel_size 252 | flops += h * w * m.in_channels * m.out_channels * kh * kw / m.groups 253 | if isinstance(module, nn.Linear): 254 | flops += m.in_features * m.out_features 255 | return flops 256 | 257 | def compute_nparam(module: nn.Module, skip_pattern): 258 | n_param = 0 259 | for name, p in module.named_parameters(): 260 | if skip_pattern not in name: 261 | n_param += p.numel() 262 | return n_param 263 | 264 | 265 | def set_random_seed(seed): 266 | torch.manual_seed(seed) 267 | random.seed(seed) 268 | np.random.seed(seed) 269 | if torch.cuda.is_available(): 270 | torch.cuda.manual_seed(seed) 271 | torch.cuda.manual_seed_all(seed) 272 | 273 | def generate_mix_data(id_data, ood_data, ood_rate): 274 | """ 275 | Merge two PyTorch datasets with a specified out-of-distribution rate. 276 | 277 | Args: 278 | id_data (torch.utils.data.Dataset): PyTorch dataset containing in-distribution samples. 279 | ood_data (torch.utils.data.Dataset): PyTorch dataset containing out-of-distribution samples. 280 | ood_rate (float): The rate of out-of-distribution samples relative to in-distribution samples. 281 | 282 | Returns: 283 | torch.utils.data.Dataset: Merged dataset with a flag indicating the origin of each sample. 284 | """ 285 | 286 | ood_num_samples = int(len(id_data) * ood_rate) 287 | 288 | 289 | 290 | ood_subset_indices = random.sample(range(len(ood_data)), min(ood_num_samples, len(ood_data))) 291 | ood_subset = Subset(ood_data, ood_subset_indices) 292 | 293 | 294 | id_flag = torch.zeros(len(id_data), dtype=torch.int64) 295 | ood_flag = torch.ones(len(ood_subset), dtype=torch.int64) 296 | 297 | 298 | merged_data = ConcatDataset([id_data, ood_subset]) 299 | flags = torch.cat([id_flag, ood_flag]) 300 | 301 | class MergedDataset(torch.utils.data.Dataset): 302 | def __init__(self, data, flags): 303 | self.data = data 304 | self.flags = flags 305 | 306 | def __getitem__(self, index): 307 | sample, flag = self.data[index], self.flags[index] 308 | return sample, flag 309 | 310 | def __len__(self): 311 | return len(self.data) 312 | 313 | 314 | return MergedDataset(merged_data, flags) 315 | def generate_balanced_data(id_data, ood_data,ood_rate=0.5): 316 | """ 317 | Merge two PyTorch datasets with a 1:1 ratio. 318 | 319 | Args: 320 | id_data (torch.utils.data.Dataset): PyTorch dataset containing in-distribution samples. 321 | ood_data (torch.utils.data.Dataset): PyTorch dataset containing out-of-distribution samples. 322 | 323 | Returns: 324 | torch.utils.data.Dataset: Merged dataset with a flag indicating the origin of each sample. 325 | """ 326 | if (ood_rate > 1): 327 | raise ValueError("ood_rate must be between 0 and 1") 328 | id_rate = 1-ood_rate 329 | 330 | if(ood_rate==0): 331 | sum_simple = len(id_data) 332 | else: 333 | sum_simple = min(len(id_data)/id_rate,len(ood_data)/ood_rate) 334 | 335 | id_sample_count = int(sum_simple * id_rate) 336 | ood_sample_count = int(sum_simple * ood_rate) 337 | 338 | 339 | if len(id_data) > id_sample_count: 340 | id_subset_indices = random.sample(range(len(id_data)), id_sample_count) 341 | id_subset = Subset(id_data, id_subset_indices) 342 | else: 343 | id_subset = id_data 344 | 345 | 346 | if len(ood_data) > ood_sample_count: 347 | ood_subset_indices = random.sample(range(len(ood_data)), ood_sample_count) 348 | ood_subset = Subset(ood_data, ood_subset_indices) 349 | else: 350 | ood_subset = ood_data 351 | 352 | 353 | id_flag = torch.zeros(len(id_subset), dtype=torch.int64) 354 | ood_flag = torch.ones(len(ood_subset), dtype=torch.int64) 355 | 356 | 357 | merged_data = ConcatDataset([id_subset, ood_subset]) 358 | flags = torch.cat([id_flag, ood_flag]) 359 | 360 | 361 | class MergedDataset(Dataset): 362 | def __init__(self, data, flags): 363 | self.data = data 364 | self.flags = flags 365 | 366 | def __getitem__(self, index): 367 | sample, flag = self.data[index], self.flags[index] 368 | return sample, flag 369 | 370 | def __len__(self): 371 | return len(self.data) 372 | 373 | 374 | return MergedDataset(merged_data, flags) 375 | 376 | def merge_datasets(dataset_list): 377 | """ 378 | Merge a list of PyTorch datasets into a single dataset with equal sampling from each dataset. 379 | 380 | Args: 381 | dataset_list (list): A list of PyTorch datasets. 382 | 383 | Returns: 384 | ConcatDataset: A concatenated dataset. 385 | """ 386 | 387 | min_num_samples = min(len(dataset) for dataset in dataset_list) 388 | 389 | 390 | subset_list = [Subset(dataset, torch.randperm(len(dataset))[:min_num_samples]) for dataset in dataset_list] 391 | 392 | 393 | merged_dataset = ConcatDataset(subset_list) 394 | 395 | return merged_dataset -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import random 4 | import numpy as np 5 | import sys 6 | import os 7 | import argparse 8 | import torch 9 | import csv 10 | import foolbox as fb 11 | import torch.nn.functional as F 12 | from sklearn.metrics import det_curve, accuracy_score, roc_auc_score, auc, precision_recall_curve 13 | from utils.third_party import _augmix_aug as tr_transforms 14 | from tqdm import tqdm 15 | 16 | def compute_fnr(out_scores, in_scores, fpr_cutoff=.05): 17 | ''' 18 | compute fnr at 05 19 | ''' 20 | in_labels = np.zeros(len(in_scores)) 21 | out_labels = np.ones(len(out_scores)) 22 | y_true = np.concatenate([in_labels, out_labels]) 23 | y_score = np.concatenate([in_scores, out_scores]) 24 | 25 | if len(set(y_true)) < 2: 26 | print("Warning: Only one class present in y_true. Skipping DET curve calculation.") 27 | return 1.0 28 | 29 | fpr, fnr, thresholds = det_curve(y_true=y_true, y_score=y_score) 30 | 31 | idx = np.argmin(np.abs(fpr - fpr_cutoff)) 32 | 33 | fpr_at_fpr_cutoff = fpr[idx] 34 | fnr_at_fpr_cutoff = fnr[idx] 35 | thresholds95 = thresholds[idx] 36 | 37 | if fpr_at_fpr_cutoff > 0.1: 38 | fnr_at_fpr_cutoff = 1.0 39 | 40 | 41 | return fnr_at_fpr_cutoff, thresholds95 42 | 43 | 44 | 45 | 46 | def compute_auroc(out_scores, in_scores): 47 | in_labels = np.zeros(len(in_scores)) 48 | out_labels = np.ones(len(out_scores)) 49 | y_true = np.concatenate([in_labels, out_labels]) 50 | y_score = np.concatenate([in_scores, out_scores]) 51 | auroc = roc_auc_score(y_true=y_true, y_score=y_score) 52 | 53 | return auroc 54 | 55 | 56 | 57 | def compute_aupr(out_scores, in_scores): 58 | in_labels = np.zeros(len(in_scores)) 59 | out_labels = np.ones(len(out_scores)) 60 | y_true = np.concatenate([in_labels, out_labels]) 61 | y_score = np.concatenate([in_scores, out_scores]) 62 | precision, recall, _ = precision_recall_curve(y_true, y_score) 63 | aupr = auc(recall, precision) 64 | 65 | return aupr 66 | 67 | 68 | def eval_ood(in_scores, out_scores): 69 | fpr,_ = compute_fnr(out_scores, in_scores) 70 | auroc = compute_auroc(out_scores, in_scores) 71 | aupr = compute_aupr(out_scores, in_scores) 72 | 73 | return fpr, auroc, aupr 74 | 75 | def eval_ood_95(in_scores, out_scores): 76 | fpr,fprs95 = compute_fnr(out_scores, in_scores) 77 | auroc = compute_auroc(out_scores, in_scores) 78 | aupr = compute_aupr(out_scores, in_scores) 79 | 80 | return fpr, auroc, aupr,fprs95 81 | def dirichlet_probability(x): 82 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 83 | """Entropy of softmax distribution from logits.""" 84 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 85 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 86 | 87 | 88 | return brief 89 | 90 | def ood_metric(output, scoring_function='dirichlet'): 91 | to_np = lambda x: x.data.cpu().numpy() 92 | 93 | if scoring_function == 'entropy': 94 | score = to_np(output.mean(1) - torch.logsumexp(output, dim=1)) 95 | elif scoring_function == 'energy': 96 | score = to_np(-torch.logsumexp(output, dim=1)) 97 | elif scoring_function == 'msp': 98 | score = -np.max(to_np(F.softmax(output, dim=1)), axis=1) 99 | elif scoring_function == 'dirichlet': 100 | score = -np.max(to_np(dirichlet_probability(output).cpu()), axis=1) 101 | else: 102 | raise ValueError(f"Unknown scoring function: {scoring_function}") 103 | 104 | if not isinstance(score, np.ndarray): 105 | score = np.array([score]) 106 | 107 | return score 108 | def get_scores(args,net, test_loader): 109 | _in_score, _out_score = [], [] 110 | correct = 0 111 | total = 0 112 | 113 | loader = test_loader 114 | 115 | for batch in tqdm(loader, total=len(loader), disable=False): 116 | test_set = batch 117 | sample, ood_label = test_set[0], test_set[1] 118 | 119 | test_data, target, ood_label = sample[0].cuda(), sample[1].cuda(), ood_label.cuda() 120 | 121 | output = net(test_data) 122 | 123 | score = ood_metric(output, scoring_function=args.scoring_function) 124 | 125 | 126 | mask_in = (ood_label == 0) 127 | mask_in = mask_in.cpu() 128 | 129 | if mask_in.any(): 130 | predictions = output.max(1)[1] 131 | mask_right = ((predictions == target).cpu() & mask_in).cpu() 132 | mask_wrong = ((predictions != target).cpu() & mask_in).cpu() 133 | _in_score.extend(score[mask_right].tolist()) 134 | 135 | correct += (predictions[mask_in]==target[mask_in]).sum().item() 136 | total += mask_in.sum().item() 137 | 138 | 139 | mask_out = ood_label == 1 140 | mask_out = mask_out.cpu() 141 | if mask_out.any(): 142 | _out_score.extend(score[mask_out].tolist()) 143 | 144 | if mask_wrong.any(): 145 | _out_score.extend(score[mask_wrong].tolist()) 146 | 147 | accuracy = correct / total if total > 0 else 0 148 | return _in_score, _out_score, accuracy 149 | def get_scores_test(args,net, test_loader): 150 | _in_score, _out_score = [], [] 151 | correct = 0 152 | total = 0 153 | 154 | loader = test_loader 155 | 156 | for batch in tqdm(loader, total=len(loader), disable=False): 157 | test_set = batch 158 | sample, ood_label = test_set[0], test_set[1] 159 | 160 | test_data, target, ood_label = sample[0].cuda(), sample[1].cuda(), ood_label.cuda() 161 | 162 | output = net(test_data) 163 | 164 | score = ood_metric(output, scoring_function=args.scoring_function) 165 | 166 | 167 | mask_in = (ood_label == 0) 168 | mask_in = mask_in.cpu() 169 | 170 | if mask_in.any(): 171 | predictions = output.max(1)[1] 172 | mask_right = ((predictions == target).cpu() & mask_in).cpu() 173 | mask_wrong = ((predictions != target).cpu() & mask_in).cpu() 174 | _in_score.extend(score[mask_right].tolist()) 175 | 176 | correct += (predictions[mask_in]==target[mask_in]).sum().item() 177 | total += mask_in.sum().item() 178 | 179 | 180 | mask_out = ood_label == 1 181 | mask_out = mask_out.cpu() 182 | if mask_out.any(): 183 | _out_score.extend(score[mask_out].tolist()) 184 | 185 | if mask_wrong.any(): 186 | _out_score.extend(score[mask_wrong].tolist()) 187 | 188 | accuracy = correct / total if total > 0 else 0 189 | return _in_score, _out_score, accuracy 190 | def get_scores_memo(args, net, test_loader): 191 | _in_score, _out_score = [], [] 192 | correct = 0 193 | total = 0 194 | 195 | loader = test_loader 196 | with tqdm(total=len(loader), disable=True) as _tqdm: 197 | _tqdm.set_description('Processing batches') 198 | for batch in loader: 199 | test_set = batch 200 | sample, ood_label = test_set[0], test_set[1] 201 | 202 | test_data, target = sample 203 | ood_label = ood_label 204 | 205 | output = net(test_data) 206 | 207 | score = ood_metric(output, scoring_function=args.scoring_function) 208 | 209 | 210 | mask_in = (ood_label == 0).item() 211 | if mask_in: 212 | predictions = output.max(1)[1] 213 | mask_right = (predictions == target).cpu().item() 214 | mask_wrong = (predictions != target).cpu().item() 215 | if mask_right: 216 | _in_score.extend(score.tolist()) 217 | elif mask_wrong: 218 | _out_score.extend(score.tolist()) 219 | correct += mask_right 220 | total += 1 221 | 222 | mask_out = (ood_label == 1).item() 223 | if mask_out: 224 | _out_score.extend(score.tolist()) 225 | 226 | 227 | accuracy = correct / total if total > 0 else 0 228 | _tqdm.set_postfix(accuracy='{:.4f}'.format(accuracy)) 229 | _tqdm.update(1) 230 | 231 | accuracy = correct / total if total > 0 else 0 232 | return _in_score, _out_score, accuracy 233 | 234 | def dirichlet_entropy(x): 235 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 236 | 237 | """Entropy of softmax distribution from logits.""" 238 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 239 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 240 | probability = torch.cat([brief, uncertainty], dim=1) + 1e-7 241 | return -(probability * torch.log(probability)).sum(1) 242 | def dirichlet_uncertainty(x): 243 | x = x / torch.norm(x, p=2, dim=-1, keepdim=True) * torch.norm(x, p=2, dim=-1, keepdim=True).detach() 244 | 245 | """Entropy of softmax distribution from logits.""" 246 | brief = torch.exp(x)/(torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 247 | uncertainty = 1000 / (torch.sum(torch.exp(x), dim=1, keepdim=True) + 1000) 248 | return uncertainty 249 | def softmax_entropy(x: torch.Tensor) -> torch.Tensor: 250 | """Entropy of softmax distribution from logits.""" 251 | x = -(x.softmax(1) * x.log_softmax(1)).sum(1) 252 | return x 253 | 254 | def get_scores_pot(args, net, test_loader): 255 | save_path="pot/"+args.method + "_" + args.corruption + "_level"+str(args.level)+".csv" 256 | _in_score, _out_score = [], [] 257 | correct = 0 258 | total = 0 259 | batch_accuracies = [] 260 | batch_fprs = [] 261 | batch_softmax_max = [] 262 | batch_uncertainty = [] 263 | softmax_Entropys = [] 264 | dirichlet_Entropys = [] 265 | 266 | loader = test_loader 267 | 268 | for batch in tqdm(loader, total=len(loader), disable=True): 269 | test_set = batch 270 | sample, ood_label = test_set[0], test_set[1] 271 | 272 | test_data, target, ood_label = sample[0].cuda(), sample[1].cuda(), ood_label.cuda() 273 | 274 | output = net(test_data) 275 | uncertainty = dirichlet_uncertainty(output).mean(0).item() 276 | softmax_Entropy = softmax_entropy(output).mean(0).item() 277 | dirichlet_Entropy = dirichlet_entropy(output).mean(0).item() 278 | 279 | score = ood_metric(output, scoring_function=args.scoring_function) 280 | 281 | 282 | mask_in = (ood_label == 0) 283 | mask_in = mask_in.cpu() 284 | 285 | if mask_in.any(): 286 | predictions = output.max(1)[1] 287 | mask_right = ((predictions == target).cpu() & mask_in).cpu() 288 | mask_wrong = ((predictions != target).cpu() & mask_in).cpu() 289 | _in_score.extend(score[mask_right].tolist()) 290 | correct += (predictions[mask_in]==target[mask_in]).sum().item() 291 | total += mask_in.sum().item() 292 | 293 | 294 | accuracy = correct / total if total > 0 else 0 295 | softmax_max = torch.softmax(output, dim=1).max(1)[0].mean().item() 296 | 297 | 298 | 299 | mask_out = ood_label == 1 300 | mask_out = mask_out.cpu() 301 | if mask_out.any(): 302 | _out_score.extend(score[mask_out].tolist()) 303 | 304 | if mask_wrong.any(): 305 | _out_score.extend(score[mask_wrong].tolist()) 306 | fpr = compute_fnr(_out_score, _in_score) 307 | batch_fprs.append(fpr) 308 | batch_accuracies.append(accuracy) 309 | batch_softmax_max.append(softmax_max) 310 | softmax_Entropys.append(softmax_Entropy) 311 | dirichlet_Entropys.append(dirichlet_Entropy) 312 | batch_uncertainty.append(uncertainty) 313 | 314 | 315 | with open(save_path, 'w', newline='') as csvfile: 316 | 317 | fieldnames = ['batch', 'accuracy','fpr', 'softmax_max','softmax_Entropy'] 318 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 319 | writer.writeheader() 320 | for i, (acc,fpr, softmax,se,de,u) in enumerate(zip(batch_accuracies,batch_fprs, batch_softmax_max,softmax_Entropys,dirichlet_Entropys,batch_uncertainty)): 321 | 322 | writer.writerow({'batch': i + 1, 'accuracy': acc, 'fpr':fpr,'softmax_max': softmax,'softmax_Entropy':se}) 323 | accuracy = correct / total if total > 0 else 0 324 | return _in_score, _out_score, accuracy 325 | 326 | 327 | def get_scores_pot_u(args, net, test_loader): 328 | save_path = "pot_u/" + args.method + "_" + args.corruption + "_level" + str(args.level) + ".csv" 329 | _in_score, _out_score = [], [] 330 | correct = 0 331 | total = 0 332 | 333 | 334 | sample_info = [] 335 | 336 | loader = test_loader 337 | sample_id = 0 338 | 339 | for batch in tqdm(loader, total=len(loader), disable=True): 340 | test_set = batch 341 | sample, ood_label = test_set[0], test_set[1] 342 | 343 | test_data, target, ood_label = sample[0].cuda(), sample[1].cuda(), ood_label.cuda() 344 | 345 | output= net(test_data) 346 | uncertainty = dirichlet_uncertainty(output) 347 | 348 | score = ood_metric(output, scoring_function=args.scoring_function) 349 | predictions = output.max(1)[1] 350 | softmax_scores = torch.softmax(output, dim=1).max(1)[0] 351 | 352 | mask_in = (ood_label == 0).cpu() 353 | mask_out = (ood_label == 1).cpu() 354 | 355 | mask_right = ((predictions == target).cpu() & mask_in) 356 | mask_wrong = ((predictions != target).cpu() & mask_in) 357 | 358 | 359 | 360 | for i in range(len(test_data)): 361 | if mask_in[i] and mask_right[i]: 362 | flag = 0 363 | elif mask_in[i] and mask_wrong[i]: 364 | flag = 1 365 | elif mask_out[i]: 366 | flag = 2 367 | else: 368 | continue 369 | 370 | entry = { 371 | 'id': sample_id, 372 | 'softmax_max': softmax_scores[i].item(), 373 | 'flag': flag 374 | } 375 | 376 | 377 | if 'dirichlet' in args.method.lower(): 378 | entry['u'] = uncertainty[i].item() 379 | 380 | sample_info.append(entry) 381 | sample_id += 1 382 | 383 | if mask_in.any(): 384 | _in_score.extend(score[mask_right].tolist()) 385 | correct += (predictions[mask_in]==target[mask_in]).sum().item() 386 | total += mask_in.sum().item() 387 | 388 | if mask_out.any(): 389 | _out_score.extend(score[mask_out].tolist()) 390 | 391 | if mask_wrong.any(): 392 | _out_score.extend(score[mask_wrong].tolist()) 393 | 394 | 395 | with open(save_path, 'w', newline='') as csvfile: 396 | fieldnames = ['id', 'softmax_max', 'u', 'flag'] if 'dirichlet' in args.method.lower() else ['id', 'softmax_max', 'flag'] 397 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 398 | writer.writeheader() 399 | for entry in sample_info: 400 | writer.writerow(entry) 401 | accuracy = correct / total if total > 0 else 0 402 | return _in_score, _out_score, accuracy 403 | 404 | 405 | 406 | def normalize_tensor(tensor, mean, std): 407 | mean = torch.tensor(mean).view(1, 3, 1, 1).cuda() 408 | std = torch.tensor(std).view(1, 3, 1, 1).cuda() 409 | normalized_tensor = (tensor - mean) / std 410 | return normalized_tensor 411 | 412 | def merge_tensors(tensor1, tensor2, p): 413 | assert tensor1.shape == tensor2.shape, "Tensors must have the same shape" 414 | mask = torch.rand(tensor1.shape[0], 1, 1, 1).to(tensor1.device) > p 415 | mask = mask.expand_as(tensor1) 416 | merged_tensor = torch.where(mask, tensor1, tensor2) 417 | return merged_tensor 418 | 419 | def get_scores_adversial(net, test_loader, fmodel, attack, args): 420 | _in_score, _out_score = [], [] 421 | correct = 0 422 | total = 0 423 | 424 | loader = test_loader 425 | 426 | count = 0 427 | 428 | 429 | for batch in tqdm(loader, desc="Processing"): 430 | count += 1 431 | 432 | test_set = batch 433 | sample, ood_label = test_set[0], test_set[1] 434 | 435 | test_data, target, ood_label = sample[0].cuda(), sample[1].cuda(), ood_label.cuda() 436 | 437 | fpredict = fmodel(test_data).max(1)[1] 438 | 439 | 440 | _, advs, _ = attack(fmodel, test_data, fpredict, epsilons=[args.epsilon]) 441 | advs = advs[0] 442 | 443 | mix_data = merge_tensors(test_data, advs, p=args.ad_rate) 444 | 445 | 446 | mix_data = normalize_tensor(mix_data, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]) 447 | 448 | output = net(mix_data) 449 | 450 | score = ood_metric(output, scoring_function=args.scoring_function) 451 | 452 | 453 | mask_in = (ood_label == 0) 454 | mask_in = mask_in.cpu() 455 | 456 | if mask_in.any(): 457 | predictions = output.max(1)[1] 458 | mask_right = ((predictions == target).cpu() & mask_in).cpu() 459 | mask_wrong = ((predictions != target).cpu() & mask_in).cpu() 460 | _in_score.extend(score[mask_right].tolist()) 461 | correct += (predictions[mask_in]==target[mask_in]).sum().item() 462 | total += mask_in.sum().item() 463 | 464 | 465 | mask_out = ood_label == 1 466 | mask_out = mask_out.cpu() 467 | if mask_out.any(): 468 | _out_score.extend(score[mask_out].tolist()) 469 | 470 | if mask_wrong.any(): 471 | _out_score.extend(score[mask_wrong].tolist()) 472 | 473 | 474 | 475 | 476 | accuracy = correct / total if total > 0 else 0 477 | return _in_score, _out_score, accuracy -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright to COME Authors, ICLR 2025 3 | built upon on Tent code. 4 | """ 5 | from logging import debug 6 | import os 7 | import time 8 | import argparse 9 | import random 10 | import numpy as np 11 | from pycm import * 12 | from copy import deepcopy 13 | import math 14 | from dataset.selectedRotateImageFolder import prepare_test_data, obtain_train_loader 15 | from utils.metrics import eval_ood, eval_ood_95, get_scores 16 | from utils.utils import get_logger, set_random_seed, generate_mix_data, merge_datasets, generate_balanced_data 17 | from torchvision import datasets as dset 18 | import torch 19 | import torch.nn.functional as F 20 | import tent, sar, tent_come, sar_come, eata, eata_come 21 | from sam import SAM 22 | import timm 23 | import models.Res as Resnet 24 | import torch.nn as nn 25 | import torchvision.transforms as transforms 26 | from torch.utils.data import DataLoader, ConcatDataset, Subset 27 | from PIL import ImageFile 28 | ImageFile.LOAD_TRUNCATED_IMAGES = True 29 | torch.set_num_threads(8) 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser(description='exps') 33 | # path 34 | parser.add_argument('--data', default='/path/to/dataset/Imagenet1K', help='path to dataset') 35 | parser.add_argument('--data_corruption', default='/path/to/dataset/ImageNet_C', help='path to corruption dataset') 36 | parser.add_argument('--ood_root', default='/path/to/dataset/', help='path to open-world dataset') 37 | parser.add_argument('--output', default='/path/to/output/result', help='the output directory of this experiment') 38 | # dataloader 39 | parser.add_argument('--workers', default=8, type=int, help='number of data loading workers') 40 | parser.add_argument('--test_batch_size', default=64, type=int, help='batch size for testing') 41 | parser.add_argument('--if_shuffle', default=True, type=bool, help='if shuffle the test set.') 42 | # corruption settings 43 | parser.add_argument('--level', default=5, type=int, help='corruption level of test(val) set.') 44 | parser.add_argument('--corruption', default='gaussian_noise', type=str, help='corruption type of test(val) set.') 45 | # Exp Settings 46 | parser.add_argument('--seed', default=2021, type=int, help='seed for initializing training.') 47 | parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.') 48 | parser.add_argument('--debug', default=False, type=bool, help='debug or not.') 49 | parser.add_argument('--method', default='Tent', type=str, help='no_adapt, Tent, EATA, SAR, Tent_COME, EATA_COME, SAR_COME') 50 | parser.add_argument('--model', default='resnet50_bn_torch', type=str, help='resnet50_bn_torch or vitbase_timm') 51 | parser.add_argument('--model_path', default='path/to/resnet50_bn_torch', type=str, help='path to resnet50_bn_torch or vitbase_timm') 52 | parser.add_argument('--exp_type', default='normal', type=str) 53 | parser.add_argument('--scoring_function', default='msp', type=str) 54 | parser.add_argument('--ood_rate', type=float, default=0.0) 55 | parser.add_argument('--steps', type=int, default=1) 56 | # SAR parameters 57 | parser.add_argument('--sar_margin_e0', default=math.log(1000) * 0.40, type=float, help='the threshold for reliable minimization in SAR') 58 | # eata settings 59 | parser.add_argument('--fisher_size', default=2000, type=int, help='number of samples to compute fisher information matrix.') 60 | parser.add_argument('--fisher_beta', type=float, default=2000., help='the trade-off between entropy and regularization loss') 61 | parser.add_argument('--e_margin', type=float, default=math.log(1000)*0.40, help='entropy margin E_0 for filtering reliable samples') 62 | parser.add_argument('--d_margin', type=float, default=0.05, help='\epsilon for filtering redundant samples') 63 | return parser.parse_args() 64 | 65 | 66 | def get_model(args): 67 | bs = args.test_batch_size 68 | if args.model == "vitbase_timm": 69 | net = timm.create_model('vit_base_patch16_224', pretrained=True) 70 | args.lr = (0.001 / 64) * bs 71 | elif args.model == "resnet50_bn_torch": 72 | net = Resnet.__dict__['resnet50'](pretrained=True) 73 | args.lr = (0.00025 / 64) * bs * 2 if bs < 32 else 0.00025 74 | else: 75 | assert False, NotImplementedError 76 | net = net.cuda() 77 | net.eval() 78 | net.requires_grad_(False) 79 | 80 | return net 81 | 82 | def get_adapt_model(net, args): 83 | if args.method == "no_adapt": 84 | adapt_model = net.eval() 85 | elif args.method == "Tent": 86 | net = tent.configure_model(net) 87 | params, param_names = tent.collect_params(net) 88 | optimizer = torch.optim.SGD(params, args.lr, momentum=0.9) 89 | adapt_model = tent.Tent(net, optimizer, steps=args.steps) 90 | elif args.method=="Tent_COME": 91 | net = tent_come.configure_model(net) 92 | params, param_names = tent_come.collect_params(net) 93 | optimizer = torch.optim.SGD(params, args.lr, momentum=0.9) 94 | adapt_model = tent_come.Tent_COME(net, optimizer, steps=args.steps,args=args) 95 | 96 | elif args.method == 'SAR': 97 | net = sar.configure_model(net) 98 | params, param_names = sar.collect_params(net) 99 | base_optimizer = torch.optim.SGD 100 | optimizer = SAM(params, base_optimizer, lr=args.lr, momentum=0.9) 101 | adapt_model = sar.SAR(net, optimizer, margin_e0=args.sar_margin_e0) 102 | 103 | elif args.method =='SAR_COME': 104 | net = sar_come.configure_model(net) 105 | params, param_names = sar_come.collect_params(net) 106 | base_optimizer = torch.optim.SGD 107 | optimizer = SAM(params, base_optimizer, lr=args.lr, momentum=0.9) 108 | adapt_model = sar_come.SAR_COME(net, optimizer, margin_e0=args.sar_margin_e0) 109 | 110 | elif args.method == "EATA": 111 | # compute fisher informatrix 112 | temp = args.corruption 113 | args.corruption = 'original' 114 | fisher_dataset, fisher_loader = prepare_test_data(args) 115 | fisher_dataset.set_dataset_size(args.fisher_size) 116 | fisher_dataset.switch_mode(True, False) 117 | args.corruption = temp 118 | 119 | net = eata.configure_model(net) 120 | params, param_names = eata.collect_params(net) 121 | #logger.info(param_names) 122 | # fishers = None 123 | ewc_optimizer = torch.optim.SGD(params, 0.001) 124 | fishers = {} 125 | train_loss_fn = nn.CrossEntropyLoss().cuda() 126 | for iter_, (images, targets) in enumerate(fisher_loader, start=1): 127 | if args.gpu is not None: 128 | images = images.cuda(args.gpu, non_blocking=True) 129 | if torch.cuda.is_available(): 130 | targets = targets.cuda(args.gpu, non_blocking=True) 131 | outputs = net(images) 132 | _, targets = outputs.max(1) 133 | loss = train_loss_fn(outputs, targets) 134 | loss.backward() 135 | for name, param in net.named_parameters(): 136 | if param.grad is not None: 137 | if iter_ > 1: 138 | fisher = param.grad.data.clone().detach() ** 2 + fishers[name][0] 139 | else: 140 | fisher = param.grad.data.clone().detach() ** 2 141 | if iter_ == len(fisher_loader): 142 | fisher = fisher / iter_ 143 | fishers.update({name: [fisher, param.data.clone().detach()]}) 144 | ewc_optimizer.zero_grad() 145 | logger.info("compute fisher matrices finished") 146 | del ewc_optimizer 147 | 148 | optimizer = torch.optim.SGD(params, args.lr, momentum=0.9) 149 | adapt_model = eata.EATA(net, optimizer, fishers, args.fisher_beta, e_margin=args.e_margin, d_margin=args.d_margin) 150 | 151 | elif args.method == "EATA_COME": 152 | # compute fisher informatrix 153 | temp = args.corruption 154 | args.corruption = 'original' 155 | fisher_dataset, fisher_loader = prepare_test_data(args) 156 | fisher_dataset.set_dataset_size(args.fisher_size) 157 | fisher_dataset.switch_mode(True, False) 158 | args.corruption = temp 159 | net = eata_come.configure_model(net) 160 | params, param_names = eata_come.collect_params(net) 161 | #logger.info(param_names) 162 | # fishers = None 163 | ewc_optimizer = torch.optim.SGD(params, 0.001) 164 | fishers = {} 165 | train_loss_fn = nn.CrossEntropyLoss().cuda() 166 | for iter_, (images, targets) in enumerate(fisher_loader, start=1): 167 | if args.gpu is not None: 168 | images = images.cuda(args.gpu, non_blocking=True) 169 | if torch.cuda.is_available(): 170 | targets = targets.cuda(args.gpu, non_blocking=True) 171 | outputs = net(images) 172 | _, targets = outputs.max(1) 173 | loss = train_loss_fn(outputs, targets) 174 | loss.backward() 175 | for name, param in net.named_parameters(): 176 | if param.grad is not None: 177 | if iter_ > 1: 178 | fisher = param.grad.data.clone().detach() ** 2 + fishers[name][0] 179 | else: 180 | fisher = param.grad.data.clone().detach() ** 2 181 | if iter_ == len(fisher_loader): 182 | fisher = fisher / iter_ 183 | fishers.update({name: [fisher, param.data.clone().detach()]}) 184 | ewc_optimizer.zero_grad() 185 | logger.info("compute fisher matrices finished") 186 | del ewc_optimizer 187 | 188 | optimizer = torch.optim.SGD(params, args.lr, momentum=0.9) 189 | adapt_model = eata_come.EATA_COME(net, optimizer, fishers, args.fisher_beta, e_margin=args.e_margin, d_margin=args.d_margin) 190 | else: 191 | assert False, NotImplementedError 192 | 193 | return adapt_model 194 | 195 | def create_ood_dataset(ood_root): 196 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 197 | transform_pipeline = transforms.Compose([ 198 | transforms.CenterCrop(224), 199 | transforms.ToTensor(), 200 | normalize 201 | ]) 202 | datasets_dict = { 203 | "Ninco": dset.ImageFolder(root=os.path.join(ood_root, "NINCO/NINCO_OOD_classes"), transform=transform_pipeline), 204 | "iNaturalist": dset.ImageFolder(root=os.path.join(ood_root, "iNaturalist/train_val_images"), transform=transform_pipeline), 205 | "SSB_Hard": dset.ImageFolder(root=os.path.join(ood_root, "ssb_hard_3"), transform=transform_pipeline), 206 | "Texture": dset.ImageFolder(root=os.path.join(ood_root, "dtd/images"), transform=transform_pipeline), 207 | "Openimage_O": dset.ImageFolder(root=os.path.join(ood_root, "openimage_o_3"), transform=transform_pipeline) 208 | } 209 | 210 | OOD_dataset = merge_datasets(list(datasets_dict.values())) 211 | 212 | return OOD_dataset, datasets_dict 213 | 214 | 215 | if __name__ == '__main__': 216 | 217 | args = get_args() 218 | 219 | # set random seeds 220 | if args.seed is not None: 221 | random.seed(args.seed) 222 | np.random.seed(args.seed) 223 | torch.manual_seed(args.seed) 224 | 225 | if not os.path.exists(args.output): 226 | os.makedirs(args.output, exist_ok=True) 227 | 228 | args.logger_name=time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())+"-{}-{}-level{}-seed{}-ood_rate{}-{}.txt".format(args.method, args.model, args.level, args.seed,args.ood_rate,args.exp_type) 229 | logger = get_logger(name="project", output_directory=args.output, log_name=args.logger_name, debug=False) 230 | 231 | common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'] 232 | 233 | if args.exp_type == 'normal': 234 | cpt_name,accs, fprs, aurocs = [], [], [],[] 235 | for corrupt in common_corruptions: 236 | net = get_model(args) 237 | adapt_model = get_adapt_model(net, args) 238 | args.corruption = corrupt 239 | 240 | ID_dataset, _ = prepare_test_data(args) 241 | ID_dataset.switch_mode(True, False) 242 | mixed_data = generate_mix_data(ID_dataset,[],0) 243 | mixed_loader = DataLoader(mixed_data, batch_size = args.test_batch_size, shuffle=True, 244 | num_workers = args.workers, pin_memory = True) 245 | in_score, out_score, acc = get_scores(args,adapt_model, mixed_loader) 246 | fpr, auroc, aupr = eval_ood(in_score, out_score) 247 | cpt_name.append(corrupt) 248 | accs.append(acc) 249 | fprs.append(fpr) 250 | aurocs.append(auroc) 251 | logger.info(f"Result under {corrupt}. Accuracy: {acc:.5f}, fpr: {fpr:.5f}, AUROC: {auroc:.5f}") 252 | logger.info("\n") 253 | args_str = f"method: {args.method}, level: {args.level}, exp_type: {args.exp_type}, steps: {args.steps}, scoring_function: {args.scoring_function}, model: {args.model}, ood_rate: {args.ood_rate}, seed: {args.seed}" 254 | logger.info(args_str) 255 | logger.info(f"Completed corruptions: {cpt_name}") 256 | logger.info(f"Accuracies: {accs}") 257 | logger.info(f"FPRs: {fprs}") 258 | logger.info(f"AUROCs: {aurocs}") 259 | elif args.exp_type == 'open-world': 260 | logger.info(args) 261 | args.corruption = 'gaussian_noise' 262 | ID_dataset, _ = prepare_test_data(args) 263 | 264 | ID_dataset.switch_mode(True, False) 265 | _, individual_datasets = create_ood_dataset(args.ood_root) 266 | OOD_datasets = ['None','Ninco', 'iNaturalist', 'SSB_Hard', 'Texture','Openimage_O'] 267 | ood_names,accs, fprs, aurocs, thresholds95s = [], [], [],[],[] 268 | for ood_name in OOD_datasets: 269 | net = get_model(args) 270 | adapt_model = get_adapt_model(net, args) 271 | if ood_name == 'None': 272 | mixed_data = generate_balanced_data(ID_dataset,[],0) 273 | else: 274 | mixed_data = generate_balanced_data(ID_dataset,individual_datasets[ood_name],args.ood_rate) 275 | 276 | mixed_loader = DataLoader(mixed_data, batch_size = args.test_batch_size, shuffle=True, 277 | num_workers = args.workers, pin_memory = True) 278 | in_score, out_score, acc = get_scores(args,adapt_model, mixed_loader) 279 | 280 | fpr, auroc, aupr, thresholds95 = eval_ood_95(in_score, out_score) 281 | ood_names.append(ood_name) 282 | accs.append(acc) 283 | fprs.append(fpr) 284 | aurocs.append(auroc) 285 | thresholds95s.append(thresholds95) 286 | logger.info(f"Result under {ood_name}. Accuracy: {acc:.5f}, fpr: {fpr:.5f}, AUROC: {auroc:.5f},thresholds95: {thresholds95}") 287 | logger.info("\n") 288 | args_str = f"method: {args.method}, level: {args.level}, exp_type: {args.exp_type}, steps: {args.steps}, scoring_function: {args.scoring_function}, model: {args.model}, ood_rate: {args.ood_rate}, seed: {args.seed}" 289 | logger.info(args_str) 290 | logger.info(f"Completed: {ood_names}") 291 | logger.info(f"Accuracies: {accs}") 292 | logger.info(f"FPRs: {fprs}") 293 | logger.info(f"thresholds95: {thresholds95s}") 294 | logger.info(f"AUROCs: {aurocs}") 295 | elif args.exp_type == 'imblanced': 296 | cpt_name,accs, fprs, aurocs = [], [], [],[] 297 | for corrupt in common_corruptions: 298 | net = get_model(args) 299 | adapt_model = get_adapt_model(net, args) 300 | args.corruption = corrupt 301 | ID_dataset, _ = prepare_test_data(args) 302 | 303 | ID_dataset.switch_mode(True, False) 304 | mixed_data = generate_mix_data(ID_dataset,[],0) 305 | mixed_loader = DataLoader(mixed_data, batch_size = args.test_batch_size, shuffle=False, 306 | num_workers = args.workers, pin_memory = True) 307 | in_score, out_score, acc = get_scores(args,adapt_model, mixed_loader) 308 | 309 | fpr, auroc, aupr = eval_ood(in_score, out_score) 310 | cpt_name.append(corrupt) 311 | accs.append(acc) 312 | fprs.append(fpr) 313 | aurocs.append(auroc) 314 | logger.info(f"Result under {corrupt}. Accuracy: {acc:.5f}, fpr: {fpr:.5f}, AUROC: {auroc:.5f}") 315 | logger.info("\n") 316 | args_str = f"method: {args.method}, level: {args.level}, exp_type: {args.exp_type}, steps: {args.steps}, scoring_function: {args.scoring_function}, model: {args.model}, ood_rate: {args.ood_rate}, seed: {args.seed}" 317 | logger.info(args_str) 318 | logger.info(f"Completed: {cpt_name}") 319 | logger.info(f"Accuracies: {accs}") 320 | logger.info(f"FPRs: {fprs}") 321 | logger.info(f"AUROCs: {aurocs}") 322 | elif args.exp_type == 'mix-shift': 323 | net = get_model(args) 324 | adapt_model = get_adapt_model(net, args) 325 | ID_datasets = [] 326 | for corrupt in common_corruptions: 327 | args.corruption = corrupt 328 | ID_dataset, _ = prepare_test_data(args) 329 | ID_dataset.switch_mode(True, False) 330 | ID_datasets.append(ID_dataset) 331 | ID_dataset = ConcatDataset(ID_datasets) 332 | mixed_data = generate_mix_data(ID_dataset,[],0) 333 | mixed_loader = DataLoader(mixed_data, batch_size = args.test_batch_size, shuffle=True, 334 | num_workers = args.workers, pin_memory = True) 335 | in_score, out_score, acc = get_scores(args,adapt_model, mixed_loader) 336 | fpr, auroc, aupr = eval_ood(in_score, out_score) 337 | 338 | logger.info("\n") 339 | args_str = f"method: {args.method}, level: {args.level}, exp_type: {args.exp_type}, steps: {args.steps}, scoring_function: {args.scoring_function}, model: {args.model}, ood_rate: {args.ood_rate}, seed: {args.seed}" 340 | logger.info(args_str) 341 | logger.info(f"Completed: mix_shift") 342 | logger.info(f"Accuracies: {acc}") 343 | logger.info(f"FPRs: {fpr}") 344 | logger.info(f"AUROCs: {auroc}") 345 | 346 | elif args.exp_type == 'life-long': 347 | cpt_name,accs, fprs, aurocs = [], [], [],[] 348 | in_scores,out_scores=[],[] 349 | net = get_model(args) 350 | adapt_model = get_adapt_model(net, args) 351 | for corrupt in common_corruptions: 352 | args.corruption = corrupt 353 | ID_dataset, _ = prepare_test_data(args) 354 | ID_dataset.switch_mode(True, False) 355 | mixed_data = generate_mix_data(ID_dataset,[],0) 356 | mixed_loader = DataLoader(mixed_data, batch_size = args.test_batch_size, shuffle=True, 357 | num_workers = args.workers, pin_memory = True) 358 | in_score, out_score, acc = get_scores(args,adapt_model, mixed_loader) 359 | in_scores.extend(in_score) 360 | out_scores.extend(out_score) 361 | fpr, auroc, aupr = eval_ood(in_scores, out_scores) 362 | cpt_name.append(corrupt) 363 | accs.append(acc) 364 | fprs.append(fpr) 365 | aurocs.append(auroc) 366 | logger.info(f"Result under {corrupt}. Accuracy: {acc:.5f}, fpr: {fpr:.5f}, AUROC: {auroc:.5f}") 367 | logger.info("\n") 368 | args_str = f"method: {args.method}, level: {args.level}, exp_type: {args.exp_type}, steps: {args.steps}, scoring_function: {args.scoring_function}, model: {args.model}, ood_rate: {args.ood_rate}, seed: {args.seed}" 369 | logger.info(args_str) 370 | logger.info(f"Completed: {cpt_name}") 371 | logger.info(f"Accuracies: {accs}") 372 | logger.info(f"FPRs: {fprs}") 373 | logger.info(f"AUROCs: {aurocs}") 374 | else: 375 | assert False, NotImplementedError --------------------------------------------------------------------------------