├── scripts ├── train_hypo_imagenet100.sh ├── train_hypo_cifar10.sh ├── eval_ckpt_pacs.sh ├── train_hypo_dg.sh ├── eval_ckpt_cifar10.sh └── download.py ├── models ├── __init__.py ├── densenet.py ├── head_wrn_vmf.py └── resnet.py ├── utils ├── __init__.py ├── modify_file_path.py ├── plot_3dsphere.py ├── measure_distances.py ├── display_results.py ├── losses.py └── util.py ├── .gitignore ├── dataloader ├── corimagenetLoader.py ├── corcifarLoader.py ├── OfficeHomeLoader.py ├── PACSLoader.py ├── TerraLoader.py └── VLCSLoader.py ├── make_datasets_cifar.py ├── README.md ├── eval_hypo.py └── train_hypo.py /scripts/train_hypo_imagenet100.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .densenet import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .losses import * 4 | 5 | from .util import * 6 | 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | logs/ 3 | save/ 4 | wandb/ 5 | datasets/ 6 | 7 | __pycache__/ 8 | .DS_Store 9 | 10 | *debug.sh -------------------------------------------------------------------------------- /scripts/train_hypo_cifar10.sh: -------------------------------------------------------------------------------- 1 | python train_hypo.py \ 2 | --in-dataset CIFAR-10 \ 3 | --id_loc datasets/CIFAR10 \ 4 | --gpu 6 \ 5 | --model resnet18 \ 6 | --loss hypo \ 7 | --epochs 500 \ 8 | --proto_m 0.95 \ 9 | --feat_dim 128 \ 10 | --batch_size 512 \ 11 | --w 2 \ 12 | --cosine 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /scripts/eval_ckpt_pacs.sh: -------------------------------------------------------------------------------- 1 | NAME=$1 2 | ID_DATASET=CIFAR-10 3 | ID_LOC=datasets/CIFAR10 4 | OOD_LOC=datasets/small_OOD_dataset 5 | 6 | 7 | python eval_hypo.py \ 8 | --epoch 500 \ 9 | --model resnet18 \ 10 | --head mlp \ 11 | --gpu 0 \ 12 | --in_dataset ${ID_DATASET} \ 13 | --id_loc ${ID_LOC} \ 14 | --ood_loc ${OOD_LOC} \ 15 | --name ${NAME} 16 | -------------------------------------------------------------------------------- /utils/modify_file_path.py: -------------------------------------------------------------------------------- 1 | ORIGINAL_ROOT_DIR = '/nobackup/hybai/datasets/Generalization/PACS' 2 | NEW_ROOT_DIR = 'datasets' 3 | 4 | domains = ['art_painting', 'cartoon', 'photo', 'sketch'] 5 | for domain in domains: 6 | input_file = f"{NEW_ROOT_DIR}/pacs_data/{domain}.txt" 7 | output_file = f"{NEW_ROOT_DIR}/pacs_data/{domain}.txt" 8 | with open(input_file, "r") as f: 9 | lines = f.readlines() 10 | updated_lines = [line.replace(ORIGINAL_ROOT_DIR, NEW_ROOT_DIR) for line in lines] 11 | with open(output_file, "w") as f: 12 | f.writelines(updated_lines) 13 | -------------------------------------------------------------------------------- /scripts/train_hypo_dg.sh: -------------------------------------------------------------------------------- 1 | # for cartoon, running script with 'learning_rate' 0.0005, 'batch_size' 32, 'w' 4; 2 | # for photo, running script with 'learning_rate' 0.0001, 'batch_size' 32, 'w' 1; 3 | # for sketch, running script with 'learning_rate' 0.002, 'batch_size' 32, 'w' 2; 4 | # for art_painting, running script with 'learning_rate' 0.0005, 'batch_size' 32, 'w' 1; 5 | python train_hypo.py \ 6 | --in-dataset PACS\ 7 | --id_loc datasets/PACS \ 8 | --gpu 1 \ 9 | --model resnet50 \ 10 | --loss hypo \ 11 | --epochs 50 \ 12 | --proto_m 0.95 \ 13 | --learning_rate 0.0005 \ 14 | --feat_dim 512 \ 15 | --batch_size 32 \ 16 | --target_domain cartoon \ 17 | --head mlp \ 18 | --w 4 \ 19 | --cosine 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/eval_ckpt_cifar10.sh: -------------------------------------------------------------------------------- 1 | ID_DATASET=CIFAR-10 2 | 3 | ID_LOC=/nobackup2/yf/datasets 4 | OOD_LOC=/nobackup2/yf/datasets 5 | 6 | CKPT_LOC=/nobackup2/yf/checkpoints/CIFAR-10 7 | CKPT_NAME=ckpt_hypo_resnet18_cifar10 8 | 9 | CKPT_LOC=/nobackup2/yf/checkpoints/hypo_cr/CIFAR-10/09_04_20:02_hypo_resnet18_lr_0.0005_cosine_True_bsz_512_head_mlp_wd_2.0_500_128_trial_0_temp_0.1_CIFAR-10_pm_0.95 10 | CKPT_NAME=checkpoint_max 11 | 12 | for cortype in 'gaussian_noise' 'zoom_blur' 'impulse_noise' 'defocus_blur' 'snow' 'brightness' 'contrast' 'elastic_transform' 'fog' 'frost' 'gaussian_blur' 'glass_blur' 'jpeg_compression' 'motion_blur' 'pixelate' 'saturate' 'shot_noise' 'spatter' 'speckle_noise' 13 | do 14 | python eval_hypo.py --model resnet18 --head mlp --gpu 0 --cortype=$cortype --in-dataset ${ID_DATASET} --id_loc ${ID_LOC} --ood_loc ${OOD_LOC} --ckpt_name ${CKPT_NAME} --ckpt_loc ${CKPT_LOC} 15 | done 16 | -------------------------------------------------------------------------------- /utils/plot_3dsphere.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | 5 | fig = plt.figure(figsize=(8, 6)) 6 | ax = fig.add_subplot(111, projection='3d') 7 | 8 | # Plot unit sphere with sparser gridlines 9 | phi = np.linspace(0, np.pi, 30) # Reduced from 100 to 50 for sparser gridlines 10 | theta = np.linspace(0, 2 * np.pi, 30) # Reduced from 100 to 50 for sparser gridlines 11 | phi, theta = np.meshgrid(phi, theta) 12 | 13 | x = np.sin(phi) * np.cos(theta) 14 | y = np.sin(phi) * np.sin(theta) 15 | z = np.cos(phi) 16 | 17 | ax.plot_surface(x, y, z, color='c', alpha=0.1, linewidth=0, antialiased=False) 18 | 19 | # Set axis labels 20 | ax.set_xlabel('X-axis', fontsize=12, labelpad=10) 21 | ax.set_ylabel('Y-axis', fontsize=12, labelpad=10) 22 | ax.set_zlabel('Z-axis', fontsize=12, labelpad=10) 23 | 24 | # Customize grid and background 25 | ax.xaxis.pane.fill = False 26 | ax.yaxis.pane.fill = False 27 | ax.zaxis.pane.fill = False 28 | ax.xaxis.pane.set_edgecolor('gray') 29 | ax.yaxis.pane.set_edgecolor('gray') 30 | ax.zaxis.pane.set_edgecolor('gray') 31 | ax.grid(color='gray', linestyle='-', linewidth=0.5, alpha=0.5) 32 | 33 | # Disable the X, Y, Z axes 34 | ax.set_axis_off() 35 | 36 | # Set the background transparent 37 | fig.patch.set_alpha(0) 38 | ax.patch.set_alpha(0) 39 | 40 | # Save and show the plot with transparent background 41 | plt.savefig('3d_plot_unit_sphere_no_title_sparser_grid_transparent_background.png', dpi=300, bbox_inches='tight', transparent=True) 42 | 43 | -------------------------------------------------------------------------------- /dataloader/corimagenetLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | 6 | from torchvision import datasets 7 | import torch 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader, ConcatDataset 10 | 11 | import os.path as osp 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | 15 | import os 16 | 17 | from random import sample, random 18 | 19 | CorIMAGENET100_train_path = './ImageNet-100-C/train' 20 | CorIMAGENET100_test_path = './ImageNet-100-C/val' 21 | 22 | class CorIMAGENETDataset(data.Dataset): 23 | def __init__(self, set_name, cortype): 24 | 25 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]) 27 | 28 | self._image_transformer = transforms.Compose([ 29 | transforms.ToTensor(), 30 | normalize 31 | ]) 32 | 33 | self.set_name = set_name 34 | 35 | if set_name == 'train': 36 | images = np.load(os.path.join(CorIMAGENET100_train_path, cortype + '.npy')) 37 | labels = np.load(os.path.join(CorIMAGENET100_train_path, 'labels.npy')) 38 | elif set_name == 'test': 39 | images = np.load(os.path.join(CorIMAGENET100_test_path, cortype + '.npy')) 40 | labels = np.load(os.path.join(CorIMAGENET100_test_path, 'labels.npy')) 41 | 42 | self.data = images 43 | self.label = labels 44 | 45 | self.num_class = 100 46 | 47 | def __getitem__(self, index): 48 | img, label = self.data[index], self.label[index] 49 | img = self._image_transformer(img) 50 | 51 | return img, label 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /dataloader/corcifarLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | 6 | from torchvision import datasets 7 | import torch 8 | from torchvision import transforms 9 | from torch.utils.data import DataLoader, ConcatDataset 10 | 11 | import os.path as osp 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | 15 | import os 16 | 17 | from random import sample, random 18 | 19 | 20 | 21 | class CorCIFARDataset(data.Dataset): 22 | def __init__(self, ood_loc, set_name, cortype, dataset): 23 | 24 | if dataset == 'CIFAR-10': 25 | print('loading CorCIFAR-10') 26 | 27 | # CorCIFAR_train_path = './data/cifar10_trainc' 28 | # CorCIFAR_test_path = './data/cifar10_testc' 29 | CorCIFAR_train_path = f'{ood_loc}/CorCIFAR10_train' 30 | CorCIFAR_test_path = f'{ood_loc}/CorCIFAR10_test' 31 | self.num_class = 10 32 | 33 | # elif dataset in ['CIFAR-100']: 34 | # print('loading CorCIFAR-100') 35 | 36 | # CorCIFAR_train_path = './data/cifar100_trainc' 37 | # CorCIFAR_test_path = './data/cifar100_testc' 38 | # self.num_class = 100 39 | 40 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 41 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 42 | 43 | self._image_transformer = transforms.Compose([ 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean, std) 46 | ]) 47 | 48 | if set_name == 'train': 49 | images=np.load(os.path.join(CorCIFAR_train_path, cortype + '.npy')) 50 | labels = np.load(os.path.join(CorCIFAR_train_path, 'labels.npy')) 51 | elif set_name == 'test': 52 | images = np.load(os.path.join(CorCIFAR_test_path, cortype + '.npy')) 53 | labels = np.load(os.path.join(CorCIFAR_test_path, 'labels.npy')) 54 | 55 | self.data = images 56 | self.label = labels 57 | 58 | def __getitem__(self, index): 59 | img, label = self.data[index], self.label[index] 60 | img = self._image_transformer(img) 61 | 62 | return img, label 63 | 64 | def __len__(self): 65 | return len(self.data) 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /make_datasets_cifar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import argparse 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | 10 | import torchvision 11 | 12 | import torchvision.transforms as trn 13 | import torchvision.datasets as dset 14 | import torch.nn.functional as F 15 | import matplotlib.pyplot as plt 16 | from torch.utils.data import TensorDataset 17 | 18 | from os import path 19 | import sys 20 | sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) 21 | 22 | import pathlib 23 | 24 | ''' 25 | This script makes the datasets used in eval cifar. The main function is make_datasets. 26 | ''' 27 | 28 | 29 | 30 | def load_CIFAR(id_loc, dataset, classes=[]): 31 | 32 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 33 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 34 | 35 | # train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4), 36 | # trn.ToTensor(), trn.Normalize(mean, std)]) 37 | train_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]) 38 | test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]) 39 | 40 | if dataset == 'CIFAR-10': 41 | cifar10_path = f'{id_loc}/cifar10' 42 | print('loading CIFAR-10') 43 | train_data = dset.CIFAR10( 44 | cifar10_path, train=True, transform=train_transform, download=True) 45 | test_data = dset.CIFAR10( 46 | cifar10_path, train=False, transform=test_transform, download=True) 47 | 48 | # elif dataset in ['CIFAR-100']: 49 | # cifar100_path = f'{id_loc}/cifar100' 50 | # print('loading CIFAR-100') 51 | # train_data = dset.CIFAR100( 52 | # cifar100_path, train=True, transform=train_transform, download=True) 53 | # test_data = dset.CIFAR100( 54 | # cifar100_path, train=False, transform=test_transform, download=True) 55 | 56 | return train_data, test_data 57 | 58 | def load_CorCifar(ood_loc, dataset, cortype): 59 | 60 | if dataset == 'CIFAR-10': 61 | print('loading CorCIFAR-10') 62 | 63 | from dataloader.corcifarLoader import CorCIFARDataset as Dataset 64 | 65 | train_data = Dataset(ood_loc, 'train', cortype, dataset) 66 | test_data = Dataset(ood_loc, 'test', cortype, dataset) 67 | 68 | # elif dataset in ['CIFAR-100']: 69 | # print('loading CorCIFAR-100') 70 | 71 | # from dataloader.corcifar100Loader import CorCIFARDataset as Dataset 72 | 73 | # train_data = Dataset('train', cortype, dataset) 74 | # test_data = Dataset('test', cortype, dataset) 75 | 76 | return train_data, test_data 77 | 78 | def make_datasets(id_loc, ood_loc, in_dset, state, cortype): 79 | #rng = np.random.default_rng(state['seed']) 80 | 81 | print('building datasets...') 82 | train_in_data, test_in_data = load_CIFAR(id_loc, in_dset) 83 | train_cor_data, test_cor_data = load_CorCifar(ood_loc, in_dset, cortype) 84 | 85 | test_loader_in = torch.utils.data.DataLoader( 86 | test_in_data, 87 | batch_size=state['batch_size'], shuffle=False, 88 | num_workers=state['prefetch'], pin_memory=True) 89 | 90 | test_loader_cor = torch.utils.data.DataLoader( 91 | test_cor_data, 92 | batch_size=state['batch_size'], shuffle=False, 93 | num_workers=state['prefetch'], pin_memory=True) 94 | 95 | return test_loader_in, test_loader_cor 96 | -------------------------------------------------------------------------------- /dataloader/OfficeHomeLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from PIL import Image 6 | from random import sample, random 7 | 8 | 9 | from torchvision import datasets 10 | import torch 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | 14 | 15 | import os.path as osp 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | 19 | import os 20 | 21 | from random import sample, random 22 | 23 | ROOT_PATH = 'datasets/OfficeHomeDataset/' 24 | 25 | class OfficeHomeDataset(data.Dataset): 26 | def __init__(self, setname, target_domain, augment_full = True): 27 | 28 | self._image_transformer_test = transforms.Compose([ 29 | transforms.Resize((224,224)), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | 34 | if augment_full: 35 | self._image_transformer_train = transforms.Compose([ 36 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 39 | transforms.RandomGrayscale(), 40 | transforms.ToTensor(), 41 | transforms.Normalize( 42 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 43 | ]) 44 | else: 45 | self._image_transformer_train = self._image_transformer_test 46 | 47 | if setname != 'test': 48 | fulldata = [] 49 | label_name = [] 50 | fullconcept = [] 51 | i = 0 52 | domain = ['Art', 'Clipart', 'Product', 'RealWorld'] 53 | domain.remove(target_domain) 54 | for domain_name in domain: 55 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 56 | images, labels = self._dataset_info(txt_path) 57 | concept = [i] * len(labels) 58 | fulldata.extend(images) 59 | label_name.extend(labels) 60 | fullconcept.extend(concept) 61 | i += 1 62 | 63 | name_train, name_val, labels_train, labels_val, concept_train, concept_val = self.get_random_subset(fulldata, label_name, fullconcept, 0.1) 64 | 65 | if setname == 'train': 66 | self.data = name_train 67 | self.label = labels_train 68 | self.concept = concept_train 69 | elif setname == 'val': 70 | self.data = name_val 71 | self.label = labels_val 72 | self.concept = concept_val 73 | else: 74 | domain_name = target_domain 75 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 76 | self.data, self.label = self._dataset_info(txt_path) 77 | self.concept = [-1] * len(self.label) 78 | 79 | self.setname = setname 80 | 81 | 82 | def _dataset_info(self, txt_labels): 83 | with open(txt_labels, 'r') as f: 84 | images_list = f.readlines() 85 | 86 | file_names = [] 87 | labels = [] 88 | for row in images_list: 89 | row = row.split(' ') 90 | path = os.path.join(row[0]) 91 | path = path.replace('\\', '/') 92 | 93 | file_names.append(path) 94 | labels.append(int(row[1])) 95 | 96 | return file_names, labels 97 | 98 | def get_random_subset(self, names, labels, concepts, percent): 99 | """ 100 | :param names: list of names 101 | :param labels: list of labels 102 | :param percent: 0 < float < 1 103 | :return: 104 | """ 105 | samples = len(names) 106 | amount = int(samples * percent) 107 | random_index = sample(range(samples), amount) 108 | name_val = [names[k] for k in random_index] 109 | name_train = [v for k, v in enumerate(names) if k not in random_index] 110 | labels_val = [labels[k] for k in random_index] 111 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 112 | concepts_val = [concepts[k] for k in random_index] 113 | concepts_train = [v for k, v in enumerate(concepts) if k not in random_index] 114 | 115 | return name_train, name_val, labels_train, labels_val, concepts_train, concepts_val 116 | 117 | 118 | 119 | 120 | def __getitem__(self, index): 121 | data, label, concept= self.data[index], self.label[index], self.concept[index] 122 | 123 | _img = Image.open(data).convert('RGB') 124 | if self.setname == 'val' or self.setname == 'test': 125 | img = self._image_transformer_test(_img) 126 | return img, label, concept 127 | else: 128 | img = self._image_transformer_train(_img) 129 | return img, label, concept 130 | 131 | def __len__(self): 132 | return len(self.data) 133 | 134 | -------------------------------------------------------------------------------- /dataloader/PACSLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from PIL import Image 6 | from random import sample, random 7 | 8 | 9 | from torchvision import datasets 10 | import torch 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | 14 | 15 | import os.path as osp 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | 19 | import os 20 | 21 | from random import sample, random 22 | 23 | ROOT_PATH = '/datasets/pacs_data/' 24 | 25 | class PACSDataset(data.Dataset): 26 | def __init__(self, setname, target_domain, augment_full = True): 27 | 28 | self._image_transformer_test = transforms.Compose([ 29 | transforms.Resize(224, Image.BILINEAR), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | 34 | if augment_full: 35 | self._image_transformer_train = transforms.Compose([ 36 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 39 | transforms.RandomGrayscale(), 40 | transforms.ToTensor(), 41 | transforms.Normalize( 42 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 43 | ]) 44 | else: 45 | self._image_transformer_train = self._image_transformer_test 46 | 47 | if setname != 'test': 48 | fulldata = [] 49 | label_name = [] 50 | fullconcept = [] 51 | i = 0 52 | domain = ['cartoon', 'art_painting', 'photo', 'sketch'] 53 | domain.remove(target_domain) 54 | for domain_name in domain: 55 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 56 | images, labels = self._dataset_info(txt_path) 57 | concept = [i] *len(labels) 58 | fulldata.extend(images) 59 | label_name.extend(labels) 60 | fullconcept.extend(concept) 61 | i += 1 62 | 63 | name_train, name_val, labels_train, labels_val, concept_train, concept_val = self.get_random_subset(fulldata, label_name, fullconcept, 0.1) 64 | 65 | if setname == 'train': 66 | self.data = name_train 67 | self.label = labels_train 68 | self.concept = concept_train 69 | elif setname == 'val': 70 | self.data = name_val 71 | self.label = labels_val 72 | self.concept = concept_val 73 | else: 74 | domain_name = target_domain 75 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 76 | self.data, self.label = self._dataset_info(txt_path) 77 | self.concept = [-1] * len(self.label) #dummy 78 | 79 | self.setname = setname 80 | 81 | 82 | def _dataset_info(self, txt_labels): 83 | with open(txt_labels, 'r') as f: 84 | images_list = f.readlines() 85 | 86 | file_names = [] 87 | labels = [] 88 | for row in images_list: 89 | row = row.split(' ') 90 | path = os.path.join(row[0]) 91 | path = path.replace('\\', '/') 92 | 93 | file_names.append(path) 94 | labels.append(int(row[1])) 95 | 96 | return file_names, labels 97 | 98 | def get_random_subset(self, names, labels, concepts, percent): 99 | """ 100 | :param names: list of names 101 | :param labels: list of labels 102 | :param percent: 0 < float < 1 103 | :return: 104 | """ 105 | samples = len(names) 106 | amount = int(samples * percent) 107 | random_index = sample(range(samples), amount) 108 | name_val = [names[k] for k in random_index] 109 | name_train = [v for k, v in enumerate(names) if k not in random_index] 110 | labels_val = [labels[k] for k in random_index] 111 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 112 | concepts_val = [concepts[k] for k in random_index] 113 | concepts_train = [v for k, v in enumerate(concepts) if k not in random_index] 114 | 115 | return name_train, name_val, labels_train, labels_val, concepts_train, concepts_val 116 | 117 | 118 | 119 | 120 | def __getitem__(self, index): 121 | data, label, concept= self.data[index], self.label[index], self.concept[index] 122 | 123 | _img = Image.open(data).convert('RGB') 124 | if self.setname == 'val' or self.setname == 'test': 125 | img = self._image_transformer_test(_img) 126 | return img, label, concept 127 | else: 128 | img = self._image_transformer_train(_img) 129 | return img, label, concept 130 | 131 | def __len__(self): 132 | return len(self.data) 133 | 134 | -------------------------------------------------------------------------------- /dataloader/TerraLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from PIL import Image 6 | from random import sample, random 7 | 8 | 9 | from torchvision import datasets 10 | import torch 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | 14 | 15 | import os.path as osp 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | 19 | import os 20 | 21 | from random import sample, random 22 | 23 | ROOT_PATH = '/datasets/terra_incognita' 24 | 25 | class TerraDataset(data.Dataset): 26 | def __init__(self, setname, target_domain, augment_full = True): 27 | 28 | self._image_transformer_test = transforms.Compose([ 29 | transforms.Resize(224, Image.BILINEAR), 30 | transforms.ToTensor(), 31 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 32 | ]) 33 | 34 | if augment_full: 35 | self._image_transformer_train = transforms.Compose([ 36 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 39 | transforms.RandomGrayscale(), 40 | transforms.ToTensor(), 41 | transforms.Normalize( 42 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 43 | ]) 44 | else: 45 | self._image_transformer_train = self._image_transformer_test 46 | 47 | if setname != 'test': 48 | fulldata = [] 49 | label_name = [] 50 | fullconcept = [] 51 | i = 0 52 | domain = ['location_38', 'location_43', 'location_46', 'location_100'] 53 | domain.remove(target_domain) 54 | for domain_name in domain: 55 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 56 | images, labels = self._dataset_info(txt_path) 57 | concept = [i] *len(labels) 58 | fulldata.extend(images) 59 | label_name.extend(labels) 60 | fullconcept.extend(concept) 61 | i += 1 62 | 63 | name_train, name_val, labels_train, labels_val, concept_train, concept_val = self.get_random_subset(fulldata, label_name, fullconcept, 0.1) 64 | 65 | if setname == 'train': 66 | self.data = name_train 67 | self.label = labels_train 68 | self.concept = concept_train 69 | elif setname == 'val': 70 | self.data = name_val 71 | self.label = labels_val 72 | self.concept = concept_val 73 | else: 74 | domain_name = target_domain 75 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 76 | self.data, self.label = self._dataset_info(txt_path) 77 | self.concept = [-1] *len(self.label) 78 | 79 | self.setname = setname 80 | 81 | 82 | def _dataset_info(self, txt_labels): 83 | with open(txt_labels, 'r') as f: 84 | images_list = f.readlines() 85 | 86 | file_names = [] 87 | labels = [] 88 | for row in images_list: 89 | row = row.split(' ') 90 | path = os.path.join(row[0]) 91 | path = path.replace('\\', '/') 92 | 93 | file_names.append(path) 94 | labels.append(int(row[1])) 95 | 96 | return file_names, labels 97 | 98 | def get_random_subset(self, names, labels, concepts, percent): 99 | """ 100 | :param names: list of names 101 | :param labels: list of labels 102 | :param percent: 0 < float < 1 103 | :return: 104 | """ 105 | samples = len(names) 106 | amount = int(samples * percent) 107 | random_index = sample(range(samples), amount) 108 | name_val = [names[k] for k in random_index] 109 | name_train = [v for k, v in enumerate(names) if k not in random_index] 110 | labels_val = [labels[k] for k in random_index] 111 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 112 | concepts_val = [concepts[k] for k in random_index] 113 | concepts_train = [v for k, v in enumerate(concepts) if k not in random_index] 114 | 115 | return name_train, name_val, labels_train, labels_val, concepts_train, concepts_val 116 | 117 | 118 | 119 | 120 | def __getitem__(self, index): 121 | data, label, concept= self.data[index], self.label[index], self.concept[index] 122 | 123 | _img = Image.open(data).convert('RGB') 124 | if self.setname == 'val' or self.setname == 'test': 125 | img = self._image_transformer_test(_img) 126 | return img, label, concept 127 | else: 128 | img = self._image_transformer_train(_img) 129 | return img, label, concept 130 | 131 | def __len__(self): 132 | return len(self.data) 133 | 134 | -------------------------------------------------------------------------------- /dataloader/VLCSLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | from PIL import Image 6 | from random import sample, random 7 | 8 | 9 | from torchvision import datasets 10 | import torch 11 | from torchvision import transforms 12 | from torch.utils.data import DataLoader, ConcatDataset 13 | 14 | 15 | import os.path as osp 16 | from torch.utils.data import Dataset 17 | from tqdm import tqdm 18 | 19 | import os 20 | 21 | from random import sample, random 22 | 23 | ROOT_PATH = '/datasets/VLCSDataset/' 24 | 25 | from PIL import ImageFile 26 | ImageFile.LOAD_TRUNCATED_IMAGES = True 27 | 28 | class VLCSDataset(data.Dataset): 29 | def __init__(self, setname, target_domain, augment_full = True): 30 | 31 | self._image_transformer_test = transforms.Compose([ 32 | transforms.Resize((224,224)), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | 37 | if augment_full: 38 | self._image_transformer_train = transforms.Compose([ 39 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 42 | transforms.RandomGrayscale(), 43 | transforms.ToTensor(), 44 | transforms.Normalize( 45 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 46 | ]) 47 | else: 48 | self._image_transformer_train = self._image_transformer_test 49 | 50 | if setname != 'test': 51 | fulldata = [] 52 | label_name = [] 53 | fullconcept = [] 54 | i = 0 55 | domain = ['Caltech101', 'LabelMe', 'SUN09', 'VOC2007'] 56 | domain.remove(target_domain) 57 | for domain_name in domain: 58 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 59 | images, labels = self._dataset_info(txt_path) 60 | concept = [i] *len(labels) 61 | fulldata.extend(images) 62 | label_name.extend(labels) 63 | fullconcept.extend(concept) 64 | i += 1 65 | 66 | name_train, name_val, labels_train, labels_val, concept_train, concept_val = self.get_random_subset(fulldata, label_name, fullconcept, 0.1) 67 | 68 | if setname == 'train': 69 | self.data = name_train 70 | self.label = labels_train 71 | self.concept = concept_train 72 | elif setname == 'val': 73 | self.data = name_val 74 | self.label = labels_val 75 | self.concept = concept_val 76 | else: 77 | domain_name = target_domain 78 | txt_path = os.path.join(ROOT_PATH, domain_name + '.txt') 79 | self.data, self.label = self._dataset_info(txt_path) 80 | self.concept = [-1]*len(self.label) 81 | 82 | self.setname = setname 83 | 84 | 85 | def _dataset_info(self, txt_labels): 86 | with open(txt_labels, 'r') as f: 87 | images_list = f.readlines() 88 | 89 | file_names = [] 90 | labels = [] 91 | for row in images_list: 92 | row = row.split(' ') 93 | path = os.path.join(row[0]) 94 | path = path.replace('\\', '/') 95 | 96 | file_names.append(path) 97 | labels.append(int(row[1])) 98 | 99 | return file_names, labels 100 | 101 | def get_random_subset(self, names, labels, concepts, percent): 102 | """ 103 | :param names: list of names 104 | :param labels: list of labels 105 | :param percent: 0 < float < 1 106 | :return: 107 | """ 108 | samples = len(names) 109 | amount = int(samples * percent) 110 | random_index = sample(range(samples), amount) 111 | name_val = [names[k] for k in random_index] 112 | name_train = [v for k, v in enumerate(names) if k not in random_index] 113 | labels_val = [labels[k] for k in random_index] 114 | labels_train = [v for k, v in enumerate(labels) if k not in random_index] 115 | concepts_val = [concepts[k] for k in random_index] 116 | concepts_train = [v for k, v in enumerate(concepts) if k not in random_index] 117 | 118 | return name_train, name_val, labels_train, labels_val, concepts_train, concepts_val 119 | 120 | 121 | 122 | 123 | def __getitem__(self, index): 124 | data, label, concept= self.data[index], self.label[index], self.concept[index] 125 | 126 | _img = Image.open(data).convert('RGB') 127 | if self.setname == 'val' or self.setname == 'test': 128 | img = self._image_transformer_test(_img) 129 | return img, label, concept 130 | else: 131 | img = self._image_transformer_train(_img) 132 | return img, label, concept 133 | 134 | def __len__(self): 135 | return len(self.data) 136 | 137 | -------------------------------------------------------------------------------- /utils/measure_distances.py: -------------------------------------------------------------------------------- 1 | import scipy.io as io 2 | import numpy as np 3 | import ot 4 | from itertools import combinations 5 | import torch 6 | from geomloss import SamplesLoss 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def compute_wasserstein_distance(X1, X2, metric = 'euclidean', reg=1e-2): 11 | ''' 12 | numerically unstable with OT library 13 | ''' 14 | # Calculate the cost matrix, which is the pairwise Euclidean distance 15 | cost_matrix = ot.dist(X1, X2, metric=metric) 16 | 17 | # Compute the empirical distribution weights, assuming all samples have equal probability 18 | n1_samples, n2_samples = X1.shape[0], X2.shape[0] 19 | weights_x1 = np.ones(n1_samples) / n2_samples 20 | weights_x2 = np.ones(n2_samples) / n2_samples 21 | 22 | # Compute the Wasserstein distance using the Sinkhorn algorithm 23 | wasserstein_distance = ot.sinkhorn2(weights_x1, weights_x2, cost_matrix, reg=reg, stopThr=1e-8) 24 | 25 | return wasserstein_distance 26 | 27 | 28 | def compute_wasserstein_distance2(X1, X2, p=2, reg=0.01): 29 | # https://www.kernel-operations.io/geomloss/api/pytorch-api.html 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | X1_torch = torch.tensor(X1, dtype=torch.float32, device=device) 32 | X2_torch = torch.tensor(X2, dtype=torch.float32, device=device) 33 | 34 | # Initialize the SamplesLoss function with the Wasserstein distance (p=2) 35 | # loss="sinkhorn": (Un-biased) Sinkhorn divergence, which interpolates between Wasserstein (blur=0) and kernel (blur= 36 | #) distances. 37 | loss = SamplesLoss(loss="sinkhorn", p=p, blur=reg) 38 | 39 | # Compute the Wasserstein distance 40 | wd = loss(X1_torch, X2_torch) 41 | 42 | # Convert the result to a scalar value 43 | wd_scalar = wd.item() 44 | return wd_scalar 45 | 46 | 47 | def plot_heatmap(data, save_name): 48 | # Find the maximum values of y and e 49 | max_y = max(data.keys()) 50 | # max_e = max(max(pair for pair in data[y].keys()) for y in data.keys()) #e.g. (2,3) 51 | 52 | unique_pairs = list(data[0].keys()) 53 | print(unique_pairs) 54 | # Initialize an empty matrix to store the data 55 | matrix = np.zeros((max_y + 1, len(unique_pairs))) 56 | 57 | for y, pairs in data.items(): 58 | for (e1, e2), value in pairs.items(): 59 | pair_index = unique_pairs.index((e1, e2)) 60 | matrix[y, pair_index] = value 61 | 62 | # Plot the heatmap using plt.imshow 63 | fig, ax = plt.subplots(figsize=(8, 7)) 64 | 65 | im = ax.imshow(matrix, interpolation="nearest", aspect="auto") 66 | ax.set_xlabel("(e1, e2)", fontsize=14) 67 | ax.set_xticks(range(len(unique_pairs)), unique_pairs, rotation=45, fontsize=12) 68 | ax.set_ylabel("y", fontsize=14) 69 | ax.tick_params(axis='y', labelsize=12) # y tickle size 70 | 71 | # Add numbers to each block of the heatmap 72 | for i in range(matrix.shape[0]): 73 | for j in range(matrix.shape[1]): 74 | value = matrix[i, j] 75 | ax.text(j, i, f"{value:.2f}", ha="center", va="center", color="w", fontsize=10) 76 | 77 | plt.tight_layout() 78 | fig.colorbar(im, ax=ax, label="Wasserstein distance (sinkhorn divergence)") 79 | plt.title("") 80 | plt.tight_layout() 81 | plt.savefig(f'plots/{save_name}.png', dpi = 300,bbox_inches='tight') 82 | 83 | 84 | if __name__ == '__main__': 85 | # classwise 86 | normalize = True 87 | for loss in ['ce', 'cider']: 88 | for normalize in ['no_norm', 'norm']: 89 | save_name = f'feature_y_e_{loss}_penultimate_{normalize}' 90 | res = {} 91 | feature_y_e=io.loadmat(f'features/{save_name}.mat') 92 | all_features = feature_y_e['feature'] 93 | all_y = feature_y_e['y'].squeeze() 94 | all_e = feature_y_e['e'].squeeze() 95 | unique_y = np.unique(all_y) 96 | unique_e = np.unique(all_e) 97 | unique_e_pairs = list(combinations(unique_e, 2)) 98 | for y in unique_y: 99 | res[y] = {} 100 | for (e1, e2) in unique_e_pairs: 101 | feat_1 = all_features[(all_y == y) & (all_e == e1)] # element-wise & (bitwise AND) operator to combine the boolean arrays; Remember to use parentheses around each condition since the & operator has higher precedence. 102 | feat_2 = all_features[(all_y==y) & (all_e == e2)] 103 | # print(feat_1.shape, feat_2.shape) 104 | # wd = compute_wasserstein_distance2(feat_1.astype('float64'), feat_2.astype('float64')) 105 | wd = compute_wasserstein_distance2(feat_1, feat_2) 106 | print(f'y = {y}, estimated W-dist for ({e1}, {e2}) is {wd}') 107 | res[y][(e1,e2)] = wd 108 | 109 | print(res) 110 | # debug test 111 | # data = {0: {(0, 1): 0.15530866384506226, (0, 2): 0.21642754971981049, (0, 3): 0.23963293433189392, (1, 2): 0.20890945196151733, (1, 3): 0.26223963499069214, (2, 3): 0.2195998877286911}, 1: {(0, 1): 0.14590714871883392, (0, 2): 0.23728705942630768, (0, 3): 0.28713685274124146, (1, 2): 0.23684373497962952, (1, 3): 0.3055537939071655, (2, 3): 0.20455986261367798}, 2: {(0, 1): 0.15946874022483826, (0, 2): 0.1995120495557785, (0, 3): 0.1753748506307602, (1, 2): 0.20574313402175903, (1, 3): 0.2013503611087799, (2, 3): 0.17896652221679688}, 3: {(0, 1): 0.13759663701057434, (0, 2): 0.18846842646598816, (0, 3): 0.1469249725341797, (1, 2): 0.15395350754261017, (1, 3): 0.093973807990551, (2, 3): 0.12456567585468292}, 4: {(0, 1): 0.16651654243469238, (0, 2): 0.23089328408241272, (0, 3): 0.2536894679069519, (1, 2): 0.22801926732063293, (1, 3): 0.2667630910873413, (2, 3): 0.23195475339889526}, 5: {(0, 1): 0.10270039737224579, (0, 2): 0.20019754767417908, (0, 3): 0.16466757655143738, (1, 2): 0.1986047923564911, (1, 3): 0.16104549169540405, (2, 3): 0.17241033911705017}, 6: {(0, 1): 0.2291284203529358, (0, 2): 0.24648775160312653, (0, 3): 0.24859079718589783, (1, 2): 0.2752886712551117, (1, 3): 0.3062085807323456, (2, 3): 0.25627970695495605}} 112 | plot_heatmap(res,save_name) -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | 13 | import torchvision.models as models 14 | 15 | import sys 16 | import math 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, nChannels, growthRate): 20 | super(Bottleneck, self).__init__() 21 | interChannels = 4*growthRate 22 | self.bn1 = nn.BatchNorm2d(nChannels) 23 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 24 | bias=False) 25 | self.bn2 = nn.BatchNorm2d(interChannels) 26 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 27 | padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | out = self.conv1(F.relu(self.bn1(x))) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | out = torch.cat((x, out), 1) 33 | return out 34 | 35 | class SingleLayer(nn.Module): 36 | def __init__(self, nChannels, growthRate): 37 | super(SingleLayer, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(nChannels) 39 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 40 | padding=1, bias=False) 41 | 42 | def forward(self, x): 43 | out = self.conv1(F.relu(self.bn1(x))) 44 | out = torch.cat((x, out), 1) 45 | return out 46 | 47 | class Transition(nn.Module): 48 | def __init__(self, nChannels, nOutChannels): 49 | super(Transition, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(nChannels) 51 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 52 | bias=False) 53 | 54 | def forward(self, x): 55 | out = self.conv1(F.relu(self.bn1(x))) 56 | out = F.avg_pool2d(out, 2) 57 | return out 58 | 59 | class DenseNet(nn.Module): 60 | def __init__(self, growthRate = 12, depth = 100, reduction = 0.5, bottleneck = True): 61 | super(DenseNet, self).__init__() 62 | 63 | nDenseBlocks = (depth-4) // 3 64 | if bottleneck: 65 | nDenseBlocks //= 2 66 | 67 | nChannels = 2*growthRate 68 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 69 | bias=False) 70 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 71 | nChannels += nDenseBlocks*growthRate 72 | nOutChannels = int(math.floor(nChannels*reduction)) 73 | self.trans1 = Transition(nChannels, nOutChannels) 74 | 75 | nChannels = nOutChannels 76 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 77 | nChannels += nDenseBlocks*growthRate 78 | nOutChannels = int(math.floor(nChannels*reduction)) 79 | self.trans2 = Transition(nChannels, nOutChannels) 80 | 81 | nChannels = nOutChannels 82 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 83 | nChannels += nDenseBlocks*growthRate 84 | 85 | self.bn1 = nn.BatchNorm2d(nChannels) 86 | # self.fc = nn.Linear(nChannels, nClasses) 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 91 | m.weight.data.normal_(0, math.sqrt(2. / n)) 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | elif isinstance(m, nn.Linear): 96 | m.bias.data.zero_() 97 | 98 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 99 | layers = [] 100 | for i in range(int(nDenseBlocks)): 101 | if bottleneck: 102 | layers.append(Bottleneck(nChannels, growthRate)) 103 | else: 104 | layers.append(SingleLayer(nChannels, growthRate)) 105 | nChannels += growthRate 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | out = self.conv1(x) 110 | out = self.trans1(self.dense1(out)) 111 | out = self.trans2(self.dense2(out)) 112 | out = self.dense3(out) 113 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 114 | # out = F.log_softmax(self.fc(out)) 115 | return out 116 | 117 | def intermediate_forward(self, x, layer_index): 118 | out = self.conv1(x) 119 | out = self.trans1(self.dense1(out)) 120 | out = self.trans2(self.dense2(out)) 121 | out = self.dense3(out) 122 | out = F.relu(self.bn1(out)) 123 | # out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 124 | # out = F.log_softmax(self.fc(out)) 125 | return out 126 | 127 | def feature_list(self, x): 128 | out_list = [] 129 | out = self.conv1(x) 130 | out = self.trans1(self.dense1(out)) 131 | out = self.trans2(self.dense2(out)) 132 | out = self.dense3(out) 133 | out = F.relu(self.bn1(out)) 134 | out_list.append(out) 135 | # out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 136 | # out = F.log_softmax(self.fc(out)) 137 | return out_list 138 | 139 | model_dict = { 140 | 'densenet100': [DenseNet, 342], 141 | } 142 | 143 | class SupCEHeadDenseNet(nn.Module): 144 | """encoder + classifier""" 145 | def __init__(self, name='densenet100', head='linear', feat_dim = 128, num_classes=100, multiplier = 1): 146 | super(SupCEHeadDenseNet, self).__init__() 147 | model_fun, dim_in = model_dict[name] 148 | self.encoder = model_fun() 149 | self.fc = nn.Linear(dim_in, num_classes) 150 | self.multiplier = multiplier 151 | 152 | if head == 'linear': 153 | self.head = nn.Linear(dim_in, feat_dim) 154 | elif head == 'mlp': 155 | self.head = nn.Sequential( 156 | nn.Linear(dim_in, dim_in), 157 | nn.ReLU(inplace=True), 158 | nn.Linear(dim_in, feat_dim) 159 | ) 160 | 161 | 162 | def forward(self, x): 163 | features = self.encoder(x) 164 | return self.fc(features) 165 | 166 | def intermediate_forward(self, x, layer_index): 167 | if layer_index == 0: 168 | return self.encoder.intermediate_forward(x, layer_index) 169 | elif layer_index == 1: 170 | feat = self.encoder(x) 171 | feat = self.multiplier * F.normalize(self.head(feat), dim=1) 172 | return feat -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HYPO: Hyperspherical Out-of-Distribution Generalization 2 | 3 | This codebase provides a Pytorch implementation for the paper: [HYPO: Hyperspherical Out-of-Distribution Generalization](https://openreview.net/pdf?id=VXak3CZZGC) by Haoyue Bai*, Yifei Ming*, Julian Katz-Samuels, and Yixuan Li, . 4 | 5 | **Remarks**: The current codebase is available for preview purposes only and is still under development. We are actively working on eliminating hard-coded links, removing unused arguments, and streamlining the processes for loading models and datasets. Please stay tuned for forthcoming updates. 6 | 7 | ### Abstract 8 | 9 | Out-of-distribution (OOD) generalization is critical for machine learning models deployed in the real world. However, achieving this can be fundamentally challenging, as it requires the ability to learn invariant features across different domains or environments. In this paper, we propose a novel framework HYPO (HYPerspherical OOD generalization) that provably learns domain-invariant representations in a hyperspherical space. In particular, our hyperspherical learning algorithm is guided by intra-class variation and inter-class separation principles---ensuring that features from the same class (across different training domains) are closely aligned with their class prototypes, while different class prototypes are maximally separated. We further provide theoretical justifications on how our prototypical learning objective improves the OOD generalization bound. Through extensive experiments on challenging OOD benchmarks, we demonstrate that our approach outperforms competitive baselines and achieves superior performance. 10 | 11 | 12 | ## Quick Start 13 | 14 | ### Data Preparation 15 | In this work, we evaluate the OOD generalization performance over a range of environmental discrepancies such as domains, image corruptions, and perturbations. 16 | 17 | **OOD generalization across domains**: The default root directory for ID and OOD datasets is `datasets/`. We consider [PACS](https://arxiv.org/abs/1710.03077), [Office-Home](https://arxiv.org/abs/1706.07522), [VLCS](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf), [Terra Incognita](https://arxiv.org/abs/1807.04975). You may use `scripts/download.py` (from [DomainBed](https://github.com/facebookresearch/DomainBed)) to download and prepare the datasets for domain generalization. 18 | 19 | **OOD generalization across common corruptions**: The default root directory for ID and OOD datasets is `datasets/`. We consider 20 | [CIFAR-10](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) & [CIFAR-10-C](https://arxiv.org/abs/1903.12261) and ImageNet-100 & [ImageNet-100-C](https://arxiv.org/abs/1903.12261). 21 | In alignment with prior works on the [ImageNet-100](https://github.com/deeplearning-wisc/MCM/tree/main) subset, the script for generating the subset is provided [here](https://github.com/deeplearning-wisc/MCM/blob/main/create_imagenet_subset.py). 22 | 23 | #### CIFAR-10 & CIFAR-10-C 24 | 25 | - Create a folder named `cifar-10/` and a folder `cifar-10-c/` under `$datasets`. 26 | - Download the dataset from the [CIFAR-10](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) and extract the training and validation sets to `$DATA/cifar-10/`. 27 | - Download the dataset from the [CIFAR-10-C](https://arxiv.org/abs/1903.12261) and extract the training and test sets to `$DATA/cifar-10-c/`. The directory structure should look like 28 | ``` 29 | cifar-10/ 30 | |–– train/ 31 | |–– val/ 32 | cifar-10-c/ 33 | |–– CorCIFAR10_train/ 34 | |–– CorCIFAR10_test/ 35 | ``` 36 | 37 | #### ImageNet-100 & ImageNet-100-C 38 | 39 | - Create a folder named `imagenet-100/` and a folder `imagenet-100-c/` under `$datasets`. 40 | - Create `images/` under `imagenet-100/` and `imagenet-100-c/. 41 | - Download the dataset from the [ImageNet-100](https://image-net.org/index.php](https://github.com/deeplearning-wisc/MCM/tree/main) and extract the training and validation sets to `$DATA/imagenet-100/images`. 42 | - Download the dataset from the [ImageNet-100-C](https://arxiv.org/abs/1903.12261) and extract the training and validation sets to `$DATA/imagenet-100-c/images`. The directory structure should look like 43 | ``` 44 | imagenet-100/ 45 | |–– images/ 46 | | |–– train/ # contains 100 folders like n01440764, n01443537, etc. 47 | | |–– val/ 48 | imagenet-100-c/ 49 | |–– images/ 50 | | |–– train/ 51 | | |–– val/ 52 | ``` 53 | 54 | ## Training and Evaluation 55 | 56 | ### Model Checkpoints 57 | 58 | **Evaluate pre-trained checkpoints** 59 | 60 | Our checkpoints can be downloaded [here](https://drive.google.com/file/d/1nflCX3YUTwX54QR_jiLlPq9q6Hni2YMe/view?usp=drive_link). Create a directory named `checkpoints/[ID_DATASET]` in the root directory of the project and put the downloaded checkpoints here. For example, for CIFAR-10 and PACS: 61 | 62 | ``` 63 | checkpoints/ 64 | ---CIFAR-10/ 65 | ------checkpoint_hypo_resnet18_cifar10.pth.tar 66 | ---PACS/ 67 | ------checkpoint_hypo_resnet50_td_photo.pth.tar 68 | ------checkpoint_hypo_resnet50_td_cartoon.pth.tar 69 | ------checkpoint_hypo_resnet50_td_sketch.pth.tar 70 | ------checkpoint_hypo_resnet50_td_art_painting.pth.tar 71 | ``` 72 | 73 | The following scripts can be used to evaluate the OOD detection performance: 74 | 75 | ``` 76 | sh scripts/eval_ckpt_cifar10.sh ckpt_c10 #for CIFAR-10 77 | sh scripts/eval_ckpt_pacs.sh ckpt_pacs # for PACS 78 | ``` 79 | 80 | 81 | 82 | **Evaluate custom checkpoints** 83 | 84 | If the default directory to save checkpoints is not `checkpoints`, create a softlink to the directory where the actual checkpoints are saved and name it as `checkpoints`. For example, checkpoints for CIFAR-100 (ID) are structured as follows: 85 | 86 | ```python 87 | checkpoints/ 88 | ---CIFAR-100/ 89 | ------name_of_ckpt/ 90 | ---------checkpoint_500.pth.tar 91 | ``` 92 | 93 | 94 | **Train from scratch** 95 | 96 | We provide sample scripts to train from scratch. Feel free to modify the hyperparameters and training configurations. 97 | 98 | ``` 99 | sh scripts/train_hypo_cifar10.sh 100 | sh scripts/train_hypo_dg.sh 101 | ``` 102 | 103 | **Fine-tune from ImageNet pre-trained models** 104 | 105 | We also provide fine-tuning scripts on large-scale datasets such as ImageNet-100. 106 | 107 | ``` 108 | sh scripts/train_hypo_imgnet100.sh 109 | ``` 110 | 111 | 112 | 113 | ### Citation 114 | 115 | If you find our work useful, please consider citing our paper: 116 | ``` 117 | @inproceedings{ 118 | hypo2024, 119 | title={Provable Out-of-Distribution Generalization in Hypersphere}, 120 | author={Haoyue Bai and Yifei Ming and Julian Katz-Samuels and Yixuan Li}, 121 | booktitle={The Twelfth International Conference on Learning Representations (ICLR)}, 122 | year={2024}, 123 | } 124 | ``` 125 | 126 | 127 | ### Further discussions 128 | For more discussions on the method and extensions, feel free to drop an email at hbai39@wisc.edu or ming5@wisc.edu 129 | -------------------------------------------------------------------------------- /utils/display_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sklearn.metrics as sk 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | 7 | recall_level_default = 0.95 8 | 9 | 10 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 11 | """Use high precision for cumsum and check that final value matches sum 12 | Parameters 13 | ---------- 14 | arr : array-like 15 | To be cumulatively summed as flat 16 | rtol : float 17 | Relative tolerance, see ``np.allclose`` 18 | atol : float 19 | Absolute tolerance, see ``np.allclose`` 20 | """ 21 | out = np.cumsum(arr, dtype=np.float64) 22 | expected = np.sum(arr, dtype=np.float64) 23 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 24 | raise RuntimeError('cumsum was found to be unstable: ' 25 | 'its last element does not correspond to sum') 26 | return out 27 | 28 | 29 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None): 30 | classes = np.unique(y_true) 31 | if (pos_label is None and 32 | not (np.array_equal(classes, [0, 1]) or 33 | np.array_equal(classes, [-1, 1]) or 34 | np.array_equal(classes, [0]) or 35 | np.array_equal(classes, [-1]) or 36 | np.array_equal(classes, [1]))): 37 | raise ValueError("Data is not binary and pos_label is not specified") 38 | elif pos_label is None: 39 | pos_label = 1. 40 | 41 | # make y_true a boolean vector 42 | y_true = (y_true == pos_label) 43 | 44 | # sort scores and corresponding truth values 45 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 46 | y_score = y_score[desc_score_indices] 47 | y_true = y_true[desc_score_indices] 48 | 49 | # y_score typically has many tied values. Here we extract 50 | # the indices associated with the distinct values. We also 51 | # concatenate a value for the end of the curve. 52 | distinct_value_indices = np.where(np.diff(y_score))[0] 53 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 54 | 55 | # accumulate the true positives with decreasing threshold 56 | tps = stable_cumsum(y_true)[threshold_idxs] 57 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 58 | 59 | thresholds = y_score[threshold_idxs] 60 | 61 | recall = tps / tps[-1] 62 | 63 | last_ind = tps.searchsorted(tps[-1]) 64 | sl = slice(last_ind, None, -1) # [last_ind::-1] 65 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 66 | 67 | cutoff = np.argmin(np.abs(recall - recall_level)) 68 | 69 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) 70 | 71 | 72 | def get_measures(_pos, _neg, recall_level=recall_level_default): 73 | pos = np.array(_pos[:]).reshape((-1, 1)) 74 | neg = np.array(_neg[:]).reshape((-1, 1)) 75 | examples = np.squeeze(np.vstack((pos, neg))) 76 | labels = np.zeros(len(examples), dtype=np.int32) 77 | labels[:len(pos)] += 1 78 | 79 | auroc = sk.roc_auc_score(labels, examples) 80 | aupr = sk.average_precision_score(labels, examples) 81 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 82 | 83 | return auroc, aupr, fpr 84 | 85 | 86 | def show_performance(pos, neg, method_name='Ours', recall_level=recall_level_default): 87 | ''' 88 | :param pos: 1's class, class to detect, outliers, or wrongly predicted 89 | example scores 90 | :param neg: 0's class scores 91 | ''' 92 | 93 | auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level) 94 | 95 | print(f'\t\t\t {method_name}') 96 | print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 97 | print('AUROC:\t\t\t{:.2f}'.format(100 * auroc)) 98 | print('AUPR:\t\t\t{:.2f}'.format(100 * aupr)) 99 | 100 | def print_measures(log, auroc, aupr, fpr, method_name='Ours', recall_level=recall_level_default): 101 | if log == None: 102 | print('FPR{:d}:\t\t\t{:.2f}'.format(int(100 * recall_level), 100 * fpr)) 103 | print('AUROC: \t\t\t{:.2f}'.format(100 * auroc)) 104 | print('AUPR: \t\t\t{:.2f}'.format(100 * aupr)) 105 | else: 106 | log.debug('\t\t\t\t' + method_name) 107 | log.debug(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 108 | log.debug('& {:.2f} & {:.2f} & {:.2f}'.format(100*fpr, 100*auroc, 100*aupr)) 109 | 110 | 111 | def print_measures_with_std(log, aurocs, auprs, fprs, method_name='Ours', recall_level=recall_level_default): 112 | if log: 113 | log.debug('\t\t\t\t' + method_name) 114 | log.debug(' FPR{:d} AUROC AUPR'.format(int(100*recall_level))) 115 | log.debug('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.mean(fprs), 100*np.mean(aurocs), 100*np.mean(auprs))) 116 | log.debug('& {:.2f} & {:.2f} & {:.2f}'.format(100*np.std(fprs), 100*np.std(aurocs), 100*np.std(auprs))) 117 | else: 118 | print('FPR{:d}:\t\t\t{:.2f}\t+/- {:.2f}'.format(int(100 * recall_level), 100 * np.mean(fprs), 100 * np.std(fprs))) 119 | print('AUROC: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(aurocs), 100 * np.std(aurocs))) 120 | print('AUPR: \t\t\t{:.2f}\t+/- {:.2f}'.format(100 * np.mean(auprs), 100 * np.std(auprs))) 121 | 122 | def plot_distribution(args, id_scores, ood_scores, out_dataset): 123 | sns.set(style="white", palette="muted") 124 | palette = ['#A8BAE3', '#55AB83'] 125 | sns.displot({"ID": -1 * id_scores, "OOD": -1 * ood_scores}, label="id", kind = "kde", palette=palette, fill = True, alpha = 0.8) 126 | plt.savefig(os.path.join(args.log_directory,f"KNN_{out_dataset}.png"), bbox_inches='tight') 127 | 128 | def save_as_dataframe(args, out_datasets, fpr_list, auroc_list, aupr_list): 129 | fpr_list = [float('{:.2f}'.format(100*fpr)) for fpr in fpr_list] 130 | auroc_list = [float('{:.2f}'.format(100*auroc)) for auroc in auroc_list] 131 | aupr_list = [float('{:.2f}'.format(100*aupr)) for aupr in aupr_list] 132 | import pandas as pd 133 | data = {k:v for k,v in zip(out_datasets, zip(fpr_list,auroc_list,aupr_list))} 134 | data['AVG'] = [np.mean(fpr_list),np.mean(auroc_list),np.mean(aupr_list) ] 135 | data['AVG'] = [float('{:.2f}'.format(metric)) for metric in data['AVG']] 136 | # Specify orient='index' to create the DataFrame using dictionary keys as rows 137 | df = pd.DataFrame.from_dict(data, orient='index', 138 | columns=['FPR95', 'AUROC', 'AUPR']) 139 | df.to_csv(os.path.join(args.log_directory,f'{args.name}.csv')) 140 | 141 | def show_performance_comparison(pos_base, neg_base, pos_ours, neg_ours, baseline_name='Baseline', 142 | method_name='Ours', recall_level=recall_level_default): 143 | ''' 144 | :param pos_base: 1's class, class to detect, outliers, or wrongly predicted 145 | example scores from the baseline 146 | :param neg_base: 0's class scores generated by the baseline 147 | ''' 148 | auroc_base, aupr_base, fpr_base = get_measures(pos_base[:], neg_base[:], recall_level) 149 | auroc_ours, aupr_ours, fpr_ours = get_measures(pos_ours[:], neg_ours[:], recall_level) 150 | 151 | print('\t\t\t' + baseline_name + '\t' + method_name) 152 | print('FPR{:d}:\t\t\t{:.2f}\t\t{:.2f}'.format( 153 | int(100 * recall_level), 100 * fpr_base, 100 * fpr_ours)) 154 | print('AUROC:\t\t\t{:.2f}\t\t{:.2f}'.format( 155 | 100 * auroc_base, 100 * auroc_ours)) 156 | print('AUPR:\t\t\t{:.2f}\t\t{:.2f}'.format( 157 | 100 * aupr_base, 100 * aupr_ours)) -------------------------------------------------------------------------------- /eval_hypo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import time 5 | from datetime import datetime 6 | import logging 7 | # import tensorboard_logger as tb_logger 8 | import pprint 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.nn.functional as F 13 | import torch.optim 14 | import torch.utils.data 15 | import numpy as np 16 | 17 | 18 | from make_datasets_cifar import * 19 | 20 | from sklearn.metrics import accuracy_score 21 | 22 | from utils import (CompLoss, DisLoss, DisLPLoss, set_loader_small, set_loader_ImageNet, set_model) 23 | 24 | parser = argparse.ArgumentParser(description='Eval HYPO') 25 | parser.add_argument('--gpu', default=7, type=int, help='which GPU to use') 26 | parser.add_argument('--seed', default=4, type=int, help='random seed') # original 4 27 | parser.add_argument('--w', default=2, type=float, 28 | help='loss scale') 29 | parser.add_argument('--proto_m', default= 0.99, type=float, 30 | help='weight of prototype update') 31 | parser.add_argument('--feat_dim', default = 128, type=int, 32 | help='feature dim') 33 | parser.add_argument('--in-dataset', default="CIFAR-10", type=str, help='ID dataset name', choices=['PACS', 'CIFAR-10', 'ImageNet-100']) 34 | parser.add_argument('--id_loc', default="datasets", type=str, help='location of ID dataset') 35 | parser.add_argument('--ood_loc', default="datasets", type=str, help='location of OOD dataset') 36 | parser.add_argument('--model', default='resnet18', type=str, help='model architecture: [resnet18, wrt40, wrt28, densenet100, resnet50, resnet34]') 37 | parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head') 38 | parser.add_argument('--loss', default = 'hypo', type=str, choices = ['hypo'], 39 | help='name of experiment') 40 | 41 | parser.add_argument('--ckpt_name', type=str, default='ckpt_hypo_resnet18_cifar10', 42 | help='name of the model checkpoint') 43 | 44 | parser.add_argument('--ckpt_loc', type=str, default='checkpoints/CIFAR-10', 45 | help='loc of the model checkpoint') 46 | parser.add_argument('-b', '--batch_size', default= 128, type=int, 47 | help='mini-batch size (default: 64)') 48 | parser.add_argument('--print-freq', '-p', default=10, type=int, 49 | help='print frequency (default: 10)') 50 | parser.add_argument('--temp', type=float, default=0.1, 51 | help='temperature for loss function') 52 | parser.add_argument('--normalize', action='store_true', 53 | help='normalize feat embeddings') 54 | parser.add_argument('--prefetch', type=int, default=4, help='Pre-fetching threads.') 55 | parser.add_argument('--target_domain', type=str, default='sketch', choices=['sketch', 'photo', 'art_painting', 'cartoon']) 56 | parser.add_argument('--cortype', type=str, default='gaussian_noise', help='data type of corrupted datasets') 57 | 58 | parser.set_defaults(bottleneck=True) 59 | parser.set_defaults(augment=True) 60 | 61 | args = parser.parse_args() 62 | 63 | 64 | state = {k: v for k, v in args._get_kwargs()} 65 | 66 | date_time = datetime.now().strftime("%d_%m_%H:%M") 67 | 68 | args.log_directory = "logs/eval/{in_dataset}/{name}/".format(in_dataset=args.in_dataset, name= args.ckpt_name) 69 | if not os.path.exists(args.log_directory): 70 | os.makedirs(args.log_directory) 71 | 72 | 73 | #init log 74 | log = logging.getLogger(__name__) 75 | formatter = logging.Formatter('%(asctime)s : %(message)s') 76 | fileHandler = logging.FileHandler(os.path.join(args.log_directory, "eval_info.log"), mode='w') 77 | fileHandler.setFormatter(formatter) 78 | streamHandler = logging.StreamHandler() 79 | streamHandler.setFormatter(formatter) 80 | log.setLevel(logging.DEBUG) 81 | log.addHandler(fileHandler) 82 | log.addHandler(streamHandler) 83 | 84 | log.debug(state) 85 | 86 | if args.in_dataset == "CIFAR-10": 87 | args.n_cls = 10 88 | elif args.in_dataset == "PACS": 89 | args.n_cls = 7 90 | elif args.in_dataset == "VLCS": 91 | args.n_cls = 5 92 | elif args.in_dataset == "OfficeHome": 93 | args.n_cls = 65 94 | elif args.in_dataset == 'terra_incognita': 95 | args.n_cls = 10 96 | elif args.in_dataset in ["CIFAR-100", "ImageNet-100"]: 97 | args.n_cls = 100 98 | 99 | 100 | #set seeds 101 | torch.manual_seed(args.seed) 102 | torch.cuda.manual_seed(args.seed) 103 | np.random.seed(args.seed) 104 | log.debug(f"Evaluating {args.ckpt_name}") 105 | 106 | 107 | def to_np(x): return x.data.cpu().numpy() 108 | 109 | 110 | if args.in_dataset == 'CIFAR-10': 111 | val_loader, test_loader_ood = make_datasets(args.id_loc, args.ood_loc, args.in_dataset, state, args.cortype) 112 | 113 | else: 114 | train_loader, val_loader, test_loader_ood = set_loader_small(args) 115 | 116 | 117 | print("\n len(loader_in.dataset) {}, " \ 118 | "len(test_loader_ood.dataset) {}".format( 119 | len(val_loader.dataset), 120 | len(test_loader_ood.dataset))) 121 | 122 | 123 | def main(): 124 | 125 | model = set_model(args) 126 | 127 | model_name=f'{args.ckpt_loc}/{args.ckpt_name}.pth.tar' 128 | model.load_state_dict(torch.load(model_name)['state_dict']) 129 | 130 | criterion_dis = DisLoss(args, model, val_loader, temperature=args.temp).cuda() # V2: prototypes with EMA style update 131 | 132 | criterion_dis.load_state_dict(torch.load(model_name)['dis_state_dict']) 133 | 134 | 135 | model.eval() 136 | 137 | 138 | print("computing over distribution ID dataset. \n") 139 | with torch.no_grad(): 140 | accuracies_in = [] 141 | for data, target in val_loader: 142 | data, target = data.cuda(), target.cuda() 143 | 144 | penultimate = model.encoder(data).squeeze() 145 | penultimate = F.normalize(penultimate, dim=1) 146 | 147 | features = model.forward(data) 148 | 149 | feat_dot_prototype = torch.div(torch.matmul(features, criterion_dis.prototypes.T), args.temp) 150 | 151 | # for numerical stability 152 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 153 | logits = feat_dot_prototype - logits_max.detach() 154 | 155 | pred = logits.data.max(1)[1] 156 | accuracies_in.append(accuracy_score(list(to_np(pred)), list(to_np(target)))) 157 | 158 | acc = sum(accuracies_in) / len(accuracies_in) 159 | print("ID accuracy: {}".format(acc)) 160 | 161 | print("computing over test distribution cor dataset. \n") 162 | with torch.no_grad(): 163 | accuracies_cor = [] 164 | for data, target in test_loader_ood: 165 | data, target = data.cuda(), target.cuda() 166 | 167 | penultimate = model.encoder(data).squeeze() 168 | penultimate = F.normalize(penultimate, dim=1) 169 | 170 | features = model.forward(data) 171 | 172 | feat_dot_prototype = torch.div(torch.matmul(features, criterion_dis.prototypes.T), args.temp) 173 | 174 | # for numerical stability 175 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 176 | logits = feat_dot_prototype - logits_max.detach() 177 | 178 | pred = logits.data.max(1)[1] 179 | accuracies_cor.append(accuracy_score(list(to_np(pred)), list(to_np(target)))) 180 | 181 | acc_cor = sum(accuracies_cor) / len(accuracies_cor) 182 | 183 | if args.in_dataset == 'CIFAR-10': 184 | print("OOD accuracy for generalization: {}, corrupted types is: {}".format(acc_cor, args.cortype)) 185 | else: 186 | print("OOD accuracy for generalization: {}, target domain is: {}".format(acc_cor, args.target_domain)) 187 | 188 | 189 | 190 | 191 | if __name__ == '__main__': 192 | main() 193 | -------------------------------------------------------------------------------- /models/head_wrn_vmf.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | 7 | #from wrn_vmf import * 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | 25 | def forward(self, x): 26 | if not self.equalInOut: 27 | x = self.relu1(self.bn1(x)) 28 | else: 29 | out = self.relu1(self.bn1(x)) 30 | if self.equalInOut: 31 | out = self.relu2(self.bn2(self.conv1(out))) 32 | else: 33 | out = self.relu2(self.bn2(self.conv1(x))) 34 | if self.droprate > 0: 35 | out = F.dropout(out, p=self.droprate, training=self.training) 36 | out = self.conv2(out) 37 | if not self.equalInOut: 38 | return torch.add(self.convShortcut(x), out) 39 | else: 40 | return torch.add(x, out) 41 | 42 | 43 | class NetworkBlock(nn.Module): 44 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 45 | super(NetworkBlock, self).__init__() 46 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 47 | 48 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 49 | layers = [] 50 | for i in range(nb_layers): 51 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 52 | return nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layer(x) 56 | 57 | 58 | class WideResNet(nn.Module): 59 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 60 | super(WideResNet, self).__init__() 61 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 62 | assert ((depth - 4) % 6 == 0) 63 | n = (depth - 4) // 6 64 | block = BasicBlock 65 | # 1st conv before any network block 66 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 67 | padding=1, bias=False) 68 | # 1st block 69 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 70 | # 2nd block 71 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 72 | # 3rd block 73 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 74 | # global average pooling and classifier 75 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.fc = nn.Linear(nChannels[3], num_classes) 78 | self.nChannels = nChannels[3] 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, math.sqrt(2. / n)) 84 | elif isinstance(m, nn.BatchNorm2d): 85 | m.weight.data.fill_(1) 86 | m.bias.data.zero_() 87 | elif isinstance(m, nn.Linear): 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x): 91 | #import pdb 92 | #pdb.set_trace() 93 | 94 | out = self.conv1(x) 95 | out = self.block1(out) 96 | out = self.block2(out) 97 | out = self.block3(out) 98 | out = self.relu(self.bn1(out)) 99 | out = F.avg_pool2d(out, 8) 100 | out = out.view(-1, self.nChannels) # 64 x 128 101 | features = out.view(-1, self.nChannels) # 64 x 128 102 | #out = self.fc(out) 103 | #return self.fc(out) 104 | #return out 105 | return self.fc(features), features 106 | 107 | def intermediate_forward(self, x, layer_index): 108 | out = self.conv1(x) 109 | out = self.block1(out) 110 | out = self.block2(out) 111 | out = self.block3(out) 112 | out = self.relu(self.bn1(out)) 113 | out = F.avg_pool2d(out, 8) 114 | out = out.view(-1, self.nChannels) 115 | return out 116 | 117 | def feature_list(self, x): 118 | out_list = [] 119 | out = self.conv1(x) 120 | out = self.block1(out) 121 | out = self.block2(out) 122 | out = self.block3(out) 123 | out = self.relu(self.bn1(out)) 124 | out_list.append(out) 125 | out = F.avg_pool2d(out, 8) 126 | out = out.view(-1, self.nChannels) 127 | return self.fc(out), out_list 128 | 129 | 130 | 131 | 132 | 133 | class HeadWideResNet(nn.Module): 134 | """encoder + head""" 135 | #def __init__(self, args, num_classes): 136 | def __init__(self, args): 137 | super(HeadWideResNet, self).__init__() 138 | #model_fun, dim_in = model_dict[args.model] 139 | #if args.in_dataset == 'ImageNet-100': 140 | # model = models.resnet34(pretrained=True) 141 | # for name, p in model.named_parameters(): 142 | # if not name.startswith('layer4'): 143 | # p.requires_grad = False 144 | 145 | num_classes = 10 146 | layers = 40 147 | widen_factor = 2 148 | droprate = 0.3 149 | 150 | #model = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate) 151 | model = WideResNet(layers, num_classes, widen_factor, droprate) 152 | model_name = '/u/b/a/baihaoyue/private/ood_detection/woods_ood/CIFAR/snapshots/pretrained/cifar10_wrn_pretrained_epoch_99.pt' 153 | if os.path.isfile(model_name): 154 | print('found pretrained model') 155 | #net.load_state_dict(torch.load(model_name)['params']) 156 | model.load_state_dict(torch.load(model_name)) 157 | print('Model restored!') 158 | dim_in = model.fc.in_features 159 | #modules=list(model.children())[:-1] # remove last linear layer 160 | #modules=list(model.children()) 161 | #self.encoder =nn.Sequential(*modules) 162 | self.encoder = model 163 | #else: 164 | # self.encoder = model_fun() 165 | #self.fc = nn.Linear(dim_in, args.n_cls) 166 | self.fc = nn.Linear(dim_in, num_classes) 167 | #self.multiplier = multiplier 168 | if args.head == 'linear': 169 | self.head = nn.Linear(dim_in, args.feat_dim) 170 | elif args.head == 'mlp': 171 | self.head = nn.Sequential( 172 | nn.Linear(dim_in, dim_in), 173 | nn.ReLU(inplace=True), 174 | nn.Linear(dim_in, args.feat_dim) 175 | ) 176 | 177 | def forward(self, x): 178 | #import pdb 179 | #pdb.set_trace() 180 | #feat = self.encoder(x).squeeze() 181 | x, feat = self.encoder(x) 182 | unnorm_features = self.head(feat) 183 | features= F.normalize(unnorm_features, dim=1) 184 | return features 185 | 186 | def intermediate_forward(self, x): 187 | feat = self.encoder(x).squeeze() 188 | return F.normalize(feat, dim=1) 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | import urllib 14 | 15 | 16 | # utils ####################################################################### 17 | 18 | def stage_path(data_dir, name): 19 | full_path = os.path.join(data_dir, name) 20 | 21 | if not os.path.exists(full_path): 22 | os.makedirs(full_path) 23 | 24 | return full_path 25 | 26 | 27 | def download_and_extract(url, dst, remove=True): 28 | gdown.download(url, dst, quiet=False) 29 | 30 | if dst.endswith(".tar.gz"): 31 | tar = tarfile.open(dst, "r:gz") 32 | tar.extractall(os.path.dirname(dst)) 33 | tar.close() 34 | 35 | if dst.endswith(".tar"): 36 | tar = tarfile.open(dst, "r:") 37 | tar.extractall(os.path.dirname(dst)) 38 | tar.close() 39 | 40 | if dst.endswith(".zip"): 41 | zf = ZipFile(dst, "r") 42 | zf.extractall(os.path.dirname(dst)) 43 | zf.close() 44 | 45 | if remove: 46 | os.remove(dst) 47 | 48 | 49 | # VLCS ######################################################################## 50 | 51 | # Slower, but builds dataset from the original sources 52 | # 53 | # def download_vlcs(data_dir): 54 | # full_path = stage_path(data_dir, "VLCS") 55 | # 56 | # tmp_path = os.path.join(full_path, "tmp/") 57 | # if not os.path.exists(tmp_path): 58 | # os.makedirs(tmp_path) 59 | # 60 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 61 | # lines = f.readlines() 62 | # files = [line.strip().split() for line in lines] 63 | # 64 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 65 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 66 | # 67 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 68 | # os.path.join(tmp_path, "caltech101.tar.gz")) 69 | # 70 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 71 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 72 | # 73 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 74 | # tar.extractall(tmp_path) 75 | # tar.close() 76 | # 77 | # for src, dst in files: 78 | # class_folder = os.path.join(data_dir, dst) 79 | # 80 | # if not os.path.exists(class_folder): 81 | # os.makedirs(class_folder) 82 | # 83 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 84 | # 85 | # if "labelme" in src: 86 | # # download labelme from the web 87 | # gdown.download(src, dst, quiet=False) 88 | # else: 89 | # src = os.path.join(tmp_path, src) 90 | # shutil.copyfile(src, dst) 91 | # 92 | # shutil.rmtree(tmp_path) 93 | 94 | 95 | def download_vlcs(data_dir): 96 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 97 | full_path = stage_path(data_dir, "VLCS") 98 | 99 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 100 | os.path.join(data_dir, "VLCS.tar.gz")) 101 | 102 | 103 | # PACS ######################################################################## 104 | 105 | def download_pacs(data_dir): 106 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 107 | full_path = stage_path(data_dir, "PACS") 108 | 109 | download_and_extract("https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", 110 | os.path.join(data_dir, "PACS.zip")) 111 | 112 | os.rename(os.path.join(data_dir, "kfold"), 113 | full_path) 114 | 115 | 116 | # Office-Home ################################################################# 117 | 118 | def download_office_home(data_dir): 119 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 120 | full_path = stage_path(data_dir, "office_home") 121 | 122 | download_and_extract("https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC", 123 | os.path.join(data_dir, "office_home.zip")) 124 | 125 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 126 | full_path) 127 | 128 | 129 | # DomainNET ################################################################### 130 | 131 | def download_domain_net(data_dir): 132 | # Original URL: http://ai.bu.edu/M3SDA/ 133 | full_path = stage_path(data_dir, "domain_net") 134 | 135 | urls = [ 136 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 137 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 138 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 139 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 140 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 141 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 142 | ] 143 | 144 | for url in urls: 145 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 146 | 147 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 148 | for line in f.readlines(): 149 | try: 150 | os.remove(os.path.join(full_path, line.strip())) 151 | except OSError: 152 | pass 153 | 154 | 155 | # TerraIncognita ############################################################## 156 | 157 | def download_terra_incognita(data_dir): 158 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 159 | # New URL: http://lila.science/datasets/caltech-camera-traps 160 | 161 | full_path = stage_path(data_dir, "terra_incognita") 162 | 163 | download_and_extract( 164 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 165 | os.path.join(full_path, "terra_incognita_images.tar.gz")) 166 | 167 | download_and_extract( 168 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 169 | os.path.join(full_path, "caltech_camera_traps.json.zip")) 170 | 171 | include_locations = ["38", "46", "100", "43"] 172 | 173 | include_categories = [ 174 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit", 175 | "raccoon", "squirrel" 176 | ] 177 | 178 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 179 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 180 | destination_folder = full_path 181 | 182 | stats = {} 183 | 184 | if not os.path.exists(destination_folder): 185 | os.mkdir(destination_folder) 186 | 187 | with open(annotations_file, "r") as f: 188 | data = json.load(f) 189 | 190 | category_dict = {} 191 | for item in data['categories']: 192 | category_dict[item['id']] = item['name'] 193 | 194 | for image in data['images']: 195 | image_location = image['location'] 196 | 197 | if image_location not in include_locations: 198 | continue 199 | 200 | loc_folder = os.path.join(destination_folder, 201 | 'location_' + str(image_location) + '/') 202 | 203 | if not os.path.exists(loc_folder): 204 | os.mkdir(loc_folder) 205 | 206 | image_id = image['id'] 207 | image_fname = image['file_name'] 208 | 209 | for annotation in data['annotations']: 210 | if annotation['image_id'] == image_id: 211 | if image_location not in stats: 212 | stats[image_location] = {} 213 | 214 | category = category_dict[annotation['category_id']] 215 | 216 | if category not in include_categories: 217 | continue 218 | 219 | if category not in stats[image_location]: 220 | stats[image_location][category] = 0 221 | else: 222 | stats[image_location][category] += 1 223 | 224 | loc_cat_folder = os.path.join(loc_folder, category + '/') 225 | 226 | if not os.path.exists(loc_cat_folder): 227 | os.mkdir(loc_cat_folder) 228 | 229 | dst_path = os.path.join(loc_cat_folder, image_fname) 230 | src_path = os.path.join(images_folder, image_fname) 231 | 232 | shutil.copyfile(src_path, dst_path) 233 | 234 | shutil.rmtree(images_folder) 235 | os.remove(annotations_file) 236 | 237 | 238 | 239 | 240 | if __name__ == "__main__": 241 | parser = argparse.ArgumentParser(description='Download datasets') 242 | parser.add_argument('--data_dir', type=str, required=True) 243 | args = parser.parse_args() 244 | 245 | # download_pacs(args.data_dir) 246 | # download_office_home(args.data_dir) 247 | # download_domain_net(args.data_dir) 248 | # download_vlcs(args.data_dir) 249 | download_terra_incognita(args.data_dir) 250 | 251 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1, is_last=False): 10 | super(BasicBlock, self).__init__() 11 | self.is_last = is_last 12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | 17 | self.shortcut = nn.Sequential() 18 | if stride != 1 or in_planes != self.expansion * planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 21 | nn.BatchNorm2d(self.expansion * planes) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(self.conv1(x))) 26 | out = self.bn2(self.conv2(out)) 27 | out += self.shortcut(x) 28 | preact = out 29 | out = F.relu(out) 30 | if self.is_last: 31 | return out, preact 32 | else: 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1, is_last=False): 40 | super(Bottleneck, self).__init__() 41 | self.is_last = is_last 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion * planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion * planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | preact = out 62 | out = F.relu(out) 63 | if self.is_last: 64 | return out, preact 65 | else: 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, 75 | bias=False) 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 78 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 79 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 80 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 81 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 87 | nn.init.constant_(m.weight, 1) 88 | nn.init.constant_(m.bias, 0) 89 | 90 | # Zero-initialize the last BN in each residual branch, 91 | # so that the residual branch starts with zeros, and each residual block behaves 92 | # like an identity. This improves the model by 0.2~0.3% according to: 93 | # https://arxiv.org/abs/1706.02677 94 | if zero_init_residual: 95 | for m in self.modules(): 96 | if isinstance(m, Bottleneck): 97 | nn.init.constant_(m.bn3.weight, 0) 98 | elif isinstance(m, BasicBlock): 99 | nn.init.constant_(m.bn2.weight, 0) 100 | 101 | def _make_layer(self, block, planes, num_blocks, stride): 102 | strides = [stride] + [1] * (num_blocks - 1) 103 | layers = [] 104 | for i in range(num_blocks): 105 | stride = strides[i] 106 | layers.append(block(self.in_planes, planes, stride)) 107 | self.in_planes = planes * block.expansion 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x, layer=100): 111 | out = F.relu(self.bn1(self.conv1(x))) 112 | out = self.layer1(out) 113 | out = self.layer2(out) 114 | out = self.layer3(out) 115 | out = self.layer4(out) 116 | out = self.avgpool(out) 117 | #now out dim: batch_size, 512, 1, 1 118 | out = torch.flatten(out, 1) #start_dim = 1 119 | return out 120 | 121 | # function to extact a specific feature 122 | def intermediate_forward(self, x): 123 | out = F.relu(self.bn1(self.conv1(x))) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = self.layer4(out) 128 | return out 129 | 130 | # function to extact the multiple features 131 | def feature_list(self, x): 132 | out_list = [] 133 | out = F.relu(self.bn1(self.conv1(x))) 134 | out = self.layer1(out) 135 | out = self.layer2(out) 136 | out = self.layer3(out) 137 | out = self.layer4(out) 138 | out_list.append(out) 139 | return out_list 140 | 141 | 142 | def resnet18(**kwargs): 143 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 144 | 145 | 146 | def resnet34(**kwargs): 147 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 148 | 149 | 150 | def resnet50(**kwargs): 151 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 152 | 153 | 154 | def resnet101(**kwargs): 155 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 156 | 157 | 158 | model_dict = { 159 | 'resnet18': [resnet18, 512], 160 | 'resnet34': [resnet34, 512], 161 | 'resnet50': [resnet50, 2048], 162 | 'resnet101': [resnet101, 2048], 163 | } 164 | 165 | 166 | class LinearBatchNorm(nn.Module): 167 | """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose""" 168 | def __init__(self, dim, affine=True): 169 | super(LinearBatchNorm, self).__init__() 170 | self.dim = dim 171 | self.bn = nn.BatchNorm2d(dim, affine=affine) 172 | 173 | def forward(self, x): 174 | x = x.view(-1, self.dim, 1, 1) 175 | x = self.bn(x) 176 | x = x.view(-1, self.dim) 177 | return x 178 | 179 | class SupCEResNet(nn.Module): 180 | """encoder + classifier""" 181 | def __init__(self, name='resnet18', normalize = False, num_classes=10): 182 | super(SupCEResNet, self).__init__() 183 | model_fun, dim_in = model_dict[name] 184 | self.encoder = model_fun() 185 | self.fc = nn.Linear(dim_in, num_classes) 186 | self.normalize = normalize 187 | 188 | def forward(self, x): 189 | features = self.encoder(x) 190 | if self.normalize: 191 | features = F.normalize(features, dim=1) 192 | return self.fc(features) 193 | 194 | class SupCEHeadResNet(nn.Module): 195 | """encoder + head""" 196 | def __init__(self, args, multiplier = 1): 197 | super(SupCEHeadResNet, self).__init__() 198 | model_fun, dim_in = model_dict[args.model] 199 | 200 | if args.model == 'resnet50': 201 | model = models.resnet50(pretrained=True) 202 | #for name, p in model.named_parameters(): 203 | # if not name.startswith('layer4'): 204 | # p.requires_grad = False 205 | for module in model.modules(): 206 | if isinstance(module, nn.BatchNorm2d): 207 | if hasattr(module, 'weight'): 208 | module.weight.requires_grad_(False) 209 | if hasattr(module, 'bias'): 210 | module.bias.requires_grad_(False) 211 | module.eval() 212 | modules=list(model.children())[:-1] # remove last linear layer 213 | self.encoder =nn.Sequential(*modules) 214 | elif args.model == 'resnet34': 215 | model = models.resnet34(pretrained=True) 216 | for name, p in model.named_parameters(): 217 | if not name.startswith('layer4'): 218 | p.requires_grad = False 219 | modules=list(model.children())[:-1] # remove last linear layer 220 | self.encoder =nn.Sequential(*modules) 221 | 222 | else: 223 | self.encoder = model_fun() 224 | 225 | self.fc = nn.Linear(dim_in, args.n_cls) 226 | self.multiplier = multiplier 227 | self.dropout = nn.Dropout(0.5) 228 | 229 | if args.head == 'linear': 230 | self.head = nn.Sequential(nn.Linear(dim_in, args.feat_dim)) 231 | elif args.head == 'mlp': 232 | self.head = nn.Sequential( 233 | nn.Linear(dim_in, dim_in), 234 | nn.ReLU(inplace=True), 235 | nn.Linear(dim_in, args.feat_dim) 236 | ) 237 | 238 | 239 | def forward(self, x): 240 | 241 | feat = self.encoder(x).squeeze() 242 | unnorm_features = self.head(feat) 243 | features= F.normalize(unnorm_features, dim=1) 244 | return features 245 | 246 | def intermediate_forward(self, x): 247 | feat = self.encoder(x).squeeze() 248 | return F.normalize(feat, dim=1) 249 | 250 | class LinearClassifier(nn.Module): 251 | """Linear classifier""" 252 | def __init__(self, name='resnet18', num_classes=10): 253 | super(LinearClassifier, self).__init__() 254 | _, feat_dim = model_dict[name] 255 | self.fc = nn.Linear(feat_dim, num_classes) 256 | 257 | def forward(self, features): 258 | return self.fc(features) 259 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Aapted from SupCon: https://github.com/HobbitLong/SupContrast/ 4 | """ 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import time 11 | 12 | def binarize(T, nb_classes): 13 | T = T.cpu().numpy() 14 | import sklearn.preprocessing 15 | T = sklearn.preprocessing.label_binarize( 16 | T, classes = range(0, nb_classes) 17 | ) 18 | T = torch.FloatTensor(T).cuda() 19 | return T 20 | 21 | def l2_norm(input): 22 | input_size = input.size() 23 | buffer = torch.pow(input, 2) 24 | normp = torch.sum(buffer, 1).add_(1e-12) 25 | norm = torch.sqrt(normp) 26 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 27 | output = _output.view(input_size) 28 | return output 29 | 30 | 31 | class Proxy_Anchor(torch.nn.Module): 32 | def __init__(self, nb_classes, sz_embed, mrg = 0.1, alpha = 32): 33 | torch.nn.Module.__init__(self) 34 | # Proxy Anchor Initialization 35 | self.proxies = torch.nn.Parameter(torch.randn(nb_classes, sz_embed).cuda()) 36 | nn.init.kaiming_normal_(self.proxies, mode='fan_out') 37 | 38 | self.nb_classes = nb_classes 39 | self.sz_embed = sz_embed 40 | self.mrg = mrg 41 | self.alpha = alpha 42 | 43 | def forward(self, X, T): 44 | P = self.proxies 45 | 46 | cos = F.linear(l2_norm(X), l2_norm(P)) # Calcluate cosine similarity 47 | P_one_hot = binarize(T = T, nb_classes = self.nb_classes) 48 | N_one_hot = 1 - P_one_hot 49 | 50 | pos_exp = torch.exp(-self.alpha * (cos - self.mrg)) 51 | neg_exp = torch.exp(self.alpha * (cos + self.mrg)) 52 | 53 | with_pos_proxies = torch.nonzero(P_one_hot.sum(dim = 0) != 0).squeeze(dim = 1) # The set of positive proxies of data in the batch 54 | num_valid_proxies = len(with_pos_proxies) # The number of positive proxies 55 | 56 | P_sim_sum = torch.where(P_one_hot == 1, pos_exp, torch.zeros_like(pos_exp)).sum(dim=0) 57 | N_sim_sum = torch.where(N_one_hot == 1, neg_exp, torch.zeros_like(neg_exp)).sum(dim=0) 58 | 59 | pos_term = torch.log(1 + P_sim_sum).sum() / num_valid_proxies 60 | neg_term = torch.log(1 + N_sim_sum).sum() / self.nb_classes 61 | loss = pos_term + neg_term 62 | 63 | return loss 64 | 65 | 66 | class CompLoss(nn.Module): 67 | ''' 68 | Compactness Loss with class-conditional prototypes 69 | ''' 70 | def __init__(self, args, temperature=0.07, base_temperature=0.07, use_domain = False): 71 | super(CompLoss, self).__init__() 72 | self.args = args 73 | self.temperature = temperature 74 | self.base_temperature = base_temperature 75 | self.use_domain = use_domain 76 | 77 | def forward(self, features, prototypes, labels, domains): 78 | 79 | prototypes = F.normalize(prototypes, dim=1) 80 | proxy_labels = torch.arange(0, self.args.n_cls).cuda() 81 | labels = labels.contiguous().view(-1, 1) 82 | mask = torch.eq(labels, proxy_labels.T).float().cuda() #bz, cls 83 | 84 | # compute logits 85 | feat_dot_prototype = torch.div( 86 | torch.matmul(features, prototypes.T), 87 | self.temperature) 88 | if self.use_domain: 89 | domains = domains.contiguous().view(-1, 1) 90 | label_mask = torch.eq(labels, labels.T).float().cuda() #bz, bz 91 | neg_label_mask = 1 - label_mask 92 | domain_mask = torch.eq(domains, domains.T).float().cuda() #bz, bz 93 | feat_dot_feat = torch.div( 94 | torch.matmul(features, features.T), 95 | self.temperature) 96 | 97 | # for numerical stability 98 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 99 | feat_logits_max, _ = torch.max(feat_dot_feat, dim=1, keepdim=True) 100 | logits_max = torch.max(feat_logits_max, feat_logits_max) 101 | 102 | prot_logits = feat_dot_prototype - logits_max.detach() 103 | feat_logits = feat_dot_feat - logits_max.detach() 104 | 105 | exp_prot_logits = torch.exp(prot_logits) 106 | exp_feat_logits = torch.exp(feat_logits) 107 | 108 | pos_part = (prot_logits * mask).sum(1, keepdim=True) # (bz, 1) 109 | 110 | prot_neg_pairs = exp_prot_logits.sum(1, keepdim=True) 111 | same_domain_neg_pairs = (neg_label_mask * domain_mask * exp_feat_logits).sum(1, keepdim=True) 112 | neg_part = torch.log(prot_neg_pairs + same_domain_neg_pairs) # (bz, 1) 113 | 114 | loss = - (self.temperature / self.base_temperature) *(pos_part - neg_part).mean() 115 | 116 | else: 117 | # for numerical stability 118 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 119 | logits = feat_dot_prototype - logits_max.detach() 120 | 121 | # compute log_prob 122 | exp_logits = torch.exp(logits) 123 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 124 | 125 | # compute mean of log-likelihood over positive 126 | mean_log_prob_pos = (mask * log_prob).sum(1) 127 | 128 | # loss 129 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos.mean() 130 | return loss 131 | 132 | 133 | class CompNGLoss(nn.Module): 134 | ''' 135 | Compactness Loss with class-conditional prototypes (without negative pairs) 136 | ''' 137 | def __init__(self, args, temperature=0.1, base_temperature=0.1): 138 | super(CompNGLoss, self).__init__() 139 | self.args = args 140 | self.temperature = temperature 141 | self.base_temperature = base_temperature 142 | 143 | def forward(self, features, prototypes, labels): 144 | prototypes = F.normalize(prototypes, dim=1) 145 | proxy_labels = torch.arange(0, self.args.n_cls).cuda() 146 | labels = labels.contiguous().view(-1, 1) 147 | mask = torch.eq(labels, proxy_labels.T).float().cuda() #bz, cls 148 | # compute logits 149 | feat_dot_prototype = torch.div( 150 | torch.matmul(features, prototypes.T), 151 | self.temperature) 152 | 153 | # compute mean of log-likelihood over positive 154 | mean_log_prob_pos = (mask * feat_dot_prototype).sum(1) 155 | # loss 156 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos.mean() 157 | return loss 158 | 159 | class DisLPLoss(nn.Module): 160 | ''' 161 | Dispersion Loss with learnable prototypes 162 | ''' 163 | def __init__(self, args, model, loader, temperature= 0.1, base_temperature=0.1): 164 | super(DisLPLoss, self).__init__() 165 | self.args = args 166 | self.temperature = temperature 167 | self.base_temperature = base_temperature 168 | self.model = model 169 | self.loader = loader 170 | self.init_class_prototypes() 171 | 172 | def compute(self): 173 | num_cls = self.args.n_cls 174 | # l2-normalize the prototypes if not normalized 175 | prototypes = F.normalize(self.prototypes, dim=1) 176 | 177 | labels = torch.arange(0, num_cls).cuda() 178 | labels = labels.contiguous().view(-1, 1) 179 | 180 | mask = (1- torch.eq(labels, labels.T).float()).cuda() 181 | 182 | logits = torch.div( 183 | torch.matmul(prototypes, prototypes.T), 184 | self.temperature) 185 | 186 | mean_prob_neg = torch.log((mask * torch.exp(logits)).sum(1) / mask.sum(1)) 187 | mean_prob_neg = mean_prob_neg[~torch.isnan(mean_prob_neg)] 188 | # loss 189 | loss = self.temperature / self.base_temperature * mean_prob_neg.mean() 190 | 191 | return loss 192 | 193 | def init_class_prototypes(self): 194 | """Initialize class prototypes""" 195 | self.model.eval() 196 | start = time.time() 197 | prototype_counts = [0]*self.args.n_cls 198 | with torch.no_grad(): 199 | prototypes = torch.zeros(self.args.n_cls,self.args.feat_dim).cuda() 200 | #for input, target in self.loader: 201 | for i, (input, target, domain) in enumerate(self.loader): 202 | input, target = input.cuda(), target.cuda() 203 | features = self.model(input) # extract normalized features 204 | for j, feature in enumerate(features): 205 | prototypes[target[j].item()] += feature 206 | prototype_counts[target[j].item()] += 1 207 | for cls in range(self.args.n_cls): 208 | prototypes[cls] /= prototype_counts[cls] 209 | # measure elapsed time 210 | duration = time.time() - start 211 | print(f'Time to initialize prototypes: {duration:.3f}') 212 | prototypes = F.normalize(prototypes, dim=1) 213 | self.prototypes = torch.nn.Parameter(prototypes) 214 | 215 | class DisLoss(nn.Module): 216 | ''' 217 | Dispersion Loss with EMA prototypes 218 | ''' 219 | def __init__(self, args, model, loader, temperature= 0.1, base_temperature=0.1): 220 | super(DisLoss, self).__init__() 221 | self.args = args 222 | self.temperature = temperature 223 | self.base_temperature = base_temperature 224 | self.register_buffer("prototypes", torch.zeros(self.args.n_cls,self.args.feat_dim)) 225 | self.model = model 226 | self.loader = loader 227 | self.init_class_prototypes() 228 | 229 | def forward(self, features, labels): 230 | 231 | prototypes = self.prototypes 232 | num_cls = self.args.n_cls 233 | for j in range(len(features)): 234 | prototypes[labels[j].item()] = F.normalize(prototypes[labels[j].item()] *self.args.proto_m + features[j]*(1-self.args.proto_m), dim=0) 235 | self.prototypes = prototypes.detach() 236 | labels = torch.arange(0, num_cls).cuda() 237 | labels = labels.contiguous().view(-1, 1) 238 | 239 | mask = (1- torch.eq(labels, labels.T).float()).cuda() 240 | 241 | logits = torch.div( 242 | torch.matmul(prototypes, prototypes.T), 243 | self.temperature) 244 | 245 | logits_mask = torch.scatter( 246 | torch.ones_like(mask), 247 | 1, 248 | torch.arange(num_cls).view(-1, 1).cuda(), 249 | 0 250 | ) 251 | mask = mask * logits_mask 252 | mean_prob_neg = torch.log((mask * torch.exp(logits)).sum(1) / mask.sum(1)) 253 | mean_prob_neg = mean_prob_neg[~torch.isnan(mean_prob_neg)] 254 | loss = self.temperature / self.base_temperature * mean_prob_neg.mean() 255 | return loss 256 | 257 | def init_class_prototypes(self): 258 | """Initialize class prototypes""" 259 | self.model.eval() 260 | start = time.time() 261 | prototype_counts = [0]*self.args.n_cls 262 | with torch.no_grad(): 263 | prototypes = torch.zeros(self.args.n_cls,self.args.feat_dim).cuda() 264 | for i, values in enumerate(self.loader): 265 | if len(values) == 3: 266 | input, target, domain = values 267 | elif len(values) == 2: 268 | input, target = values 269 | domain = None 270 | input, target = input.cuda(), target.cuda() 271 | features = self.model(input) 272 | for j, feature in enumerate(features): 273 | prototypes[target[j].item()] += feature 274 | prototype_counts[target[j].item()] += 1 275 | for cls in range(self.args.n_cls): 276 | prototypes[cls] /= prototype_counts[cls] 277 | # measure elapsed time 278 | duration = time.time() - start 279 | print(f'Time to initialize prototypes: {duration:.3f}') 280 | prototypes = F.normalize(prototypes, dim=1) 281 | self.prototypes = prototypes 282 | -------------------------------------------------------------------------------- /train_hypo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import time 5 | from datetime import datetime 6 | import logging 7 | import tensorboard_logger as tb_logger 8 | import pprint 9 | 10 | import torch 11 | import torch.nn.parallel 12 | import torch.nn.functional as F 13 | import torch.optim 14 | import torch.utils.data 15 | import numpy as np 16 | 17 | 18 | import wandb 19 | wandb.login() 20 | 21 | from sklearn.metrics import accuracy_score 22 | 23 | from utils import (CompLoss, CompNGLoss, DisLoss, DisLPLoss, 24 | AverageMeter, adjust_learning_rate, warmup_learning_rate, 25 | set_loader_small, set_loader_ImageNet, set_model) 26 | 27 | parser = argparse.ArgumentParser(description='Script for training with HYPO') 28 | parser.add_argument('--gpu', default=6, type=int, help='which GPU to use') 29 | parser.add_argument('--seed', default=4, type=int, help='random seed') # original 4 30 | parser.add_argument('--w', default=2, type=float, 31 | help='loss scale') 32 | parser.add_argument('--proto_m', default= 0.95, type=float, 33 | help='weight of prototype update') 34 | parser.add_argument('--feat_dim', default = 128, type=int, 35 | help='feature dim') 36 | parser.add_argument('--in-dataset', default="CIFAR-10", type=str, help='in-distribution dataset', choices=['PACS', 'VLCS', 'CIFAR-10', 'CIFAR-100', 'ImageNet-100', 'OfficeHome', 'terra_incognita']) 37 | parser.add_argument('--id_loc', default="datasets/CIFAR10", type=str, help='location of in-distribution dataset') 38 | parser.add_argument('--model', default='resnet18', type=str, help='model architecture: [resnet18, wrt40, wrt28, wrt16, densenet100, resnet50, resnet34]') 39 | parser.add_argument('--head', default='mlp', type=str, help='either mlp or linear head') 40 | parser.add_argument('--loss', default = 'hypo', type=str, choices = ['hypo', 'erm'], 41 | help='name of experiment') 42 | parser.add_argument('--epochs', default=50, type=int, 43 | help='number of total epochs to run') 44 | parser.add_argument('--trial', type=str, default='0', 45 | help='id for recording multiple runs') 46 | parser.add_argument('--save-epoch', default=100, type=int, 47 | help='save the model every save_epoch') 48 | parser.add_argument('--start-epoch', default=0, type=int, 49 | help='manual epoch number (useful on restarts)') 50 | parser.add_argument('-b', '--batch_size', default= 32, type=int, #512 # batch-size 51 | help='mini-batch size (default: 64)') 52 | parser.add_argument('--learning_rate', default=5e-4, type=float, 53 | help='initial learning rate') 54 | # if linear lr schedule 55 | parser.add_argument('--lr_decay_epochs', type=str, default='100,150,180', 56 | help='where to decay lr, can be a list') 57 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, 58 | help='decay rate for learning rate') 59 | # if cosine lr schedule 60 | parser.add_argument('--cosine', action='store_true', 61 | help='using cosine annealing') 62 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 63 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 64 | help='weight decay (default: 0.0001)') 65 | parser.add_argument('--print-freq', '-p', default=10, type=int, 66 | help='print frequency (default: 10)') 67 | parser.add_argument('--temp', type=float, default=0.1, 68 | help='temperature for loss function') 69 | parser.add_argument('--warm', action='store_true', 70 | help='warm-up for large batch training') 71 | parser.add_argument('--normalize', action='store_true', 72 | help='normalize feat embeddings') 73 | 74 | # add pacs specific 75 | parser.add_argument('--prefetch', type=int, default=4, help='Pre-fetching threads.') 76 | parser.add_argument('--target_domain', type=str, default='cartoon') 77 | 78 | # debug 79 | parser.add_argument('--use_domain', type=bool, default=False, help='whether to use in-domain negative pairs in compactness loss') 80 | parser.add_argument('--mode', default='online', choices = ['online','disabled'], help='whether disable wandb logging') 81 | 82 | parser.set_defaults(bottleneck=True) 83 | parser.set_defaults(augment=True) 84 | 85 | args = parser.parse_args() 86 | torch.cuda.set_device(args.gpu) 87 | state = {k: v for k, v in args._get_kwargs()} 88 | 89 | date_time = datetime.now().strftime("%d_%m_%H:%M") 90 | 91 | #processing str to list for linear lr scheduling 92 | args.lr_decay_epochs = [int(step) for step in args.lr_decay_epochs.split(',')] 93 | 94 | 95 | if args.in_dataset == "ImageNet-100" or args.in_dataset == 'CIFAR-10': 96 | args.name = (f"{date_time}_{args.loss}_{args.model}_lr_{args.learning_rate}_cosine_" 97 | f"{args.cosine}_bsz_{args.batch_size}_head_{args.head}_wd_{args.w}_{args.epochs}_{args.feat_dim}_" 98 | f"trial_{args.trial}_temp_{args.temp}_{args.in_dataset}_pm_{args.proto_m}") 99 | 100 | else: 101 | args.name = (f"{date_time}_{args.loss}_std_{args.model}_lr_{args.learning_rate}_cosine_" 102 | f"{args.cosine}_bsz_{args.batch_size}_td_{args.target_domain}_head_{args.head}_wd_{args.w}_{args.epochs}_{args.feat_dim}_" 103 | f"trial_{args.trial}_temp_{args.temp}_{args.in_dataset}_pm_{args.proto_m}") 104 | 105 | args.log_directory = "logs/{in_dataset}/{name}/".format(in_dataset=args.in_dataset, name= args.name) 106 | 107 | args.model_directory = "/nobackup2/yf/checkpoints/hypo_cr/{in_dataset}/{name}/".format(in_dataset=args.in_dataset, name= args.name) 108 | 109 | 110 | args.tb_path = './save/hypo/{}_tensorboard'.format(args.in_dataset) 111 | if not os.path.exists(args.model_directory): 112 | os.makedirs(args.model_directory) 113 | if not os.path.exists(args.log_directory): 114 | os.makedirs(args.log_directory) 115 | args.tb_folder = os.path.join(args.tb_path, args.name) 116 | if not os.path.isdir(args.tb_folder): 117 | os.makedirs(args.tb_folder) 118 | 119 | #save args 120 | with open(os.path.join(args.log_directory, 'train_args.txt'), 'w') as f: 121 | f.write(pprint.pformat(state)) 122 | 123 | #init log 124 | log = logging.getLogger(__name__) 125 | formatter = logging.Formatter('%(asctime)s : %(message)s') 126 | fileHandler = logging.FileHandler(os.path.join(args.log_directory, "train_info.log"), mode='w') 127 | fileHandler.setFormatter(formatter) 128 | streamHandler = logging.StreamHandler() 129 | streamHandler.setFormatter(formatter) 130 | log.setLevel(logging.DEBUG) 131 | log.addHandler(fileHandler) 132 | log.addHandler(streamHandler) 133 | 134 | log.debug(state) 135 | 136 | if args.in_dataset == "CIFAR-10": 137 | args.n_cls = 10 138 | elif args.in_dataset == "PACS": 139 | args.n_cls = 7 140 | elif args.in_dataset == "VLCS": 141 | args.n_cls = 5 142 | elif args.in_dataset == "OfficeHome": 143 | args.n_cls = 65 144 | elif args.in_dataset == 'terra_incognita': 145 | args.n_cls = 10 146 | elif args.in_dataset in ["CIFAR-100", "ImageNet-100"]: 147 | args.n_cls = 100 148 | 149 | #set seeds 150 | torch.manual_seed(args.seed) 151 | torch.cuda.manual_seed(args.seed) 152 | np.random.seed(args.seed) 153 | log.debug(f"{args.name}") 154 | 155 | # warm-up for large-batch training 156 | if args.batch_size > 256: 157 | args.warm = True 158 | if args.warm: 159 | args.warmup_from = 0.001 160 | args.warm_epochs = 10 161 | if args.cosine: 162 | eta_min = args.learning_rate * (args.lr_decay_rate ** 3) 163 | args.warmup_to = eta_min + (args.learning_rate - eta_min) * ( 164 | 1 + math.cos(math.pi * args.warm_epochs / args.epochs)) / 2 165 | else: 166 | args.warmup_to = args.learning_rate 167 | 168 | def to_np(x): return x.data.cpu().numpy() 169 | 170 | def main(): 171 | tb_log = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2) 172 | 173 | wandb.init( 174 | # Set the project where this run will be logged 175 | project="hypo", 176 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) 177 | name=args.name, 178 | # Track hyperparameters and run metadata 179 | mode=args.mode, 180 | config=args) 181 | 182 | if args.in_dataset == "ImageNet-100": 183 | train_loader, val_loader, test_loader = set_loader_ImageNet(args) 184 | else: 185 | if args.in_dataset == 'CIFAR-10' or args.in_dataset == 'CIFAR-100': 186 | train_loader, val_loader, test_loader = set_loader_small(args) 187 | else: 188 | train_loader, val_loader, test_loader = set_loader_small(args) 189 | 190 | 191 | model = set_model(args) 192 | 193 | criterion_comp = CompLoss(args, temperature=args.temp, use_domain = args.use_domain).cuda() 194 | criterion_dis = DisLoss(args, model, val_loader, temperature=args.temp).cuda() 195 | 196 | optimizer = torch.optim.SGD(model.parameters(), lr = args.learning_rate, 197 | momentum=args.momentum, 198 | nesterov=True, 199 | weight_decay=args.weight_decay) 200 | 201 | acc_best = 0.0 202 | acc_best_id = 0.0 203 | acc_test_best = 0.0 204 | acc_test_best_id = 0.0 205 | epoch_test_best = 0.0 206 | epoch_val_best = 0.0 207 | for epoch in range(args.start_epoch, args.epochs): 208 | adjust_learning_rate(args, optimizer, epoch) 209 | # train for one epoch 210 | train_sloss, train_uloss, train_dloss, acc, acc_cor= train_hypo(args, train_loader, val_loader, test_loader, model, criterion_comp, criterion_dis, optimizer, epoch, log) 211 | 212 | tb_log.log_value('train_uni_loss', train_uloss, epoch) 213 | tb_log.log_value('train_dis_loss', train_dloss, epoch) 214 | wandb.log({'Comp Loss Ep': train_uloss,'Dis Loss Ep': train_dloss }) 215 | tb_log.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) 216 | wandb.log({'current lr': optimizer.param_groups[0]['lr'], 'acc':acc, 'acc cor': acc_cor}) 217 | 218 | # save checkpoint 219 | if acc >= acc_best_id: 220 | acc_best = acc_cor 221 | acc_best_id = acc 222 | epoch_val_best = epoch 223 | wandb.log({'val best ood acc': acc_best, 'val best id acc':acc_best_id}) 224 | print('best accuracy {} at epoch {}'.format(acc_best_id, epoch)) 225 | print('accuracy cor {} at epoch {}'.format(acc_best, epoch)) 226 | save_checkpoint(args, { 227 | 'epoch': epoch + 1, 228 | 'state_dict': model.state_dict(), 229 | 'opt_state_dict': optimizer.state_dict(), 230 | 'dis_state_dict': criterion_dis.state_dict(), 231 | 'uni_state_dict': criterion_comp.state_dict(), 232 | }, epoch + 1, save_best=True) 233 | 234 | if acc_cor >= acc_test_best: 235 | acc_test_best = acc_cor 236 | acc_test_best_id = acc 237 | epoch_test_best = epoch 238 | wandb.log({'test best ood acc': acc_test_best, 'test best id acc':acc_test_best_id}) 239 | print('best test accuracy {} at epoch {}'.format(acc_test_best, epoch)) 240 | 241 | print('total val best ood accuracy {} id accuracy {} at epoch {}'.format(acc_best, acc_best_id, epoch_val_best)) 242 | print('total test best ood accuracy {} id accuracy {} at epoch {}'.format(acc_test_best, acc_test_best_id, epoch_test_best)) 243 | print('last epoch ood accuracy {} id accuracy {} at epoch {}'.format(acc_cor, acc, epoch)) 244 | 245 | summary_metrics = { 246 | 'val best ood accuracy': acc_best, 247 | 'val best id accuracy': acc_best_id, 248 | 'val best epoch': epoch_val_best, 249 | 'test best ood accuracy': acc_test_best, 250 | 'test best id accuracy': acc_test_best_id, 251 | 'test best epoch': epoch_test_best, 252 | 'last ood accuracy': acc_cor, 253 | 'last id accuracy': acc, 254 | 'last epoch': epoch 255 | } 256 | 257 | for metric_name, metric_value in summary_metrics.items(): 258 | wandb.summary[metric_name] = metric_value 259 | 260 | def train_hypo(args, train_loader, val_loader, test_loader, model, criterion_comp, criterion_dis, optimizer, epoch, log): 261 | """Train for one epoch on the training set""" 262 | batch_time = AverageMeter() 263 | supcon_losses = AverageMeter() 264 | comp_losses = AverageMeter() 265 | dis_losses = AverageMeter() 266 | losses = AverageMeter() 267 | 268 | model.train() 269 | end = time.time() 270 | for i, values in enumerate(train_loader): 271 | if len(values) == 3: 272 | input, target, domain = values 273 | elif len(values) == 2: 274 | input, target = values 275 | domain = None 276 | 277 | warmup_learning_rate(args, epoch, i, len(train_loader), optimizer) 278 | bsz = target.shape[0] 279 | 280 | input = input.cuda() 281 | target = target.cuda() 282 | 283 | penultimate = model.encoder(input).squeeze() 284 | 285 | if args.normalize: # default: False 286 | penultimate= F.normalize(penultimate, dim=1) 287 | features= model.head(penultimate) 288 | features= F.normalize(features, dim=1) 289 | 290 | dis_loss = criterion_dis(features, target) 291 | comp_loss = criterion_comp(features, criterion_dis.prototypes, target, None) 292 | 293 | loss = args.w * comp_loss + dis_loss 294 | 295 | dis_losses.update(dis_loss.data, input.size(0)) 296 | comp_losses.update(comp_loss.data, input.size(0)) 297 | 298 | optimizer.zero_grad() 299 | loss.backward() 300 | optimizer.step() 301 | 302 | batch_time.update(time.time() - end) 303 | end = time.time() 304 | if i % args.print_freq == 0: 305 | 306 | log.debug('Epoch: [{0}][{1}/{2}]\t' 307 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 308 | 'Dis Loss {dloss.val:.4f} ({dloss.avg:.4f})\t' 309 | 'Comp Loss {uloss.val:.4f} ({uloss.avg:.4f})\t'.format( 310 | epoch, i, len(train_loader), batch_time=batch_time, dloss=dis_losses, uloss = comp_losses)) 311 | 312 | wandb.log({'Dis Loss' : dis_losses.val, 'Comp Loss' : comp_losses.val}) 313 | 314 | 315 | model.eval() 316 | with torch.no_grad(): 317 | accuracies = [] 318 | for i, values in enumerate(val_loader): 319 | if len(values) == 3: 320 | input, target, domain = values 321 | elif len(values) == 2: 322 | input, target = values 323 | domain = None 324 | input = input.cuda() 325 | target = target.cuda() 326 | 327 | features = model.forward(input) 328 | feat_dot_prototype = torch.div(torch.matmul(features, criterion_dis.prototypes.T), args.temp) 329 | 330 | # for numerical stability 331 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 332 | logits = feat_dot_prototype - logits_max.detach() 333 | 334 | pred = logits.data.max(1)[1] 335 | 336 | accuracies.append(accuracy_score(list(to_np(pred)), list(to_np(target)))) 337 | 338 | acc = sum(accuracies) / len(accuracies) 339 | 340 | if test_loader is None: 341 | acc_cor = 0. 342 | else: 343 | with torch.no_grad(): 344 | accuracies_cor = [] 345 | for i, values in enumerate(test_loader): 346 | if len(values) == 3: 347 | input, target, domain = values 348 | elif len(values) == 2: 349 | input, target = values 350 | domain = None 351 | input = input.cuda() 352 | target = target.cuda() 353 | 354 | features = model.forward(input) 355 | feat_dot_prototype = torch.div(torch.matmul(features, criterion_dis.prototypes.T), args.temp) 356 | 357 | # for numerical stability 358 | logits_max, _ = torch.max(feat_dot_prototype, dim=1, keepdim=True) 359 | logits = feat_dot_prototype - logits_max.detach() 360 | 361 | pred = logits.data.max(1)[1] 362 | accuracies_cor.append(accuracy_score(list(to_np(pred)), list(to_np(target)))) 363 | 364 | acc_cor = sum(accuracies_cor) / len(accuracies_cor) 365 | 366 | # measure elapsed time 367 | return supcon_losses.avg, comp_losses.avg, dis_losses.avg, acc, acc_cor 368 | 369 | 370 | 371 | def save_checkpoint(args, state, epoch, save_best = False): 372 | """Saves checkpoint to disk""" 373 | if save_best: 374 | filename = args.model_directory + 'checkpoint_max.pth.tar' 375 | else: 376 | filename = args.model_directory + f'checkpoint_{epoch}.pth.tar' 377 | torch.save(state, filename) 378 | 379 | 380 | if __name__ == '__main__': 381 | main() 382 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import os 5 | 6 | from torch import nn 7 | from models.resnet import SupCEHeadResNet 8 | 9 | 10 | import numpy as np 11 | import torch 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torchvision import datasets, transforms 15 | import torchvision.transforms as transforms 16 | 17 | class TwoCropTransform: 18 | """Create two crops of the same image""" 19 | def __init__(self, transform): 20 | self.transform = transform 21 | 22 | def __call__(self, x): 23 | return [self.transform(x), self.transform(x)] 24 | 25 | class TwoCrop: 26 | def __init__(self): 27 | self.transform = 0 28 | def __call__(self, x): 29 | return [x, x] 30 | 31 | 32 | 33 | class AverageMeter(object): 34 | """Computes and stores the average and current value""" 35 | def __init__(self): 36 | self.reset() 37 | 38 | def reset(self): 39 | self.val = 0 40 | self.avg = 0 41 | self.sum = 0 42 | self.count = 0 43 | 44 | def update(self, val, n=1): 45 | self.val = val 46 | self.sum += val * n 47 | self.count += n 48 | self.avg = self.sum / self.count 49 | 50 | 51 | def accuracy(output, labels, topk=(1, 5)): 52 | """Computes the top-k accuracy for a given set of model outputs and labels. 53 | 54 | Args: 55 | output (torch.Tensor): The model outputs, with shape (batch_size, num_classes). 56 | labels (torch.Tensor): The true labels, with shape (batch_size,). 57 | topk (tuple, optional): The top-k values for which to compute the accuracy. Default is (1, 5). 58 | 59 | Returns: 60 | list: A list of the top-k accuracies. 61 | """ 62 | with torch.no_grad(): 63 | maxk = max(topk) 64 | batch_size = labels.size(0) 65 | 66 | _, pred = output.topk(maxk, 1, True, True) 67 | pred = pred.t() 68 | correct = pred.eq(labels.view(1, -1).expand_as(pred)) 69 | 70 | res = [] 71 | for k in topk: 72 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 73 | res.append(correct_k.mul_(100.0 / batch_size).item()) 74 | return res 75 | 76 | def adjust_learning_rate(args, optimizer, epoch): 77 | lr = args.learning_rate 78 | if args.cosine: 79 | eta_min = lr * (args.lr_decay_rate ** 3) 80 | lr = eta_min + (lr - eta_min) * ( 81 | 1 + math.cos(math.pi * epoch / args.epochs)) / 2 82 | else: 83 | steps = np.sum(epoch > np.asarray(args.lr_decay_epochs)) 84 | if steps > 0: 85 | lr = lr * (args.lr_decay_rate ** steps) 86 | 87 | for param_group in optimizer.param_groups: 88 | param_group['lr'] = lr 89 | 90 | def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer): 91 | if args.warm and epoch <= args.warm_epochs: 92 | p = (batch_id + (epoch - 1) * total_batches) / \ 93 | (args.warm_epochs * total_batches) 94 | lr = args.warmup_from + p * (args.warmup_to - args.warmup_from) 95 | 96 | for param_group in optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | 100 | def set_optimizer(opt, model): 101 | optimizer = optim.SGD(model.parameters(), 102 | lr=opt.learning_rate, 103 | momentum=opt.momentum, 104 | weight_decay=opt.weight_decay) 105 | return optimizer 106 | 107 | def set_loader_small(args, eval = False, batch_size = None, img_size = 32): 108 | root = args.id_loc 109 | if batch_size is None: 110 | batch_size = args.batch_size 111 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 112 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 113 | # data augmentations for supcon 114 | train_transform_supcon = transforms.Compose([ 115 | transforms.RandomResizedCrop(size=img_size, scale=(0.2, 1.)), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.RandomApply([ 118 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 119 | ], p=0.8), 120 | transforms.RandomGrayscale(p=0.2), 121 | transforms.ToTensor(), 122 | normalize, 123 | ]) 124 | 125 | transform_test = transforms.Compose([ 126 | transforms.ToTensor(), 127 | normalize 128 | ]) 129 | 130 | kwargs = {'num_workers': 4, 'pin_memory': True} 131 | if args.in_dataset == "CIFAR-10": 132 | # Data loading code 133 | if eval: 134 | dataset = datasets.CIFAR10(root, train=True, download=True, transform=transform_test) 135 | if args.subset: 136 | dataset = torch.utils.data.Subset(dataset , np.random.choice(len(dataset), 20000, replace=False)) 137 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs) 138 | else: 139 | train_data = datasets.CIFAR10(root, train=True, download=True, 140 | transform=train_transform_supcon) 141 | 142 | train_loader = torch.utils.data.DataLoader( 143 | train_data, 144 | batch_size=args.batch_size, shuffle=True, **kwargs) 145 | 146 | 147 | val_loader = torch.utils.data.DataLoader( 148 | datasets.CIFAR10(root, train=False, transform=transform_test), 149 | batch_size=args.batch_size, shuffle=False, **kwargs) 150 | test_loader = torch.utils.data.DataLoader( 151 | datasets.CIFAR10(root, train=False, transform=transform_test), 152 | batch_size=args.batch_size, shuffle=False, **kwargs) 153 | elif args.in_dataset == "CIFAR-100": 154 | # Data loading code 155 | if eval: 156 | dataset = datasets.CIFAR100(root, train=True, download=True, transform=transform_test) 157 | if args.subset: 158 | dataset = torch.utils.data.Subset(dataset , np.random.choice(len(dataset), 20000, replace=False)) 159 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs) 160 | else: 161 | train_loader = torch.utils.data.DataLoader( 162 | datasets.CIFAR100(root, train=True, download=True, 163 | transform=train_transform_supcon), 164 | batch_size=args.batch_size, shuffle=True, **kwargs) 165 | val_loader = torch.utils.data.DataLoader( 166 | datasets.CIFAR100(root, train=False, transform=transform_test), 167 | batch_size=args.batch_size, shuffle=False, **kwargs) 168 | test_loader = torch.utils.data.DataLoader( 169 | datasets.CIFAR100(root, train=False, transform=transform_test), 170 | batch_size=args.batch_size, shuffle=False, **kwargs) 171 | 172 | elif args.in_dataset == "ImageNet-100": 173 | traindir = os.path.join('./ImageNet-100', 'train') 174 | valdir = os.path.join('./ImageNet-100', 'val') 175 | 176 | mean = [0.485, 0.456, 0.406] 177 | std=[0.229, 0.224, 0.225] 178 | 179 | train_data = datasets.ImageFolder( 180 | traindir, 181 | TwoCropTransform(transforms.Compose([ 182 | transforms.RandomResizedCrop(224), 183 | transforms.RandomHorizontalFlip(), 184 | transforms.ToTensor(), 185 | transforms.Normalize(mean, std), 186 | ])) 187 | ) 188 | 189 | val_data = datasets.ImageFolder( 190 | valdir, 191 | transforms.Compose([ 192 | transforms.Resize(256), 193 | transforms.CenterCrop(224), 194 | transforms.ToTensor(), 195 | transforms.Normalize(mean, std), 196 | ])) 197 | 198 | from dataloader.corimagenetLoader import CorIMAGENETDataset as Dataset 199 | test_data = Dataset('test', cortype) 200 | 201 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True) 202 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 203 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 204 | 205 | elif args.in_dataset == "PACS": 206 | print('loading PACS') 207 | from dataloader.PACSLoader import PACSDataset as Dataset 208 | 209 | train_data = Dataset('train', args.target_domain) 210 | val_data = Dataset('val', args.target_domain) 211 | test_data = Dataset('test', args.target_domain) 212 | 213 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True) 214 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=True) 215 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 216 | 217 | elif args.in_dataset == "OfficeHome": 218 | print('loading OfficeHome') 219 | from dataloader.OfficeHomeLoader import OfficeHomeDataset as Dataset 220 | 221 | train_data = Dataset('train', args.target_domain) 222 | val_data = Dataset('val', args.target_domain) 223 | test_data = Dataset('test', args.target_domain) 224 | 225 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True) 226 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 227 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 228 | 229 | elif args.in_dataset == "VLCS": 230 | print('loading VLCS') 231 | from dataloader.VLCSLoader import VLCSDataset as Dataset 232 | 233 | train_data = Dataset('train', args.target_domain) 234 | val_data = Dataset('val', args.target_domain) 235 | test_data = Dataset('test', args.target_domain) 236 | 237 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True) 238 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 239 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 240 | 241 | elif args.in_dataset == "terra_incognita": 242 | print('loading terra incognita') 243 | from dataloader.TerraLoader import TerraDataset as Dataset 244 | 245 | train_data = Dataset('train', args.target_domain) 246 | val_data = Dataset('val', args.target_domain) 247 | test_data = Dataset('test', args.target_domain) 248 | 249 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True) 250 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 251 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 252 | 253 | return train_loader, val_loader, test_loader 254 | 255 | def set_loader_ImageNet(args, eval = False, batch_size = None): 256 | #root = args.id_loc 257 | root = './ImageNet-100' 258 | if batch_size is None: 259 | batch_size = args.batch_size 260 | # for ImageNet 261 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 262 | std=[0.229, 0.224, 0.225]) 263 | 264 | train_transform_supcon = transforms.Compose([ 265 | transforms.RandomResizedCrop(size=224, scale=(0.4, 1.)), 266 | transforms.RandomHorizontalFlip(), 267 | transforms.RandomApply([ 268 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 269 | ], p=0.8), 270 | transforms.RandomGrayscale(p=0.2), 271 | transforms.ToTensor(), 272 | normalize, 273 | ]) 274 | transform_test = transforms.Compose([ 275 | transforms.Resize(224), 276 | transforms.CenterCrop(224), 277 | transforms.ToTensor(), 278 | normalize 279 | ]) 280 | kwargs = {'num_workers': 4, 'pin_memory': True} 281 | if batch_size is not None: 282 | args.batch_size = batch_size 283 | 284 | # Data loading code 285 | if eval: 286 | dataset = datasets.ImageFolder(os.path.join(root, 'train'), transform=transform_test) 287 | if args.subset: 288 | dataset = torch.utils.data.Subset(dataset , np.random.choice(len(dataset), 20000, replace=False)) 289 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, **kwargs) 290 | else: 291 | dataset = datasets.ImageFolder(os.path.join(root, 'train'), 292 | transform=train_transform_supcon) 293 | train_loader = torch.utils.data.DataLoader( 294 | dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 295 | val_loader = torch.utils.data.DataLoader( 296 | datasets.ImageFolder(os.path.join(root, 'val'),transform=transform_test), 297 | batch_size=args.batch_size, shuffle=False, **kwargs) 298 | 299 | from dataloader.corimagenetLoader import CorIMAGENETDataset as Dataset 300 | cortype='gaussian_noise' 301 | test_data = Dataset('test', cortype) 302 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True) 303 | 304 | return train_loader, val_loader, test_loader 305 | 306 | 307 | def set_model(args): 308 | 309 | # create model 310 | model = SupCEHeadResNet(args) 311 | # get the number of model parameters 312 | print('Number of model parameters: {}'.format( 313 | sum([p.data.nelement() for p in model.parameters()]))) 314 | torch.backends.cudnn.deterministic = True 315 | torch.backends.cudnn.benchmark = True 316 | model = model.cuda() 317 | 318 | 319 | return model 320 | 321 | def sample_estimator(model, classifier, num_classes, feature_list, train_loader): 322 | """ 323 | compute sample mean and precision (inverse of covariance) 324 | return: sample_class_mean: list of class mean 325 | precision: list of precisions 326 | """ 327 | import sklearn.covariance 328 | 329 | model.eval() 330 | group_lasso = sklearn.covariance.EmpiricalCovariance(assume_centered=False) 331 | correct, total = 0, 0 332 | num_output = len(feature_list) 333 | num_sample_per_class = np.empty(num_classes) 334 | num_sample_per_class.fill(0) 335 | list_features = [] 336 | for i in range(num_output): 337 | temp_list = [] 338 | for j in range(num_classes): 339 | temp_list.append(0) 340 | list_features.append(temp_list) 341 | 342 | for data, target in train_loader: 343 | total += data.size(0) 344 | data = data.cuda() 345 | penultimate, out_features = model.encoder.feature_list(data) 346 | output = classifier(penultimate) 347 | # output, out_features = model.module.feature_list(data) 348 | 349 | # get hidden features 350 | for i in range(num_output): 351 | out_features[i] = out_features[i].view(out_features[i].size(0), out_features[i].size(1), -1) 352 | out_features[i] = torch.mean(out_features[i].data, 2) 353 | #TEMP 354 | # out_features[-1] = out_features[i] / out_features[i].norm(p=2, dim=1, keepdim=True) 355 | out_features[-1] = F.normalize(out_features[-1], dim=1) 356 | # compute the accuracy 357 | pred = output.data.max(1)[1] 358 | equal_flag = pred.eq(target.cuda()).cpu() 359 | correct += equal_flag.sum() 360 | 361 | # construct the sample matrix 362 | for i in range(data.size(0)): 363 | label = target[i] 364 | if num_sample_per_class[label] == 0: 365 | out_count = 0 366 | for out in out_features: 367 | list_features[out_count][label] = out[i].view(1, -1) 368 | out_count += 1 369 | else: 370 | out_count = 0 371 | for out in out_features: 372 | list_features[out_count][label] \ 373 | = torch.cat((list_features[out_count][label], out[i].view(1, -1)), 0) 374 | out_count += 1 375 | num_sample_per_class[label] += 1 376 | 377 | sample_class_mean = [] 378 | out_count = 0 379 | for num_feature in feature_list: 380 | temp_list = torch.Tensor(num_classes, int(num_feature)).cuda() 381 | for j in range(num_classes): 382 | temp_list[j] = torch.mean(list_features[out_count][j], 0) 383 | sample_class_mean.append(temp_list) 384 | out_count += 1 385 | 386 | precision = [] 387 | for k in range(num_output): 388 | X = 0 389 | for i in range(num_classes): 390 | if i == 0: 391 | X = list_features[k][i] - sample_class_mean[k][i] 392 | else: 393 | X = torch.cat((X, list_features[k][i] - sample_class_mean[k][i]), 0) 394 | 395 | # find inverse 396 | group_lasso.fit(X.cpu().numpy()) 397 | temp_precision = group_lasso.precision_ 398 | temp_precision = torch.from_numpy(temp_precision).float().cuda() 399 | precision.append(temp_precision) 400 | 401 | print('\n Training Accuracy:({:.2f}%)\n'.format(100. * correct / total)) 402 | 403 | return sample_class_mean, precision 404 | 405 | 406 | def estimate_dataset_mean_std(name = 'cifar10'): 407 | data = datasets.CIFAR10(root='./datasets/cifar10', train=True, download=True, 408 | transform=transforms.ToTensor()).data 409 | data = data.astype(np.float32)/255. 410 | 411 | means = [] 412 | stdevs = [] 413 | for i in range(3): 414 | pixels = data[:,:,:,i].ravel() 415 | means.append(np.mean(pixels)) 416 | stdevs.append(np.std(pixels)) 417 | 418 | print("means: {}".format(means)) 419 | print("stdevs: {}".format(stdevs)) 420 | print('transforms.Normalize(mean = {}, std = {})'.format(means, stdevs)) 421 | 422 | if __name__ == '__main__': 423 | estimate_dataset_mean_std() 424 | --------------------------------------------------------------------------------