├── README.md ├── collect_metrics.py ├── configs ├── basic.json ├── linear.json ├── linear │ ├── cifar10.json │ ├── imagenet.json │ ├── svhn.json │ └── tinyimagenet.json ├── pcssr.json ├── pcssr │ ├── cifar10.json │ ├── imagenet.json │ ├── svhn.json │ └── tinyimagenet.json ├── rcssr.json └── rcssr │ ├── cifar10.json │ ├── imagenet.json │ ├── svhn.json │ └── tinyimagenet.json ├── dataset.py ├── exps ├── cifar10 │ ├── spl_a.json │ ├── spl_b.json │ ├── spl_c.json │ ├── spl_d.json │ └── spl_e.json ├── imagenet │ └── vs_inaturalist.json ├── svhn │ ├── spl_a.json │ ├── spl_b.json │ ├── spl_c.json │ ├── spl_d.json │ └── spl_e.json └── tinyimagenet │ ├── spl_a.json │ ├── spl_b.json │ ├── spl_c.json │ ├── spl_d.json │ └── spl_e.json ├── main.py ├── methods ├── augtools.py ├── cssr.py ├── cssr_ft.py ├── resnet.py ├── util.py └── wideresnet.py ├── metrics.py └── run.sh /README.md: -------------------------------------------------------------------------------- 1 | # Class Specific Semantic Reconstruction for Open Set Recognition [TPAMI 2022] 2 | 3 | Official PyTorch implementation of [Class Specific Semantic Reconstruction for Open Set Recognition](https://ieeexplore.ieee.org/document/9864101). 4 | 5 | ## 1. Train 6 | 7 | Before training, please setup dataset directories in `dataset.py`: 8 | ``` 9 | DATA_PATH = '' # path for cifar10, svhn 10 | TINYIMAGENET_PATH = '' # path for tinyimagenet 11 | LARGE_OOD_PATH = '' # path for ood datasets, e.g., iNaturalist in imagenet experiment 12 | IMAGENET_PATH = '' # path for imagenet-1k datasets 13 | ``` 14 | 15 | To train models from scratch, run command: 16 | ``` 17 | python main.py --gpu 0 --ds {DATASET} --config {MODEL} --save {SAVING_NAME} --method cssr 18 | ``` 19 | 20 | Command options: 21 | - **DATASET:** Experiment configuration file, specifying datasets and random splits, e.g., `./exps/$dataset/spl_$s.json`. 22 | - **MODEL:** OSR model configuration file, specifying model parameters, e.g., ./configs/$model/$dataset.json. `$model` includes linear/pcssr/rcssr, which corresponds to the baseline and the proposed model. 23 | 24 | Or simply run bash file `sh run.sh` to run all experiments simultaneously. 25 | 26 | To train models by finetuning pretrained backbones, like experiments for imagenet-1k, run command: 27 | ``` 28 | python main.py --gpu 0 --ds ./exps/imagenet/vs_inaturalist.json --config ./configs/rcssr/imagenet.json --save imagenet1k_rcssr --method cssr_ft 29 | ``` 30 | 31 | ## 2. Evaluation 32 | 33 | Add `--test` on training commands to restore and evaluate a pretrained model on specified data setup, e.g., 34 | ``` 35 | python main.py --gpu 0 --ds {DATASET} --config {MODEL} --save {SAVING_NAME} --method cssr --test 36 | ``` 37 | 38 | With models trained by `sh run.sh`, script `collect_metrics.py` helps collect and present experimental results: `python collect_metrics.py` 39 | 40 | 41 | ## 3. Citation 42 | ``` 43 | @ARTICLE{9864101, 44 | author={Huang, Hongzhi and Wang, Yu and Hu, Qinghua and Cheng, Ming-Ming}, 45 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 46 | title={Class-Specific Semantic Reconstruction for Open Set Recognition}, 47 | year={2022}, 48 | doi={10.1109/TPAMI.2022.3200384} 49 | } 50 | ``` -------------------------------------------------------------------------------- /collect_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | methods = ['linear','pcssr','rcssr'] 4 | datasets = ['cifar10','svhn','tinyimagenet'] 5 | splits = ['a','b','c','d', 'e'] 6 | 7 | def get_metric(file,metric,is_last = True): 8 | with open(file,'r') as f: 9 | hist = json.load(f) 10 | # last metric 11 | if is_last: 12 | res = hist[-1] 13 | for m in metric.split('.'): 14 | if m not in res: 15 | res = -1 16 | break 17 | res = res[m] 18 | # best metric 19 | else: 20 | res = 0 21 | for epoch in hist: 22 | for m in metric.split('.'): 23 | if not m in epoch: 24 | epoch = -1 25 | break 26 | epoch = epoch[m] 27 | res = max(res,epoch) 28 | return res 29 | 30 | def generate_tables(use_last = True): 31 | for ds in datasets: 32 | print("\nDataset",ds,"Last Epoch" if use_last else "Best Epoch") 33 | print('method','average',*splits,sep='\t') 34 | for mth in methods: 35 | metrics = [get_metric(f'./save/{mth}_{ds}_{s}/hist.json','open_detection.auroc',use_last) for s in splits] 36 | # print(metrics) 37 | avg = sum(metrics) / len(metrics) 38 | metrics = [avg] + metrics 39 | metrics = list(map(lambda x:'%.04f' % x,metrics)) 40 | print(mth,*metrics,sep='\t') 41 | 42 | generate_tables(True) 43 | generate_tables(False) -------------------------------------------------------------------------------- /configs/basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "learn_rate" : 0.4, 3 | "epoch_num" : 200, 4 | "lr_decay" : 0.1, 5 | "milestones" : [120,190], 6 | "lr_schedule" : "multi_step", 7 | 8 | "optimizer" : "sgd", 9 | "warmup_epoch": 0, 10 | "batch_size" : 128, 11 | 12 | "classification_hidden_dim" : 256, 13 | "backbone" : "resnet18", 14 | 15 | "rot_augmentation" : "weak", 16 | "cat_augmentation" : "strong", 17 | 18 | "strong_option" : "CUST", 19 | "cate_rotaug_strategy" : "none", 20 | 21 | "cust_aug_crop_withresize" : false, 22 | "customize_augment_pool" : { 23 | "AutoContrast" : false, 24 | "BrightnessDark" : false, 25 | "BrightnessLight": false, 26 | "BrightnessOverall" : true, 27 | "Color" : true, 28 | "ContrastLow": false, 29 | "ContrastHigh": false, 30 | "ContrastOverall": false, 31 | "Equalize": true, 32 | "Identity": true, 33 | "Posterize": true, 34 | "Rotate": true, 35 | "Sharpness": false, 36 | "SharpnessLarge": true, 37 | "Shear": true, 38 | "Solarize": true 39 | }, 40 | "customize_augment_postprocess" : "cutout", 41 | "manual_contrast" : true, 42 | 43 | "category_model" : { 44 | "model" : "proj", 45 | "ae_hidden" : [], 46 | "ae_latent" : 64, 47 | "gamma" : 0.1, 48 | "error_measure" : "L1" 49 | }, 50 | 51 | "cat_weight" : 1, 52 | 53 | "arch_type" : "softmax_avg", 54 | 55 | "abs_logits" : false, 56 | "aug_sublabel" : 1, 57 | "score" : "R[0]", 58 | "energy_T" : 100, 59 | 60 | "use_mpn_pooling" : false, 61 | "mpn_group" : 16, 62 | "integrate_score" : "S[0]", 63 | "enable_gram" : true, 64 | "extra_augmentation" : "none", 65 | "oemth" : "none" 66 | } -------------------------------------------------------------------------------- /configs/linear.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/basic.json", 3 | "projection_dim" : -1, 4 | 5 | "category_model" : { 6 | "model" : "linear", 7 | "ae_hidden" : [], 8 | "ae_latent" : 64, 9 | "projection_dim" : -1, 10 | "gamma" : 1, 11 | "error_measure" : "L1" 12 | }, 13 | "arch_type" : "avg_softmax", 14 | 15 | "score" : "R[4]" 16 | } -------------------------------------------------------------------------------- /configs/linear/cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/linear.json", 3 | "backbone" : "wideresnet40-4" 4 | } -------------------------------------------------------------------------------- /configs/linear/imagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/linear.json", 3 | "epoch_num" : 4, 4 | "milestones" : [3], 5 | "warmup_epoch":1, 6 | "cat_augmentation" : "simple", 7 | "backbone" : "prt_pytorchr18", 8 | 9 | "category_model" : { 10 | "model" : "linear", 11 | "ae_hidden" : [], 12 | "ae_latent" : 32, 13 | "projection_dim" : -1, 14 | "gamma" : 1, 15 | "simmeasure" : "L1", 16 | "rc_act" : true 17 | } 18 | } -------------------------------------------------------------------------------- /configs/linear/svhn.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/linear.json", 3 | "backbone" : "wideresnet40-4" 4 | } -------------------------------------------------------------------------------- /configs/linear/tinyimagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/linear.json", 3 | "epoch_num" : 235, 4 | "milestones" : [150,225], 5 | "warmup_epoch": 1, 6 | "backbone" : "resnet18" 7 | } -------------------------------------------------------------------------------- /configs/pcssr.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/basic.json", 3 | "projection_dim" : -1, 4 | 5 | "category_model" : { 6 | "model" : "pcssr", 7 | "ae_hidden" : [], 8 | "ae_latent" : 64, 9 | "projection_dim" : -1, 10 | "gamma" : 0.1, 11 | "error_measure" : "L1" 12 | }, 13 | 14 | "score" : "R[0]/R[1]/R[1]", 15 | "integrate_score": "S[0]+S[1]+S[2]" 16 | } -------------------------------------------------------------------------------- /configs/pcssr/cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/pcssr.json", 3 | "backbone" : "wideresnet40-4" 4 | } -------------------------------------------------------------------------------- /configs/pcssr/imagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/pcssr.json", 3 | "epoch_num" : 4, 4 | "milestones" : [3], 5 | "warmup_epoch":1, 6 | "cat_augmentation" : "simple", 7 | "backbone" : "prt_pytorchr18", 8 | "score" : "R[0]/R[2]", 9 | 10 | "category_model" : { 11 | "model" : "pcssr", 12 | "ae_hidden" : [], 13 | "ae_latent" : 32, 14 | "projection_dim" : 512, 15 | "gamma" : 0.1, 16 | "simmeasure" : "L1", 17 | "rc_act" : false 18 | } 19 | } -------------------------------------------------------------------------------- /configs/pcssr/svhn.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/pcssr.json", 3 | "backbone" : "wideresnet40-4" 4 | } -------------------------------------------------------------------------------- /configs/pcssr/tinyimagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/pcssr.json", 3 | "epoch_num" : 235, 4 | "milestones" : [150,225], 5 | "warmup_epoch": 1, 6 | "backbone" : "resnet18" 7 | } -------------------------------------------------------------------------------- /configs/rcssr.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/basic.json", 3 | "projection_dim" : -1, 4 | 5 | "category_model" : { 6 | "model" : "rcssr", 7 | "ae_hidden" : [], 8 | "ae_latent" : 64, 9 | "projection_dim" : -1, 10 | "gamma" : 0.1, 11 | "error_measure" : "L1" 12 | }, 13 | 14 | "score" : "R[0]", 15 | "integrate_score": "S[0]+S[1]+S[2]" 16 | } -------------------------------------------------------------------------------- /configs/rcssr/cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/rcssr.json", 3 | "backbone" : "wideresnet40-4", 4 | 5 | "category_model" : { 6 | "model" : "rcssr", 7 | "ae_hidden" : [], 8 | "ae_latent" : 64, 9 | "projection_dim" : -1, 10 | "gamma" : 0.1, 11 | "simmeasure" : "L1", 12 | "rc_act" : false 13 | } 14 | } -------------------------------------------------------------------------------- /configs/rcssr/imagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/rcssr.json", 3 | "epoch_num" : 4, 4 | "milestones" : [3], 5 | "warmup_epoch":1, 6 | "cat_augmentation" : "simple", 7 | "backbone" : "prt_pytorchr18", 8 | "category_model" : { 9 | "model" : "rcssr", 10 | "ae_hidden" : [], 11 | "ae_latent" : 16, 12 | "projection_dim" : -1, 13 | "gamma" : 0.1, 14 | "simmeasure" : "L1", 15 | "rc_act" : false 16 | } 17 | } -------------------------------------------------------------------------------- /configs/rcssr/svhn.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/rcssr.json", 3 | "backbone" : "wideresnet40-4" 4 | } -------------------------------------------------------------------------------- /configs/rcssr/tinyimagenet.json: -------------------------------------------------------------------------------- 1 | { 2 | "inherit" : "./configs/rcssr.json", 3 | "epoch_num" : 235, 4 | "milestones" : [150,225], 5 | "warmup_epoch": 1, 6 | "backbone" : "resnet18" 7 | } -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import Dataset 6 | import torchvision.datasets as datasets 7 | import methods.util as util 8 | 9 | UNKNOWN_LABEL = -1 10 | 11 | imagenet_mean = [0.485, 0.456, 0.406] 12 | imagenet_std = [0.229, 0.224, 0.225] 13 | 14 | cifar_mean = (0.5,0.5,0.5) 15 | cifar_std = (0.25,0.25,0.25) 16 | 17 | tiny_mean = (0.5,0.5,0.5) 18 | tiny_std = (0.25,0.25,0.25) 19 | 20 | svhn_mean = (0.5,0.5,0.5) 21 | svhn_std = (0.25,0.25,0.25) 22 | 23 | workers = 6 24 | test_workers = 6 25 | use_droplast = True 26 | require_org_image = True 27 | no_test_transform = False 28 | 29 | DATA_PATH = '/HOME/scz1838/run/data' 30 | TINYIMAGENET_PATH = DATA_PATH + '/tiny-imagenet-200/' 31 | LARGE_OOD_PATH = '/HOME/scz1838/run/largeoodds' 32 | IMAGENET_PATH = '/data/public/imagenet2012' 33 | 34 | 35 | class tinyimagenet_data(Dataset): 36 | 37 | def __init__(self, _type, transform): 38 | if _type == 'train': 39 | self.ds = datasets.ImageFolder(f'{TINYIMAGENET_PATH}/train/', transform=transform) 40 | self.labels = [self.ds.samples[i][1] for i in range(len(self.ds))] 41 | elif _type == 'test': 42 | tmp_ds = datasets.ImageFolder(f'{TINYIMAGENET_PATH}/train/', transform=transform) 43 | cls2idx = tmp_ds.class_to_idx 44 | self.ds = datasets.ImageFolder(f'{TINYIMAGENET_PATH}/val/', transform=transform) 45 | with open(f'{TINYIMAGENET_PATH}/val/val_annotations.txt','r') as f: 46 | file2cls = {} 47 | for line in f.readlines(): 48 | line = line.strip().split('\t') 49 | file2cls[line[0]] = line[1] 50 | self.labels = [] 51 | for i in range(len(self.ds)): 52 | filename = self.ds.samples[i][0].split('/')[-1] 53 | self.labels.append(cls2idx[file2cls[filename]]) 54 | # print("test labels",self.labels) 55 | 56 | def __len__(self): 57 | return len(self.ds) 58 | 59 | def __getitem__(self,idx): 60 | return self.ds[idx][0],self.labels[idx] 61 | 62 | class Imagenet1000(Dataset): 63 | 64 | lab_cvt = None 65 | 66 | def __init__(self,istrain, transform): 67 | 68 | set = "train" if istrain else "val" 69 | self.ds = datasets.ImageFolder(f'{IMAGENET_PATH}/{set}/', transform=transform) 70 | self.labels = [self.ds.samples[i][1] for i in range(len(self.ds))] 71 | 72 | def __len__(self): 73 | return len(self.ds) 74 | 75 | def __getitem__(self,idx): 76 | return self.ds[idx] 77 | 78 | class LargeOODDataset(Dataset): 79 | 80 | def __init__(self,ds_name,transform) -> None: 81 | super().__init__() 82 | data_path = f'{LARGE_OOD_PATH}/{ds_name}/' 83 | self.ds = datasets.ImageFolder(data_path, transform=transform) 84 | self.labels = [-1] * len(self.ds) 85 | 86 | def __len__(self,): 87 | return len(self.ds) 88 | 89 | def __getitem__(self, index): 90 | return self.ds[index] 91 | 92 | 93 | class PartialDataset(Dataset): 94 | 95 | def __init__(self,knwon_ds,lab_keep = None,lab_cvt = None) -> None: 96 | super().__init__() 97 | self.known_ds = knwon_ds 98 | labels = knwon_ds.labels 99 | if lab_cvt is None: # by default, identity mapping 100 | lab_cvt = [i for i in range(1999)] 101 | if lab_keep is None: # by default, keep positive labels 102 | lab_keep = [x for x in lab_cvt if x > -1] 103 | keep = {x for x in lab_keep} 104 | self.sample_indexes = [i for i in range(len(knwon_ds)) if lab_cvt[labels[i]] in keep] 105 | self.labels = [lab_cvt[labels[i]] for i in range(len(knwon_ds)) if lab_cvt[labels[i]] in keep] 106 | self.labrefl = lab_cvt 107 | 108 | def __len__(self) -> int: 109 | return len(self.sample_indexes) 110 | 111 | def __getitem__(self, index: int): 112 | inp,lb = self.known_ds[self.sample_indexes[index]] 113 | return inp,self.labrefl[lb],index 114 | 115 | class UnionDataset(Dataset): 116 | 117 | def __init__(self,ds_list) -> None: 118 | super().__init__() 119 | self.dslist = ds_list 120 | self.totallen = sum([len(ds) for ds in ds_list]) 121 | self.labels = [] 122 | for x in ds_list: 123 | self.labels += x.labels 124 | 125 | def __len__(self) -> int: 126 | return self.totallen 127 | 128 | def __getitem__(self, index: int): 129 | orgindex = index 130 | for ds in self.dslist: 131 | if index < len(ds): 132 | a,b,c = ds[index] 133 | return a,b,orgindex 134 | index -= len(ds) 135 | return None 136 | 137 | 138 | def gen_transform(mean,std,crop = False,toPIL = False,imgsize = 32,testmode = False): 139 | t = [] 140 | if toPIL: 141 | t.append(transforms.ToPILImage()) 142 | if not testmode: 143 | return transforms.Compose(t) 144 | if crop: 145 | if imgsize > 200: 146 | t += [transforms.Resize(256),transforms.CenterCrop(imgsize)] 147 | else: 148 | t.append(transforms.CenterCrop(imgsize)) 149 | # print(t) 150 | return transforms.Compose(t + [transforms.ToTensor(), transforms.Normalize(mean, std)]) 151 | 152 | 153 | def gen_cifar_transform(crop = False, toPIL = False,testmode = False): 154 | return gen_transform(cifar_mean,cifar_std,crop,toPIL=toPIL,imgsize=32,testmode = testmode) 155 | 156 | def gen_tinyimagenet_transform(crop = False,testmode = False): 157 | return gen_transform(tiny_mean,tiny_std,crop,False,imgsize=64,testmode = testmode) 158 | 159 | def gen_imagenet_transform(crop = False, testmode = False): 160 | return gen_transform(imagenet_mean,imagenet_std,crop,False,imgsize=224,testmode = testmode) 161 | 162 | def gen_svhn_transform(crop = False,toPIL = False,testmode = False): 163 | return gen_transform(svhn_mean,svhn_std,crop,toPIL=toPIL,imgsize=32,testmode = testmode) 164 | 165 | def get_cifar10(settype): 166 | if settype == 'train': 167 | trans = gen_cifar_transform() 168 | ds = torchvision.datasets.CIFAR10(root=DATA_PATH, train=True, download=True, transform=trans) 169 | else: 170 | ds = torchvision.datasets.CIFAR10(root=DATA_PATH, train=False, download=True, transform=gen_cifar_transform(testmode=True)) 171 | ds.labels = ds.targets 172 | return ds 173 | 174 | def get_cifar100(settype): 175 | if settype == 'train': 176 | trans = gen_cifar_transform() 177 | ds = torchvision.datasets.CIFAR100(root=DATA_PATH, train=True, download=True, transform=trans) 178 | else: 179 | ds = torchvision.datasets.CIFAR100(root=DATA_PATH, train=False, download=True, transform=gen_cifar_transform(testmode=True)) 180 | ds.labels = ds.targets 181 | return ds 182 | 183 | def get_svhn(settype): 184 | if settype == 'train': 185 | trans = gen_svhn_transform() 186 | ds = torchvision.datasets.SVHN(root=DATA_PATH, split='train', download=True, transform=trans) 187 | else : 188 | ds = torchvision.datasets.SVHN(root=DATA_PATH, split='test', download=True, transform=gen_svhn_transform(testmode=True)) 189 | return ds 190 | 191 | def get_tinyimagenet(settype): 192 | if settype == 'train': 193 | trans = gen_tinyimagenet_transform() 194 | ds = tinyimagenet_data('train',trans) 195 | else: 196 | ds = tinyimagenet_data('test',gen_tinyimagenet_transform(testmode=True)) 197 | return ds 198 | 199 | def get_imagenet1000(settype): 200 | if settype == 'train': 201 | trans = gen_imagenet_transform() 202 | ds = Imagenet1000(True,trans) 203 | else: 204 | ds = Imagenet1000(False,gen_imagenet_transform(crop = True, testmode=True)) 205 | return ds 206 | 207 | def get_ood_inaturalist(settype): 208 | if settype == 'train': 209 | raise Exception("OOD iNaturalist cannot be used as train set.") 210 | else: 211 | return LargeOODDataset('iNaturalist',gen_imagenet_transform(crop = True, testmode=True)) 212 | 213 | ds_dict = { 214 | "cifarova" : get_cifar10, 215 | "cifar10" : get_cifar10, 216 | "cifar100" : get_cifar100, 217 | "svhn" : get_svhn, 218 | "tinyimagenet" : get_tinyimagenet, 219 | "imagenet" : get_imagenet1000, 220 | 'oodinaturalist' : get_ood_inaturalist, 221 | } 222 | 223 | cache_base_ds = { 224 | 225 | } 226 | 227 | def get_ds_with_name(settype,ds_name): 228 | global cache_base_ds 229 | key = str(settype) + ds_name 230 | if key not in cache_base_ds.keys(): 231 | cache_base_ds[key] = ds_dict[ds_name](settype) 232 | return cache_base_ds[key] 233 | 234 | def get_partialds_with_name(settype,ds_name,label_cvt,label_keep): 235 | ds = get_ds_with_name(settype,ds_name) 236 | return PartialDataset(ds,label_keep,label_cvt) 237 | 238 | # setting list [[ds_name, sample partition list, label convertion list],...] 239 | def get_combined_dataset(settype,setting_list): 240 | ds_list = [] 241 | for setting in setting_list: 242 | ds = get_partialds_with_name(settype,setting['dataset'],setting['convert_class'],setting['keep_class']) 243 | if ds.__len__() > 0: 244 | ds_list.append(ds) 245 | return UnionDataset(ds_list) if len(ds_list) > 0 else None 246 | 247 | def get_combined_dataloaders(args,settings): 248 | istrain_mode = True 249 | print("Load with train mode :",istrain_mode) 250 | train_labeled = get_combined_dataset('train',settings['train']) 251 | test = get_combined_dataset('test',settings['test']) 252 | return torch.utils.data.DataLoader(train_labeled, batch_size=args.bs, shuffle=istrain_mode, num_workers=workers,pin_memory=True,drop_last = use_droplast) if train_labeled is not None else None,\ 253 | torch.utils.data.DataLoader(test, batch_size=args.bs, shuffle=False, num_workers=test_workers,pin_memory=args.gpu != 'cpu') if test is not None else None 254 | 255 | ds_classnum_dict = { 256 | 'cifar10' : 6, 257 | 'svhn' : 6, 258 | 'tinyimagenet' : 20, 259 | "imagenet" : 1000, 260 | } 261 | 262 | imgsize_dict = { 263 | 'cifar10' : 32, 264 | 'svhn' : 32, 265 | 'tinyimagenet' : 64, 266 | "imagenet" : 224, 267 | } 268 | 269 | def load_partitioned_dataset(args,ds): 270 | with open(ds,'r') as f: 271 | settings = json.load(f) 272 | util.img_size = imgsize_dict[settings['name']] 273 | a,b = get_combined_dataloaders(args,settings) 274 | return a,b,ds_classnum_dict[settings['name']] 275 | 276 | -------------------------------------------------------------------------------- /exps/cifar10/spl_a.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "cifar10", 3 | "train" : [ 4 | {"dataset":"cifar10","convert_class":[-1,-1,0,1,2,3,4,5,-1,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"cifar10","convert_class":[-1,-1,0,1,2,3,4,5,-1,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/cifar10/spl_b.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "cifar10", 3 | "train" : [ 4 | {"dataset":"cifar10","convert_class":[-1,0,1,-1,2,-1,3,4,5,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"cifar10","convert_class":[-1,0,1,-1,2,-1,3,4,5,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/cifar10/spl_c.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "cifar10", 3 | "train" : [ 4 | {"dataset":"cifar10","convert_class":[0,-1,-1,1,2,3,4,-1,5,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"cifar10","convert_class":[0,-1,-1,1,2,3,4,-1,5,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/cifar10/spl_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "cifar10", 3 | "train" : [ 4 | {"dataset":"cifar10","convert_class":[0,1,-1,-1,2,-1,3,-1,4,5],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"cifar10","convert_class":[0,1,-1,-1,2,-1,3,-1,4,5],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/cifar10/spl_e.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "cifar10", 3 | "train" : [ 4 | {"dataset":"cifar10","convert_class":[0,-1,1,2,-1,3,4,5,-1,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"cifar10","convert_class":[0,-1,1,2,-1,3,4,5,-1,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/imagenet/vs_inaturalist.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "imagenet", 3 | "train": [{ 4 | "dataset": "imagenet", 5 | "convert_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999] 7 | }], 8 | "test": [{ 9 | "dataset": "imagenet", 10 | "convert_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], 11 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999] 12 | }, { 13 | "dataset": "oodinaturalist", 14 | "convert_class": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], 15 | "keep_class": [-1] 16 | }] 17 | } -------------------------------------------------------------------------------- /exps/svhn/spl_a.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "svhn", 3 | "train" : [ 4 | {"dataset":"svhn","convert_class":[-1,-1,0,1,2,3,4,5,-1,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"svhn","convert_class":[-1,-1,0,1,2,3,4,5,-1,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/svhn/spl_b.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "svhn", 3 | "train" : [ 4 | {"dataset":"svhn","convert_class":[-1,0,1,-1,2,-1,3,4,5,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"svhn","convert_class":[-1,0,1,-1,2,-1,3,4,5,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/svhn/spl_c.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "svhn", 3 | "train" : [ 4 | {"dataset":"svhn","convert_class":[0,-1,-1,1,2,3,4,-1,5,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"svhn","convert_class":[0,-1,-1,1,2,3,4,-1,5,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/svhn/spl_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "svhn", 3 | "train" : [ 4 | {"dataset":"svhn","convert_class":[0,1,-1,-1,2,-1,3,-1,4,5],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"svhn","convert_class":[0,1,-1,-1,2,-1,3,-1,4,5],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/svhn/spl_e.json: -------------------------------------------------------------------------------- 1 | { 2 | "name" : "svhn", 3 | "train" : [ 4 | {"dataset":"svhn","convert_class":[0,-1,1,2,-1,3,4,5,-1,-1],"keep_class":[0,1,2,3,4,5]} 5 | ], 6 | "test" : [ 7 | {"dataset":"svhn","convert_class":[0,-1,1,2,-1,3,4,5,-1,-1],"keep_class":[-1,0,1,2,3,4,5]} 8 | ] 9 | } -------------------------------------------------------------------------------- /exps/tinyimagenet/spl_a.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tinyimagenet", 3 | "train": [{ 4 | "dataset": "tinyimagenet", 5 | "convert_class": [0, -1, -1, -1, -1, -1, -1, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 3, -1, 4, -1, -1, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1, -1, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 11, -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, 14, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, 17, -1, 18, -1, -1, -1, -1, -1, -1, 19, -1, -1, -1, -1], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 7 | }], 8 | "test": [{ 9 | "dataset": "tinyimagenet", 10 | "convert_class": [0, -1, -1, -1, -1, -1, -1, 1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 3, -1, 4, -1, -1, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1, -1, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 11, -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, 14, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, 17, -1, 18, -1, -1, -1, -1, -1, -1, 19, -1, -1, -1, -1], 11 | "keep_class": [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 12 | }] 13 | } -------------------------------------------------------------------------------- /exps/tinyimagenet/spl_b.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tinyimagenet", 3 | "train": [{ 4 | "dataset": "tinyimagenet", 5 | "convert_class": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, 3, 4, -1, -1, -1, -1, -1, -1, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, 7, -1, -1, -1, -1, 8, -1, -1, -1, -1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, 11, -1, -1, 12, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, 19, -1, -1, -1, -1, -1, -1], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 7 | }], 8 | "test": [{ 9 | "dataset": "tinyimagenet", 10 | "convert_class": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, 3, 4, -1, -1, -1, -1, -1, -1, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, 7, -1, -1, -1, -1, 8, -1, -1, -1, -1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, 11, -1, -1, 12, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, 19, -1, -1, -1, -1, -1, -1], 11 | "keep_class": [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 12 | }] 13 | } -------------------------------------------------------------------------------- /exps/tinyimagenet/spl_c.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tinyimagenet", 3 | "train": [{ 4 | "dataset": "tinyimagenet", 5 | "convert_class": [-1, -1, -1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, 3, -1, -1, -1, 4, -1, -1, -1, -1, -1, 5, -1, 6, 7, -1, 8, -1, 9, -1, 10, 11, -1, -1, -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 7 | }], 8 | "test": [{ 9 | "dataset": "tinyimagenet", 10 | "convert_class": [-1, -1, -1, -1, -1, -1, -1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, 3, -1, -1, -1, 4, -1, -1, -1, -1, -1, 5, -1, 6, 7, -1, 8, -1, 9, -1, 10, 11, -1, -1, -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19], 11 | "keep_class": [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 12 | }] 13 | } -------------------------------------------------------------------------------- /exps/tinyimagenet/spl_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tinyimagenet", 3 | "train": [{ 4 | "dataset": "tinyimagenet", 5 | "convert_class": [-1, -1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, 7, -1, 8, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 11, -1, 12, -1, -1, 13, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 7 | }], 8 | "test": [{ 9 | "dataset": "tinyimagenet", 10 | "convert_class": [-1, -1, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 6, -1, -1, 7, -1, 8, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 11, -1, 12, -1, -1, 13, -1, -1, -1, -1, -1, 14, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, 18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19], 11 | "keep_class": [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 12 | }] 13 | } -------------------------------------------------------------------------------- /exps/tinyimagenet/spl_e.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "tinyimagenet", 3 | "train": [{ 4 | "dataset": "tinyimagenet", 5 | "convert_class": [-1, -1, 0, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 3, -1, -1, -1, -1, -1, -1, -1, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 5, -1, -1, -1, -1, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, 10, -1, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, 18, 19, -1, -1, -1, -1, -1, -1, -1, -1], 6 | "keep_class": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 7 | }], 8 | "test": [{ 9 | "dataset": "tinyimagenet", 10 | "convert_class": [-1, -1, 0, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 3, -1, -1, -1, -1, -1, -1, -1, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 5, -1, -1, -1, -1, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 9, -1, -1, 10, -1, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 14, -1, -1, 15, -1, -1, -1, -1, 16, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 17, -1, -1, -1, -1, -1, 18, 19, -1, -1, -1, -1, -1, -1, -1, -1], 11 | "keep_class": [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 12 | }] 13 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import argparse 4 | 5 | import dataset 6 | import json 7 | import metrics 8 | import methods.cssr 9 | import methods.cssr_ft 10 | 11 | from methods import * 12 | import os 13 | import sys 14 | import methods.util as util 15 | 16 | import warnings 17 | 18 | warnings.filterwarnings('ignore') 19 | 20 | 21 | def save_everything(subfix = ""): 22 | # save model 23 | if subfix == "": 24 | mth.save_model(saving_path + 'model.pth') 25 | # save training process data 26 | with open(saving_path + "hist.json",'w') as f: 27 | json.dump(history, f) 28 | 29 | def load_everything(subfix = ""): 30 | global history,best_auroc,best_acc 31 | # load the model 32 | if subfix != "" : 33 | mth.load_model(saving_path + f'model_{subfix}.pth') 34 | else: 35 | mth.load_model(saving_path + 'model.pth') 36 | # load history 37 | history = np.load(saving_path + 'hist.npy',allow_pickle=True).tolist() 38 | 39 | bac,bau = get_best_acc_auc() 40 | best_acc = bac[1] 41 | best_auroc = bau[1] 42 | 43 | def log_history(epoch,data_dict): 44 | item = { 45 | 'epoch' : epoch 46 | } 47 | item.update(data_dict) 48 | if isinstance(history,list): 49 | history.append(item) 50 | print(f"Epoch:{epoch}") 51 | for key in data_dict.keys(): 52 | print(" ",key,":",data_dict[key]) 53 | 54 | best_acc = -1 55 | best_auroc = -1 56 | last_acc = -1 57 | last_auroc = -1 58 | last_f1 = -1 59 | cwauc = -1 60 | 61 | def training_main(): 62 | tot_epoch = config['epoch_num'] 63 | global best_acc,best_auroc 64 | 65 | for epoch in range(mth.epoch,tot_epoch): 66 | sys.stdout.flush() 67 | losses = mth.train_epoch() 68 | acc = 0 69 | auroc = 0 70 | if epoch % 1 == 0: 71 | save_everything(f'ckpt{epoch}') 72 | 73 | if epoch % test_interval == test_interval - 1 : 74 | # big test with aurocs 75 | scores,thresh,pred = mth.knownpred_unknwonscore_test(test_loader) 76 | acc = evaluation.close_accuracy(pred) 77 | open_detection = evaluation.open_detection_indexes(scores,thresh) 78 | auroc = open_detection['auroc'] 79 | log_history(epoch,{ 80 | "loss" : losses, 81 | "close acc" : acc, 82 | "open_detection" : open_detection, 83 | "open_reco" : evaluation.open_reco_indexes(scores,thresh,pred) 84 | }) 85 | else: 86 | # close_pred = mth.known_prediction_test(train_labeled_loader,train_unlabeled_loader,test_loader) 87 | # acc = evaluation.close_accuracy(close_pred) 88 | log_history(epoch,{ 89 | "loss" : losses, 90 | # "close acc" : acc, 91 | }) 92 | # if epoch % 10 == 0: 93 | save_everything() 94 | if acc > best_acc: 95 | best_acc = acc 96 | save_everything("acc") 97 | if auroc > best_auroc: 98 | best_auroc = auroc 99 | save_everything("auroc") 100 | 101 | def get_best_acc_auc(): 102 | best_auc,best_acc = [0,0],[0,0] 103 | for itm in history: 104 | epoch = itm['epoch'] 105 | if 'close acc' in itm.keys(): 106 | acc = itm['close acc'] 107 | if acc > best_acc[1]: 108 | best_acc = [epoch,acc] 109 | if not 'open_detection' in itm.keys(): 110 | continue 111 | auc = itm['open_detection']['auroc'] 112 | if auc > best_auc[1]: 113 | best_auc = [epoch,auc] 114 | return best_acc,best_auc 115 | 116 | def overall_testing(): 117 | global train_loader,test_loader 118 | global last_acc,last_auroc,last_f1,cwauc,best_acc,best_auroc 119 | 120 | scores,thresh,pred = mth.knownpred_unknwonscore_test(test_loader) 121 | last_acc = evaluation.close_accuracy(pred) 122 | indexes = evaluation.open_detection_indexes(scores,thresh) 123 | last_auroc = indexes['auroc'] 124 | osr_indexes = evaluation.open_reco_indexes(scores,thresh,pred) 125 | last_f1 = osr_indexes['macro_f1'] 126 | log_history(-1,{ 127 | "close acc" : last_acc, 128 | "open_detection" :indexes, 129 | "open_reco" : osr_indexes 130 | }) 131 | print("Metrics", {\ 132 | "close acc" : last_acc, 133 | "open_detection" :indexes, 134 | "open_reco" : osr_indexes}) 135 | 136 | 137 | 138 | def update_config_keyvalues(config,update): 139 | if update == "": 140 | return config 141 | spls = update.split(",") 142 | for spl in spls: 143 | key,val = spl.split(':') 144 | key_parts = key.split('.') 145 | sconfig = config 146 | for i in range(len(key_parts) - 1): 147 | sconfig = sconfig[key_parts[i]] 148 | org = sconfig[key_parts[-1]] 149 | if isinstance(org,bool): 150 | sconfig[key_parts[-1]] = val == 'True' 151 | elif isinstance(org,int): 152 | sconfig[key_parts[-1]] = int(val) 153 | elif isinstance(org,float): 154 | sconfig[key_parts[-1]] = float(val) 155 | else: 156 | sconfig[key_parts[-1]] = val 157 | print("Updating",key,"with",val,"results in",sconfig[key_parts[-1]]) 158 | return config 159 | 160 | def update_subconfig(cfg,u): 161 | for k in u.keys(): 162 | if not k in cfg.keys() or not isinstance(cfg[k],dict): 163 | cfg[k] = u[k] 164 | else: 165 | update_subconfig(cfg[k],u[k]) 166 | 167 | def load_config(file): 168 | with open(file,"r") as f : 169 | config = json.load(f) 170 | if 'inherit' in config.keys(): 171 | inheritfile = config['inherit'] 172 | if inheritfile != 'None': 173 | parent = load_config(inheritfile) 174 | update_subconfig(parent,config) 175 | config = parent 176 | return config 177 | 178 | def set_up_gpu(args): 179 | if args.gpu != 'cpu': 180 | args.gpu = ",".join([c for c in args.gpu]) 181 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 182 | 183 | if __name__ == "__main__": 184 | import torch 185 | torch.backends.cudnn.benchmark = True 186 | 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--gpu', type=str, required=False,default="1", help='GPU number') 189 | parser.add_argument('--ds', type=str, required=False,default="None", help='dataset setting, choose file from ./exps') 190 | parser.add_argument('--config', type=str, required=False,default="None", help='model configuration, choose from ./configs') 191 | parser.add_argument('--save', type=str, required=False,default="None", help='Saving folder name') 192 | parser.add_argument('--method', type=str, required=False,default="ours", help='Methods : ' + ",".join(util.method_list.keys())) 193 | parser.add_argument('--test', action="store_true",help='Evaluation mode') 194 | parser.add_argument('--configupdate', type=str, required=False,default="", help='Update several key values in config') 195 | parser.add_argument('--test_interval', type=int, required=False,default=1, help='The frequency of model evaluation') 196 | 197 | args = parser.parse_args() 198 | 199 | test_interval = args.test_interval 200 | if not args.save.endswith("/"): 201 | args.save += "/" 202 | 203 | set_up_gpu(args) 204 | 205 | saving_path = "./save/" + args.save 206 | util.setup_dir(saving_path) 207 | 208 | if args.config != "None" : 209 | config = load_config(args.config) 210 | else: 211 | config = {} 212 | config = update_config_keyvalues(config,args.configupdate) 213 | args.bs = config['batch_size'] 214 | print('Config:',config) 215 | 216 | train_loader , test_loader ,classnum = dataset.load_partitioned_dataset(args,args.ds) 217 | mth = util.method_list[args.method](config,classnum,train_loader.dataset) 218 | 219 | history = [] 220 | evaluation = metrics.OSREvaluation(test_loader) 221 | 222 | if not args.test: 223 | print(f"TotalEpochs:{config['epoch_num']}") 224 | training_main() 225 | save_everything() 226 | overall_testing() 227 | print("Overall: LastAcc",last_acc," LastAuroc", last_auroc," BestAcc",best_acc," BestAuroc",best_auroc,"CWAuroc",cwauc) 228 | else: 229 | load_everything() 230 | overall_testing() 231 | 232 | -------------------------------------------------------------------------------- /methods/augtools.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/kekmodel/FixMatch-pytorch/blob/master/dataset/randaugment.py 3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 5 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 6 | import logging 7 | import random 8 | 9 | 10 | import numpy as np 11 | import PIL 12 | import PIL.ImageOps 13 | import PIL.ImageEnhance 14 | import PIL.ImageDraw 15 | from PIL import Image 16 | import methods.util as util 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | PARAMETER_MAX = 10 21 | 22 | 23 | def AutoContrast(img, **kwarg): 24 | return PIL.ImageOps.autocontrast(img) 25 | 26 | 27 | def Brightness(img, v, max_v, bias=0): 28 | v = _float_parameter(v, max_v) + bias 29 | return PIL.ImageEnhance.Brightness(img).enhance(v) 30 | 31 | 32 | def Color(img, v, max_v, bias=0): 33 | v = _float_parameter(v, max_v) + bias 34 | return PIL.ImageEnhance.Color(img).enhance(v) 35 | 36 | 37 | def Contrast(img, v, max_v, bias=0): 38 | v = _float_parameter(v, max_v) + bias 39 | return PIL.ImageEnhance.Contrast(img).enhance(v) 40 | 41 | 42 | def Cutout(img, v, max_v, bias=0): 43 | if v == 0: 44 | return img 45 | v = _float_parameter(v, max_v) + bias 46 | v = int(v * min(img.size)) 47 | return CutoutAbs(img, v) 48 | 49 | 50 | def CutoutAbs(img, v, **kwarg): 51 | w, h = img.size 52 | x0 = np.random.uniform(0, w) 53 | y0 = np.random.uniform(0, h) 54 | x0 = int(max(0, x0 - v / 2.)) 55 | y0 = int(max(0, y0 - v / 2.)) 56 | x1 = int(min(w, x0 + v)) 57 | y1 = int(min(h, y0 + v)) 58 | xy = (x0, y0, x1, y1) 59 | # gray 60 | color = (127, 127, 127) 61 | img = img.copy() 62 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 63 | return img 64 | 65 | 66 | def Equalize(img, **kwarg): 67 | return PIL.ImageOps.equalize(img) 68 | 69 | 70 | def Identity(img, **kwarg): 71 | return img 72 | 73 | 74 | def Invert(img, **kwarg): 75 | return PIL.ImageOps.invert(img) 76 | 77 | 78 | def Posterize(img, v, max_v, bias=0): 79 | v = _int_parameter(v, max_v) + bias 80 | return PIL.ImageOps.posterize(img, v) 81 | 82 | 83 | def Rotate(img, v, max_v, bias=0): 84 | v = _int_parameter(v, max_v) + bias 85 | if random.random() < 0.5 and max_v > 0: 86 | v = -v 87 | return img.rotate(v) 88 | 89 | 90 | def Sharpness(img, v, max_v, bias=0): 91 | v = _float_parameter(v, max_v) + bias 92 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 93 | 94 | 95 | def ShearX(img, v, max_v, bias=0): 96 | v = _float_parameter(v, max_v) + bias 97 | if random.random() < 0.5 and max_v > 0: 98 | v = -v 99 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 100 | 101 | 102 | def ShearY(img, v, max_v, bias=0): 103 | v = _float_parameter(v, max_v) + bias 104 | if random.random() < 0.5 and max_v > 0: 105 | v = -v 106 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 107 | 108 | 109 | def Solarize(img, v, max_v, bias=0): 110 | v = _int_parameter(v, max_v) + bias 111 | return PIL.ImageOps.solarize(img, 256 - v) 112 | 113 | 114 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 115 | v = _int_parameter(v, max_v) + bias 116 | if random.random() < 0.5 and max_v > 0: 117 | v = -v 118 | img_np = np.array(img).astype(np.int) 119 | img_np = img_np + v 120 | img_np = np.clip(img_np, 0, 255) 121 | img_np = img_np.astype(np.uint8) 122 | img = Image.fromarray(img_np) 123 | return PIL.ImageOps.solarize(img, threshold) 124 | 125 | 126 | def TranslateX(img, v, max_v, bias=0): 127 | v = _float_parameter(v, max_v) + bias 128 | if random.random() < 0.5 and max_v > 0: 129 | v = -v 130 | v = int(v * img.size[0]) 131 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 132 | 133 | 134 | def TranslateY(img, v, max_v, bias=0): 135 | v = _float_parameter(v, max_v) + bias 136 | if random.random() < 0.5 and max_v > 0: 137 | v = -v 138 | v = int(v * img.size[1]) 139 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 140 | 141 | 142 | def _float_parameter(v, max_v): 143 | return float(v) * max_v / PARAMETER_MAX 144 | 145 | 146 | def _int_parameter(v, max_v): 147 | return int(v * max_v / PARAMETER_MAX) 148 | 149 | 150 | def fixmatch_augment_pool(): 151 | # FixMatch paper 152 | augs = [(AutoContrast, None, None), 153 | (Brightness, 0.9, 0.05), 154 | (Color, 0.9, 0.05), 155 | (Contrast, 0.9, 0.05), 156 | (Equalize, None, None), 157 | (Identity, None, None), 158 | (Posterize, 4, 4), 159 | (Rotate, 30, 0), 160 | (Sharpness, 0.9, 0.05), 161 | (ShearX, 0.3, 0), 162 | (ShearY, 0.3, 0), 163 | (Solarize, 256, 0), 164 | (TranslateX, 0.3, 0), 165 | (TranslateY, 0.3, 0)] 166 | return augs 167 | 168 | 169 | 170 | class RandAugmentMC(object): 171 | def __init__(self, n, m, useCutout = True): 172 | assert n >= 1 173 | assert 1 <= m <= 10 174 | self.n = n 175 | self.m = m 176 | self.augment_pool = fixmatch_augment_pool() 177 | self.cutout = useCutout 178 | 179 | def __call__(self, img): 180 | ops = random.choices(self.augment_pool, k=self.n) 181 | for op, max_v, bias in ops: 182 | v = np.random.randint(1, self.m) 183 | if random.random() < 0.5: 184 | img = op(img, v=v, max_v=max_v, bias=bias) 185 | if self.cutout: 186 | img = CutoutAbs(img, int(util.img_size*0.5)) 187 | return img 188 | 189 | class CustomizeAugment(object): 190 | 191 | def __init__(self, n, m, pool,cutout_ratio = 0.25): 192 | assert n >= 1 193 | assert 1 <= m <= 10 194 | self.n = n 195 | self.m = m 196 | self.augment_pool = pool 197 | self.cutout_ratio = cutout_ratio 198 | 199 | def __call__(self, img): 200 | ops = random.choices(self.augment_pool, k=self.n) 201 | for op, max_v, bias in ops: 202 | v = np.random.randint(1, self.m) 203 | if random.random() < 0.5: 204 | img = op(img, v=v, max_v=max_v, bias=bias) 205 | # img = CutoutAbs(img, int(util.img_size * self.cutout_ratio)) 206 | return img 207 | 208 | class GaussianBlur(object): 209 | """blur a single image on CPU""" 210 | def __init__(self, kernel_size): 211 | self.radius = kernel_size // 2 212 | 213 | def __call__(self, img): 214 | act_radius = np.random.randint(0,self.radius) 215 | if act_radius > 0: 216 | img = img.filter(PIL.ImageFilter.GaussianBlur(radius=act_radius)) 217 | return img 218 | 219 | class CutoutTrans(object): 220 | 221 | def __init__(self,imgsize) -> None: 222 | super().__init__() 223 | self.size = int(imgsize * 0.5) 224 | 225 | def __call__(self, img): 226 | return CutoutAbs(img,self.size) 227 | 228 | class HighlyCustomizableAugment(object): 229 | 230 | def __init__(self,n,m,numcls,orgds,config) -> None: 231 | super().__init__() 232 | components = { 233 | 'AutoContrast' : [(AutoContrast, None, None)], 234 | 'BrightnessDark' : [(Brightness, 0.9, 0.05)], 235 | 'BrightnessLight' : [(Brightness, 0.9, 1.05)], 236 | 'BrightnessOverall' : [(Brightness, 1.8, 0.1)], 237 | 'Color' : [(Color, 0.9, 0.05)], 238 | 'ContrastLow' : [(Contrast, 0.9, 0.05)], 239 | 'ContrastHigh' : [(Contrast, 0.9, 1.05)], 240 | 'ContrastOverall' : [(Contrast, 1.9, 0.05)], 241 | 'Equalize' : [(Equalize, None, None)], 242 | 'Identity' : [(Identity, None, None)], 243 | 'Posterize' : [(Posterize, 4, 4)], 244 | 'Rotate' : [(Rotate, 15, 0)], 245 | 'Sharpness' : [(Sharpness, 0.9, 0.05)], 246 | 'SharpnessLarge' : [(Sharpness, 1.8, 0.05)], 247 | 'Shear' : [(ShearX, 0.3, 0),(ShearY, 0.3, 0)], 248 | 'Solarize' : [(Solarize, 256, 0)], 249 | 'Translate' : [(TranslateX, 0.3, 0),(TranslateY, 0.3, 0)] 250 | } 251 | self.orgds = orgds 252 | cfg_pool = config['customize_augment_pool'] 253 | pool = [] 254 | for k in cfg_pool.keys(): 255 | if cfg_pool[k] : 256 | pool += components[k] 257 | self.pool = pool 258 | self.postprc = { 259 | 'none' : lambda x : x, 260 | 'cutout' : self.postproc_cutout, 261 | 'mixup' : self.postproc_mixup, 262 | 'cutmix' : self.postproc_cutmix 263 | }[config['customize_augment_postprocess']] 264 | self.n = min(n,len(pool)) 265 | self.m = m 266 | self.numcls = numcls 267 | 268 | def get_onehot(self,y): 269 | r = np.zeros([self.numcls]) 270 | r[y] = 1 271 | return r 272 | 273 | def postproc_cutout(self,x): 274 | x = CutoutAbs(x, int(util.img_size*0.5)) 275 | return x 276 | 277 | def postproc_mixup(self,x,y): 278 | x1,y1 = self.orgds[random.randint(0,len(self.orgds)-1)] 279 | y1 = self.get_onehot(y1) 280 | lam = np.random.uniform(0,1) 281 | x =Image.blend(x1,x,lam) 282 | return x, y * lam + y1 * (1-lam) 283 | 284 | 285 | def postproc_cutmix(self,x,y): 286 | pass 287 | 288 | def __call__(self, img): 289 | ops = random.choices(self.pool, k=self.n) 290 | # print(ops) 291 | for op, max_v, bias in ops: 292 | v = np.random.randint(1, self.m) 293 | if random.random() < 0.5: 294 | img = op(img, v=v, max_v=max_v, bias=bias) 295 | return self.postprc(img) 296 | 297 | -------------------------------------------------------------------------------- /methods/cssr.py: -------------------------------------------------------------------------------- 1 | import methods.wideresnet as wideresnet 2 | from methods.augtools import HighlyCustomizableAugment, RandAugmentMC 3 | import methods.util as util 4 | import tqdm 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | import torch.utils.data as data 10 | import random 11 | from methods.util import AverageMeter 12 | import time 13 | from torchvision.transforms import transforms 14 | from methods.resnet import ResNet 15 | from torchvision import models as torchvision_models 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 22 | 23 | 24 | class GramRecoder(nn.Module): 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.gram_feats = [] 29 | self.collecting = False 30 | 31 | def begin_collect(self,): 32 | self.gram_feats.clear() 33 | self.collecting = True 34 | # print("begin collect") 35 | 36 | def record(self,ft): 37 | if self.collecting: 38 | self.gram_feats.append(ft) 39 | # print("record") 40 | 41 | def obtain_gram_feats(self,): 42 | tmp = self.gram_feats 43 | self.collecting = False 44 | self.gram_feats = [] 45 | # print("record") 46 | return tmp 47 | 48 | 49 | class PretrainedResNet(nn.Module): 50 | 51 | def __init__(self,rawname,pretrain_path = None) -> None: 52 | super().__init__() 53 | if pretrain_path == 'default': 54 | self.model = torchvision_models.__dict__[rawname](pretrained = True) 55 | self.output_dim = self.model.fc.weight.shape[1] 56 | self.model.fc = nn.Identity() 57 | else: 58 | self.model = torchvision_models.__dict__[rawname]() 59 | self.output_dim = self.model.fc.weight.shape[1] 60 | self.model.fc = nn.Identity() 61 | if pretrain_path is not None: 62 | sd = torch.load(pretrain_path) 63 | self.model.load_state_dict(sd,strict = True) 64 | 65 | def forward(self,x): 66 | x = self.model.conv1(x) 67 | x = self.model.bn1(x) 68 | x = self.model.relu(x) 69 | x = self.model.maxpool(x) 70 | 71 | x = self.model.layer1(x) 72 | x = self.model.layer2(x) 73 | x = self.model.layer3(x) 74 | x = self.model.layer4(x) 75 | return x 76 | 77 | 78 | class Backbone(nn.Module): 79 | 80 | def __init__(self,config,inchan): 81 | super().__init__() 82 | 83 | if config['backbone'] == 'wideresnet28-2': 84 | self.backbone = wideresnet.WideResNetBackbone(None,28,2,0,config['category_model']['projection_dim']) 85 | elif config['backbone'] == 'wideresnet40-4': 86 | self.backbone = wideresnet.WideResNetBackbone(None,40,4,0,config['category_model']['projection_dim']) 87 | elif config['backbone'] == 'wideresnet16-8': 88 | self.backbone = wideresnet.WideResNetBackbone(None,16,8,0.4,config['category_model']['projection_dim']) 89 | elif config['backbone'] == 'wideresnet28-10': 90 | self.backbone = wideresnet.WideResNetBackbone(None,28,10,0.3,config['category_model']['projection_dim']) 91 | elif config['backbone'] == 'resnet18': 92 | self.backbone = ResNet(output_dim=config['category_model']['projection_dim'],inchan = inchan) 93 | elif config['backbone'] == 'resnet18a': 94 | self.backbone = ResNet(output_dim=config['category_model']['projection_dim'],resfirststride=2,inchan = inchan) 95 | elif config['backbone'] == 'resnet18b': 96 | self.backbone = ResNet(output_dim=config['category_model']['projection_dim'],resfirststride=2,inchan = inchan) 97 | elif config['backbone'] == 'resnet34': 98 | self.backbone = ResNet(output_dim=config['category_model']['projection_dim'],num_block=[3,4,6,3],inchan=inchan) 99 | elif config['backbone'] in ['prt_r18','prt_r34','prt_r50']: 100 | self.backbone = PretrainedResNet( 101 | {'prt_r18':'resnet18','prt_r34':'resnet34','prt_r50':'resnet50'}[config['backbone']]) 102 | elif config['backbone'] in ['prt_pytorchr18','prt_pytorchr34','prt_pytorchr50']: 103 | name,path = { 104 | 'prt_pytorchr18':('resnet18','default'), 105 | 'prt_pytorchr34':('resnet34','default'), 106 | 'prt_pytorchr50':('resnet50','default') 107 | }[config['backbone']] 108 | self.backbone = PretrainedResNet(name,path) 109 | elif config['backbone'] in ['prt_dinor18','prt_dinor34','prt_dinor50']: 110 | name,path = { 111 | 'prt_dinor50':('resnet50','./model_weights/dino_resnet50_pretrain.pth') 112 | }[config['backbone']] 113 | self.backbone = PretrainedResNet(name,path) 114 | else: 115 | bkb = config['backbone'] 116 | raise Exception(f'Backbone \"{bkb}\" is not defined.') 117 | 118 | # types : ae_softmax_avg , ae_avg_softmax , avg_ae_softmax 119 | self.output_dim = self.backbone.output_dim 120 | # self.classifier = CRFClassifier(self.backbone.output_dim,numclss,config) 121 | 122 | def forward(self,x): 123 | x = self.backbone(x) 124 | # latent , global prob , logits 125 | return x 126 | 127 | 128 | class LinearClassifier(nn.Module): 129 | 130 | def __init__(self,inchannels,num_class, config): 131 | super().__init__() 132 | self.gamma = config['gamma'] 133 | self.cls = nn.Conv2d(inchannels, num_class , 1,padding= 0, bias=False) 134 | 135 | def forward(self,x): 136 | x = self.cls(x) 137 | return x * self.gamma 138 | 139 | 140 | def sim_conv_layer(input_channel,output_channel,kernel_size=1,padding =0,use_activation = True): 141 | if use_activation : 142 | res = nn.Sequential( 143 | nn.Conv2d(input_channel, output_channel, kernel_size,padding= padding, bias=False), 144 | nn.Tanh()) 145 | else: 146 | res = nn.Conv2d(input_channel, output_channel, kernel_size,padding= padding, bias=False) 147 | return res 148 | 149 | 150 | class AutoEncoder(nn.Module): 151 | 152 | def __init__(self,inchannel,hidden_layers,latent_chan): 153 | super().__init__() 154 | layer_block = sim_conv_layer 155 | self.latent_size = latent_chan 156 | if latent_chan > 0: 157 | self.encode_convs = [] 158 | self.decode_convs = [] 159 | for i in range(len(hidden_layers)): 160 | h = hidden_layers[i] 161 | ecv = layer_block(inchannel,h,) 162 | dcv = layer_block(h,inchannel,use_activation = i != 0) 163 | inchannel = h 164 | self.encode_convs.append(ecv) 165 | self.decode_convs.append(dcv) 166 | self.encode_convs = nn.ModuleList(self.encode_convs) 167 | self.decode_convs.reverse() 168 | self.decode_convs = nn.ModuleList(self.decode_convs) 169 | self.latent_conv = layer_block(inchannel,latent_chan) 170 | self.latent_deconv = layer_block(latent_chan,inchannel,use_activation = (len(hidden_layers) > 0)) 171 | else: 172 | self.center = nn.Parameter(torch.rand([inchannel,1,1]),True) 173 | 174 | def forward(self,x): 175 | if self.latent_size > 0: 176 | output = x 177 | for cv in self.encode_convs: 178 | output = cv(output) 179 | latent = self.latent_conv(output) 180 | output = self.latent_deconv(latent) 181 | for cv in self.decode_convs: 182 | output = cv(output) 183 | return output,latent 184 | else: 185 | return self.center,self.center 186 | 187 | 188 | class CSSRClassifier(nn.Module): 189 | 190 | def __init__(self,inchannels,num_class, config): 191 | super().__init__() 192 | ae_hidden = config['ae_hidden'] 193 | ae_latent = config['ae_latent'] 194 | self.class_aes = [] 195 | for i in range(num_class): 196 | ae = AutoEncoder(inchannels,ae_hidden,ae_latent) 197 | self.class_aes.append(ae) 198 | self.class_aes = nn.ModuleList(self.class_aes) 199 | # self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 200 | self.useL1 = config['error_measure'] == 'L1' 201 | 202 | self.reduction = -1 if config['model'] == 'pcssr' else 1 203 | self.reduction *= config['gamma'] 204 | 205 | 206 | def ae_error(self,rc,x): 207 | if self.useL1: 208 | # return torch.sum(torch.abs(rc-x) * self.reduction,dim=1,keepdim=True) 209 | return torch.norm(rc - x,p = 1,dim = 1,keepdim=True) * self.reduction 210 | else: 211 | return torch.norm(rc - x,p = 2,dim = 1,keepdim=True) ** 2 * self.reduction 212 | 213 | clip_len = 100 214 | 215 | def forward(self,x): 216 | cls_ers = [] 217 | for i in range(len(self.class_aes)): 218 | rc,lt = self.class_aes[i](x) 219 | cls_er = self.ae_error(rc,x) 220 | if CSSRClassifier.clip_len > 0: 221 | cls_er = torch.clamp(cls_er,-CSSRClassifier.clip_len,CSSRClassifier.clip_len) 222 | cls_ers.append(cls_er) 223 | logits = torch.cat(cls_ers,dim=1) 224 | return logits 225 | 226 | def G_p(ob, p): 227 | temp = ob.detach() 228 | temp = temp**p 229 | temp = temp.reshape(temp.shape[0],temp.shape[1],-1) 230 | temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1))))# 231 | temp = temp.reshape([temp.shape[0],-1])#.sum(dim=2) 232 | temp = (temp.sign()*torch.abs(temp)**(1/p)).reshape(temp.shape[0],-1) 233 | 234 | return temp 235 | 236 | 237 | def G_p_pro(ob, p = 8): 238 | temp = ob.detach() 239 | temp = temp**p 240 | temp = temp.reshape(temp.shape[0],temp.shape[1],-1) 241 | temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1))))# 242 | # temp = temp.reshape([temp.shape[0],-1])#.sum(dim=2) 243 | temp = (temp.sign()*torch.abs(temp)**(1/p))#.reshape(temp.shape[0],ob.shape[1],ob.shape[1]) 244 | 245 | return temp 246 | 247 | def G_p_inf(ob,p = 1): 248 | temp = ob.detach() 249 | temp = temp**p 250 | # print(temp.shape) 251 | temp = temp.reshape([temp.shape[0],temp.shape[1],-1]).transpose(dim0=2,dim1=1).reshape([-1,temp.shape[1],1]) 252 | # print(temp.shape) 253 | temp = ((torch.matmul(temp,temp.transpose(dim0=2,dim1=1))))# 254 | temp = (temp.sign()*torch.abs(temp)**(1/p)) 255 | # print(temp.shape) 256 | return temp.reshape(ob.shape[0],ob.shape[2],ob.shape[3],ob.shape[1],ob.shape[1]) 257 | 258 | # import methods.pooling.MPNConv as MPN 259 | 260 | class BackboneAndClassifier(nn.Module): 261 | 262 | def __init__(self,num_classes,config): 263 | super().__init__() 264 | clsblock = {'linear':LinearClassifier,'pcssr':CSSRClassifier,'rcssr' : CSSRClassifier} 265 | self.backbone = Backbone(config,3) 266 | cat_config = config['category_model'] 267 | self.cat_cls = clsblock[cat_config['model']](self.backbone.output_dim,num_classes,cat_config) 268 | 269 | def forward(self,x,feature_only = False): 270 | x = self.backbone(x) 271 | if feature_only: 272 | return x 273 | return x, self.cat_cls(x) 274 | 275 | 276 | class CSSRModel(nn.Module): 277 | 278 | def __init__(self,num_classes,config,crt): 279 | super().__init__() 280 | self.crt = crt 281 | 282 | # ------ New Arch 283 | self.backbone_cs = BackboneAndClassifier(num_classes,config) 284 | 285 | self.config = config 286 | self.mins = {i : [] for i in range(num_classes)} 287 | self.maxs = {i : [] for i in range(num_classes)} 288 | self.num_classes = num_classes 289 | 290 | self.avg_feature = [[0,0] for i in range(num_classes)] 291 | self.avg_gram = [[[0,0] for i in range(num_classes)] for i in self.powers] 292 | self.enable_gram = config['enable_gram'] 293 | 294 | def update_minmax(self,feat_list,power = [],ypred = None): 295 | # feat_list = self.gram_feature_list(batch) 296 | for pr in range(self.num_classes): 297 | cond = ypred == pr 298 | if not cond.any(): 299 | continue 300 | for L,feat_L in enumerate(feat_list): 301 | if L==len(self.mins[pr]): 302 | self.mins[pr].append([None]*len(power)) 303 | self.maxs[pr].append([None]*len(power)) 304 | 305 | for p,P in enumerate(power): 306 | g_p = G_p(feat_L[cond],P) 307 | 308 | current_min = g_p.min(dim=0,keepdim=True)[0] 309 | current_max = g_p.max(dim=0,keepdim=True)[0] 310 | 311 | if self.mins[pr][L][p] is None: 312 | self.mins[pr][L][p] = current_min 313 | self.maxs[pr][L][p] = current_max 314 | else: 315 | self.mins[pr][L][p] = torch.min(current_min,self.mins[pr][L][p]) 316 | self.maxs[pr][L][p] = torch.max(current_max,self.maxs[pr][L][p]) 317 | 318 | def get_deviations(self,feat_list,power,ypred): 319 | batch_deviations = None 320 | for pr in range(self.num_classes): 321 | mins,maxs = self.mins[pr],self.maxs[pr] 322 | cls_batch_deviations = [] 323 | cond = ypred==pr 324 | if not cond.any(): 325 | continue 326 | for L,feat_L in enumerate(feat_list): 327 | dev = 0 328 | for p,P in enumerate(power): 329 | g_p = G_p(feat_L[cond],P) 330 | # print(L,len(mins)) 331 | # print(p,len(mins[L])) 332 | dev += (F.relu(mins[L][p]-g_p)/torch.abs(mins[L][p]+10**-6)).sum(dim=1,keepdim=True) 333 | dev += (F.relu(g_p-maxs[L][p])/torch.abs(maxs[L][p]+10**-6)).sum(dim=1,keepdim=True) 334 | cls_batch_deviations.append(dev.cpu().detach().numpy()) 335 | cls_batch_deviations = np.concatenate(cls_batch_deviations,axis=1) 336 | if batch_deviations is None: 337 | batch_deviations = np.zeros([ypred.shape[0],cls_batch_deviations.shape[1]]) 338 | batch_deviations[cond] = cls_batch_deviations 339 | return batch_deviations 340 | 341 | powers = [8] 342 | 343 | def cal_feature_prototype(self,feat,ypred): 344 | feat = torch.abs(feat) 345 | for pr in range(self.num_classes): 346 | cond = ypred==pr 347 | if not cond.any(): 348 | continue 349 | csfeat = feat[cond] 350 | cf = csfeat.mean(dim = [0,2,3])#.cpu().numpy() 351 | # print(cf.shape) 352 | ct = cond.sum() 353 | ft = self.avg_feature[pr] 354 | self.avg_feature[pr] = [ft[0] + ct, (ft[1] * ft[0] + cf * ct)/(ft[0] + ct)] 355 | if self.enable_gram: 356 | for p in range(len(self.powers)): 357 | gram = G_p_pro(csfeat,self.powers[p]).mean(dim = 0) 358 | gm = self.avg_gram[p][pr] 359 | self.avg_gram[p][pr] = [gm[0] + ct, (gm[1] * gm[0] + gram * ct)/(gm[0] + ct)] 360 | 361 | 362 | def obtain_usable_feature_prototype(self): 363 | if isinstance(self.avg_feature,list): 364 | clsft_lost = [] 365 | exm = None 366 | for x in self.avg_feature: 367 | if x[0] > 0: 368 | clsft_lost.append(x[1]) 369 | exm = x[1] 370 | else: 371 | clsft_lost.append(None) 372 | clsft = torch.stack([torch.zeros_like(exm) if x is None else x for x in clsft_lost]) 373 | # print(clsft.shape) 374 | clsft /= clsft.sum(dim = 0) #**2 375 | # clsft /= clsft.sum(dim = 1,keepdim = True) 376 | # print(clsft) 377 | self.avg_feature = clsft.reshape([clsft.shape[0],1,clsft.shape[1],1,1]) 378 | if self.enable_gram: 379 | for i in range(len(self.powers)): 380 | self.avg_gram[i] = torch.stack([x[1] if x[0] > 0 else torch.zeros([exm.shape[0],exm.shape[0]]).cuda() for x in self.avg_gram[i]]) 381 | # self.avg_gram /= self.avg_gram.sum(dim = 0) 382 | # print(self.avg_gram.shape) 383 | return self.avg_feature,self.avg_gram 384 | 385 | def get_feature_prototype_deviation(self,feat,ypred): 386 | # feat = torch.abs(feat) 387 | avg_feature,_ = self.obtain_usable_feature_prototype() 388 | scores = np.zeros([feat.shape[0],feat.shape[2],feat.shape[3]]) 389 | for pr in range(self.num_classes): 390 | cond = ypred==pr 391 | if not cond.any(): 392 | continue 393 | scores[cond] = (avg_feature[pr] * feat[cond]).mean(axis = 1).cpu().numpy() 394 | return scores 395 | 396 | def get_feature_gram_deviation(self,feat,ypred): 397 | _,avg_gram = self.obtain_usable_feature_prototype() 398 | scores = np.zeros([feat.shape[0],feat.shape[2],feat.shape[3]]) 399 | for pr in range(self.num_classes): 400 | cond = ypred==pr 401 | if not cond.any(): 402 | continue 403 | res = 0 404 | for i in range(len(self.powers)): 405 | gm = G_p_pro(feat[cond],p=self.powers[i]) 406 | # scores[cond] = (gm / gm.mean(dim = [3,4],keepdim = True) * avg_gram[pr]).sum(dim = [3,4]).cpu().numpy() 407 | res += (gm * avg_gram[i][pr]).sum(dim = [1,2],keepdim = True).cpu().numpy() 408 | scores[cond] = res 409 | return scores 410 | 411 | def pred_by_feature_gram(self,feat): 412 | _,avg_gram = self.obtain_usable_feature_prototype() 413 | scores = np.zeros([self.num_classes, feat.shape[0]]) 414 | gm = G_p_pro(feat) 415 | for pr in range(self.num_classes): 416 | # scores[cond] = (gm / gm.mean(dim = [3,4],keepdim = True) * avg_gram[pr]).sum(dim = [3,4]).cpu().numpy() 417 | scores[pr] = (gm * avg_gram[pr]).sum(dim = [1,2]).cpu().numpy() 418 | return scores.argmax(axis = 0) 419 | 420 | def forward(self,x,ycls = None,reqpredauc = False,prepareTest = False,reqfeature = False): 421 | 422 | # ----- New Arch 423 | x = self.backbone_cs(x,feature_only = reqfeature) 424 | if reqfeature: 425 | return x 426 | x,xcls_raw = x 427 | 428 | def pred_score(xcls): 429 | score_reduce = lambda x : x.reshape([x.shape[0],-1]).mean(axis = 1) 430 | x_detach = x.detach() 431 | probs = self.crt(xcls,prob = True).cpu().numpy() 432 | pred = probs.argmax(axis = 1) 433 | max_prob = probs.max(axis = 1) 434 | 435 | cls_scores = xcls.cpu().numpy()[[i for i in range(pred.shape[0])],pred] 436 | rep_scores = torch.abs(x_detach).mean(dim = 1).cpu().numpy() 437 | if not self.training and not prepareTest and (not isinstance(self.avg_feature,list) or self.avg_feature[0][0] != 0): 438 | rep_cspt = self.get_feature_prototype_deviation(x_detach,pred) 439 | if self.enable_gram: 440 | rep_gram = self.get_feature_gram_deviation(x_detach,pred) 441 | else: 442 | rep_gram = np.zeros_like(cls_scores) 443 | else: 444 | rep_cspt = np.zeros_like(cls_scores) 445 | rep_gram = np.zeros_like(cls_scores) 446 | R = [cls_scores,rep_scores,rep_cspt,rep_gram,max_prob] 447 | 448 | scores = np.stack([score_reduce(eval(self.config['score'])),score_reduce(rep_cspt),score_reduce(rep_gram)],axis = 1) 449 | return pred,scores 450 | 451 | if self.training: 452 | xcls = self.crt(xcls_raw,ycls) 453 | if reqpredauc : 454 | pred,score = pred_score(xcls_raw.detach()) 455 | return xcls,pred,score 456 | else: 457 | xcls = xcls_raw 458 | # xrot = self.rot_cls(x) 459 | if reqpredauc: 460 | pred,score = pred_score(xcls) 461 | deviations = None 462 | # powers = range(1,10) 463 | if prepareTest: 464 | if not isinstance(self.avg_feature,list): 465 | self.avg_feature = [[0,0] for i in range(self.num_classes)] 466 | self.avg_gram = [[[0,0] for i in range(self.num_classes)] for i in self.powers] 467 | # hdfts = self.backbone.backbone.obtain_gram_feats() 468 | # self.update_minmax(hdfts + [x] + clslatents,powers,pred) 469 | self.cal_feature_prototype(x,pred) 470 | # else: 471 | # deviations = self.get_deviations(self.backbone.backbone.obtain_gram_feats() + [x]+ clslatents,powers,pred) 472 | return pred,score,deviations 473 | 474 | return xcls 475 | 476 | 477 | 478 | class CSSRCriterion(nn.Module): 479 | 480 | def get_onehot_label(self,y,clsnum): 481 | y = torch.reshape(y,[-1,1]).long() 482 | return torch.zeros(y.shape[0], clsnum).cuda().scatter_(1, y, 1) 483 | 484 | def __init__(self,avg_order,enable_sigma = True): 485 | super().__init__() 486 | self.avg_order = {"avg_softmax":1,"softmax_avg":2}[avg_order] 487 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 488 | self.enable_sigma = enable_sigma 489 | 490 | def forward(self,x,y = None,prob = False,pred = False): 491 | if self.avg_order == 1: 492 | g = self.avg_pool(x).view(x.shape[0],-1) 493 | g = torch.softmax(g,dim=1) 494 | elif self.avg_order == 2: 495 | g = torch.softmax(x,dim=1) 496 | g = self.avg_pool(g).view(x.size(0), -1) 497 | if prob: return g 498 | if pred: return torch.argmax(g,dim = 1) 499 | loss = -torch.sum(self.get_onehot_label(y,g.shape[1]) * torch.log(g),dim=1).mean() 500 | return loss 501 | 502 | 503 | def manual_contrast(x): 504 | s = random.uniform(0.1,2) 505 | return x * s 506 | 507 | 508 | class WrapDataset(data.Dataset): 509 | 510 | def __init__(self,labeled_ds,config,inchan_num = 3) -> None: 511 | super().__init__() 512 | self.labeled_ds = labeled_ds 513 | 514 | __mean = [0.5,0.5,0.5][:inchan_num] 515 | __std = [0.25,0.25,0.25][:inchan_num] 516 | 517 | trans = [transforms.RandomHorizontalFlip()] 518 | if config['cust_aug_crop_withresize']: 519 | trans.append(transforms.RandomResizedCrop(size = util.img_size,scale = (0.25,1))) 520 | elif util.img_size > 200: 521 | trans += [transforms.Resize(256),transforms.RandomResizedCrop(util.img_size)] 522 | else: 523 | trans.append(transforms.RandomCrop(size=util.img_size, 524 | padding=int(util.img_size*0.125), 525 | padding_mode='reflect')) 526 | if config['strong_option'] == 'RA': 527 | trans.append(RandAugmentMC(n=2, m=10)) 528 | elif config['strong_option'] == 'CUST': 529 | trans.append(HighlyCustomizableAugment(2,10,-1,labeled_ds,config)) 530 | elif config['strong_option'] == 'NONE': 531 | pass 532 | else: 533 | raise NotImplementedError() 534 | trans += [transforms.ToTensor(), 535 | transforms.Normalize(mean=__mean, std=__std)] 536 | 537 | if config['manual_contrast']: 538 | trans.append(manual_contrast) 539 | strong = transforms.Compose(trans) 540 | 541 | if util.img_size > 200: 542 | self.simple = [transforms.RandomResizedCrop(util.img_size)] 543 | else: 544 | self.simple = [transforms.RandomCrop(size=util.img_size, 545 | padding=int(util.img_size*0.125), 546 | padding_mode='reflect')] 547 | self.simple = transforms.Compose(([transforms.RandomHorizontalFlip()]) + self.simple + [ 548 | transforms.ToTensor(), 549 | transforms.Normalize(mean=__mean, std=__std)] + ([manual_contrast] if config['manual_contrast'] else [])) 550 | 551 | self.test_normalize = transforms.Compose([ 552 | transforms.CenterCrop(util.img_size), 553 | transforms.ToTensor(), 554 | transforms.Normalize(mean=__mean, std=__std)]) 555 | 556 | td = {'strong' : strong, 'simple' : self.simple} 557 | self.aug = td[config['cat_augmentation']] 558 | self.test_mode = False 559 | 560 | def __len__(self) -> int: 561 | return len(self.labeled_ds) 562 | 563 | def __getitem__(self, index: int) : 564 | img,lb,_ = self.labeled_ds[index] 565 | if self.test_mode: 566 | img = self.test_normalize(img) 567 | else: 568 | img = self.aug(img) 569 | return img,lb,index 570 | 571 | 572 | @util.regmethod('cssr') 573 | class CSSRMethod: 574 | 575 | def get_cfg(self,key,default): 576 | return self.config[key] if key in self.config else default 577 | 578 | def __init__(self, config, clssnum, train_set) -> None: 579 | self.config = config 580 | self.epoch = 0 581 | self.lr = config['learn_rate'] 582 | self.batch_size = config['batch_size'] 583 | 584 | self.clsnum = clssnum 585 | self.crt = CSSRCriterion(config['arch_type'],False) 586 | self.model = CSSRModel(self.clsnum,config,self.crt).cuda() 587 | self.modelopt = torch.optim.SGD(self.model.parameters(), lr=self.lr,weight_decay=5e-4) 588 | 589 | self.wrap_ds = WrapDataset(train_set,self.config,inchan_num=3,) 590 | self.wrap_loader = data.DataLoader(self.wrap_ds, 591 | batch_size=self.config['batch_size'], shuffle=True,pin_memory=True, num_workers=6) 592 | self.lr_schedule = util.get_scheduler(self.config,self.wrap_loader) 593 | 594 | self.prepared = -999 595 | 596 | def train_epoch(self): 597 | data_time = AverageMeter() 598 | batch_time = AverageMeter() 599 | train_acc = AverageMeter() 600 | 601 | running_loss = AverageMeter() 602 | 603 | self.model.train() 604 | 605 | endtime = time.time() 606 | for i, data in enumerate(tqdm.tqdm(self.wrap_loader)): 607 | data_time.update(time.time() - endtime) 608 | 609 | self.lr = self.lr_schedule.get_lr(self.epoch,i,self.lr) 610 | util.set_lr([self.modelopt],self.lr) 611 | sx, lb = data[0].cuda(),data[1].cuda() 612 | 613 | loss,pred,scores = self.model(sx,lb,reqpredauc = True) 614 | self.modelopt.zero_grad() 615 | loss.backward() 616 | self.modelopt.step() 617 | 618 | nplb = data[1].numpy() 619 | train_acc.update((pred == nplb).sum() / pred.shape[0],pred.shape[0]) 620 | running_loss.update(loss.item()) 621 | batch_time.update(time.time() - endtime) 622 | endtime = time.time() 623 | self.epoch += 1 624 | training_res = \ 625 | {"Loss" : running_loss.avg, 626 | "TrainAcc" : train_acc.avg, 627 | "Learn Rate" : self.lr, 628 | "DataTime" : data_time.avg, 629 | "BatchTime" : batch_time.avg} 630 | 631 | return training_res 632 | 633 | 634 | def known_prediction_test(self,test_loader): 635 | self.model.eval() 636 | pred,scores,_,_ = self.scoring(test_loader) 637 | return pred 638 | 639 | def scoring(self,loader,prepare = False): 640 | gts = [] 641 | deviations = [] 642 | 643 | scores = [] 644 | prediction = [] 645 | with torch.no_grad(): 646 | for d in tqdm.tqdm(loader): 647 | x1 = d[0].cuda(non_blocking = True) 648 | gt = d[1].numpy() 649 | pred,scr,dev = self.model(x1,reqpredauc = True,prepareTest = prepare) 650 | prediction.append(pred) 651 | scores.append(scr) 652 | gts.append(gt) 653 | 654 | prediction = np.concatenate(prediction) 655 | scores = np.concatenate(scores) 656 | gts = np.concatenate(gts) 657 | 658 | return prediction,scores,deviations,gts 659 | 660 | def knownpred_unknwonscore_test(self,test_loader): 661 | self.model.eval() 662 | if self.prepared != self.epoch: 663 | self.wrap_ds.test_mode = True 664 | tpred,tscores,_,_ = self.scoring(self.wrap_loader,True) 665 | self.wrap_ds.test_mode = False 666 | self.prepared = self.epoch 667 | pred,scores,devs,gts = self.scoring(test_loader) 668 | 669 | if self.config['integrate_score'] != "S[0]": 670 | tpred,tscores,_,_ = self.scoring(self.wrap_loader,False) 671 | mean,std = tscores.mean(axis = 0),tscores.std(axis = 0) 672 | scores = (scores - mean)/(std + 1e-8) 673 | S = scores.T 674 | return eval(self.config['integrate_score']),-9999999,pred 675 | 676 | def save_model(self,path): 677 | save_dict = { 678 | 'model' : self.model.state_dict(), 679 | 'config': self.config, 680 | 'optimzer' : self.modelopt.state_dict(), 681 | 'epoch' : self.epoch, 682 | } 683 | torch.save(save_dict,path) 684 | 685 | def load_model(self,path): 686 | save_dict = torch.load(path) 687 | self.model.load_state_dict(save_dict['model']) 688 | if 'optimzer' in save_dict and self.modelopt is not None: 689 | self.modelopt.load_state_dict(save_dict['optimzer']) 690 | self.epoch = save_dict['epoch'] 691 | -------------------------------------------------------------------------------- /methods/cssr_ft.py: -------------------------------------------------------------------------------- 1 | 2 | from methods.augtools import HighlyCustomizableAugment, RandAugmentMC 3 | import methods.util as util 4 | import tqdm 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data as data 9 | import random 10 | from methods.util import AverageMeter 11 | import time 12 | from torchvision.transforms import transforms 13 | from methods.cssr import Backbone,AutoEncoder 14 | 15 | class LinearClassifier(nn.Module): 16 | 17 | def __init__(self,inchannels,num_class, config): 18 | super().__init__() 19 | if config['projection_dim'] > 0: 20 | self.proj = nn.Sequential( 21 | nn.Conv2d(inchannels, config['projection_dim'], 1,padding= 0, bias=False), 22 | nn.BatchNorm2d(config['projection_dim']), 23 | nn.LeakyReLU(inplace=True)) 24 | inchannels = config['projection_dim'] 25 | else: 26 | self.proj = nn.Identity() 27 | self.cls = nn.Conv2d(inchannels, num_class , 1,padding= 0, bias=False) 28 | 29 | 30 | def forward(self,x): 31 | x = self.proj(x) 32 | x1 = self.cls(x) 33 | return x,x1 34 | 35 | 36 | class CSSRClassifier(nn.Module): 37 | 38 | def __init__(self,inchannels,num_class, config): 39 | super().__init__() 40 | if config['projection_dim'] > 0: 41 | self.proj = nn.Sequential( 42 | nn.Conv2d(inchannels, config['projection_dim'], 1,padding= 0, bias=False), 43 | nn.BatchNorm2d(config['projection_dim']), 44 | nn.LeakyReLU(inplace=True)) 45 | inchannels = config['projection_dim'] 46 | else: 47 | self.proj = nn.Identity() 48 | 49 | ae_hidden = config['ae_hidden'] 50 | ae_latent = config['ae_latent'] 51 | self.class_aes = [] 52 | for i in range(num_class): 53 | ae = AutoEncoder(inchannels,ae_hidden,ae_latent) 54 | self.class_aes.append(ae) 55 | self.class_aes = nn.ModuleList(self.class_aes) 56 | self.useL1 = config['error_measure'] == 'L1' 57 | 58 | self.reduction = -1 if config['model'] == 'pcssr' else 1 59 | self.reduction *= config['gamma'] 60 | 61 | 62 | def ae_error(self,rc,x): 63 | if self.useL1: 64 | return torch.norm(rc - x,p = 1,dim = 1,keepdim=True) * self.reduction 65 | else: 66 | return torch.norm(rc - x,p = 2,dim = 1,keepdim=True) ** 2 * self.reduction 67 | 68 | clip_len = 100 69 | 70 | def forward(self,x): 71 | x = self.proj(x) 72 | cls_ers = [] 73 | for ae in self.class_aes: 74 | rc,_ = ae(x) 75 | cls_er = self.ae_error(rc,x) 76 | if CSSRClassifier.clip_len > 0: 77 | cls_er = torch.clamp(cls_er,-CSSRClassifier.clip_len,CSSRClassifier.clip_len) 78 | cls_ers.append(cls_er) 79 | logits = torch.cat(cls_ers,dim=1) 80 | return x,logits 81 | 82 | 83 | 84 | class BaselineModel(nn.Module): 85 | 86 | def __init__(self,num_classes,config,crt): 87 | super().__init__() 88 | self.backbone = Backbone(config,3) 89 | self.crt = crt 90 | 91 | clsblock = {'linear':LinearClassifier, 'pcssr':CSSRClassifier,'rcssr' : CSSRClassifier} 92 | mod_config = config['category_model'] 93 | self.cls = clsblock[mod_config['model']](self.backbone.output_dim,num_classes,mod_config).cuda() 94 | self.config = config 95 | 96 | def forward(self,x,ycls = None,fixbackbone=False): 97 | if fixbackbone: 98 | with torch.no_grad(): 99 | x = self.backbone(x) 100 | else: 101 | x = self.backbone(x) 102 | 103 | def pred_score(xcls): 104 | score_reduce = lambda x : x.reshape([x.shape[0],-1]).mean(axis = 1) 105 | 106 | probs = self.crt(xcls,prob = True).cpu().numpy() 107 | pred = probs.argmax(axis = 1) 108 | max_prob = probs.max(axis = 1) 109 | 110 | cls_scores = xcls.cpu().numpy()[[i for i in range(pred.shape[0])],pred] 111 | rep_scores = torch.abs(x.detach()).mean(dim = 1).cpu().numpy() 112 | R = [cls_scores,rep_scores,0,0,max_prob] 113 | 114 | scores = score_reduce(eval(self.config['score'])) 115 | return pred,scores 116 | 117 | if self.training: 118 | x,logitcls = self.cls(x) 119 | return logitcls 120 | else: 121 | x,xcls = self.cls(x) 122 | pred,scores = pred_score(xcls) 123 | return pred,scores 124 | 125 | 126 | class CSSRCriterion(nn.Module): 127 | 128 | def get_onehot_label(self,y,clsnum): 129 | y = torch.reshape(y,[-1,1]) 130 | return torch.zeros(y.shape[0], clsnum).cuda().scatter_(1, y, 1) 131 | 132 | def __init__(self,avg_order,enable_sigma = True): 133 | super().__init__() 134 | self.avg_order = {"avg_softmax":1,"softmax_avg":2}[avg_order] 135 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 136 | self.enable_sigma = enable_sigma 137 | 138 | def forward(self,x,y = None, prob = False,pred = False,dontreduce = False): 139 | if self.avg_order == 1: 140 | g = self.avg_pool(x).view(x.shape[0],-1) 141 | g = torch.softmax(g,dim=1) 142 | elif self.avg_order == 2: 143 | g = torch.softmax(x,dim=1) 144 | g = self.avg_pool(g).view(x.size(0), -1) 145 | if prob: return g 146 | if pred: return torch.argmax(g,dim = 1) 147 | if not dontreduce: 148 | loss = -torch.sum(self.get_onehot_label(y,g.shape[1]) * torch.log(g),dim=1).mean() 149 | else: 150 | loss = -torch.sum(self.get_onehot_label(y,g.shape[1]) * torch.log(g),dim=1) 151 | 152 | return loss 153 | 154 | 155 | def manual_contrast(x): 156 | s = random.uniform(0.1,2) 157 | return x * s 158 | 159 | 160 | class WrapDataset(data.Dataset): 161 | 162 | def __init__(self,labeled_ds,config,inchan_num = 3) -> None: 163 | super().__init__() 164 | self.labeled_ds = labeled_ds 165 | 166 | __mean = [0.5,0.5,0.5][:inchan_num] 167 | __std = [0.25,0.25,0.25][:inchan_num] 168 | 169 | trans = [transforms.RandomHorizontalFlip()] 170 | if config['cust_aug_crop_withresize']: 171 | trans.append(transforms.RandomResizedCrop(size = util.img_size,scale = (0.25,1))) 172 | elif util.img_size > 200: 173 | trans += [transforms.Resize(256),transforms.RandomResizedCrop(util.img_size)] 174 | else: 175 | trans.append(transforms.RandomCrop(size=util.img_size, 176 | padding=int(util.img_size*0.125), 177 | padding_mode='reflect')) 178 | if config['strong_option'] == 'RA': 179 | trans.append(RandAugmentMC(n=2, m=10)) 180 | elif config['strong_option'] == 'CUST': 181 | trans.append(HighlyCustomizableAugment(2,10,-1,labeled_ds,config)) 182 | elif config['strong_option'] == 'NONE': 183 | pass 184 | else: 185 | raise NotImplementedError() 186 | trans += [transforms.ToTensor(), 187 | transforms.Normalize(mean=__mean, std=__std)] 188 | 189 | if config['manual_contrast']: 190 | trans.append(manual_contrast) 191 | strong = transforms.Compose(trans) 192 | 193 | self.simple = transforms.Compose(([transforms.RandomHorizontalFlip()]) + [ 194 | transforms.Resize(256), 195 | transforms.RandomResizedCrop(size = util.img_size,scale = (0.25,1)), 196 | transforms.ToTensor(), 197 | transforms.Normalize(mean=__mean, std=__std)]) 198 | # self.testaug = gen_testaug_transform() 199 | self.normalize = transforms.Compose([transforms.ToTensor(), 200 | transforms.Normalize(mean=__mean, std=__std)]) 201 | 202 | td = {'strong' : strong, 'simple' : self.simple} 203 | self.aug = td[config['cat_augmentation']] 204 | self.test_mode = False 205 | 206 | def __len__(self) -> int: 207 | return len(self.labeled_ds) 208 | 209 | def __getitem__(self, index: int) : 210 | img,lb,_ = self.labeled_ds[index] 211 | if self.test_mode: 212 | img = self.normalize(img) 213 | else: 214 | img = self.aug(img) 215 | return img,lb,index 216 | 217 | @util.regmethod('cssr_ft') 218 | class CSSRFTMethod: 219 | 220 | def get_cfg(self,key,default): 221 | return self.config[key] if key in self.config else default 222 | 223 | def __init__(self, config, clssnum, train_set) -> None: 224 | self.config = config 225 | self.epoch = 0 226 | self.clsnum = clssnum 227 | self.crt = CSSRCriterion(config['arch_type'],False) 228 | self.model = BaselineModel(self.clsnum,config,self.crt).cuda() 229 | # ---- Training Related 230 | self.batch_size = config['batch_size'] 231 | self.lr = config['learn_rate'] * (self.batch_size / 128) 232 | 233 | self.modelopt = torch.optim.SGD([ 234 | { 'params': self.model.cls.parameters(), 'lr' : self.lr, 'weight_decay':1e-4 }, 235 | ], lr=self.lr,weight_decay=5e-4) 236 | 237 | # ---- schedules 238 | self.lrdecay = self.config['lr_decay'] 239 | self.wrap_ds = WrapDataset(train_set,self.config,3) 240 | self.wrap_loader = data.DataLoader(self.wrap_ds, 241 | batch_size=self.config['batch_size'], shuffle=True,pin_memory=True, num_workers=6) 242 | self.lr_schedule = util.get_scheduler(self.config,self.wrap_loader) 243 | 244 | def train_epoch(self): 245 | data_time = AverageMeter() 246 | batch_time = AverageMeter() 247 | train_acc = AverageMeter() 248 | running_loss = AverageMeter() 249 | 250 | self.model.train() 251 | self.model.backbone.eval() 252 | progress_bar = tqdm.tqdm(self.wrap_loader) 253 | endtime = time.time() 254 | for i, data in enumerate(progress_bar): 255 | data_time.update(time.time() - endtime) 256 | progress_bar.set_description('epoch ' + str(self.epoch)) 257 | self.lr = self.lr_schedule.get_lr(self.epoch,i,self.lr) 258 | util.set_lr([self.modelopt],self.lr) 259 | sx, lb = data[0].cuda(),data[1].cuda() 260 | 261 | cls_logits = self.model(sx,ycls = lb,fixbackbone=True) 262 | loss = self.crt(cls_logits,lb) 263 | pred = self.crt(cls_logits,pred = True).cpu().numpy() 264 | 265 | self.modelopt.zero_grad() 266 | loss.backward() 267 | self.modelopt.step() 268 | 269 | nplb = data[1].numpy() 270 | train_acc.update((pred == nplb).sum() / pred.shape[0],pred.shape[0]) 271 | running_loss.update(loss.item()) 272 | batch_time.update(time.time() - endtime) 273 | endtime = time.time() 274 | 275 | progress_bar.set_postfix( 276 | acc='%.4f' % train_acc.avg, 277 | loss='%.4f' % running_loss.avg, 278 | datatime = '%.4f' % data_time.avg, 279 | batchtime = '%.4f' % batch_time.avg, 280 | learnrate = '%.4f' % self.lr, 281 | ) 282 | if i % 200 == 0: 283 | print("Itr",i,'TrainAcc:%.4f' % train_acc.avg,'loss:%.4f' % (running_loss.avg),'learnrate:%.4f' % self.lr) 284 | 285 | training_res = \ 286 | {"Loss" : running_loss.avg, 287 | "TrainAcc" : train_acc.avg, 288 | "Learn Rate" : self.lr, 289 | "DataTime" : data_time.avg, 290 | "BatchTime" : batch_time.avg} 291 | 292 | return training_res 293 | 294 | def known_prediction_test(self,test_loader): 295 | self.model.eval() 296 | pred,scores = self.scoring(test_loader) 297 | return pred 298 | 299 | def scoring(self,loader): 300 | scores = [] 301 | prediction = [] 302 | with torch.no_grad(): 303 | for d in tqdm.tqdm(loader): 304 | x1 = d[0].cuda(non_blocking = True) 305 | pred,scr = self.model(x1) 306 | prediction.append(pred) 307 | scores.append(scr) 308 | 309 | prediction = np.concatenate(prediction) 310 | scores = np.concatenate(scores) 311 | return prediction,scores 312 | 313 | 314 | def knownpred_unknwonscore_test(self, test_loader): 315 | self.model.eval() 316 | pred,scores = self.scoring(test_loader) 317 | return scores,-9999999,pred 318 | 319 | def save_model(self,path): 320 | save_dict = { 321 | 'model' : self.model.state_dict(), 322 | 'config': self.config, 323 | 'optimzer' : self.modelopt.state_dict(), 324 | 'epoch' : self.epoch 325 | } 326 | torch.save(save_dict,path) 327 | 328 | def load_model(self,path): 329 | save_dict = torch.load(path) 330 | print("The loading model has config") 331 | print(save_dict['config']) 332 | self.model.load_state_dict(save_dict['model']) 333 | if 'optimzer' in save_dict: 334 | self.modelopt.load_state_dict(save_dict['optimzer']) 335 | self.epoch = save_dict['epoch'] 336 | -------------------------------------------------------------------------------- /methods/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes, stride=1): 4 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 5 | 6 | def conv1x1(in_planes, out_planes, stride=1): 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 8 | 9 | class BasicBlock(nn.Module): 10 | 11 | """Basic Block for resnet 18 and resnet 34 12 | 13 | """ 14 | #BasicBlock and BottleNeck block 15 | #have different output size 16 | #we use class attribute expansion 17 | #to distinct 18 | expansion = 1 19 | def __init__(self,mastermodel, in_channels, out_channels, stride=1): 20 | super().__init__() 21 | #residual function 22 | self.residual_function = nn.Sequential( 23 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 24 | nn.BatchNorm2d(out_channels), 25 | nn.LeakyReLU(inplace=True), # org LeakyRelu 26 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 28 | ) 29 | #shortcut 30 | self.shortcut = nn.Sequential() 31 | 32 | #the shortcut output dimension is not the same with residual function 33 | #use 1*1 convolution to match the dimension 34 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 37 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 38 | ) 39 | self.mastermodel = mastermodel 40 | 41 | def forward(self, x): 42 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 43 | 44 | 45 | class ResNet(nn.Module): 46 | def __init__(self, block = BasicBlock, num_block = [2,2,2,2],avg_output = False,output_dim = 256,preprocessstride = 1,resfirststride = 1,inchan = 3): 47 | super().__init__() 48 | img_chan = inchan 49 | self.conv1 = nn.Sequential( 50 | nn.Conv2d(img_chan, 64, kernel_size=3, padding=1, bias=False,stride=preprocessstride), 51 | nn.BatchNorm2d(64), 52 | nn.LeakyReLU()) 53 | self.in_channels = 64 54 | self.conv2_x = self._make_layer(block, 64, num_block[0], resfirststride) 55 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 56 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 57 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 58 | self.conv6_x = nn.Identity() if output_dim <= 0 else self.conv_layer(512,output_dim,1,0) 59 | self.conv6_is_identity = output_dim <= 0 60 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 61 | if output_dim > -1: 62 | self.output_dim = output_dim 63 | else: 64 | self.output_dim = 512 * block.expansion 65 | 66 | self.avg_output = avg_output 67 | 68 | 69 | def conv_layer(self,input_channel,output_channel,kernel_size=3,padding =1): 70 | print("conv layer input",input_channel,"output",output_channel) 71 | res = nn.Sequential( 72 | nn.Conv2d(input_channel, output_channel, kernel_size,1, padding, bias=False), 73 | nn.BatchNorm2d(output_channel), 74 | nn.LeakyReLU(0.2)) 75 | return res 76 | 77 | def _make_layer(self, block, out_channels, num_blocks, stride): 78 | # we have num_block blocks per layer, the first block 79 | # could be 1 or 2, other blocks would always be 1 80 | print("Making resnet layer with channel",out_channels,"block",num_blocks,"stride",stride) 81 | 82 | strides = [stride] + [1] * (num_blocks - 1) 83 | layers = [] 84 | for stride in strides: 85 | layers.append(block(None,self.in_channels, out_channels, stride)) 86 | self.in_channels = out_channels * block.expansion 87 | return nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | output = self.conv1(x) 91 | output = self.conv2_x(output) 92 | output = self.conv3_x(output) 93 | output = self.conv4_x(output) 94 | output = self.conv5_x(output) 95 | output = self.conv6_x(output) 96 | if self.avg_output: 97 | output = self.avg_pool(output) 98 | output = output.view(output.size(0), -1) 99 | return output 100 | -------------------------------------------------------------------------------- /methods/util.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import numpy as np 5 | 6 | img_size = 32 7 | 8 | def setup_dir(dir_path): 9 | if not os.path.exists(dir_path): 10 | os.mkdir(dir_path) 11 | 12 | 13 | def set_lr(opts,lr): 14 | for op in opts : 15 | for param_group in op.param_groups: 16 | param_group['lr'] = lr 17 | 18 | 19 | class AverageMeter(object): 20 | """Computes and stores the average and current value 21 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 22 | """ 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | class WarmUpLrSchedule: 40 | 41 | def __init__(self, warm_epoch, epoch_tot_steps, init_lr): 42 | self.ep_steps = epoch_tot_steps 43 | self.tgtstep = warm_epoch * epoch_tot_steps 44 | self.init_lr = init_lr 45 | self.warm_epoch = warm_epoch 46 | 47 | def get_lr(self,epoch,step,lr): 48 | tstep = epoch * self.ep_steps + step 49 | if self.tgtstep > 0 and tstep <= self.tgtstep: 50 | lr = self.init_lr * tstep / self.tgtstep 51 | return lr 52 | 53 | class MultiStepLrSchedule: 54 | 55 | def __init__(self,milestones,lrdecays,start_lr,warmup_schedule = None): 56 | super().__init__() 57 | self.milestones = milestones 58 | self.warmup = warmup_schedule 59 | self.lrdecays = lrdecays 60 | self.start_lr = start_lr 61 | 62 | # step 表示epoch中已经输入过的样本数 63 | def get_lr(self,epoch,step,lr): 64 | lr = self.start_lr 65 | # if step == 0 : # update learning rate 66 | for m in self.milestones: 67 | if epoch >= m: 68 | lr *= self.lrdecays 69 | # print("LEARNRATE",lr) 70 | if self.warmup is not None: 71 | lr = self.warmup.get_lr(epoch,step,lr) 72 | # print("LEARNRATE",lr) 73 | return lr 74 | 75 | # cosine_s,cosine_e = 0,0 76 | 77 | 78 | # epoch wise 79 | class EpochwiseCosineAnnealingLrSchedule: 80 | 81 | def __init__(self,startlr,milestones,lrdecay,epoch_num,warmup = None): 82 | super().__init__() 83 | self.cosine_s,self.cosine_e = 0,0 84 | self.milestones = milestones 85 | self.lrdecay = lrdecay 86 | self.warmup = warmup 87 | self.warmup_epoch = 0 if warmup is None else warmup.warm_epoch 88 | self.epoch_num = epoch_num 89 | self.startlr = startlr 90 | self.ms = [self.warmup_epoch] + self.milestones + [self.epoch_num] 91 | self.ref = {self.ms[i] : self.ms[i+1] for i in range(len(self.ms)-1)} 92 | 93 | def get_lr(self,epoch,step,lr): 94 | #global cosine_s,cosine_e 95 | if self.warmup is not None: 96 | lr = self.warmup.get_lr(epoch,step,lr) 97 | if step != 0 : 98 | return lr 99 | if epoch in self.ms: 100 | if epoch != self.warmup_epoch: 101 | self.startlr *= self.lrdecay 102 | self.cosine_s = epoch 103 | self.cosine_e = self.ref[epoch] 104 | #print("calc lr",epoch,self.ms,self.cosine_s,self.cosine_e) 105 | if self.cosine_e > 0: 106 | lr = self.startlr * (np.cos((epoch - self.cosine_s) / (self.cosine_e - self.cosine_s) * 3.14159)+1) * 0.5 107 | 108 | return lr 109 | 110 | 111 | # Step wise 112 | class StepwiseCosineAnnealingLrSchedule: 113 | 114 | def __init__(self,startlr,epoch_tot_steps,milestones,lrdecay,epoch_num,warmup = None): 115 | super().__init__() 116 | self.cosine_s,self.cosine_e = 0,0 117 | self.milestones = milestones 118 | self.lrdecay = lrdecay 119 | self.warmup = warmup 120 | self.warmup_epoch = 0 if warmup is None else warmup.warm_epoch 121 | self.epoch_num = epoch_num 122 | self.startlr = startlr 123 | self.ms = [self.warmup_epoch] + self.milestones + [self.epoch_num] 124 | self.ref = {self.ms[i] : self.ms[i+1] for i in range(len(self.ms)-1)} 125 | self.ep_steps = epoch_tot_steps 126 | 127 | # step wise 128 | def get_lr(self,epoch,step,lr): 129 | if self.warmup is not None: 130 | lr = self.warmup(epoch,step,lr) 131 | if step == 0 and epoch in self.ms: 132 | if epoch != self.warmup_epoch: 133 | self.startlr *= self.lrdecay 134 | self.cosine_s = epoch 135 | self.cosine_e = self.ref[epoch] 136 | if self.cosine_e > 0: 137 | steps = step + (epoch - self.cosine_s) * self.epoch_tot_steps 138 | lr = self.startlr * (np.cos( steps / (self.cosine_e - self.cosine_s) / self.epoch_tot_steps * 3.14159)+1) * 0.5 139 | return lr 140 | 141 | def get_scheduler(config,train_loader): 142 | if config['lr_schedule'] == 'multi_step': 143 | warmup = WarmUpLrSchedule(config['warmup_epoch'], len(train_loader) ,config['learn_rate']) 144 | return MultiStepLrSchedule(config["milestones"],config['lr_decay'],config['learn_rate'],warmup) 145 | elif config['lr_schedule'] == 'cosine': 146 | warmup = WarmUpLrSchedule(config['warmup_epoch'], len(train_loader) ,config['learn_rate']) 147 | return EpochwiseCosineAnnealingLrSchedule(config['learn_rate'],config["milestones"],config['lr_decay'],config['epoch_num'],warmup) 148 | 149 | 150 | method_list = {} 151 | class regmethod: 152 | 153 | def __init__(self,name) -> None: 154 | self.name = name 155 | 156 | def __call__(self,func, *args, **kwds): 157 | global method_list 158 | method_list[self.name] = func 159 | print("Registering",self.name) 160 | return func -------------------------------------------------------------------------------- /methods/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 11 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 15 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | self.activate_before_residual = activate_before_residual 23 | def forward(self, x): 24 | if not self.equalInOut and self.activate_before_residual == True: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual): 39 | layers = [] 40 | for i in range(int(nb_layers)): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | 47 | 48 | class WideResNetBackbone(nn.Module): 49 | def __init__(self, args, depth=28, widen_factor=2, dropRate=0.0,req_output_dim = -1): 50 | super(WideResNetBackbone, self).__init__() 51 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 52 | assert((depth - 4) % 6 == 0) 53 | n = (depth - 4) / 6 54 | block = BasicBlock 55 | # 1st conv before any network block 56 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 57 | padding=1, bias=False) 58 | # 1st block 59 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) 60 | # 2nd block 61 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 62 | # 3rd block 63 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 64 | # global average pooling and classifier 65 | self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) 66 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 67 | # self.fc = nn.Linear(nChannels[3], num_classes) 68 | self.output_dim = nChannels[3] 69 | if req_output_dim > 0 and req_output_dim != self.output_dim: 70 | self.dim_map = nn.Sequential( 71 | nn.Conv2d(self.output_dim, req_output_dim, 1,1, 0, bias=False), 72 | nn.BatchNorm2d(req_output_dim), 73 | nn.LeakyReLU(0.2)) 74 | self.output_dim = req_output_dim 75 | else: 76 | self.dim_map = None 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | nn.init.xavier_normal_(m.weight.data) 87 | m.bias.data.zero_() 88 | 89 | def forward(self, x): 90 | out = self.conv1(x) 91 | out = self.block1(out) 92 | out = self.block2(out) 93 | out = self.block3(out) 94 | out = self.relu(self.bn1(out)) 95 | if self.dim_map is not None: 96 | out = self.dim_map(out) 97 | return out 98 | 99 | 100 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from sklearn import metrics 4 | 5 | 6 | class OSREvaluation(): 7 | 8 | def __init__(self,test_loader) -> None: 9 | super().__init__() 10 | labels = test_loader.dataset.labels 11 | self.test_labels = np.array(labels,np.int) 12 | 13 | self.close_samples = self.test_labels >= 0 14 | self.close_samples_ct = np.sum(self.close_samples) 15 | 16 | def close_accuracy(self,prediction): 17 | return np.sum((prediction == self.test_labels) & self.close_samples) / self.close_samples_ct 18 | 19 | 20 | def open_detection_indexes(self,scores,thresh): 21 | if np.isnan(scores).any() or np.isinf(scores).any(): 22 | return {"auroc" : -1} 23 | fpr, tpr, thresholds = metrics.roc_curve(self.close_samples, scores) 24 | auroc = metrics.auc(fpr,tpr) 25 | precision, recall, _ = metrics.precision_recall_curve(self.close_samples, scores) 26 | aupr_in = metrics.auc(recall,precision) 27 | precision, recall, _ = metrics.precision_recall_curve(np.bitwise_not(self.close_samples), -scores) 28 | aupr_out = metrics.auc(recall,precision) 29 | 30 | det_acc = .5 * (tpr + 1.-fpr).max() 31 | 32 | if thresh < -99999: 33 | tidx = np.abs(np.array(tpr) - 0.95).argmin() 34 | thresh = thresholds[tidx] 35 | predicts = scores >= thresh 36 | ys = self.close_samples 37 | accuracy = metrics.accuracy_score(ys,predicts) 38 | f1 = metrics.f1_score(ys,predicts) 39 | recall = metrics.recall_score(ys,predicts) 40 | precision = metrics.precision_score(ys,predicts) 41 | fpr_at_tpr95 = fpr[tidx] 42 | return { 43 | "auroc" : auroc, 44 | "auprIN" : aupr_in, 45 | "auprOUT" : aupr_out, 46 | "accuracy" : accuracy, 47 | 'f1' : f1, 48 | 'recall' : recall, 49 | 'precision' : precision, 50 | "fpr@tpr95" : fpr_at_tpr95, 51 | "bestdetacc" : det_acc, 52 | } 53 | 54 | # modified from APRL 55 | def compute_oscr(self,pred, scores): 56 | unk_scores = scores[self.test_labels < 0] 57 | kn_cond = self.test_labels >= 0 58 | kn_ct = kn_cond.sum() 59 | unk_ct = scores.shape[0] - kn_ct 60 | kn_scores = scores[kn_cond] 61 | kn_correct_pred = pred[kn_cond] == self.test_labels[kn_cond] 62 | 63 | def fpr(thr): 64 | return (unk_scores > thr).sum() / unk_ct 65 | 66 | def ccr(thr): 67 | ac_cond = (kn_scores > thr) & (kn_correct_pred) 68 | return ac_cond.sum() / kn_ct 69 | 70 | sorted_scores = -np.sort(-scores) 71 | # Cutoffs are of prediction values 72 | 73 | CCR = [0] 74 | FPR = [0] 75 | 76 | for s in sorted_scores: 77 | CCR.append(ccr(s)) 78 | FPR.append(fpr(s)) 79 | CCR += [1] 80 | FPR += [1] 81 | 82 | # Positions of ROC curve (FPR, TPR) 83 | ROC = sorted(zip(FPR, CCR), reverse=True) 84 | OSCR = 0 85 | # Compute AUROC Using Trapezoidal Rule 86 | for j in range(len(CCR)-1): 87 | h = ROC[j][0] - ROC[j+1][0] 88 | w = (ROC[j][1] + ROC[j+1][1]) / 2.0 89 | 90 | OSCR = OSCR + h*w 91 | 92 | return OSCR 93 | 94 | # score > thresh 表示是一个开放样本 95 | def open_reco_indexes(self,scores,thresh,rawpredicts): 96 | if np.isnan(scores).any() or np.isinf(scores).any(): 97 | return {} 98 | predicts = rawpredicts.copy() 99 | if thresh < -99999: 100 | fpr, tpr, thresholds = metrics.roc_curve(self.close_samples, scores) 101 | thresh = thresholds[np.abs(np.array(tpr) - 0.95).argmin()] 102 | predicts[scores <= thresh] = -1 103 | ys = self.test_labels.copy() 104 | ys[ys < 0] = -1 105 | accuracy = metrics.accuracy_score(ys,predicts) 106 | macro_f1 = metrics.f1_score(ys,predicts,average='macro') 107 | weighted_f1 = metrics.f1_score(ys,predicts,average='weighted') 108 | macro_recall = metrics.recall_score(ys,predicts,average='macro') 109 | weighted_recall = metrics.recall_score(ys,predicts,average='weighted') 110 | macro_precision = metrics.precision_score(ys,predicts,average='macro') 111 | weighted_precision = metrics.precision_score(ys,predicts,average='weighted') 112 | oscr = self.compute_oscr(rawpredicts,scores) 113 | closeacc_withrej = np.sum((predicts == self.test_labels) & self.close_samples) / self.close_samples_ct 114 | 115 | clswise = {} 116 | tot = 0 117 | numclas = np.max(self.test_labels) + 1 118 | if numclas < 50: 119 | for c in range(numclas): 120 | cond = rawpredicts == c 121 | sbc = self.close_samples[cond] 122 | ck = sbc.sum() 123 | if ck == 0 or ck == sbc.shape[0]: 124 | clswise[f'class{c}'] = 0.5 125 | else: 126 | fpr, tpr, thresholds = metrics.roc_curve(sbc, scores[cond]) 127 | auroc = metrics.auc(fpr,tpr) 128 | clswise[f'class{c}'] = auroc 129 | tot += auroc * cond.sum() 130 | # print("AUROC for class",c,"is",auroc,", with sample number",np.sum(cond)) 131 | clswise['mean'] = tot / rawpredicts.shape[0] 132 | else: 133 | numclas = 'Too many to analyse' 134 | 135 | return { 136 | 'closeacc_withrej':closeacc_withrej, 137 | 'accuracy' : accuracy, 138 | 'macro_f1' : macro_f1, 139 | 'oscr' : oscr, 140 | 'weighted_f1' : weighted_f1, 141 | 'macro_recall' : macro_recall, 142 | 'weighted_recall' : weighted_recall, 143 | 'macro_precision' : macro_precision, 144 | 'weighted_precision' : weighted_precision, 145 | 'classwise_auc' : clswise 146 | } 147 | 148 | def openscore_distribution(self,scores,pred,savename): 149 | import matplotlib.pyplot as plt 150 | import matplotlib as mpl 151 | mpl.use('Agg') 152 | percentiles = np.array([0.25,99.75]) 153 | ptiles = np.percentile(scores, percentiles) 154 | 155 | bins = np.linspace(ptiles[0], ptiles[1], 80) 156 | plt.hist(scores[self.test_labels <= -1], bins=bins, facecolor="red", alpha=0.4,label='test unknown',density = True) 157 | 158 | plt.hist(scores[self.test_labels > -1], bins=bins, facecolor="green", alpha=0.4,label='test known',density = True) 159 | 160 | fs = 16 161 | plt.legend(fontsize = fs) 162 | plt.yticks([]) 163 | plt.xticks(fontsize = fs - 2) 164 | plt.xlabel('Score Value',fontsize = fs) 165 | plt.tight_layout() 166 | savename = savename.replace("/",'#') 167 | plt.savefig("./test_figs/"+savename+".jpg") 168 | plt.cla() 169 | 170 | def known_unknown_confusion(self,scores,pred): 171 | numcls = np.max(self.test_labels) + 1 172 | numukn = -np.min(self.test_labels) 173 | kvu_confusion = np.zeros([numcls,numukn]) 174 | upk_confusion = np.zeros([numukn,numcls]) 175 | for i in range(numcls): 176 | for j in range(numukn): 177 | uj = -j-1 178 | cond = (self.test_labels == i) | (self.test_labels == uj) 179 | sbsc = scores[cond] 180 | sbgt = self.test_labels[cond] >= 0 181 | fpr, tpr, thresholds = metrics.roc_curve(sbgt, sbsc) 182 | auroc = metrics.auc(fpr,tpr) 183 | kvu_confusion[i,j] = auroc 184 | 185 | for j in range(numukn): 186 | uj = -j-1 187 | cond = self.test_labels == uj 188 | p = pred[cond] 189 | for i in range(numcls): 190 | upk_confusion[j,i] = (p == i).sum() 191 | return kvu_confusion,upk_confusion 192 | 193 | def openset_recognition_curve(self,scores,pred): 194 | unk_scores = scores[self.test_labels < 0] 195 | kn_cond = self.test_labels >= 0 196 | kn_ct = kn_cond.sum() 197 | unk_ct = scores.shape[0] - kn_ct 198 | kn_scores = scores[kn_cond] 199 | kn_correct_pred = pred[kn_cond] == self.test_labels[kn_cond] 200 | 201 | def fpr(thr): 202 | return (unk_scores > thr).sum() / unk_ct 203 | 204 | def ccr(thr): 205 | ac_cond = (kn_scores > thr) & (kn_correct_pred) 206 | return ac_cond.sum() / kn_ct 207 | 208 | results = [] 209 | sorted_scores = -np.sort(-scores) 210 | for s in sorted_scores: 211 | results.append([fpr(s),ccr(s)]) 212 | return np.array(results) 213 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | 2 | for dataset in cifar10 svhn tinyimagenet 3 | do 4 | for s in a b c d e 5 | do 6 | python3 main.py --gpu 0 --ds ./exps/$dataset/spl_$s.json --config ./configs/linear/$dataset.json --save linear_$dataset_$s --method cssr --test_interval 2 7 | python3 main.py --gpu 0 --ds ./exps/$dataset/spl_$s.json --config ./configs/pcssr/$dataset.json --save pcssr_$dataset_$s --method cssr --test_interval 2 8 | python3 main.py --gpu 0 --ds ./exps/$dataset/spl_$s.json --config ./configs/rcssr/$dataset.json --save rcssr_$dataset_$s --method cssr --test_interval 2 9 | done 10 | done 11 | 12 | # imagenet 13 | # python3 main.py --gpu 0 --ds ./exps/imagenet/vs_inaturalist.json --config ./configs/rcssr/imagenet.json --save imagenet1k_rcssr --method cssr_ft --------------------------------------------------------------------------------