├── LICENSE ├── README.md ├── init_ref_model.py ├── load_data.py ├── main.py ├── model.py ├── projnorm.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yaodong Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Projection Norm (ProjNorm) 2 | 3 | This is the code for the [ICML2022 paper](https://arxiv.org/abs/2202.05834): 4 | 5 | ### *Predicting Out-of-Distribution Error with the Projection Norm* 6 | 7 | by Yaodong Yu*, Zitong Yang*, Alexander Wei, Yi Ma, Jacob Steinhardt from UC Berkeley (*equal contribution). 8 | 9 | ## Prerequisites 10 | * Python 11 | * Pytorch (1.10.0) 12 | * CUDA 13 | * numpy 14 | 15 | 16 | ## How to compute ProjNorm to study model performance under distributional shift? 17 | We use [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) (in-distribution dataset) & [CIFAR10-C](https://arxiv.org/abs/1903.12261) (out-of-distribution datasets) to demonstrate how to compute ProjNorm. 18 | 19 | ### Step 0: Download OOD data 20 | ```bash 21 | mkdir -p ./data/cifar 22 | curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar 23 | tar -xvf CIFAR-10-C.tar -C data/cifar/ 24 | ``` 25 | 26 | ### Step 1: Init base model and reference model 27 | ```bash 28 | python init_ref_model.py --arch resnet18 --train_epoch 20 --pseudo_iters 500 --lr 0.001 --batch_size 128 --seed 1 29 | ``` 30 | #### Arguments: 31 | * ```arch```: network architecture 32 | * ```train_epoch```: number of training epochs for training the base model 33 | * ```pseudo_iters```: number of iterations for training the reference model 34 | * ```lr```: learning rate 35 | * ```batch_size```: mini-batch size 36 | * ```seed```: random seed 37 | 38 | #### Output: 39 | 40 | The base model (```base_model```) and reference model (```reference_model```) are saved to ```'./checkpoints/{}'.format(arch)```. 41 | 42 | ### Step 2: Compute ProjNorm for in-distribution data and out-of-distribution data 43 | ```bash 44 | python main.py --arch resnet18 --corruption snow --severity 5 --pseudo_iters 500 --lr 0.001 --batch_size 128 --seed 1 45 | ``` 46 | #### Arguments: 47 | * ```arch```: network architecture (apply the same architecture as in **Step 1**) 48 | * ```corruption```: corruption type 49 | * ```severity```: corruption severity 50 | * ```pseudo_iters```: number of iterations for training the reference model 51 | * ```lr```: learning rate 52 | * ```batch_size```: mini-batch size 53 | * ```seed```: random seed (apply the same random seed as in **Step 1**) 54 | 55 | #### Output: 56 | 57 | (```in-distribution test error```, ```in-distribution ProjNorm value```) 58 | 59 | (```out-of-distribution test error```, ```out-of-distribution ProjNorm value```) 60 | 61 | ## Reference 62 | For more experimental and technical details, please check our [paper](https://arxiv.org/abs/2202.05834). If you find this useful for your work, please consider citing 63 | ``` 64 | @InProceedings{pmlr-v162-yu22i, 65 | title = {Predicting Out-of-Distribution Error with the Projection Norm}, 66 | author = {Yu, Yaodong and Yang, Zitong and Wei, Alexander and Ma, Yi and Steinhardt, Jacob}, 67 | booktitle = {Proceedings of the 39th International Conference on Machine Learning}, 68 | pages = {25721--25746}, 69 | year = {2022}, 70 | editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, 71 | volume = {162}, 72 | series = {Proceedings of Machine Learning Research}, 73 | month = {17--23 Jul}, 74 | publisher = {PMLR}, 75 | pdf = {https://proceedings.mlr.press/v162/yu22i/yu22i.pdf} 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /init_ref_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | 5 | from projnorm import * 6 | from load_data import * 7 | from model import ResNet18 8 | 9 | """# Configuration""" 10 | parser = argparse.ArgumentParser(description='ProjNorm.') 11 | parser.add_argument('--arch', default='resnet18', type=str) 12 | parser.add_argument('--cifar_data_path', default='./data', type=str) 13 | parser.add_argument('--cifar_corruption_path', default='./data/cifar/CIFAR-10-C', type=str) 14 | parser.add_argument('--pseudo_iters', default=50, type=int) 15 | parser.add_argument('--num_classes', default=10, type=int) 16 | parser.add_argument('--batch_size', default=128, type=int) 17 | parser.add_argument('--lr', default=0.001, type=float) 18 | parser.add_argument('--train_epoch', default=2, type=int) 19 | parser.add_argument('--seed', default=1, type=int) 20 | args = vars(parser.parse_args()) 21 | 22 | 23 | def train(net, trainloader): 24 | net.train() 25 | optimizer = optim.SGD(net.parameters(), lr=args['lr'], momentum=0.9, weight_decay=0.0) 26 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 27 | T_max=args['train_epoch'] * len(trainloader)) 28 | criterion = nn.CrossEntropyLoss() 29 | 30 | for epoch in range(args['train_epoch']): 31 | train_loss = 0 32 | correct = 0 33 | total = 0 34 | for batch_idx, (inputs, targets) in enumerate(trainloader): 35 | inputs, targets = inputs.cuda(), targets.cuda() 36 | optimizer.zero_grad() 37 | outputs = net(inputs) 38 | loss = criterion(outputs, targets) 39 | loss.backward() 40 | optimizer.step() 41 | train_loss += loss.item() 42 | _, predicted = outputs.max(1) 43 | total += targets.size(0) 44 | correct += predicted.eq(targets).sum().item() 45 | if batch_idx % 20 == 0: 46 | for param_group in optimizer.param_groups: 47 | current_lr = param_group['lr'] 48 | print('Epoch: ', epoch, '(', batch_idx, '/', len(trainloader), ')', 49 | 'Loss: %.3f | Acc: %.3f%% (%d/%d)| Lr: %.5f' % ( 50 | train_loss / (batch_idx + 1), 100. * correct / total, correct, total, current_lr)) 51 | scheduler.step() 52 | net.eval() 53 | 54 | return net 55 | 56 | 57 | if __name__ == "__main__": 58 | # save path 59 | save_dir_path = './checkpoints/{}'.format(args['arch']) 60 | if not os.path.exists(save_dir_path): 61 | os.makedirs(save_dir_path) 62 | 63 | # setup train/val_iid loaders 64 | trainset = load_cifar10_image(corruption_type='clean', 65 | clean_cifar_path=args['cifar_data_path'], 66 | corruption_cifar_path=args['cifar_corruption_path'], 67 | corruption_severity=0, 68 | datatype='train') 69 | trainloader = torch.utils.data.DataLoader(trainset, 70 | batch_size=args['batch_size'], 71 | shuffle=True) 72 | 73 | # init and train base model 74 | if args['arch'] == 'resnet18': 75 | base_model = ResNet18(num_classes=args['num_classes'], seed=args['seed']).cuda() 76 | else: 77 | raise ValueError('incorrect model name') 78 | 79 | base_model = train(base_model, trainloader) 80 | base_model.eval() 81 | torch.save(base_model, '{}/base_model.pt'.format(save_dir_path)) 82 | print('base model saved to', '{}/base_model.pt'.format(save_dir_path)) 83 | 84 | # init ProjNorm 85 | PN = ProjNorm(base_model=base_model) 86 | 87 | # train iid reference model 88 | if args['arch'] == 'resnet18': 89 | ref_model = ResNet18(num_classes=args['num_classes'], seed=args['seed']).cuda() 90 | else: 91 | raise ValueError('incorrect model name') 92 | 93 | PN.update_ref_model(trainloader, 94 | ref_model, 95 | lr=args['lr'], 96 | pseudo_iters=args['pseudo_iters']) 97 | torch.save(PN.reference_model.eval(), '{}/ref_model.pt'.format(save_dir_path)) 98 | print('reference model saved to', '{}/ref_model.pt'.format(save_dir_path)) 99 | 100 | print('========finished========') 101 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | import os 5 | import numpy as np 6 | 7 | 8 | def load_cifar10_image(corruption_type, 9 | clean_cifar_path, 10 | corruption_cifar_path, 11 | corruption_severity=0, 12 | datatype='test', 13 | num_samples=50000, 14 | seed=1): 15 | """ 16 | Returns: 17 | pytorch dataset object 18 | """ 19 | assert datatype == 'test' or datatype == 'train' 20 | training_flag = True if datatype == 'train' else False 21 | 22 | mean = [0.485, 0.456, 0.406] 23 | std = [0.229, 0.224, 0.225] 24 | transform = transforms.Compose([ 25 | transforms.Resize(224), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean, std), 28 | ]) 29 | 30 | dataset = datasets.CIFAR10(clean_cifar_path, 31 | train=training_flag, 32 | transform=transform, 33 | download=True) 34 | 35 | if corruption_severity > 0: 36 | assert not training_flag 37 | path_images = os.path.join(corruption_cifar_path, corruption_type + '.npy') 38 | path_labels = os.path.join(corruption_cifar_path, 'labels.npy') 39 | 40 | dataset.data = np.load(path_images)[(corruption_severity - 1) * 10000:corruption_severity * 10000] 41 | dataset.targets = list(np.load(path_labels)[(corruption_severity - 1) * 10000:corruption_severity * 10000]) 42 | dataset.targets = [int(item) for item in dataset.targets] 43 | 44 | # randomly permute data 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed(seed) 47 | number_samples = dataset.data.shape[0] 48 | index_permute = torch.randperm(number_samples) 49 | dataset.data = dataset.data[index_permute] 50 | dataset.targets = np.array([int(item) for item in dataset.targets]) 51 | dataset.targets = dataset.targets[index_permute].tolist() 52 | 53 | # randomly subsample data 54 | if datatype == 'train' and num_samples < 50000: 55 | indices = torch.randperm(50000)[:num_samples] 56 | dataset = torch.utils.data.Subset(dataset, indices) 57 | print('number of training data: ', len(dataset)) 58 | if datatype == 'test' and num_samples < 10000: 59 | indices = torch.randperm(10000)[:num_samples] 60 | dataset = torch.utils.data.Subset(dataset, indices) 61 | print('number of test data: ', len(dataset)) 62 | 63 | return dataset 64 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from projnorm import * 4 | from load_data import * 5 | from model import ResNet18 6 | from utils import evaluation 7 | 8 | """# Configuration""" 9 | parser = argparse.ArgumentParser(description='ProjNorm.') 10 | parser.add_argument('--arch', default='resnet18', type=str) 11 | parser.add_argument('--cifar_data_path', default='./data', type=str) 12 | parser.add_argument('--cifar_corruption_path', default='./data/cifar/CIFAR-10-C', type=str) 13 | parser.add_argument('--corruption', default='snow', type=str) 14 | parser.add_argument('--severity', default=5, type=int) 15 | parser.add_argument('--pseudo_iters', default=50, type=int) 16 | parser.add_argument('--num_classes', default=10, type=int) 17 | parser.add_argument('--batch_size', default=128, type=int) 18 | parser.add_argument('--lr', default=0.001, type=float) 19 | parser.add_argument('--seed', default=1, type=int) 20 | parser.add_argument('--use_base_model', action='store_true', 21 | default=False, help='apply base_model for computing ProjNorm') 22 | args = vars(parser.parse_args()) 23 | 24 | if __name__ == "__main__": 25 | # setup valset_iid/val_ood loaders 26 | random_seeds = torch.randint(0, 10000, (2,)) 27 | valset_iid = load_cifar10_image(corruption_type='clean', 28 | clean_cifar_path=args['cifar_data_path'], 29 | corruption_cifar_path=args['cifar_corruption_path'], 30 | corruption_severity=0, 31 | datatype='test', 32 | seed=random_seeds[0]) 33 | val_iid_loader = torch.utils.data.DataLoader(valset_iid, 34 | batch_size=args['batch_size'], 35 | shuffle=True) 36 | 37 | valset_ood = load_cifar10_image(corruption_type=args['corruption'], 38 | clean_cifar_path=args['cifar_data_path'], 39 | corruption_cifar_path=args['cifar_corruption_path'], 40 | corruption_severity=args['severity'], 41 | datatype='test', 42 | seed=random_seeds[1]) 43 | val_ood_loader = torch.utils.data.DataLoader(valset_ood, 44 | batch_size=args['batch_size'], 45 | shuffle=True) 46 | 47 | # init ProjNorm 48 | save_dir_path = './checkpoints/{}'.format(args['arch']) 49 | 50 | base_model = torch.load('{}/base_model.pt'.format(save_dir_path)) 51 | base_model.eval() 52 | PN = ProjNorm(base_model=base_model) 53 | 54 | if not args['use_base_model']: 55 | ref_model = torch.load('{}/ref_model.pt'.format(save_dir_path)) 56 | ref_model.eval() 57 | PN.reference_model = ref_model 58 | 59 | ################ train iid pseudo model ################ 60 | if args['arch'] == 'resnet18': 61 | pseudo_model = ResNet18(num_classes=args['num_classes'], seed=args['seed']).cuda() 62 | else: 63 | raise ValueError('incorrect model name') 64 | 65 | PN.update_pseudo_model(val_iid_loader, 66 | pseudo_model, 67 | lr=args['lr'], 68 | pseudo_iters=args['pseudo_iters']) 69 | 70 | # compute IID ProjNorm 71 | projnorm_value_iid = PN.compute_projnorm(PN.reference_model, PN.pseudo_model) 72 | 73 | ################ train ood pseudo model ################ 74 | if args['arch'] == 'resnet18': 75 | pseudo_model = ResNet18(num_classes=args['num_classes'], seed=args['seed']).cuda() 76 | else: 77 | raise ValueError('incorrect model name') 78 | 79 | PN.update_pseudo_model(val_ood_loader, 80 | pseudo_model, 81 | lr=args['lr'], 82 | pseudo_iters=args['pseudo_iters']) 83 | 84 | # compute OOD ProjNorm 85 | projnorm_value = PN.compute_projnorm(PN.reference_model, PN.pseudo_model) 86 | 87 | print('=============in-distribution=============') 88 | print('(in-distribution) ProjNorm value: ', projnorm_value_iid) 89 | test_loss_iid, test_error_iid = evaluation(net=base_model, testloader=val_iid_loader) 90 | print('(in-distribution) test error: ', test_error_iid) 91 | 92 | print('===========out-of-distribution===========') 93 | print('(out-of-distribution) ProjNorm value: ', projnorm_value) 94 | test_loss_ood, test_error_ood = evaluation(net=base_model, testloader=val_ood_loader) 95 | print('(out-of-distribution) test error: ', test_error_ood) 96 | print('========finished========') 97 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | def ResNet18(num_classes=10, seed=123): 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed(seed) 9 | resnet18 = models.resnet18(pretrained=True) 10 | resnet18.fc = nn.Linear(512, num_classes) 11 | return resnet18 12 | 13 | 14 | def ResNet50(num_classes=10, seed=123): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | resnet50 = models.resnet50(pretrained=True).cuda() 18 | resnet50.fc = nn.Linear(2048, num_classes).cuda() 19 | return resnet50 20 | -------------------------------------------------------------------------------- /projnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import copy 6 | 7 | 8 | def _weight_diff_norm_init(net, net_baseline): 9 | """ 10 | Returns: 11 | the l2 norm difference the two networks 12 | """ 13 | params1 = list(net.parameters()) 14 | params2 = list(net_baseline.parameters()) 15 | 16 | diff = 0 17 | for i in range(len(list(net.parameters()))): 18 | param1 = params1[i] 19 | param2 = params2[i] 20 | diff += (torch.norm(param1.flatten() - param2.flatten()) ** 2).cpu().detach().numpy() 21 | return np.sqrt(diff) 22 | 23 | 24 | class ProjNorm(torch.nn.Module): 25 | """ 26 | Projection Norm (ProjNorm) 27 | """ 28 | def __init__(self, base_model): 29 | super(ProjNorm, self).__init__() 30 | self.base_model = copy.deepcopy(base_model) 31 | self.reference_model = copy.deepcopy(base_model) 32 | self.pseudo_model = None 33 | self.max_epochs = 1000 34 | 35 | def update_pseudo_model(self, data_loader, pseudo_model, lr, pseudo_iters): 36 | optimizer = optim.SGD(pseudo_model.parameters(), 37 | lr=lr, 38 | momentum=0.9, 39 | weight_decay=0.0) 40 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pseudo_iters) 41 | criterion = nn.CrossEntropyLoss().cuda() 42 | trainloader_iterator = iter(data_loader) 43 | 44 | for iteration in range(1, pseudo_iters + 1): 45 | pseudo_model.train() 46 | 47 | try: 48 | inputs, targets = next(trainloader_iterator) 49 | except StopIteration: 50 | trainloader_iterator = iter(data_loader) 51 | inputs, targets = next(trainloader_iterator) 52 | if iteration == 1: 53 | print('targets[:10]:', targets[:10]) 54 | 55 | inputs = inputs.cuda() 56 | # pseudo-label by base_model 57 | _, pseudo_labels = self.base_model(inputs).max(1) 58 | pseudo_labels = pseudo_labels.detach() 59 | 60 | optimizer.zero_grad() 61 | outputs = pseudo_model(inputs) 62 | loss = criterion(outputs, pseudo_labels) 63 | loss.backward() 64 | optimizer.step() 65 | scheduler.step() 66 | train_loss = loss.item() * inputs.size(0) 67 | _, predicted = outputs.max(1) 68 | total = pseudo_labels.size(0) 69 | correct = predicted.eq(pseudo_labels).sum().item() 70 | if iteration % 20 == 0: 71 | current_lr = 0.0 72 | for param_group in optimizer.param_groups: 73 | current_lr = param_group['lr'] 74 | print('iteration {}: train loss: {:.6f}, train acc: {:.6f}, current lr: {:.6f}'.format(iteration, 75 | train_loss / total, 76 | correct / total, 77 | current_lr)) 78 | pseudo_model.eval() 79 | self.pseudo_model = copy.deepcopy(pseudo_model) 80 | print('========Pseudo-training finished========') 81 | 82 | def update_ref_model(self, data_loader, ref_model, lr, pseudo_iters): 83 | optimizer = optim.SGD(ref_model.parameters(), 84 | lr=lr, 85 | momentum=0.9, 86 | weight_decay=0.0) 87 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=pseudo_iters) 88 | criterion = nn.CrossEntropyLoss().cuda() 89 | trainloader_iterator = iter(data_loader) 90 | 91 | for iteration in range(1, pseudo_iters + 1): 92 | ref_model.train() 93 | try: 94 | inputs, targets = next(trainloader_iterator) 95 | except StopIteration: 96 | trainloader_iterator = iter(data_loader) 97 | inputs, targets = next(trainloader_iterator) 98 | if iteration == 1: 99 | print('targets[:10]:', targets[:10]) 100 | inputs, targets = inputs.cuda(), targets.cuda() 101 | optimizer.zero_grad() 102 | outputs = ref_model(inputs) 103 | loss = criterion(outputs, targets) 104 | loss.backward() 105 | optimizer.step() 106 | scheduler.step() 107 | 108 | train_loss = loss.item() * inputs.size(0) 109 | _, predicted = outputs.max(1) 110 | total = targets.size(0) 111 | correct = predicted.eq(targets).sum().item() 112 | 113 | if iteration % 20 == 0: 114 | current_lr = 0.0 115 | for param_group in optimizer.param_groups: 116 | current_lr = param_group['lr'] 117 | print('iteration {}: train loss: {:.6f}, train acc: {:.6f}, current lr: {:.6f}'.format(iteration, 118 | train_loss / total, 119 | correct / total, 120 | current_lr)) 121 | ref_model.eval() 122 | self.reference_model = copy.deepcopy(ref_model) 123 | print('========Pseudo-training (reference model) finished========') 124 | 125 | def compute_projnorm(self, model_ref, model_ood): 126 | return _weight_diff_norm_init(model_ref, model_ood) 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def evaluation(net, testloader): 6 | net.eval() 7 | criterion = nn.CrossEntropyLoss() 8 | test_loss = 0 9 | correct = 0 10 | total = 0 11 | with torch.no_grad(): 12 | for _, (inputs, targets) in enumerate(testloader): 13 | inputs, targets = inputs.cuda(), targets.cuda() 14 | outputs = net(inputs) 15 | loss = criterion(outputs, targets) 16 | test_loss += loss.item() * inputs.size(0) 17 | _, predicted = outputs.max(1) 18 | correct += predicted.eq(targets).sum().item() 19 | total += targets.size(0) 20 | return test_loss / total, 100. * correct / total --------------------------------------------------------------------------------