├── LICENSE ├── README.md ├── assets └── method.jpg ├── config.py ├── data ├── augmentations │ ├── __init__.py │ ├── cut_out.py │ └── randaugment.py ├── cifar.py ├── corrupt_data.py ├── cub.py ├── data_utils.py ├── fgvc_aircraft.py ├── get_datasets.py ├── herbarium_19.py ├── imagenet.py └── stanford_cars.py ├── models ├── __init__.py ├── loss.py ├── model.py └── vision_transformer.py ├── my_utils ├── __init__.py ├── cluster_and_log_utils.py ├── general_utils.py └── ood_utils.py ├── test_ood_cifar.py ├── test_ood_imagenet.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Shijie Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery 2 | 3 | 4 | 5 | Official implementation of our TPAMI 2025 paper "ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery". 6 | 7 | ![method](assets/method.jpg) 8 | 9 | ## :running: ​Running 10 | 11 | ### Dependencies 12 | 13 | ``` 14 | loguru 15 | numpy 16 | pandas 17 | scikit_learn 18 | scipy 19 | torch==1.10.0 20 | torchvision==0.11.1 21 | tqdm 22 | ``` 23 | 24 | ### Datasets 25 | 26 | We conduct experiments on 7 datasets: 27 | 28 | * Generic datasets: CIFAR-10, CIFAR-100, ImageNet-100 29 | * Fine-grained datasets: [CUB](https://drive.google.com/drive/folders/1kFzIqZL_pEBVR7Ca_8IKibfWoeZc3GT1), [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/), [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) 30 | 31 | ### Config 32 | 33 | Set paths to datasets in `config.py` 34 | 35 | ### Training ProtoGCD 36 | 37 | CIFAR100: 38 | 39 | ```shell 40 | CUDA_VISIBLE_DEVICES=0 python train_fix.py --dataset_name 'cifar100' --batch_size 128 --epochs 200 --num_workers 4 --use_ssb_splits --weight_decay 5e-5 --lr 0.1 --eval_funcs 'v2' --weight_sup 0.35 --weight_entropy_reg 2 --weight_proto_sep 0.1 --temp_logits 0.1 --temp_teacher_logits 0.05 --wait_ratio_epochs 0 --ramp_ratio_teacher_epochs 100 --init_ratio 0.0 --final_ratio 1.0 --exp_name cifar100_protogcd 41 | ``` 42 | 43 | CUB: 44 | 45 | ```shell 46 | CUDA_VISIBLE_DEVICES=0 python train_fix.py --dataset_name 'cub' --batch_size 128 --epochs 200 --num_workers 2 --use_ssb_splits --weight_decay 5e-5 --lr 0.1 --eval_funcs 'v2' --weight_sup 0.35 --weight_entropy_reg 2 --weight_proto_sep 0.05 --temp_logits 0.1 --temp_teacher_logits 0.05 --wait_ratio_epochs 0 --ramp_ratio_teacher_epochs 100 --init_ratio 0.0 --final_ratio 1.0 --exp_name cub_protogcd 47 | ``` 48 | 49 | ### Evaluate OOD detection 50 | 51 | CIFAR: 52 | 53 | ```shell 54 | CUDA_VISIBLE_DEVICES=0 python test_ood_cifar.py --dataset_name 'cifar100' --batch_size 128 --num_workers 4 --use_ssb_splits --num_to_avg 10 --score msp --ckpts_date YOUR_CKPTS_NAME --temp_logits 0.1 55 | ``` 56 | 57 | ImageNet: 58 | 59 | ```shell 60 | CUDA_VISIBLE_DEVICES=0 python test_ood_imagenet.py --dataset_name 'imagenet_100' --batch_size 128 --num_workers 4 --use_ssb_splits --num_to_avg 10 --score msp --ckpts_date YOUR_CKPTS_NAME --temp_logits 0.1 61 | ``` 62 | 63 | 64 | 65 | ## :clipboard: ​Citing this work 66 | 67 | ```bibtex 68 | @ARTICLE{10948388, 69 | author={Ma, Shijie and Zhu, Fei and Zhang, Xu-Yao and Liu, Cheng-Lin}, 70 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 71 | title={ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery}, 72 | year={2025}, 73 | volume={}, 74 | number={}, 75 | pages={1-17}, 76 | keywords={Prototypes;Adaptation models;Contrastive learning;Training;Magnetic heads;Feature extraction;Estimation;Automobiles;Accuracy;Pragmatics;Generalized category discovery;open-world learning;prototype learning;semi-supervised learning}, 77 | doi={10.1109/TPAMI.2025.3557502} 78 | } 79 | ``` 80 | 81 | 82 | 83 | ## :gift: ​Acknowledgements 84 | 85 | In building the ProtoGCD codebase, we reference [SimGCD](https://github.com/CVMI-Lab/SimGCD). 86 | 87 | 88 | 89 | ## :white_check_mark: ​License 90 | 91 | This project is licensed under the MIT License - see the [LICENSE](https://github.com/mashijie1028/ProtoGCD/blob/main/LICENSE) file for details. 92 | 93 | 94 | 95 | ## :email: ​Contact 96 | 97 | If you have further questions or discussions, feel free to contact me: 98 | 99 | Shijie Ma (mashijie2021@ia.ac.cn) -------------------------------------------------------------------------------- /assets/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/assets/method.jpg -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ----------------- 2 | # DATASET ROOTS 3 | # ----------------- 4 | cifar_10_root = '/data4/sjma/dataset/CIFAR/' 5 | cifar_100_root = '/data4/sjma/dataset/CIFAR/' 6 | cub_root = '/data4/sjma/dataset/CUB/' 7 | aircraft_root = '/data4/sjma/dataset/FGVC-Aircraft/fgvc-aircraft-2013b/' 8 | car_root = "/data4/sjma/dataset/Stanford-Cars/" 9 | herbarium_dataroot = '/data4/sjma/dataset/Herbarium19-Small/' 10 | #imagenet_root = '/lustre/datasharing/sjma/ImageNet/ILSVRC2012/imagenet/' 11 | imagenet_root = '/data4/sjma/dataset/ImageNet/ILSVRC2012/imagenet/' 12 | 13 | # OSR Split dir 14 | osr_split_dir = '/data4/sjma/dataset/ssb_splits/' 15 | 16 | # ----------------- 17 | # OTHER PATHS 18 | # ----------------- 19 | exp_root = 'dev_outputs' # All logs and checkpoints will be saved here 20 | -------------------------------------------------------------------------------- /data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from data.augmentations.cut_out import * 3 | from data.augmentations.randaugment import RandAugment 4 | 5 | def get_transform(transform_type='imagenet', image_size=32, args=None): 6 | 7 | if transform_type == 'imagenet': 8 | 9 | mean = (0.485, 0.456, 0.406) 10 | std = (0.229, 0.224, 0.225) 11 | interpolation = args.interpolation 12 | crop_pct = args.crop_pct 13 | 14 | train_transform = transforms.Compose([ 15 | transforms.Resize(int(image_size / crop_pct), interpolation), 16 | transforms.RandomCrop(image_size), 17 | transforms.RandomHorizontalFlip(p=0.5), 18 | transforms.ColorJitter(), 19 | transforms.ToTensor(), 20 | transforms.Normalize( 21 | mean=torch.tensor(mean), 22 | std=torch.tensor(std)) 23 | ]) 24 | 25 | test_transform = transforms.Compose([ 26 | transforms.Resize(int(image_size / crop_pct), interpolation), 27 | transforms.CenterCrop(image_size), 28 | transforms.ToTensor(), 29 | transforms.Normalize( 30 | mean=torch.tensor(mean), 31 | std=torch.tensor(std)) 32 | ]) 33 | 34 | elif transform_type == 'pytorch-cifar': 35 | 36 | mean = (0.4914, 0.4822, 0.4465) 37 | std = (0.2023, 0.1994, 0.2010) 38 | 39 | train_transform = transforms.Compose([ 40 | transforms.RandomCrop(image_size, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=mean, std=std), 44 | ]) 45 | 46 | test_transform = transforms.Compose([ 47 | transforms.Resize((image_size, image_size)), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=mean, std=std), 50 | ]) 51 | 52 | elif transform_type == 'herbarium_default': 53 | 54 | train_transform = transforms.Compose([ 55 | transforms.Resize((image_size, image_size)), 56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | ]) 60 | 61 | test_transform = transforms.Compose([ 62 | transforms.Resize((image_size, image_size)), 63 | transforms.ToTensor(), 64 | ]) 65 | 66 | elif transform_type == 'cutout': 67 | 68 | mean = np.array([0.4914, 0.4822, 0.4465]) 69 | std = np.array([0.2470, 0.2435, 0.2616]) 70 | 71 | train_transform = transforms.Compose([ 72 | transforms.RandomCrop(image_size, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | normalize(mean, std), 75 | cutout(mask_size=int(image_size / 2), 76 | p=1, 77 | cutout_inside=False), 78 | to_tensor(), 79 | ]) 80 | test_transform = transforms.Compose([ 81 | transforms.Resize((image_size, image_size)), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean, std), 84 | ]) 85 | 86 | elif transform_type == 'rand-augment': 87 | 88 | mean = (0.485, 0.456, 0.406) 89 | std = (0.229, 0.224, 0.225) 90 | 91 | train_transform = transforms.Compose([ 92 | transforms.Resize((image_size, image_size)), 93 | transforms.RandomCrop(image_size, padding=4), 94 | transforms.RandomHorizontalFlip(), 95 | transforms.ToTensor(), 96 | transforms.Normalize(mean=mean, std=std), 97 | ]) 98 | 99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None)) 100 | 101 | test_transform = transforms.Compose([ 102 | transforms.Resize((image_size, image_size)), 103 | transforms.ToTensor(), 104 | transforms.Normalize(mean=mean, std=std), 105 | ]) 106 | 107 | elif transform_type == 'random_affine': 108 | 109 | mean = (0.485, 0.456, 0.406) 110 | std = (0.229, 0.224, 0.225) 111 | interpolation = args.interpolation 112 | crop_pct = args.crop_pct 113 | 114 | train_transform = transforms.Compose([ 115 | transforms.Resize((image_size, image_size), interpolation), 116 | transforms.RandomAffine(degrees=(-45, 45), 117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)), 118 | transforms.ColorJitter(), 119 | transforms.ToTensor(), 120 | transforms.Normalize( 121 | mean=torch.tensor(mean), 122 | std=torch.tensor(std)) 123 | ]) 124 | 125 | test_transform = transforms.Compose([ 126 | transforms.Resize(int(image_size / crop_pct), interpolation), 127 | transforms.CenterCrop(image_size), 128 | transforms.ToTensor(), 129 | transforms.Normalize( 130 | mean=torch.tensor(mean), 131 | std=torch.tensor(std)) 132 | ]) 133 | 134 | else: 135 | 136 | raise NotImplementedError 137 | 138 | return (train_transform, test_transform) -------------------------------------------------------------------------------- /data/augmentations/cut_out.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/hysts/pytorch_cutout 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)): 9 | mask_size_half = mask_size // 2 10 | offset = 1 if mask_size % 2 == 0 else 0 11 | 12 | def _cutout(image): 13 | image = np.asarray(image).copy() 14 | 15 | if np.random.random() > p: 16 | return image 17 | 18 | h, w = image.shape[:2] 19 | 20 | if cutout_inside: 21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half 22 | cymin, cymax = mask_size_half, h + offset - mask_size_half 23 | else: 24 | cxmin, cxmax = 0, w + offset 25 | cymin, cymax = 0, h + offset 26 | 27 | cx = np.random.randint(cxmin, cxmax) 28 | cy = np.random.randint(cymin, cymax) 29 | xmin = cx - mask_size_half 30 | ymin = cy - mask_size_half 31 | xmax = xmin + mask_size 32 | ymax = ymin + mask_size 33 | xmin = max(0, xmin) 34 | ymin = max(0, ymin) 35 | xmax = min(w, xmax) 36 | ymax = min(h, ymax) 37 | image[ymin:ymax, xmin:xmax] = mask_color 38 | return image 39 | 40 | return _cutout 41 | 42 | def to_tensor(): 43 | def _to_tensor(image): 44 | if len(image.shape) == 3: 45 | return torch.from_numpy( 46 | image.transpose(2, 0, 1).astype(float)) 47 | else: 48 | return torch.from_numpy(image[None, :, :].astype(float)) 49 | 50 | return _to_tensor 51 | 52 | def normalize(mean, std): 53 | 54 | mean = np.array(mean) 55 | std = np.array(std) 56 | 57 | def _normalize(image): 58 | image = np.asarray(image).astype(float) / 255. 59 | image = (image - mean) / std 60 | return image 61 | 62 | return _normalize 63 | -------------------------------------------------------------------------------- /data/augmentations/randaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | """ 4 | 5 | import random 6 | 7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | 13 | def ShearX(img, v): # [-0.3, 0.3] 14 | assert -0.3 <= v <= 0.3 15 | if random.random() > 0.5: 16 | v = -v 17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 18 | 19 | 20 | def ShearY(img, v): # [-0.3, 0.3] 21 | assert -0.3 <= v <= 0.3 22 | if random.random() > 0.5: 23 | v = -v 24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 25 | 26 | 27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 28 | assert -0.45 <= v <= 0.45 29 | if random.random() > 0.5: 30 | v = -v 31 | v = v * img.size[0] 32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 33 | 34 | 35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 36 | assert 0 <= v 37 | if random.random() > 0.5: 38 | v = -v 39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 40 | 41 | 42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 43 | assert -0.45 <= v <= 0.45 44 | if random.random() > 0.5: 45 | v = -v 46 | v = v * img.size[1] 47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 48 | 49 | 50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 51 | assert 0 <= v 52 | if random.random() > 0.5: 53 | v = -v 54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def Rotate(img, v): # [-30, 30] 58 | assert -30 <= v <= 30 59 | if random.random() > 0.5: 60 | v = -v 61 | return img.rotate(v) 62 | 63 | 64 | def AutoContrast(img, _): 65 | return PIL.ImageOps.autocontrast(img) 66 | 67 | 68 | def Invert(img, _): 69 | return PIL.ImageOps.invert(img) 70 | 71 | 72 | def Equalize(img, _): 73 | return PIL.ImageOps.equalize(img) 74 | 75 | 76 | def Flip(img, _): # not from the paper 77 | return PIL.ImageOps.mirror(img) 78 | 79 | 80 | def Solarize(img, v): # [0, 256] 81 | assert 0 <= v <= 256 82 | return PIL.ImageOps.solarize(img, v) 83 | 84 | 85 | def SolarizeAdd(img, addition=0, threshold=128): 86 | img_np = np.array(img).astype(np.int) 87 | img_np = img_np + addition 88 | img_np = np.clip(img_np, 0, 255) 89 | img_np = img_np.astype(np.uint8) 90 | img = Image.fromarray(img_np) 91 | return PIL.ImageOps.solarize(img, threshold) 92 | 93 | 94 | def Posterize(img, v): # [4, 8] 95 | v = int(v) 96 | v = max(1, v) 97 | return PIL.ImageOps.posterize(img, v) 98 | 99 | 100 | def Contrast(img, v): # [0.1,1.9] 101 | assert 0.1 <= v <= 1.9 102 | return PIL.ImageEnhance.Contrast(img).enhance(v) 103 | 104 | 105 | def Color(img, v): # [0.1,1.9] 106 | assert 0.1 <= v <= 1.9 107 | return PIL.ImageEnhance.Color(img).enhance(v) 108 | 109 | 110 | def Brightness(img, v): # [0.1,1.9] 111 | assert 0.1 <= v <= 1.9 112 | return PIL.ImageEnhance.Brightness(img).enhance(v) 113 | 114 | 115 | def Sharpness(img, v): # [0.1,1.9] 116 | assert 0.1 <= v <= 1.9 117 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 118 | 119 | 120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 121 | assert 0.0 <= v <= 0.2 122 | if v <= 0.: 123 | return img 124 | 125 | v = v * img.size[0] 126 | return CutoutAbs(img, v) 127 | 128 | 129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 130 | # assert 0 <= v <= 20 131 | if v < 0: 132 | return img 133 | w, h = img.size 134 | x0 = np.random.uniform(w) 135 | y0 = np.random.uniform(h) 136 | 137 | x0 = int(max(0, x0 - v / 2.)) 138 | y0 = int(max(0, y0 - v / 2.)) 139 | x1 = min(w, x0 + v) 140 | y1 = min(h, y0 + v) 141 | 142 | xy = (x0, y0, x1, y1) 143 | color = (125, 123, 114) 144 | # color = (0, 0, 0) 145 | img = img.copy() 146 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 147 | return img 148 | 149 | 150 | def SamplePairing(imgs): # [0, 0.4] 151 | def f(img1, v): 152 | i = np.random.choice(len(imgs)) 153 | img2 = PIL.Image.fromarray(imgs[i]) 154 | return PIL.Image.blend(img1, img2, v) 155 | 156 | return f 157 | 158 | 159 | def Identity(img, v): 160 | return img 161 | 162 | 163 | def augment_list(): # 16 oeprations and their ranges 164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 165 | # l = [ 166 | # (Identity, 0., 1.0), 167 | # (ShearX, 0., 0.3), # 0 168 | # (ShearY, 0., 0.3), # 1 169 | # (TranslateX, 0., 0.33), # 2 170 | # (TranslateY, 0., 0.33), # 3 171 | # (Rotate, 0, 30), # 4 172 | # (AutoContrast, 0, 1), # 5 173 | # (Invert, 0, 1), # 6 174 | # (Equalize, 0, 1), # 7 175 | # (Solarize, 0, 110), # 8 176 | # (Posterize, 4, 8), # 9 177 | # # (Contrast, 0.1, 1.9), # 10 178 | # (Color, 0.1, 1.9), # 11 179 | # (Brightness, 0.1, 1.9), # 12 180 | # (Sharpness, 0.1, 1.9), # 13 181 | # # (Cutout, 0, 0.2), # 14 182 | # # (SamplePairing(imgs), 0, 0.4), # 15 183 | # ] 184 | 185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 186 | l = [ 187 | (AutoContrast, 0, 1), 188 | (Equalize, 0, 1), 189 | (Invert, 0, 1), 190 | (Rotate, 0, 30), 191 | (Posterize, 0, 4), 192 | (Solarize, 0, 256), 193 | (SolarizeAdd, 0, 110), 194 | (Color, 0.1, 1.9), 195 | (Contrast, 0.1, 1.9), 196 | (Brightness, 0.1, 1.9), 197 | (Sharpness, 0.1, 1.9), 198 | (ShearX, 0., 0.3), 199 | (ShearY, 0., 0.3), 200 | (CutoutAbs, 0, 40), 201 | (TranslateXabs, 0., 100), 202 | (TranslateYabs, 0., 100), 203 | ] 204 | 205 | return l 206 | 207 | def augment_list_svhn(): # 16 oeprations and their ranges 208 | 209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 210 | l = [ 211 | (AutoContrast, 0, 1), 212 | (Equalize, 0, 1), 213 | (Invert, 0, 1), 214 | (Posterize, 0, 4), 215 | (Solarize, 0, 256), 216 | (SolarizeAdd, 0, 110), 217 | (Color, 0.1, 1.9), 218 | (Contrast, 0.1, 1.9), 219 | (Brightness, 0.1, 1.9), 220 | (Sharpness, 0.1, 1.9), 221 | (ShearX, 0., 0.3), 222 | (ShearY, 0., 0.3), 223 | (CutoutAbs, 0, 40), 224 | ] 225 | 226 | return l 227 | 228 | 229 | class Lighting(object): 230 | """Lighting noise(AlexNet - style PCA - based noise)""" 231 | 232 | def __init__(self, alphastd, eigval, eigvec): 233 | self.alphastd = alphastd 234 | self.eigval = torch.Tensor(eigval) 235 | self.eigvec = torch.Tensor(eigvec) 236 | 237 | def __call__(self, img): 238 | if self.alphastd == 0: 239 | return img 240 | 241 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 242 | rgb = self.eigvec.type_as(img).clone() \ 243 | .mul(alpha.view(1, 3).expand(3, 3)) \ 244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 245 | .sum(1).squeeze() 246 | 247 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 248 | 249 | 250 | class CutoutDefault(object): 251 | """ 252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 253 | """ 254 | def __init__(self, length): 255 | self.length = length 256 | 257 | def __call__(self, img): 258 | h, w = img.size(1), img.size(2) 259 | mask = np.ones((h, w), float) 260 | y = np.random.randint(h) 261 | x = np.random.randint(w) 262 | 263 | y1 = np.clip(y - self.length // 2, 0, h) 264 | y2 = np.clip(y + self.length // 2, 0, h) 265 | x1 = np.clip(x - self.length // 2, 0, w) 266 | x2 = np.clip(x + self.length // 2, 0, w) 267 | 268 | mask[y1: y2, x1: x2] = 0. 269 | mask = torch.from_numpy(mask) 270 | mask = mask.expand_as(img) 271 | img *= mask 272 | return img 273 | 274 | 275 | class RandAugment: 276 | def __init__(self, n, m, args=None): 277 | self.n = n # [1, 2] 278 | self.m = m # [0...30] 279 | 280 | if args is None: 281 | self.augment_list = augment_list() 282 | 283 | elif args.dataset == 'svhn' or args.dataset == 'mnist': 284 | self.augment_list = augment_list_svhn() 285 | 286 | else: 287 | self.augment_list = augment_list() 288 | 289 | def __call__(self, img): 290 | ops = random.choices(self.augment_list, k=self.n) 291 | for op, minval, maxval in ops: 292 | val = (float(self.m) / 30) * float(maxval - minval) + minval 293 | img = op(img, val) 294 | 295 | return img 296 | -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10, CIFAR100 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | from data.data_utils import subsample_instances 6 | from config import cifar_10_root, cifar_100_root 7 | 8 | 9 | class CustomCIFAR10(CIFAR10): 10 | 11 | def __init__(self, *args, **kwargs): 12 | 13 | super(CustomCIFAR10, self).__init__(*args, **kwargs) 14 | 15 | self.uq_idxs = np.array(range(len(self))) 16 | 17 | def __getitem__(self, item): 18 | 19 | img, label = super().__getitem__(item) 20 | uq_idx = self.uq_idxs[item] 21 | 22 | return img, label, uq_idx 23 | 24 | def __len__(self): 25 | return len(self.targets) 26 | 27 | 28 | class CustomCIFAR100(CIFAR100): 29 | 30 | def __init__(self, *args, **kwargs): 31 | super(CustomCIFAR100, self).__init__(*args, **kwargs) 32 | 33 | self.uq_idxs = np.array(range(len(self))) 34 | 35 | def __getitem__(self, item): 36 | img, label = super().__getitem__(item) 37 | uq_idx = self.uq_idxs[item] 38 | 39 | return img, label, uq_idx 40 | 41 | def __len__(self): 42 | return len(self.targets) 43 | 44 | 45 | def subsample_dataset(dataset, idxs): 46 | 47 | # Allow for setting in which all empty set of indices is passed 48 | 49 | if len(idxs) > 0: 50 | 51 | dataset.data = dataset.data[idxs] 52 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 53 | dataset.uq_idxs = dataset.uq_idxs[idxs] 54 | 55 | return dataset 56 | 57 | else: 58 | 59 | return None 60 | 61 | 62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): 63 | 64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 65 | 66 | target_xform_dict = {} 67 | for i, k in enumerate(include_classes): 68 | target_xform_dict[k] = i 69 | 70 | dataset = subsample_dataset(dataset, cls_idxs) 71 | 72 | # dataset.target_transform = lambda x: target_xform_dict[x] 73 | 74 | return dataset 75 | 76 | 77 | def get_train_val_indices(train_dataset, val_split=0.2): 78 | 79 | train_classes = np.unique(train_dataset.targets) 80 | 81 | # Get train/test indices 82 | train_idxs = [] 83 | val_idxs = [] 84 | for cls in train_classes: 85 | 86 | cls_idxs = np.where(train_dataset.targets == cls)[0] 87 | 88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 89 | t_ = [x for x in cls_idxs if x not in v_] 90 | 91 | train_idxs.extend(t_) 92 | val_idxs.extend(v_) 93 | 94 | return train_idxs, val_idxs 95 | 96 | 97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), 98 | prop_train_labels=0.8, split_train_val=False, seed=0): 99 | 100 | np.random.seed(seed) 101 | 102 | # Init entire training set 103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) 104 | 105 | # Get labelled training set which has subsampled classes, then subsample some indices from that 106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 109 | 110 | # Split into training and validation sets 111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 114 | val_dataset_labelled_split.transform = test_transform 115 | 116 | # Get unlabelled data 117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 119 | 120 | # Get test set for all classes 121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) 122 | 123 | # Either split train into train and val or use test set as val 124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 126 | 127 | all_datasets = { 128 | 'train_labelled': train_dataset_labelled, 129 | 'train_unlabelled': train_dataset_unlabelled, 130 | 'val': val_dataset_labelled, 131 | 'test': test_dataset, 132 | } 133 | 134 | return all_datasets 135 | 136 | 137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), 138 | prop_train_labels=0.8, split_train_val=False, seed=0): 139 | 140 | np.random.seed(seed) 141 | 142 | # Init entire training set 143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) 144 | 145 | # Get labelled training set which has subsampled classes, then subsample some indices from that 146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 149 | 150 | # Split into training and validation sets 151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 154 | val_dataset_labelled_split.transform = test_transform 155 | 156 | # Get unlabelled data 157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 159 | 160 | # Get test set for all classes 161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) 162 | 163 | # Either split train into train and val or use test set as val 164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 166 | 167 | all_datasets = { 168 | 'train_labelled': train_dataset_labelled, 169 | 'train_unlabelled': train_dataset_unlabelled, 170 | 'val': val_dataset_labelled, 171 | 'test': test_dataset, 172 | } 173 | 174 | return all_datasets 175 | 176 | 177 | if __name__ == '__main__': 178 | 179 | x = get_cifar_100_datasets(None, None, split_train_val=False, 180 | train_classes=range(80), prop_train_labels=0.5) 181 | 182 | print('Printing lens...') 183 | for k, v in x.items(): 184 | if v is not None: 185 | print(f'{k}: {len(v)}') 186 | 187 | print('Printing labelled and unlabelled overlap...') 188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 189 | print('Printing total instances in train...') 190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 191 | 192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 194 | print(f'Len labelled set: {len(x["train_labelled"])}') 195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/corrupt_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | import torchvision.transforms.functional as tfunc 8 | 9 | ROOT_DIR_10 = "/data4/sjma/dataset/Corruptions/CIFAR-10-C/" 10 | ROOT_DIR_100 = "/data4/sjma/dataset/Corruptions/CIFAR-100-C/" 11 | 12 | 13 | class DatasetFromTorchTensor(Dataset): 14 | def __init__(self, data, target, transform=None): 15 | # Data type handling must be done beforehand. It is too difficult at this point. 16 | self.data = data 17 | self.target = target 18 | if len(self.target.shape)==1: 19 | self.target = target.long() 20 | self.transform = transform 21 | 22 | def __getitem__(self, index): 23 | x = self.data[index] 24 | y = self.target[index] 25 | if self.transform: 26 | x = tfunc.to_pil_image(x) 27 | x = self.transform(x) 28 | return x, y 29 | 30 | def __len__(self): 31 | return len(self.data) 32 | 33 | 34 | def get_data(data_name, dataset, test_transform=None, severity=1): 35 | if data_name == 'cifar10': 36 | ROOT_DIR = ROOT_DIR_10 37 | if data_name == 'cifar100': 38 | ROOT_DIR = ROOT_DIR_100 39 | data_path = os.path.join(ROOT_DIR, dataset+'.npy') 40 | label_path = os.path.join(ROOT_DIR, 'labels.npy') 41 | data = torch.tensor(np.transpose(np.load(data_path), (0,3,1,2))) 42 | labels = torch.tensor(np.load(label_path)) 43 | start = 10000 * (severity - 1) 44 | 45 | data = data[start:start+10000] 46 | labels = labels[start:start+10000] 47 | test_data = DatasetFromTorchTensor(data, labels, transform=test_transform) 48 | 49 | return test_data 50 | 51 | 52 | 53 | if __name__ =='__main__': 54 | train, test = get_data('snow') 55 | print(len(train), len(test)) 56 | -------------------------------------------------------------------------------- /data/cub.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.datasets.utils import download_url 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import cub_root 12 | 13 | 14 | class CustomCub2011(Dataset): 15 | base_folder = 'CUB_200_2011/images' 16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' 17 | filename = 'CUB_200_2011.tgz' 18 | tgz_md5 = '97eceeb196236b17998738112f37df78' 19 | 20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=False): 21 | 22 | self.root = os.path.expanduser(root) 23 | self.transform = transform 24 | self.target_transform = target_transform 25 | 26 | self.loader = loader 27 | self.train = train 28 | 29 | 30 | if download: 31 | self._download() 32 | 33 | if not self._check_integrity(): 34 | raise RuntimeError('Dataset not found or corrupted.' + 35 | ' You can use download=True to download it') 36 | 37 | self.uq_idxs = np.array(range(len(self))) 38 | 39 | def _load_metadata(self): 40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', 41 | names=['img_id', 'filepath']) 42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), 43 | sep=' ', names=['img_id', 'target']) 44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), 45 | sep=' ', names=['img_id', 'is_training_img']) 46 | 47 | data = images.merge(image_class_labels, on='img_id') 48 | self.data = data.merge(train_test_split, on='img_id') 49 | 50 | if self.train: 51 | self.data = self.data[self.data.is_training_img == 1] 52 | else: 53 | self.data = self.data[self.data.is_training_img == 0] 54 | 55 | def _check_integrity(self): 56 | try: 57 | self._load_metadata() 58 | except Exception: 59 | return False 60 | 61 | for index, row in self.data.iterrows(): 62 | filepath = os.path.join(self.root, self.base_folder, row.filepath) 63 | if not os.path.isfile(filepath): 64 | print(filepath) 65 | return False 66 | return True 67 | 68 | def _download(self): 69 | import tarfile 70 | 71 | if self._check_integrity(): 72 | print('Files already downloaded and verified') 73 | return 74 | 75 | download_url(self.url, self.root, self.filename, self.tgz_md5) 76 | 77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 78 | tar.extractall(path=self.root) 79 | 80 | def __len__(self): 81 | return len(self.data) 82 | 83 | def __getitem__(self, idx): 84 | sample = self.data.iloc[idx] 85 | path = os.path.join(self.root, self.base_folder, sample.filepath) 86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0 87 | img = self.loader(path) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return img, target, self.uq_idxs[idx] 96 | 97 | 98 | def subsample_dataset(dataset, idxs): 99 | 100 | mask = np.zeros(len(dataset)).astype('bool') 101 | mask[idxs] = True 102 | 103 | dataset.data = dataset.data[mask] 104 | dataset.uq_idxs = dataset.uq_idxs[mask] 105 | 106 | return dataset 107 | 108 | 109 | def subsample_classes(dataset, include_classes=range(160)): 110 | 111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199 112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub] 113 | 114 | # TODO: For now have no target transform 115 | target_xform_dict = {} 116 | for i, k in enumerate(include_classes): 117 | target_xform_dict[k] = i 118 | 119 | dataset = subsample_dataset(dataset, cls_idxs) 120 | 121 | dataset.target_transform = lambda x: target_xform_dict[x] 122 | 123 | return dataset 124 | 125 | 126 | def get_train_val_indices(train_dataset, val_split=0.2): 127 | 128 | train_classes = np.unique(train_dataset.data['target']) 129 | 130 | # Get train/test indices 131 | train_idxs = [] 132 | val_idxs = [] 133 | for cls in train_classes: 134 | 135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0] 136 | 137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 138 | t_ = [x for x in cls_idxs if x not in v_] 139 | 140 | train_idxs.extend(t_) 141 | val_idxs.extend(v_) 142 | 143 | return train_idxs, val_idxs 144 | 145 | 146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 147 | split_train_val=False, seed=0, download=False): 148 | 149 | np.random.seed(seed) 150 | 151 | # Init entire training set 152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download) 153 | 154 | # Get labelled training set which has subsampled classes, then subsample some indices from that 155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 158 | 159 | # Split into training and validation sets 160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 163 | val_dataset_labelled_split.transform = test_transform 164 | 165 | # Get unlabelled data 166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 168 | 169 | # Get test set for all classes 170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False) 171 | 172 | # Either split train into train and val or use test set as val 173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 175 | 176 | all_datasets = { 177 | 'train_labelled': train_dataset_labelled, 178 | 'train_unlabelled': train_dataset_unlabelled, 179 | 'val': val_dataset_labelled, 180 | 'test': test_dataset, 181 | } 182 | 183 | return all_datasets 184 | 185 | if __name__ == '__main__': 186 | 187 | x = get_cub_datasets(None, None, split_train_val=False, 188 | train_classes=range(100), prop_train_labels=0.5) 189 | 190 | print('Printing lens...') 191 | for k, v in x.items(): 192 | if v is not None: 193 | print(f'{k}: {len(v)}') 194 | 195 | print('Printing labelled and unlabelled overlap...') 196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 197 | print('Printing total instances in train...') 198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 199 | 200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}') 201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}') 202 | print(f'Len labelled set: {len(x["train_labelled"])}') 203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8): 5 | 6 | np.random.seed(0) 7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False, 8 | size=(int(prop_indices_to_subsample * len(dataset)),)) 9 | 10 | return subsample_indices 11 | 12 | 13 | class MergedDataset(Dataset): 14 | 15 | """ 16 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them 17 | Allows you to iterate over them in parallel 18 | """ 19 | 20 | def __init__(self, labelled_dataset, unlabelled_dataset): 21 | 22 | self.labelled_dataset = labelled_dataset 23 | self.unlabelled_dataset = unlabelled_dataset 24 | self.target_transform = None 25 | 26 | def __getitem__(self, item): 27 | 28 | if item < len(self.labelled_dataset): 29 | img, label, uq_idx = self.labelled_dataset[item] 30 | labeled_or_not = 1 31 | 32 | else: 33 | 34 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)] 35 | labeled_or_not = 0 36 | 37 | 38 | return img, label, uq_idx, np.array([labeled_or_not]) 39 | 40 | def __len__(self): 41 | return len(self.unlabelled_dataset) + len(self.labelled_dataset) 42 | -------------------------------------------------------------------------------- /data/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | from torchvision.datasets.folder import default_loader 6 | from torch.utils.data import Dataset 7 | 8 | from data.data_utils import subsample_instances 9 | from config import aircraft_root 10 | 11 | def make_dataset(dir, image_ids, targets): 12 | assert(len(image_ids) == len(targets)) 13 | images = [] 14 | dir = os.path.expanduser(dir) 15 | for i in range(len(image_ids)): 16 | item = (os.path.join(dir, 'data', 'images', 17 | '%s.jpg' % image_ids[i]), targets[i]) 18 | images.append(item) 19 | return images 20 | 21 | 22 | def find_classes(classes_file): 23 | 24 | # read classes file, separating out image IDs and class names 25 | image_ids = [] 26 | targets = [] 27 | f = open(classes_file, 'r') 28 | for line in f: 29 | split_line = line.split(' ') 30 | image_ids.append(split_line[0]) 31 | targets.append(' '.join(split_line[1:])) 32 | f.close() 33 | 34 | # index class names 35 | classes = np.unique(targets) 36 | class_to_idx = {classes[i]: i for i in range(len(classes))} 37 | targets = [class_to_idx[c] for c in targets] 38 | 39 | return (image_ids, targets, classes, class_to_idx) 40 | 41 | 42 | class FGVCAircraft(Dataset): 43 | 44 | """`FGVC-Aircraft `_ Dataset. 45 | 46 | Args: 47 | root (string): Root directory path to dataset. 48 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification 49 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). 50 | transform (callable, optional): A function/transform that takes in a PIL image 51 | and returns a transformed version. E.g. ``transforms.RandomCrop`` 52 | target_transform (callable, optional): A function/transform that takes in the 53 | target and transforms it. 54 | loader (callable, optional): A function to load an image given its path. 55 | download (bool, optional): If true, downloads the dataset from the internet and 56 | puts it in the root directory. If dataset is already downloaded, it is not 57 | downloaded again. 58 | """ 59 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' 60 | class_types = ('variant', 'family', 'manufacturer') 61 | splits = ('train', 'val', 'trainval', 'test') 62 | 63 | def __init__(self, root, class_type='variant', split='train', transform=None, 64 | target_transform=None, loader=default_loader, download=False): 65 | if split not in self.splits: 66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 67 | split, ', '.join(self.splits), 68 | )) 69 | if class_type not in self.class_types: 70 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( 71 | class_type, ', '.join(self.class_types), 72 | )) 73 | self.root = os.path.expanduser(root) 74 | self.class_type = class_type 75 | self.split = split 76 | self.classes_file = os.path.join(self.root, 'data', 77 | 'images_%s_%s.txt' % (self.class_type, self.split)) 78 | 79 | if download: 80 | self.download() 81 | 82 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) 83 | samples = make_dataset(self.root, image_ids, targets) 84 | 85 | self.transform = transform 86 | self.target_transform = target_transform 87 | self.loader = loader 88 | 89 | self.samples = samples 90 | self.classes = classes 91 | self.class_to_idx = class_to_idx 92 | self.train = True if split == 'train' else False 93 | 94 | self.uq_idxs = np.array(range(len(self))) 95 | 96 | def __getitem__(self, index): 97 | """ 98 | Args: 99 | index (int): Index 100 | 101 | Returns: 102 | tuple: (sample, target) where target is class_index of the target class. 103 | """ 104 | 105 | path, target = self.samples[index] 106 | sample = self.loader(path) 107 | if self.transform is not None: 108 | sample = self.transform(sample) 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | 112 | return sample, target, self.uq_idxs[index] 113 | 114 | def __len__(self): 115 | return len(self.samples) 116 | 117 | def __repr__(self): 118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 120 | fmt_str += ' Root Location: {}\n'.format(self.root) 121 | tmp = ' Transforms (if any): ' 122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 123 | tmp = ' Target Transforms (if any): ' 124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 125 | return fmt_str 126 | 127 | def _check_exists(self): 128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ 129 | os.path.exists(self.classes_file) 130 | 131 | def download(self): 132 | """Download the FGVC-Aircraft data if it doesn't exist already.""" 133 | from six.moves import urllib 134 | import tarfile 135 | 136 | if self._check_exists(): 137 | return 138 | 139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz 140 | print('Downloading %s ... (may take a few minutes)' % self.url) 141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir)) 142 | tar_name = self.url.rpartition('/')[-1] 143 | tar_path = os.path.join(parent_dir, tar_name) 144 | data = urllib.request.urlopen(self.url) 145 | 146 | # download .tar.gz file 147 | with open(tar_path, 'wb') as f: 148 | f.write(data.read()) 149 | 150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b 151 | data_folder = tar_path.strip('.tar.gz') 152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder)) 153 | tar = tarfile.open(tar_path) 154 | tar.extractall(parent_dir) 155 | 156 | # if necessary, rename data folder to self.root 157 | if not os.path.samefile(data_folder, self.root): 158 | print('Renaming %s to %s ...' % (data_folder, self.root)) 159 | os.rename(data_folder, self.root) 160 | 161 | # delete .tar.gz file 162 | print('Deleting %s ...' % tar_path) 163 | os.remove(tar_path) 164 | 165 | print('Done!') 166 | 167 | 168 | def subsample_dataset(dataset, idxs): 169 | 170 | mask = np.zeros(len(dataset)).astype('bool') 171 | mask[idxs] = True 172 | 173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs] 174 | dataset.uq_idxs = dataset.uq_idxs[mask] 175 | 176 | return dataset 177 | 178 | 179 | def subsample_classes(dataset, include_classes=range(60)): 180 | 181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes] 182 | 183 | # TODO: Don't transform targets for now 184 | target_xform_dict = {} 185 | for i, k in enumerate(include_classes): 186 | target_xform_dict[k] = i 187 | 188 | dataset = subsample_dataset(dataset, cls_idxs) 189 | 190 | dataset.target_transform = lambda x: target_xform_dict[x] 191 | 192 | return dataset 193 | 194 | 195 | def get_train_val_indices(train_dataset, val_split=0.2): 196 | 197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)] 198 | train_classes = np.unique(all_targets) 199 | 200 | # Get train/test indices 201 | train_idxs = [] 202 | val_idxs = [] 203 | for cls in train_classes: 204 | cls_idxs = np.where(all_targets == cls)[0] 205 | 206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 207 | t_ = [x for x in cls_idxs if x not in v_] 208 | 209 | train_idxs.extend(t_) 210 | val_idxs.extend(v_) 211 | 212 | return train_idxs, val_idxs 213 | 214 | 215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8, 216 | split_train_val=False, seed=0): 217 | 218 | np.random.seed(seed) 219 | 220 | # Init entire training set 221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval') 222 | 223 | # Get labelled training set which has subsampled classes, then subsample some indices from that 224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 227 | 228 | # Split into training and validation sets 229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 232 | val_dataset_labelled_split.transform = test_transform 233 | 234 | # Get unlabelled data 235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 237 | 238 | # Get test set for all classes 239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test') 240 | 241 | # Either split train into train and val or use test set as val 242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 244 | 245 | all_datasets = { 246 | 'train_labelled': train_dataset_labelled, 247 | 'train_unlabelled': train_dataset_unlabelled, 248 | 'val': val_dataset_labelled, 249 | 'test': test_dataset, 250 | } 251 | 252 | return all_datasets 253 | 254 | if __name__ == '__main__': 255 | 256 | x = get_aircraft_datasets(None, None, split_train_val=False) 257 | 258 | print('Printing lens...') 259 | for k, v in x.items(): 260 | if v is not None: 261 | print(f'{k}: {len(v)}') 262 | 263 | print('Printing labelled and unlabelled overlap...') 264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 265 | print('Printing total instances in train...') 266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 267 | print('Printing number of labelled classes...') 268 | print(len(set([i[1] for i in x['train_labelled'].samples]))) 269 | print('Printing total number of classes...') 270 | print(len(set([i[1] for i in x['train_unlabelled'].samples]))) 271 | -------------------------------------------------------------------------------- /data/get_datasets.py: -------------------------------------------------------------------------------- 1 | from data.data_utils import MergedDataset 2 | 3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets 4 | from data.herbarium_19 import get_herbarium_datasets 5 | from data.stanford_cars import get_scars_datasets 6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets 7 | from data.cub import get_cub_datasets 8 | from data.fgvc_aircraft import get_aircraft_datasets 9 | 10 | from copy import deepcopy 11 | import pickle 12 | import os 13 | 14 | from config import osr_split_dir 15 | 16 | 17 | get_dataset_funcs = { 18 | 'cifar10': get_cifar_10_datasets, 19 | 'cifar100': get_cifar_100_datasets, 20 | 'imagenet_100': get_imagenet_100_datasets, 21 | 'imagenet_1k': get_imagenet_1k_datasets, 22 | 'herbarium_19': get_herbarium_datasets, 23 | 'cub': get_cub_datasets, 24 | 'aircraft': get_aircraft_datasets, 25 | 'scars': get_scars_datasets 26 | } 27 | 28 | 29 | def get_datasets(dataset_name, train_transform, test_transform, args): 30 | 31 | """ 32 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled 33 | test_dataset, 34 | unlabelled_train_examples_test, 35 | datasets 36 | """ 37 | 38 | # 39 | if dataset_name not in get_dataset_funcs.keys(): 40 | raise ValueError 41 | 42 | # Get datasets 43 | get_dataset_f = get_dataset_funcs[dataset_name] 44 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform, 45 | train_classes=args.train_classes, 46 | prop_train_labels=args.prop_train_labels, 47 | split_train_val=False) 48 | # Set target transforms: 49 | target_transform_dict = {} 50 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)): 51 | target_transform_dict[cls] = i 52 | target_transform = lambda x: target_transform_dict[x] 53 | 54 | for dataset_name, dataset in datasets.items(): 55 | if dataset is not None: 56 | dataset.target_transform = target_transform 57 | 58 | # Train split (labelled and unlabelled classes) for training 59 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']), 60 | unlabelled_dataset=deepcopy(datasets['train_unlabelled'])) 61 | 62 | test_dataset = datasets['test'] 63 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled']) 64 | unlabelled_train_examples_test.transform = test_transform 65 | 66 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets 67 | 68 | 69 | def get_class_splits(args): 70 | 71 | # For FGVC datasets, optionally return bespoke splits 72 | if args.dataset_name in ('scars', 'cub', 'aircraft'): 73 | if hasattr(args, 'use_ssb_splits'): 74 | use_ssb_splits = args.use_ssb_splits 75 | else: 76 | use_ssb_splits = False 77 | 78 | # ------------- 79 | # GET CLASS SPLITS 80 | # ------------- 81 | if args.dataset_name == 'cifar10': 82 | 83 | args.image_size = 32 84 | args.train_classes = range(5) 85 | args.unlabeled_classes = range(5, 10) 86 | 87 | elif args.dataset_name == 'cifar100': 88 | 89 | args.image_size = 32 90 | args.train_classes = range(80) 91 | args.unlabeled_classes = range(80, 100) 92 | 93 | elif args.dataset_name == 'herbarium_19': 94 | 95 | args.image_size = 224 96 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl') 97 | 98 | with open(herb_path_splits, 'rb') as handle: 99 | class_splits = pickle.load(handle) 100 | 101 | args.train_classes = class_splits['Old'] 102 | args.unlabeled_classes = class_splits['New'] 103 | 104 | elif args.dataset_name == 'imagenet_100': 105 | 106 | args.image_size = 224 107 | args.train_classes = range(50) 108 | args.unlabeled_classes = range(50, 100) 109 | 110 | elif args.dataset_name == 'imagenet_1k': 111 | 112 | args.image_size = 224 113 | args.train_classes = range(500) 114 | args.unlabeled_classes = range(500, 1000) 115 | 116 | elif args.dataset_name == 'scars': 117 | 118 | args.image_size = 224 119 | 120 | if use_ssb_splits: 121 | 122 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl') 123 | with open(split_path, 'rb') as handle: 124 | class_info = pickle.load(handle) 125 | 126 | args.train_classes = class_info['known_classes'] 127 | open_set_classes = class_info['unknown_classes'] 128 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 129 | 130 | else: 131 | 132 | args.train_classes = range(98) 133 | args.unlabeled_classes = range(98, 196) 134 | 135 | elif args.dataset_name == 'aircraft': 136 | 137 | args.image_size = 224 138 | if use_ssb_splits: 139 | 140 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl') 141 | with open(split_path, 'rb') as handle: 142 | class_info = pickle.load(handle) 143 | 144 | args.train_classes = class_info['known_classes'] 145 | open_set_classes = class_info['unknown_classes'] 146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 147 | 148 | else: 149 | 150 | args.train_classes = range(50) 151 | args.unlabeled_classes = range(50, 100) 152 | 153 | elif args.dataset_name == 'cub': 154 | 155 | args.image_size = 224 156 | 157 | if use_ssb_splits: 158 | 159 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl') 160 | with open(split_path, 'rb') as handle: 161 | class_info = pickle.load(handle) 162 | 163 | args.train_classes = class_info['known_classes'] 164 | open_set_classes = class_info['unknown_classes'] 165 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy'] 166 | 167 | else: 168 | 169 | args.train_classes = range(100) 170 | args.unlabeled_classes = range(100, 200) 171 | 172 | else: 173 | 174 | raise NotImplementedError 175 | 176 | return args 177 | -------------------------------------------------------------------------------- /data/herbarium_19.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | from data.data_utils import subsample_instances 8 | from config import herbarium_dataroot 9 | 10 | class HerbariumDataset19(torchvision.datasets.ImageFolder): 11 | 12 | def __init__(self, *args, **kwargs): 13 | 14 | # Process metadata json for training images into a DataFrame 15 | super().__init__(*args, **kwargs) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, idx): 20 | 21 | img, label = super().__getitem__(idx) 22 | uq_idx = self.uq_idxs[idx] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs): 28 | 29 | mask = np.zeros(len(dataset)).astype('bool') 30 | mask[idxs] = True 31 | 32 | dataset.samples = np.array(dataset.samples)[mask].tolist() 33 | dataset.targets = np.array(dataset.targets)[mask].tolist() 34 | 35 | dataset.uq_idxs = dataset.uq_idxs[mask] 36 | 37 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples] 38 | dataset.targets = [int(x) for x in dataset.targets] 39 | 40 | return dataset 41 | 42 | 43 | def subsample_classes(dataset, include_classes=range(250)): 44 | 45 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes] 46 | 47 | target_xform_dict = {} 48 | for i, k in enumerate(include_classes): 49 | target_xform_dict[k] = i 50 | 51 | dataset = subsample_dataset(dataset, cls_idxs) 52 | 53 | dataset.target_transform = lambda x: target_xform_dict[x] 54 | 55 | return dataset 56 | 57 | 58 | def get_train_val_indices(train_dataset, val_instances_per_class=5): 59 | 60 | train_classes = list(set(train_dataset.targets)) 61 | 62 | # Get train/test indices 63 | train_idxs = [] 64 | val_idxs = [] 65 | for cls in train_classes: 66 | 67 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 68 | 69 | # Have a balanced test set 70 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,)) 71 | t_ = [x for x in cls_idxs if x not in v_] 72 | 73 | train_idxs.extend(t_) 74 | val_idxs.extend(v_) 75 | 76 | return train_idxs, val_idxs 77 | 78 | 79 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8, 80 | seed=0, split_train_val=False): 81 | 82 | np.random.seed(seed) 83 | 84 | # Init entire training set 85 | train_dataset = HerbariumDataset19(transform=train_transform, 86 | root=os.path.join(herbarium_dataroot, 'small-train')) 87 | 88 | # Get labelled training set which has subsampled classes, then subsample some indices from that 89 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class 90 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes) 91 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 92 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 93 | 94 | # Split into training and validation sets 95 | if split_train_val: 96 | 97 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled, 98 | val_instances_per_class=5) 99 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 100 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 101 | val_dataset_labelled_split.transform = test_transform 102 | 103 | else: 104 | 105 | train_dataset_labelled_split, val_dataset_labelled_split = None, None 106 | 107 | # Get unlabelled data 108 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs) 109 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices))) 110 | 111 | # Get test dataset 112 | test_dataset = HerbariumDataset19(transform=test_transform, 113 | root=os.path.join(herbarium_dataroot, 'small-validation')) 114 | 115 | # Transform dict 116 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes)) 117 | target_xform_dict = {} 118 | for i, k in enumerate(list(train_classes) + unlabelled_classes): 119 | target_xform_dict[k] = i 120 | 121 | test_dataset.target_transform = lambda x: target_xform_dict[x] 122 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x] 123 | 124 | # Either split train into train and val or use test set as val 125 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 126 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 127 | 128 | all_datasets = { 129 | 'train_labelled': train_dataset_labelled, 130 | 'train_unlabelled': train_dataset_unlabelled, 131 | 'val': val_dataset_labelled, 132 | 'test': test_dataset, 133 | } 134 | 135 | return all_datasets 136 | 137 | if __name__ == '__main__': 138 | 139 | np.random.seed(0) 140 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False) 141 | 142 | x = get_herbarium_datasets(None, None, train_classes=train_classes, 143 | prop_train_labels=0.5) 144 | 145 | assert set(x['train_unlabelled'].targets) == set(range(683)) 146 | 147 | print('Printing lens...') 148 | for k, v in x.items(): 149 | if v is not None: 150 | print(f'{k}: {len(v)}') 151 | 152 | print('Printing labelled and unlabelled overlap...') 153 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 154 | print('Printing total instances in train...') 155 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 156 | print('Printing number of labelled classes...') 157 | print(len(set(x['train_labelled'].targets))) 158 | print('Printing total number of classes...') 159 | print(len(set(x['train_unlabelled'].targets))) 160 | 161 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 162 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 163 | print(f'Len labelled set: {len(x["train_labelled"])}') 164 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import numpy as np 3 | 4 | import os 5 | 6 | from copy import deepcopy 7 | from data.data_utils import subsample_instances 8 | from config import imagenet_root 9 | 10 | 11 | class ImageNetBase(torchvision.datasets.ImageFolder): 12 | 13 | def __init__(self, root, transform): 14 | 15 | super(ImageNetBase, self).__init__(root, transform) 16 | 17 | self.uq_idxs = np.array(range(len(self))) 18 | 19 | def __getitem__(self, item): 20 | 21 | img, label = super().__getitem__(item) 22 | uq_idx = self.uq_idxs[item] 23 | 24 | return img, label, uq_idx 25 | 26 | 27 | def subsample_dataset(dataset, idxs): 28 | 29 | imgs_ = [] 30 | for i in idxs: 31 | imgs_.append(dataset.imgs[i]) 32 | dataset.imgs = imgs_ 33 | 34 | samples_ = [] 35 | for i in idxs: 36 | samples_.append(dataset.samples[i]) 37 | dataset.samples = samples_ 38 | 39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs] 40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs] 41 | 42 | dataset.targets = np.array(dataset.targets)[idxs].tolist() 43 | dataset.uq_idxs = dataset.uq_idxs[idxs] 44 | 45 | return dataset 46 | 47 | 48 | def subsample_classes(dataset, include_classes=list(range(1000))): 49 | 50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] 51 | 52 | target_xform_dict = {} 53 | for i, k in enumerate(include_classes): 54 | target_xform_dict[k] = i 55 | 56 | dataset = subsample_dataset(dataset, cls_idxs) 57 | dataset.target_transform = lambda x: target_xform_dict[x] 58 | 59 | return dataset 60 | 61 | 62 | def get_train_val_indices(train_dataset, val_split=0.2): 63 | 64 | train_classes = list(set(train_dataset.targets)) 65 | 66 | # Get train/test indices 67 | train_idxs = [] 68 | val_idxs = [] 69 | for cls in train_classes: 70 | 71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0] 72 | 73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 74 | t_ = [x for x in cls_idxs if x not in v_] 75 | 76 | train_idxs.extend(t_) 77 | val_idxs.extend(v_) 78 | 79 | return train_idxs, val_idxs 80 | 81 | 82 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80), 83 | prop_train_labels=0.8, split_train_val=False, seed=0): 84 | 85 | np.random.seed(seed) 86 | 87 | # Subsample imagenet dataset initially to include 100 classes 88 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False) 89 | subsampled_100_classes = np.sort(subsampled_100_classes) 90 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}') 91 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))} 92 | 93 | # Init entire training set 94 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 95 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes) 96 | 97 | # Reset dataset 98 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples] 99 | whole_training_set.targets = [s[1] for s in whole_training_set.samples] 100 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set))) 101 | whole_training_set.target_transform = None 102 | 103 | # Get labelled training set which has subsampled classes, then subsample some indices from that 104 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 105 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 106 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 107 | 108 | # Split into training and validation sets 109 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 110 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 111 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 112 | val_dataset_labelled_split.transform = test_transform 113 | 114 | # Get unlabelled data 115 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 116 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 117 | 118 | # Get test set for all classes 119 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 120 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes) 121 | 122 | # Reset test set 123 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples] 124 | test_dataset.targets = [s[1] for s in test_dataset.samples] 125 | test_dataset.uq_idxs = np.array(range(len(test_dataset))) 126 | test_dataset.target_transform = None 127 | 128 | # Either split train into train and val or use test set as val 129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 131 | 132 | all_datasets = { 133 | 'train_labelled': train_dataset_labelled, 134 | 'train_unlabelled': train_dataset_unlabelled, 135 | 'val': val_dataset_labelled, 136 | 'test': test_dataset, 137 | } 138 | 139 | return all_datasets 140 | 141 | 142 | def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500), 143 | prop_train_labels=0.5, split_train_val=False, seed=0): 144 | 145 | np.random.seed(seed) 146 | 147 | # Init entire training set 148 | whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform) 149 | 150 | # Get labelled training set which has subsampled classes, then subsample some indices from that 151 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 152 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 153 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 154 | 155 | # Split into training and validation sets 156 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 157 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 158 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 159 | val_dataset_labelled_split.transform = test_transform 160 | 161 | # Get unlabelled data 162 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 163 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 164 | 165 | # Get test set for all classes 166 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform) 167 | 168 | # Either split train into train and val or use test set as val 169 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 170 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 171 | 172 | all_datasets = { 173 | 'train_labelled': train_dataset_labelled, 174 | 'train_unlabelled': train_dataset_unlabelled, 175 | 'val': val_dataset_labelled, 176 | 'test': test_dataset, 177 | } 178 | 179 | return all_datasets 180 | 181 | 182 | 183 | if __name__ == '__main__': 184 | 185 | x = get_imagenet_100_datasets(None, None, split_train_val=False, 186 | train_classes=range(50), prop_train_labels=0.5) 187 | 188 | print('Printing lens...') 189 | for k, v in x.items(): 190 | if v is not None: 191 | print(f'{k}: {len(v)}') 192 | 193 | print('Printing labelled and unlabelled overlap...') 194 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 195 | print('Printing total instances in train...') 196 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 197 | 198 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') 199 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') 200 | print(f'Len labelled set: {len(x["train_labelled"])}') 201 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /data/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from copy import deepcopy 5 | from scipy import io as mat_io 6 | 7 | from torchvision.datasets.folder import default_loader 8 | from torch.utils.data import Dataset 9 | 10 | from data.data_utils import subsample_instances 11 | from config import car_root 12 | 13 | class CarsDataset(Dataset): 14 | """ 15 | Cars Dataset 16 | """ 17 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None): 18 | 19 | metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat') 20 | data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/') 21 | 22 | self.loader = default_loader 23 | self.data_dir = data_dir 24 | self.data = [] 25 | self.target = [] 26 | self.train = train 27 | 28 | self.transform = transform 29 | 30 | if not isinstance(metas, str): 31 | raise Exception("Train metas must be string location !") 32 | labels_meta = mat_io.loadmat(metas) 33 | 34 | for idx, img_ in enumerate(labels_meta['annotations'][0]): 35 | if limit: 36 | if idx > limit: 37 | break 38 | 39 | # self.data.append(img_resized) 40 | self.data.append(data_dir + img_[5][0]) 41 | # if self.mode == 'train': 42 | self.target.append(img_[4][0][0]) 43 | 44 | self.uq_idxs = np.array(range(len(self))) 45 | self.target_transform = None 46 | 47 | def __getitem__(self, idx): 48 | 49 | image = self.loader(self.data[idx]) 50 | target = self.target[idx] - 1 51 | 52 | if self.transform is not None: 53 | image = self.transform(image) 54 | 55 | if self.target_transform is not None: 56 | target = self.target_transform(target) 57 | 58 | idx = self.uq_idxs[idx] 59 | 60 | return image, target, idx 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | 66 | def subsample_dataset(dataset, idxs): 67 | 68 | dataset.data = np.array(dataset.data)[idxs].tolist() 69 | dataset.target = np.array(dataset.target)[idxs].tolist() 70 | dataset.uq_idxs = dataset.uq_idxs[idxs] 71 | 72 | return dataset 73 | 74 | 75 | def subsample_classes(dataset, include_classes=range(160)): 76 | 77 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195 78 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars] 79 | 80 | target_xform_dict = {} 81 | for i, k in enumerate(include_classes): 82 | target_xform_dict[k] = i 83 | 84 | dataset = subsample_dataset(dataset, cls_idxs) 85 | 86 | # dataset.target_transform = lambda x: target_xform_dict[x] 87 | 88 | return dataset 89 | 90 | def get_train_val_indices(train_dataset, val_split=0.2): 91 | 92 | train_classes = np.unique(train_dataset.target) 93 | 94 | # Get train/test indices 95 | train_idxs = [] 96 | val_idxs = [] 97 | for cls in train_classes: 98 | 99 | cls_idxs = np.where(train_dataset.target == cls)[0] 100 | 101 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) 102 | t_ = [x for x in cls_idxs if x not in v_] 103 | 104 | train_idxs.extend(t_) 105 | val_idxs.extend(v_) 106 | 107 | return train_idxs, val_idxs 108 | 109 | 110 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8, 111 | split_train_val=False, seed=0): 112 | 113 | np.random.seed(seed) 114 | 115 | # Init entire training set 116 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True) 117 | 118 | # Get labelled training set which has subsampled classes, then subsample some indices from that 119 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) 120 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) 121 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) 122 | 123 | # Split into training and validation sets 124 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) 125 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) 126 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) 127 | val_dataset_labelled_split.transform = test_transform 128 | 129 | # Get unlabelled data 130 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) 131 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) 132 | 133 | # Get test set for all classes 134 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False) 135 | 136 | # Either split train into train and val or use test set as val 137 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled 138 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None 139 | 140 | all_datasets = { 141 | 'train_labelled': train_dataset_labelled, 142 | 'train_unlabelled': train_dataset_unlabelled, 143 | 'val': val_dataset_labelled, 144 | 'test': test_dataset, 145 | } 146 | 147 | return all_datasets 148 | 149 | if __name__ == '__main__': 150 | 151 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False) 152 | 153 | print('Printing lens...') 154 | for k, v in x.items(): 155 | if v is not None: 156 | print(f'{k}: {len(v)}') 157 | 158 | print('Printing labelled and unlabelled overlap...') 159 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) 160 | print('Printing total instances in train...') 161 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) 162 | 163 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}') 164 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}') 165 | print(f'Len labelled set: {len(x["train_labelled"])}') 166 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}') -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/models/__init__.py -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | 8 | class SupConLoss(torch.nn.Module): 9 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 10 | It also supports the unsupervised contrastive loss in SimCLR 11 | From: https://github.com/HobbitLong/SupContrast""" 12 | def __init__(self, temperature=0.07, contrast_mode='all', 13 | base_temperature=0.07): 14 | super(SupConLoss, self).__init__() 15 | self.temperature = temperature 16 | self.contrast_mode = contrast_mode 17 | self.base_temperature = base_temperature 18 | 19 | def forward(self, features, labels=None, mask=None): 20 | """Compute loss for model. If both `labels` and `mask` are None, 21 | it degenerates to SimCLR unsupervised loss: 22 | https://arxiv.org/pdf/2002.05709.pdf 23 | Args: 24 | features: hidden vector of shape [bsz, n_views, ...]. 25 | labels: ground truth of shape [bsz]. 26 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 27 | has the same class as sample i. Can be asymmetric. 28 | Returns: 29 | A loss scalar. 30 | """ 31 | 32 | device = (torch.device('cuda') 33 | if features.is_cuda 34 | else torch.device('cpu')) 35 | 36 | if len(features.shape) < 3: 37 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 38 | 'at least 3 dimensions are required') 39 | if len(features.shape) > 3: 40 | features = features.view(features.shape[0], features.shape[1], -1) 41 | 42 | batch_size = features.shape[0] 43 | if labels is not None and mask is not None: 44 | raise ValueError('Cannot define both `labels` and `mask`') 45 | elif labels is None and mask is None: 46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 47 | elif labels is not None: 48 | labels = labels.contiguous().view(-1, 1) 49 | if labels.shape[0] != batch_size: 50 | raise ValueError('Num of labels does not match num of features') 51 | mask = torch.eq(labels, labels.T).float().to(device) 52 | else: 53 | mask = mask.float().to(device) 54 | 55 | contrast_count = features.shape[1] 56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 57 | if self.contrast_mode == 'one': 58 | anchor_feature = features[:, 0] 59 | anchor_count = 1 60 | elif self.contrast_mode == 'all': 61 | anchor_feature = contrast_feature 62 | anchor_count = contrast_count 63 | else: 64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 65 | 66 | # compute logits 67 | anchor_dot_contrast = torch.div( 68 | torch.matmul(anchor_feature, contrast_feature.T), 69 | self.temperature) 70 | 71 | # for numerical stability 72 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 73 | logits = anchor_dot_contrast - logits_max.detach() 74 | 75 | # tile mask 76 | mask = mask.repeat(anchor_count, contrast_count) 77 | # mask-out self-contrast cases 78 | logits_mask = torch.scatter( 79 | torch.ones_like(mask), 80 | 1, 81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 82 | 0 83 | ) 84 | mask = mask * logits_mask 85 | 86 | # compute log_prob 87 | exp_logits = torch.exp(logits) * logits_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positive 91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 92 | 93 | # loss 94 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 95 | loss = loss.view(anchor_count, batch_size).mean() 96 | 97 | return loss 98 | 99 | 100 | 101 | def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'): 102 | 103 | b_ = 0.5 * int(features.size(0)) 104 | 105 | labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0) 106 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float() 107 | labels = labels.to(device) 108 | 109 | features = F.normalize(features, dim=1) 110 | 111 | similarity_matrix = torch.matmul(features, features.T) 112 | 113 | # discard the main diagonal from both: labels and similarities matrix 114 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device) 115 | labels = labels[~mask].view(labels.shape[0], -1) 116 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1) 117 | 118 | # select and combine multiple positives 119 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1) 120 | 121 | # select only the negatives the negatives 122 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1) 123 | 124 | logits = torch.cat([positives, negatives], dim=1) 125 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) 126 | 127 | logits = logits / temperature 128 | return logits, labels 129 | 130 | 131 | def entropy_regularization_loss(logits, temperature): 132 | avg_probs = (logits / temperature).softmax(dim=1).mean(dim=0) 133 | entropy_reg_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs))) 134 | return entropy_reg_loss 135 | 136 | 137 | def prototype_separation_loss(prototypes, temperature=0.1, base_temperature=0.1, device='cuda'): 138 | num_classes = prototypes.size(0) 139 | labels = torch.arange(0, num_classes).to(device) 140 | labels = labels.contiguous().view(-1, 1) 141 | 142 | mask = (1- torch.eq(labels, labels.T).float()).cuda() 143 | 144 | logits = torch.div(torch.matmul(prototypes, prototypes.T), temperature) 145 | 146 | mean_prob_neg = torch.log((mask * torch.exp(logits)).sum(1) / mask.sum(1)) 147 | mean_prob_neg = mean_prob_neg[~torch.isnan(mean_prob_neg)] 148 | 149 | # loss 150 | loss = temperature / base_temperature * mean_prob_neg.mean() 151 | 152 | return loss 153 | 154 | 155 | 156 | class DistillLoss_ratio(nn.Module): 157 | def __init__(self, num_classes=100, wait_ratio_epochs=0, ramp_ratio_teacher_epochs=100, 158 | nepochs=200, ncrops=2, init_ratio=0.0, final_ratio=1.0, 159 | temp_logits=0.1, temp_teacher_logits=0.05, device='cuda'): 160 | super().__init__() 161 | self.device = device 162 | self.num_classes = num_classes 163 | self.temp_logits = temp_logits 164 | self.temp_teacher_logits = temp_teacher_logits 165 | self.ncrops = ncrops 166 | self.ratio_schedule = np.concatenate(( 167 | np.zeros(wait_ratio_epochs), 168 | np.linspace(init_ratio, 169 | final_ratio, ramp_ratio_teacher_epochs), 170 | np.ones(nepochs - wait_ratio_epochs - ramp_ratio_teacher_epochs) * final_ratio 171 | )) 172 | 173 | def forward(self, student_output, teacher_output, epoch): 174 | """ 175 | Cross-entropy between softmax outputs of the teacher and student networks. 176 | """ 177 | student_out = student_output / self.temp_logits 178 | student_out = student_out.chunk(self.ncrops) 179 | 180 | # confidence filtering 181 | ratio_epoch = self.ratio_schedule[epoch] 182 | teacher_out = F.softmax(teacher_output / self.temp_teacher_logits, dim=-1) 183 | teacher_out = teacher_out.detach().chunk(self.ncrops) 184 | 185 | teacher_label = [] 186 | for i in range(self.ncrops): 187 | top2 = torch.topk(teacher_out[i], k=2, dim=-1, largest=True)[0] 188 | top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 189 | filter_number = int(len(teacher_out[i]) * ratio_epoch) 190 | topk_filter = torch.topk(top2_div, k=filter_number, largest=True)[1] 191 | pseudo_label = F.one_hot(teacher_out[i].argmax(dim=-1), num_classes=self.num_classes) 192 | pseudo_label = pseudo_label.float() 193 | teacher_out[i][topk_filter] = pseudo_label[topk_filter] 194 | teacher_label.append(teacher_out[i]) 195 | 196 | total_loss = 0 197 | n_loss_terms = 0 198 | for iq, q in enumerate(teacher_label): 199 | #for v in range(len(student_out)): 200 | for iv, v in enumerate(student_out): 201 | if iv == iq: 202 | # we skip cases where student and teacher operate on the same view 203 | continue 204 | #loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 205 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1) 206 | total_loss += loss.mean() 207 | n_loss_terms += 1 208 | total_loss /= n_loss_terms 209 | return total_loss 210 | 211 | 212 | 213 | class DistillLoss_ratio_ramp(nn.Module): 214 | def __init__(self, num_classes=100, wait_ratio_epochs=0, ramp_ratio_teacher_epochs=100, 215 | nepochs=200, ncrops=2, init_ratio=0.2, final_ratio=1.0, 216 | temp_logits=0.1, temp_teacher_logits_init=0.07, temp_teacher_logits_final=0.04, ramp_temp_teacher_epochs=30, 217 | device='cuda'): 218 | super().__init__() 219 | self.device = device 220 | self.num_classes = num_classes 221 | self.temp_logits = temp_logits 222 | # self.temp_teacher_logits_init = temp_teacher_logits_init 223 | # self.temp_teacher_logits_final = temp_teacher_logits_final 224 | self.ncrops = ncrops 225 | self.teacher_temp_schedule = np.concatenate(( 226 | np.linspace(temp_teacher_logits_init, 227 | temp_teacher_logits_final, ramp_temp_teacher_epochs), 228 | np.ones(nepochs - ramp_temp_teacher_epochs) * temp_teacher_logits_final 229 | )) 230 | self.ratio_schedule = np.concatenate(( 231 | np.zeros(wait_ratio_epochs), 232 | np.linspace(init_ratio, 233 | final_ratio, ramp_ratio_teacher_epochs), 234 | np.ones(nepochs - wait_ratio_epochs - ramp_ratio_teacher_epochs) * final_ratio 235 | )) 236 | 237 | def forward(self, student_output, teacher_output, epoch): 238 | """ 239 | Cross-entropy between softmax outputs of the teacher and student networks. 240 | """ 241 | student_out = student_output / self.temp_logits 242 | student_out = student_out.chunk(self.ncrops) 243 | 244 | # confidence filtering 245 | temp_teacher_epoch = self.teacher_temp_schedule[epoch] 246 | ratio_epoch = self.ratio_schedule[epoch] 247 | teacher_out = F.softmax(teacher_output / temp_teacher_epoch, dim=-1) 248 | teacher_out = teacher_out.detach().chunk(self.ncrops) 249 | 250 | teacher_label = [] 251 | for i in range(self.ncrops): 252 | top2 = torch.topk(teacher_out[i], k=2, dim=-1, largest=True)[0] 253 | top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 254 | filter_number = int(len(teacher_out[i]) * ratio_epoch) 255 | topk_filter = torch.topk(top2_div, k=filter_number, largest=True)[1] 256 | pseudo_label = F.one_hot(teacher_out[i].argmax(dim=-1), num_classes=self.num_classes) 257 | pseudo_label = pseudo_label.float() 258 | teacher_out[i][topk_filter] = pseudo_label[topk_filter] 259 | teacher_label.append(teacher_out[i]) 260 | 261 | total_loss = 0 262 | n_loss_terms = 0 263 | for iq, q in enumerate(teacher_label): 264 | #for v in range(len(student_out)): 265 | for iv, v in enumerate(student_out): 266 | if iv == iq: 267 | # we skip cases where student and teacher operate on the same view 268 | continue 269 | #loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 270 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1) 271 | total_loss += loss.mean() 272 | n_loss_terms += 1 273 | total_loss /= n_loss_terms 274 | return total_loss 275 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class DINOHead(nn.Module): 8 | def __init__(self, in_dim, out_dim, use_bn=False, init_prototypes=None, 9 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, num_labeled_classes=50): 10 | super().__init__() 11 | nlayers = max(nlayers, 1) 12 | if nlayers == 1: 13 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 14 | elif nlayers != 0: 15 | layers = [nn.Linear(in_dim, hidden_dim)] 16 | if use_bn: 17 | layers.append(nn.BatchNorm1d(hidden_dim)) 18 | layers.append(nn.GELU()) 19 | for _ in range(nlayers - 2): 20 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 21 | if use_bn: 22 | layers.append(nn.BatchNorm1d(hidden_dim)) 23 | layers.append(nn.GELU()) 24 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 25 | self.mlp = nn.Sequential(*layers) 26 | self.apply(self._init_weights) 27 | 28 | # prototypes 29 | self.prototype_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False)) 30 | self.prototype_layer.weight_g.data.fill_(1) 31 | self.prototype_layer.weight_g.requires_grad = False 32 | print('prototype size: ', self.prototype_layer.weight_v.size()) 33 | 34 | if init_prototypes is not None: 35 | print('initialize templates with labeled means and k-means centroids...') 36 | print(init_prototypes.size()) 37 | print(init_prototypes) 38 | #self.prototype_layer.weight_v.data.copy_(init_prototypes) 39 | self.prototype_layer.weight_v.data[:num_labeled_classes].copy_(init_prototypes[:num_labeled_classes]) 40 | print(self.prototype_layer.weight_v) 41 | else: 42 | print('randomly initialize prototypes...') 43 | print(self.prototype_layer.weight_v) 44 | 45 | 46 | def _init_weights(self, m): 47 | if isinstance(m, nn.Linear): 48 | torch.nn.init.trunc_normal_(m.weight, std=.02) 49 | if isinstance(m, nn.Linear) and m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | 52 | def forward(self, x): 53 | x_proj = self.mlp(x) 54 | x = F.normalize(x, dim=-1, p=2) 55 | # x = x.detach() 56 | logits = self.prototype_layer(x) 57 | 58 | prototypes = self.prototype_layer.weight_v.clone() 59 | normed_prototypes = F.normalize(prototypes, dim=-1, p=2) 60 | 61 | return x_proj, logits, normed_prototypes 62 | 63 | 64 | 65 | class DINOHead_k(nn.Module): 66 | ''' 67 | DINOHead for estimating k. 68 | difference with DINOHead: `forward()`, return one more `x` 69 | date: 20230515 70 | ''' 71 | def __init__(self, in_dim, out_dim, use_bn=False, init_prototypes=None, 72 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, num_labeled_classes=50): 73 | super().__init__() 74 | nlayers = max(nlayers, 1) 75 | if nlayers == 1: 76 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 77 | elif nlayers != 0: 78 | layers = [nn.Linear(in_dim, hidden_dim)] 79 | if use_bn: 80 | layers.append(nn.BatchNorm1d(hidden_dim)) 81 | layers.append(nn.GELU()) 82 | for _ in range(nlayers - 2): 83 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 84 | if use_bn: 85 | layers.append(nn.BatchNorm1d(hidden_dim)) 86 | layers.append(nn.GELU()) 87 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 88 | self.mlp = nn.Sequential(*layers) 89 | self.apply(self._init_weights) 90 | 91 | # prototypes 92 | self.prototype_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False)) 93 | self.prototype_layer.weight_g.data.fill_(1) 94 | self.prototype_layer.weight_g.requires_grad = False 95 | print('prototype size: ', self.prototype_layer.weight_v.size()) 96 | 97 | if init_prototypes is not None: 98 | print('initialize templates with labeled means and k-means centroids...') 99 | print(init_prototypes.size()) 100 | print(init_prototypes) 101 | #self.prototype_layer.weight_v.data.copy_(init_prototypes) 102 | self.prototype_layer.weight_v.data[:num_labeled_classes].copy_(init_prototypes[:num_labeled_classes]) 103 | print(self.prototype_layer.weight_v) 104 | else: 105 | print('randomly initialize prototypes...') 106 | print(self.prototype_layer.weight_v) 107 | 108 | 109 | def _init_weights(self, m): 110 | if isinstance(m, nn.Linear): 111 | torch.nn.init.trunc_normal_(m.weight, std=.02) 112 | if isinstance(m, nn.Linear) and m.bias is not None: 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def forward(self, x): 116 | x_proj = self.mlp(x) 117 | x = F.normalize(x, dim=-1, p=2) 118 | # x = x.detach() 119 | logits = self.prototype_layer(x) 120 | 121 | prototypes = self.prototype_layer.weight_v.clone() 122 | normed_prototypes = F.normalize(prototypes, dim=-1, p=2) 123 | 124 | return x, x_proj, logits, normed_prototypes 125 | 126 | 127 | 128 | class ContrastiveLearningViewGenerator(object): 129 | """Take two random crops of one image as the query and key.""" 130 | 131 | def __init__(self, base_transform, n_views=2): 132 | self.base_transform = base_transform 133 | self.n_views = n_views 134 | 135 | def __call__(self, x): 136 | if not isinstance(self.base_transform, list): 137 | return [self.base_transform(x) for i in range(self.n_views)] 138 | else: 139 | return [self.base_transform[i](x) for i in range(self.n_views)] 140 | 141 | 142 | def get_params_groups(model): 143 | regularized = [] 144 | not_regularized = [] 145 | for name, param in model.named_parameters(): 146 | if not param.requires_grad: 147 | continue 148 | # we do not regularize biases nor Norm parameters 149 | if name.endswith(".bias") or len(param.shape) == 1: 150 | not_regularized.append(param) 151 | else: 152 | regularized.append(param) 153 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] 154 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Mostly copy-paste from timm library. 16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 17 | """ 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | import warnings 25 | 26 | 27 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 28 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 29 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 30 | def norm_cdf(x): 31 | # Computes standard normal cumulative distribution function 32 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 33 | 34 | if (mean < a - 2 * std) or (mean > b + 2 * std): 35 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 36 | "The distribution of values may be incorrect.", 37 | stacklevel=2) 38 | 39 | with torch.no_grad(): 40 | # Values are generated by using a truncated uniform distribution and 41 | # then using the inverse CDF for the normal distribution. 42 | # Get upper and lower cdf values 43 | l = norm_cdf((a - mean) / std) 44 | u = norm_cdf((b - mean) / std) 45 | 46 | # Uniformly fill tensor with values from [l, u], then translate to 47 | # [2l-1, 2u-1]. 48 | tensor.uniform_(2 * l - 1, 2 * u - 1) 49 | 50 | # Use inverse cdf transform for normal distribution to get truncated 51 | # standard normal 52 | tensor.erfinv_() 53 | 54 | # Transform to proper mean, std 55 | tensor.mul_(std * math.sqrt(2.)) 56 | tensor.add_(mean) 57 | 58 | # Clamp to ensure it's in the proper range 59 | tensor.clamp_(min=a, max=b) 60 | return tensor 61 | 62 | 63 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 64 | # type: (Tensor, float, float, float, float) -> Tensor 65 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 66 | 67 | 68 | def drop_path(x, drop_prob: float = 0., training: bool = False): 69 | if drop_prob == 0. or not training: 70 | return x 71 | keep_prob = 1 - drop_prob 72 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 73 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 74 | random_tensor.floor_() # binarize 75 | output = x.div(keep_prob) * random_tensor 76 | return output 77 | 78 | 79 | class DropPath(nn.Module): 80 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 81 | """ 82 | def __init__(self, drop_prob=None): 83 | super(DropPath, self).__init__() 84 | self.drop_prob = drop_prob 85 | 86 | def forward(self, x): 87 | return drop_path(x, self.drop_prob, self.training) 88 | 89 | 90 | class Mlp(nn.Module): 91 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 92 | super().__init__() 93 | out_features = out_features or in_features 94 | hidden_features = hidden_features or in_features 95 | self.fc1 = nn.Linear(in_features, hidden_features) 96 | self.act = act_layer() 97 | self.fc2 = nn.Linear(hidden_features, out_features) 98 | self.drop = nn.Dropout(drop) 99 | 100 | def forward(self, x): 101 | x = self.fc1(x) 102 | x = self.act(x) 103 | x = self.drop(x) 104 | x = self.fc2(x) 105 | x = self.drop(x) 106 | return x 107 | 108 | 109 | class Attention(nn.Module): 110 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 111 | super().__init__() 112 | self.num_heads = num_heads 113 | head_dim = dim // num_heads 114 | self.scale = qk_scale or head_dim ** -0.5 115 | 116 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 117 | self.attn_drop = nn.Dropout(attn_drop) 118 | self.proj = nn.Linear(dim, dim) 119 | self.proj_drop = nn.Dropout(proj_drop) 120 | 121 | def forward(self, x): 122 | B, N, C = x.shape 123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 124 | q, k, v = qkv[0], qkv[1], qkv[2] 125 | 126 | attn = (q @ k.transpose(-2, -1)) * self.scale 127 | attn = attn.softmax(dim=-1) 128 | attn = self.attn_drop(attn) 129 | 130 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 131 | x = self.proj(x) 132 | x = self.proj_drop(x) 133 | return x, attn 134 | 135 | 136 | class Block(nn.Module): 137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 139 | super().__init__() 140 | self.norm1 = norm_layer(dim) 141 | self.attn = Attention( 142 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 144 | self.norm2 = norm_layer(dim) 145 | mlp_hidden_dim = int(dim * mlp_ratio) 146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 147 | 148 | def forward(self, x, return_attention=False): 149 | y, attn = self.attn(self.norm1(x)) 150 | if return_attention: 151 | return attn 152 | x = x + self.drop_path(y) 153 | x = x + self.drop_path(self.mlp(self.norm2(x))) 154 | return x 155 | 156 | 157 | class PatchEmbed(nn.Module): 158 | """ Image to Patch Embedding 159 | """ 160 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 161 | super().__init__() 162 | num_patches = (img_size // patch_size) * (img_size // patch_size) 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | self.num_patches = num_patches 166 | 167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 168 | 169 | def forward(self, x): 170 | B, C, H, W = x.shape 171 | x = self.proj(x).flatten(2).transpose(1, 2) 172 | return x 173 | 174 | 175 | class VisionTransformer(nn.Module): 176 | """ Vision Transformer """ 177 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, 178 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 179 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): 180 | super().__init__() 181 | self.num_features = self.embed_dim = embed_dim 182 | 183 | self.patch_embed = PatchEmbed( 184 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 185 | num_patches = self.patch_embed.num_patches 186 | 187 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 188 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 189 | self.pos_drop = nn.Dropout(p=drop_rate) 190 | 191 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 192 | self.blocks = nn.ModuleList([ 193 | Block( 194 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 195 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 196 | for i in range(depth)]) 197 | self.norm = norm_layer(embed_dim) 198 | 199 | # Classifier head 200 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 201 | 202 | trunc_normal_(self.pos_embed, std=.02) 203 | trunc_normal_(self.cls_token, std=.02) 204 | self.apply(self._init_weights) 205 | 206 | def _init_weights(self, m): 207 | if isinstance(m, nn.Linear): 208 | trunc_normal_(m.weight, std=.02) 209 | if isinstance(m, nn.Linear) and m.bias is not None: 210 | nn.init.constant_(m.bias, 0) 211 | elif isinstance(m, nn.LayerNorm): 212 | nn.init.constant_(m.bias, 0) 213 | nn.init.constant_(m.weight, 1.0) 214 | 215 | def interpolate_pos_encoding(self, x, w, h): 216 | npatch = x.shape[1] - 1 217 | N = self.pos_embed.shape[1] - 1 218 | if npatch == N and w == h: 219 | return self.pos_embed 220 | class_pos_embed = self.pos_embed[:, 0] 221 | patch_pos_embed = self.pos_embed[:, 1:] 222 | dim = x.shape[-1] 223 | w0 = w // self.patch_embed.patch_size 224 | h0 = h // self.patch_embed.patch_size 225 | # we add a small number to avoid floating point error in the interpolation 226 | # see discussion at https://github.com/facebookresearch/dino/issues/8 227 | w0, h0 = w0 + 0.1, h0 + 0.1 228 | patch_pos_embed = nn.functional.interpolate( 229 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 230 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 231 | mode='bicubic', 232 | ) 233 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 234 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 235 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) 236 | 237 | def prepare_tokens(self, x): 238 | B, nc, w, h = x.shape 239 | x = self.patch_embed(x) # patch linear embedding 240 | 241 | # add the [CLS] token to the embed patch tokens 242 | cls_tokens = self.cls_token.expand(B, -1, -1) 243 | x = torch.cat((cls_tokens, x), dim=1) 244 | 245 | # add positional encoding to each token 246 | x = x + self.interpolate_pos_encoding(x, w, h) 247 | 248 | return self.pos_drop(x) 249 | 250 | def forward(self, x): 251 | x = self.prepare_tokens(x) 252 | for blk in self.blocks: 253 | x = blk(x) 254 | x = self.norm(x) 255 | return x[:, 0] 256 | 257 | def get_last_selfattention(self, x): 258 | x = self.prepare_tokens(x) 259 | for i, blk in enumerate(self.blocks): 260 | if i < len(self.blocks) - 1: 261 | x = blk(x) 262 | else: 263 | # return attention of the last block 264 | return blk(x, return_attention=True) 265 | 266 | def get_intermediate_layers(self, x, n=1): 267 | x = self.prepare_tokens(x) 268 | # we return the output tokens from the `n` last blocks 269 | output = [] 270 | for i, blk in enumerate(self.blocks): 271 | x = blk(x) 272 | if len(self.blocks) - i <= n: 273 | output.append(self.norm(x)) 274 | return output 275 | 276 | 277 | def vit_tiny(patch_size=16, **kwargs): 278 | model = VisionTransformer( 279 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, 280 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 281 | return model 282 | 283 | 284 | def vit_small(patch_size=16, **kwargs): 285 | model = VisionTransformer( 286 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, 287 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 288 | return model 289 | 290 | 291 | def vit_base(patch_size=16, **kwargs): 292 | model = VisionTransformer( 293 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 294 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 295 | return model 296 | 297 | 298 | class DINOHead(nn.Module): 299 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): 300 | super().__init__() 301 | nlayers = max(nlayers, 1) 302 | if nlayers == 1: 303 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 304 | else: 305 | layers = [nn.Linear(in_dim, hidden_dim)] 306 | if use_bn: 307 | layers.append(nn.BatchNorm1d(hidden_dim)) 308 | layers.append(nn.GELU()) 309 | for _ in range(nlayers - 2): 310 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 311 | if use_bn: 312 | layers.append(nn.BatchNorm1d(hidden_dim)) 313 | layers.append(nn.GELU()) 314 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 315 | self.mlp = nn.Sequential(*layers) 316 | self.apply(self._init_weights) 317 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 318 | self.last_layer.weight_g.data.fill_(1) 319 | if norm_last_layer: 320 | self.last_layer.weight_g.requires_grad = False 321 | 322 | def _init_weights(self, m): 323 | if isinstance(m, nn.Linear): 324 | trunc_normal_(m.weight, std=.02) 325 | if isinstance(m, nn.Linear) and m.bias is not None: 326 | nn.init.constant_(m.bias, 0) 327 | 328 | def forward(self, x): 329 | x = self.mlp(x) 330 | x = nn.functional.normalize(x, dim=-1, p=2) 331 | x = self.last_layer(x) 332 | return x -------------------------------------------------------------------------------- /my_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/my_utils/__init__.py -------------------------------------------------------------------------------- /my_utils/cluster_and_log_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | #import torch.distributed as dist 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment as linear_assignment 5 | 6 | 7 | # def all_sum_item(item): 8 | # item = torch.tensor(item).cuda() 9 | # dist.all_reduce(item) 10 | # return item.item() 11 | 12 | 13 | def old_cluster_acc(y_true, y_pred, return_ind=False): 14 | """ 15 | https://github.com/sgvaze/generalized-category-discovery/blob/main/project_utils/cluster_utils.py#L39 16 | used ONLY for estimating # of novel categories in `estimate_k.py` 17 | 18 | Calculate clustering accuracy. Require scikit-learn installed 19 | 20 | # Arguments 21 | y: true labels, numpy.array with shape `(n_samples,)` 22 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 23 | 24 | # Return 25 | accuracy, in [0,1] 26 | """ 27 | y_true = y_true.astype(int) 28 | assert y_pred.size == y_true.size 29 | D = max(y_pred.max(), y_true.max()) + 1 30 | w = np.zeros((D, D), dtype=int) 31 | for i in range(y_pred.size): 32 | w[y_pred[i], y_true[i]] += 1 33 | 34 | ind = linear_assignment(w.max() - w) 35 | ind = np.vstack(ind).T 36 | 37 | if return_ind: 38 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w 39 | else: 40 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size 41 | 42 | 43 | 44 | def split_cluster_acc_v2(y_true, y_pred, mask): 45 | """ 46 | Calculate clustering accuracy. Require scikit-learn installed 47 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 48 | 49 | # Arguments 50 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 51 | y: true labels, numpy.array with shape `(n_samples,)` 52 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 53 | 54 | # Return 55 | accuracy, in [0,1] 56 | """ 57 | y_true = y_true.astype(int) 58 | 59 | old_classes_gt = set(y_true[mask]) 60 | new_classes_gt = set(y_true[~mask]) 61 | 62 | assert y_pred.size == y_true.size 63 | D = max(y_pred.max(), y_true.max()) + 1 64 | w = np.zeros((D, D), dtype=int) 65 | for i in range(y_pred.size): 66 | w[y_pred[i], y_true[i]] += 1 67 | 68 | ind = linear_assignment(w.max() - w) 69 | ind = np.vstack(ind).T 70 | 71 | ind_map = {j: i for i, j in ind} 72 | total_acc = sum([w[i, j] for i, j in ind]) 73 | total_instances = y_pred.size 74 | # try: 75 | # if dist.get_world_size() > 0: 76 | # total_acc = all_sum_item(total_acc) 77 | # total_instances = all_sum_item(total_instances) 78 | # except: 79 | # pass 80 | total_acc /= total_instances 81 | 82 | old_acc = 0 83 | total_old_instances = 0 84 | for i in old_classes_gt: 85 | old_acc += w[ind_map[i], i] 86 | total_old_instances += sum(w[:, i]) 87 | 88 | # try: 89 | # if dist.get_world_size() > 0: 90 | # old_acc = all_sum_item(old_acc) 91 | # total_old_instances = all_sum_item(total_old_instances) 92 | # except: 93 | # pass 94 | old_acc /= total_old_instances 95 | 96 | new_acc = 0 97 | total_new_instances = 0 98 | for i in new_classes_gt: 99 | new_acc += w[ind_map[i], i] 100 | total_new_instances += sum(w[:, i]) 101 | 102 | # try: 103 | # if dist.get_world_size() > 0: 104 | # new_acc = all_sum_item(new_acc) 105 | # total_new_instances = all_sum_item(total_new_instances) 106 | # except: 107 | # pass 108 | new_acc /= total_new_instances 109 | 110 | return total_acc, old_acc, new_acc 111 | 112 | 113 | def split_cluster_acc_v2_balanced(y_true, y_pred, mask): 114 | """ 115 | Calculate clustering accuracy. Require scikit-learn installed 116 | First compute linear assignment on all data, then look at how good the accuracy is on subsets 117 | 118 | # Arguments 119 | mask: Which instances come from old classes (True) and which ones come from new classes (False) 120 | y: true labels, numpy.array with shape `(n_samples,)` 121 | y_pred: predicted labels, numpy.array with shape `(n_samples,)` 122 | 123 | # Return 124 | accuracy, in [0,1] 125 | """ 126 | y_true = y_true.astype(int) 127 | 128 | old_classes_gt = set(y_true[mask]) 129 | new_classes_gt = set(y_true[~mask]) 130 | 131 | assert y_pred.size == y_true.size 132 | D = max(y_pred.max(), y_true.max()) + 1 133 | w = np.zeros((D, D), dtype=int) 134 | for i in range(y_pred.size): 135 | w[y_pred[i], y_true[i]] += 1 136 | 137 | ind = linear_assignment(w.max() - w) 138 | ind = np.vstack(ind).T 139 | 140 | ind_map = {j: i for i, j in ind} 141 | 142 | old_acc = np.zeros(len(old_classes_gt)) 143 | total_old_instances = np.zeros(len(old_classes_gt)) 144 | for idx, i in enumerate(old_classes_gt): 145 | old_acc[idx] += w[ind_map[i], i] 146 | total_old_instances[idx] += sum(w[:, i]) 147 | 148 | new_acc = np.zeros(len(new_classes_gt)) 149 | total_new_instances = np.zeros(len(new_classes_gt)) 150 | for idx, i in enumerate(new_classes_gt): 151 | new_acc[idx] += w[ind_map[i], i] 152 | total_new_instances[idx] += sum(w[:, i]) 153 | 154 | # try: 155 | # if dist.get_world_size() > 0: 156 | # old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda() 157 | # dist.all_reduce(old_acc), dist.all_reduce(new_acc) 158 | # dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances) 159 | # old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy() 160 | # total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy() 161 | # except: 162 | # pass 163 | 164 | total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances]) 165 | old_acc /= total_old_instances 166 | new_acc /= total_new_instances 167 | total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean() 168 | return total_acc, old_acc, new_acc 169 | 170 | 171 | EVAL_FUNCS = { 172 | 'v2': split_cluster_acc_v2, 173 | 'v2b': split_cluster_acc_v2_balanced 174 | } 175 | 176 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None, 177 | print_output=True, args=None): 178 | 179 | """ 180 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results 181 | 182 | :param y_true: GT labels 183 | :param y_pred: Predicted indices 184 | :param mask: Which instances belong to Old and New classes 185 | :param T: Epoch 186 | :param eval_funcs: Which evaluation functions to use 187 | :param save_name: What are we evaluating ACC on 188 | :param writer: Tensorboard logger 189 | :return: 190 | """ 191 | 192 | mask = mask.astype(bool) 193 | y_true = y_true.astype(int) 194 | y_pred = y_pred.astype(int) 195 | 196 | for i, f_name in enumerate(eval_funcs): 197 | 198 | acc_f = EVAL_FUNCS[f_name] 199 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask) 200 | log_name = f'{save_name}_{f_name}' 201 | 202 | if i == 0: 203 | to_return = (all_acc, old_acc, new_acc) 204 | 205 | if print_output: 206 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}' 207 | # try: 208 | # if dist.get_rank() == 0: 209 | # try: 210 | # args.logger.info(print_str) 211 | # except: 212 | # print(print_str) 213 | # except: 214 | # pass 215 | 216 | return to_return 217 | -------------------------------------------------------------------------------- /my_utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import inspect 4 | 5 | from datetime import datetime 6 | from loguru import logger 7 | import time 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_experiment(args, runner_name=None, exp_id=None): 31 | # Get filepath of calling script 32 | if runner_name is None: 33 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] 34 | 35 | #root_dir = os.path.join(args.exp_root, *runner_name) 36 | root_dir = os.path.join(args.exp_root, args.dataset_name) 37 | 38 | if not os.path.exists(root_dir): 39 | os.makedirs(root_dir) 40 | 41 | # Either generate a unique experiment ID, or use one which is passed 42 | if exp_id is None: 43 | 44 | if args.exp_name is None: 45 | raise ValueError("Need to specify the experiment name") 46 | # Unique identifier for experiment 47 | # now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \ 48 | # datetime.now().strftime("%S.%f")[:-3] + ')' 49 | #now = args.exp_name + '_' + str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 50 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 51 | 52 | #log_dir = os.path.join(root_dir, 'log', now) 53 | log_dir = os.path.join(root_dir, now) 54 | while os.path.exists(log_dir): 55 | # now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \ 56 | # datetime.now().strftime("%S.%f")[:-3] + ')' 57 | #now = args.exp_name + '_' + str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 58 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 59 | 60 | #log_dir = os.path.join(root_dir, 'log', now) 61 | log_dir = os.path.join(root_dir, now) 62 | 63 | else: 64 | 65 | #log_dir = os.path.join(root_dir, 'log', f'{exp_id}') 66 | log_dir = os.path.join(root_dir, f'{exp_id}') 67 | 68 | if not os.path.exists(log_dir): 69 | os.makedirs(log_dir) 70 | 71 | 72 | #logger.add(os.path.join(log_dir, 'log.txt')) 73 | logger.add(os.path.join(log_dir, 'log.txt'), enqueue=True) 74 | args.logger = logger 75 | args.log_dir = log_dir 76 | 77 | # Instantiate directory to save models to 78 | model_root_dir = os.path.join(args.log_dir, 'checkpoints') 79 | if not os.path.exists(model_root_dir): 80 | os.mkdir(model_root_dir) 81 | 82 | args.model_dir = model_root_dir 83 | args.model_path = os.path.join(args.model_dir, 'model.pt') 84 | 85 | print(f'Experiment saved to: {args.log_dir}') 86 | 87 | hparam_dict = {} 88 | 89 | for k, v in vars(args).items(): 90 | if isinstance(v, (int, float, str, bool, torch.Tensor)): 91 | hparam_dict[k] = v 92 | 93 | print(runner_name) 94 | 95 | # print and save args 96 | print(args) 97 | save_args_path = os.path.join(log_dir, 'args.txt') 98 | f_args = open(save_args_path, 'w') 99 | f_args.write('args: \n') 100 | f_args.write(str(vars(args))) 101 | f_args.close() 102 | 103 | return args 104 | 105 | 106 | # estimate # of novel class K 107 | def init_experiment_estimate_k(args, runner_name=None, exp_id=None): 108 | # Get filepath of calling script 109 | if runner_name is None: 110 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:] 111 | 112 | #root_dir = os.path.join(args.exp_root, *runner_name) 113 | root_dir = os.path.join(args.exp_root, args.dataset_name) 114 | 115 | if not os.path.exists(root_dir): 116 | os.makedirs(root_dir) 117 | 118 | # Either generate a unique experiment ID, or use one which is passed 119 | if exp_id is None: 120 | 121 | if args.exp_name is None: 122 | raise ValueError("Need to specify the experiment name") 123 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 124 | 125 | log_dir = os.path.join(root_dir, now) 126 | now_k = now + '-k' + str(args.estimate_novel_k) 127 | log_dir = os.path.join(root_dir, now_k) 128 | while os.path.exists(log_dir): 129 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime())) 130 | 131 | now_k = now + '-k' + str(args.estimate_novel_k) 132 | log_dir = os.path.join(root_dir, now_k) 133 | 134 | else: 135 | 136 | log_dir = os.path.join(root_dir, f'{exp_id}') 137 | 138 | if not os.path.exists(log_dir): 139 | os.makedirs(log_dir) 140 | 141 | 142 | #logger.add(os.path.join(log_dir, 'log.txt')) 143 | logger.add(os.path.join(log_dir, 'log.txt'), enqueue=True) 144 | args.logger = logger 145 | args.log_dir = log_dir 146 | 147 | print(f'Experiment saved to: {args.log_dir}') 148 | 149 | hparam_dict = {} 150 | 151 | for k, v in vars(args).items(): 152 | if isinstance(v, (int, float, str, bool, torch.Tensor)): 153 | hparam_dict[k] = v 154 | 155 | print(runner_name) 156 | 157 | # print and save args 158 | print(args) 159 | save_args_path = os.path.join(log_dir, 'args.txt') 160 | f_args = open(save_args_path, 'w') 161 | f_args.write('args: \n') 162 | f_args.write(str(vars(args))) 163 | f_args.close() 164 | 165 | return args 166 | 167 | 168 | 169 | class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler): 170 | 171 | def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None, 172 | replacement=True, generator=None): 173 | super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank) 174 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \ 175 | num_samples <= 0: 176 | raise ValueError("num_samples should be a positive integer " 177 | "value, but got num_samples={}".format(num_samples)) 178 | if not isinstance(replacement, bool): 179 | raise ValueError("replacement should be a boolean value, but got " 180 | "replacement={}".format(replacement)) 181 | self.weights = torch.as_tensor(weights, dtype=torch.double) 182 | self.num_samples = num_samples 183 | self.replacement = replacement 184 | self.generator = generator 185 | self.weights = self.weights[self.rank::self.num_replicas] 186 | self.num_samples = self.num_samples // self.num_replicas 187 | 188 | def __iter__(self): 189 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) 190 | rand_tensor = self.rank + rand_tensor * self.num_replicas 191 | yield from iter(rand_tensor.tolist()) 192 | 193 | def __len__(self): 194 | return self.num_samples 195 | -------------------------------------------------------------------------------- /my_utils/ood_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import sklearn.metrics as sk 5 | from tqdm import tqdm 6 | 7 | recall_level_default = 0.95 8 | 9 | 10 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 11 | """Use high precision for cumsum and check that final value matches sum 12 | Parameters 13 | ---------- 14 | arr : array-like 15 | To be cumulatively summed as flat 16 | rtol : float 17 | Relative tolerance, see ``np.allclose`` 18 | atol : float 19 | Absolute tolerance, see ``np.allclose`` 20 | """ 21 | out = np.cumsum(arr, dtype=np.float64) 22 | expected = np.sum(arr, dtype=np.float64) 23 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 24 | raise RuntimeError('cumsum was found to be unstable: ' 25 | 'its last element does not correspond to sum') 26 | return out 27 | 28 | 29 | concat = lambda x: np.concatenate(x, axis=0) 30 | to_np = lambda x: x.data.cpu().numpy() 31 | 32 | 33 | def get_ood_scores_in(loader, model, args): 34 | _score = [] 35 | _right_score = [] 36 | _wrong_score = [] 37 | 38 | with torch.no_grad(): 39 | for batch_idx, (images, labels, _) in enumerate(tqdm(loader)): 40 | 41 | images = images.cuda(non_blocking=True) 42 | feats, _, logits, prototypes = model(images) 43 | 44 | #output = model(data) 45 | #smax = to_np(F.softmax(logits, dim=1)) 46 | smax = to_np(F.softmax(logits / args.temp_logits, dim=1)) # NOTE!!! 47 | output_ = to_np(logits) 48 | 49 | # if args.use_xent: 50 | # _score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1)))) 51 | # else: 52 | # _score.append(-np.max(smax, axis=1)) 53 | 54 | if args.score == 'energy': 55 | _score.append(-to_np((args.T * torch.logsumexp(logits / args.T, dim=1)))) 56 | elif args.score == 'mls': 57 | _score.append(-np.max(output_, axis=1)) 58 | elif args.score == 'xent': 59 | #_score.append(to_np((logits.mean(1) - torch.logsumexp(logits, dim=1)))) 60 | _score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))) # NOTE!!! 61 | elif args.score == 'proto': 62 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0] 63 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 64 | smax_sort = np.sort(smax, axis=1) 65 | top2_div = smax_sort[:, -1] / (smax_sort[:, -2] + 1e-6) 66 | _score.append(-top2_div) # NOTE!!! 67 | elif args.score == 'margin': 68 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0] 69 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 70 | smax_sort = np.sort(smax, axis=1) 71 | top2_margin = smax_sort[:, -1] - smax_sort[:, -2] 72 | _score.append(-top2_margin) # NOTE!!! 73 | else: 74 | _score.append(-np.max(smax, axis=1)) 75 | 76 | preds = np.argmax(smax, axis=1) 77 | targets = labels.numpy().squeeze() 78 | right_indices = preds == targets 79 | wrong_indices = np.invert(right_indices) 80 | 81 | if args.score == 'xent': 82 | _right_score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))[right_indices]) 83 | _wrong_score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))[wrong_indices]) 84 | else: 85 | _right_score.append(-np.max(smax[right_indices], axis=1)) 86 | _wrong_score.append(-np.max(smax[wrong_indices], axis=1)) 87 | 88 | return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy() 89 | 90 | 91 | def get_ood_scores(loader, model, ood_num_examples, args): 92 | _score = [] 93 | 94 | with torch.no_grad(): 95 | for batch_idx, (images, labels) in enumerate(loader): 96 | if batch_idx >= ood_num_examples // args.batch_size: 97 | break 98 | 99 | images = images.cuda(non_blocking=True) 100 | feats, _, logits, prototypes = model(images) 101 | 102 | #output = model(data) 103 | #smax = to_np(F.softmax(logits, dim=1)) 104 | smax = to_np(F.softmax(logits / args.temp_logits, dim=1)) # NOTE!!! 105 | output_ = to_np(logits) 106 | 107 | # if args.use_xent: 108 | # _score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1)))) 109 | # else: 110 | # _score.append(-np.max(smax, axis=1)) 111 | 112 | if args.score == 'energy': 113 | _score.append(-to_np((args.T * torch.logsumexp(logits / args.T, dim=1)))) 114 | elif args.score == 'mls': 115 | _score.append(-np.max(output_, axis=1)) 116 | elif args.score == 'xent': 117 | #_score.append(to_np((logits.mean(1) - torch.logsumexp(logits, dim=1)))) 118 | _score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))) # NOTE!!! 119 | elif args.score == 'proto': 120 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0] 121 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 122 | smax_sort = np.sort(smax, axis=1) 123 | top2_div = smax_sort[:, -1] / (smax_sort[:, -2] + 1e-6) 124 | _score.append(-top2_div) # NOTE!!! 125 | elif args.score == 'margin': 126 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0] 127 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6) 128 | smax_sort = np.sort(smax, axis=1) 129 | top2_margin = smax_sort[:, -1] - smax_sort[:, -2] 130 | _score.append(-top2_margin) # NOTE!!! 131 | else: 132 | _score.append(-np.max(smax, axis=1)) 133 | 134 | return concat(_score)[:ood_num_examples].copy() 135 | 136 | 137 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None): 138 | classes = np.unique(y_true) 139 | if (pos_label is None and 140 | not (np.array_equal(classes, [0, 1]) or 141 | np.array_equal(classes, [-1, 1]) or 142 | np.array_equal(classes, [0]) or 143 | np.array_equal(classes, [-1]) or 144 | np.array_equal(classes, [1]))): 145 | raise ValueError("Data is not binary and pos_label is not specified") 146 | elif pos_label is None: 147 | pos_label = 1. 148 | 149 | # make y_true a boolean vector 150 | y_true = (y_true == pos_label) 151 | 152 | # sort scores and corresponding truth values 153 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 154 | y_score = y_score[desc_score_indices] 155 | y_true = y_true[desc_score_indices] 156 | 157 | # y_score typically has many tied values. Here we extract 158 | # the indices associated with the distinct values. We also 159 | # concatenate a value for the end of the curve. 160 | distinct_value_indices = np.where(np.diff(y_score))[0] 161 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 162 | 163 | # accumulate the true positives with decreasing threshold 164 | tps = stable_cumsum(y_true)[threshold_idxs] 165 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 166 | 167 | thresholds = y_score[threshold_idxs] 168 | 169 | recall = tps / tps[-1] 170 | 171 | last_ind = tps.searchsorted(tps[-1]) 172 | sl = slice(last_ind, None, -1) # [last_ind::-1] 173 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 174 | 175 | cutoff = np.argmin(np.abs(recall - recall_level)) 176 | 177 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) 178 | 179 | 180 | def get_measures(_pos, _neg, recall_level=recall_level_default): 181 | pos = np.array(_pos[:]).reshape((-1, 1)) 182 | neg = np.array(_neg[:]).reshape((-1, 1)) 183 | examples = np.squeeze(np.vstack((pos, neg))) 184 | labels = np.zeros(len(examples), dtype=np.int32) 185 | labels[:len(pos)] += 1 186 | 187 | auroc = sk.roc_auc_score(labels, examples) 188 | aupr = sk.average_precision_score(labels, examples) 189 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 190 | 191 | return auroc, aupr, fpr 192 | 193 | 194 | def print_measures(auroc, aupr_in, aupr_out, fpr_in, fpr_out, recall_level=recall_level_default): 195 | print('FPR(IN){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_in)) 196 | print('FPR(OUT){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_out)) 197 | print('AUROC: {:.2f}'.format(100 * auroc)) 198 | print('AUPR(IN): {:.2f}'.format(100 * aupr_in)) 199 | print('AUPR(OUT): {:.2f}'.format(100 * aupr_out)) 200 | 201 | 202 | def write_measures(auroc, aupr_in, aupr_out, fpr_in, fpr_out, file_path, recall_level=recall_level_default): 203 | with open(file_path, 'a+') as f_log: 204 | f_log.write('FPR(IN){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_in)) 205 | f_log.write('\n') 206 | f_log.write('FPR(OUT){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_out)) 207 | f_log.write('\n') 208 | f_log.write('AUROC: {:.2f}'.format(100 * auroc)) 209 | f_log.write('\n') 210 | f_log.write('AUPR(IN): {:.2f}'.format(100 * aupr_in)) 211 | f_log.write('\n') 212 | f_log.write('AUPR(OUT): {:.2f}'.format(100 * aupr_out)) 213 | f_log.write('\n') 214 | 215 | 216 | def print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, recall_level=recall_level_default): 217 | print('FPR(IN){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_in), 100*np.std(fprs_in))) 218 | print('FPR(OUT){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_out), 100*np.std(fprs_out))) 219 | print('AUROC: {:.2f} +/- {:.2f}'.format(100*np.mean(aurocs), 100*np.std(aurocs))) 220 | print('AUPR(IN): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_in), 100*np.std(auprs_in))) 221 | print('AUPR(OUT): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_out), 100*np.std(auprs_out))) 222 | 223 | 224 | def write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path, recall_level=recall_level_default): 225 | with open(file_path, 'a+') as f_log: 226 | f_log.write('FPR(IN){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_in), 100*np.std(fprs_in))) 227 | f_log.write('\n') 228 | f_log.write('FPR(OUT){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_out), 100*np.std(fprs_out))) 229 | f_log.write('\n') 230 | f_log.write('AUROC: {:.2f} +/- {:.2f}'.format(100*np.mean(aurocs), 100*np.std(aurocs))) 231 | f_log.write('\n') 232 | f_log.write('AUPR(IN): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_in), 100*np.std(auprs_in))) 233 | f_log.write('\n') 234 | f_log.write('AUPR(OUT): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_out), 100*np.std(auprs_out))) 235 | f_log.write('\n') 236 | -------------------------------------------------------------------------------- /test_ood_cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset 9 | from torchvision import datasets, transforms 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data.augmentations import get_transform 14 | from data.get_datasets import get_datasets, get_class_splits 15 | 16 | from config import exp_root 17 | from models.model import DINOHead_k 18 | from models.model import ContrastiveLearningViewGenerator, get_params_groups 19 | from my_utils.ood_utils import get_ood_scores_in, get_ood_scores, get_measures, print_measures, write_measures, print_measures_with_std, write_measures_with_std 20 | 21 | 22 | def get_and_print_results(ood_loader, model, in_score, args): 23 | aurocs, auprs_in, auprs_out, fprs_in, fprs_out = [], [], [], [], [] 24 | 25 | for _ in range(args.num_to_avg): 26 | out_score = get_ood_scores(ood_loader, model, OOD_NUM_EXAMPLES, args) 27 | measures_in = get_measures(-in_score, -out_score) 28 | measures_out = get_measures(out_score, in_score) # OE's defines out samples as positive 29 | 30 | auroc = measures_in[0]; aupr_in = measures_in[1]; aupr_out = measures_out[1]; fpr_in = measures_in[2]; fpr_out = measures_out[2] 31 | aurocs.append(auroc); auprs_in.append(aupr_in); auprs_out.append(aupr_out); fprs_in.append(fpr_in); fprs_out.append(fpr_out) 32 | 33 | if args.num_to_avg >= 5: 34 | print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out) 35 | write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path=args.ood_log_path) 36 | else: 37 | print_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out)) 38 | write_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out), file_path=args.ood_log_path) 39 | 40 | return (auroc, aupr_in, aupr_out, fpr_in, fpr_out) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | parser.add_argument('--batch_size', default=128, type=int) 47 | parser.add_argument('--num_workers', default=4, type=int) 48 | 49 | parser.add_argument('--warmup_model_dir', type=str, default=None) 50 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') 51 | parser.add_argument('--ckpts_date', type=str, default=None) 52 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 53 | parser.add_argument('--use_ssb_splits', action='store_true', default=True) 54 | #parser.add_argument('--init_prototypes', action='store_true', default=False) 55 | 56 | #parser.add_argument('--grad_from_block', type=int, default=11) 57 | parser.add_argument('--exp_root', type=str, default=exp_root) 58 | parser.add_argument('--ood_log_path', type=str, default='OOD_results') 59 | parser.add_argument('--transform', type=str, default='imagenet') 60 | parser.add_argument('--n_views', default=2, type=int) 61 | 62 | parser.add_argument('--score', type=str, default='msp', help='OOD detection score function: [msp, mls, energy, xent]') 63 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature') 64 | parser.add_argument('--T', default=1., type=float, help='temperature: energy|Odin') 65 | parser.add_argument('--num_to_avg', type=int, default=10, help='Average measures across num_to_avg runs.') 66 | 67 | # ---------------------- 68 | # INIT 69 | # ---------------------- 70 | args = parser.parse_args() 71 | device = torch.device('cuda:0') 72 | args = get_class_splits(args) 73 | 74 | args.num_labeled_classes = len(args.train_classes) 75 | args.num_unlabeled_classes = len(args.unlabeled_classes) 76 | 77 | #init_experiment(args, runner_name=['ProtoGCD']) 78 | #args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') 79 | args.ood_log_path = os.path.join(args.ood_log_path, args.dataset_name) 80 | if not os.path.exists(args.ood_log_path): 81 | os.makedirs(args.ood_log_path) 82 | args.ood_log_path = os.path.join(args.ood_log_path, args.ckpts_date + '-' + args.score + '-T' + str(args.temp_logits) + '.txt') 83 | 84 | torch.backends.cudnn.benchmark = True 85 | 86 | # ---------------------- 87 | # BASE MODEL 88 | # ---------------------- 89 | args.interpolation = 3 90 | args.crop_pct = 0.875 91 | 92 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 93 | 94 | # if args.warmup_model_dir is not None: 95 | # args.logger.info(f'Loading weights from {args.warmup_model_dir}') 96 | # backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 97 | 98 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 99 | args.image_size = 224 100 | args.feat_dim = 768 101 | args.num_mlp_layers = 3 102 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes 103 | 104 | 105 | # -------------------- 106 | # CONTRASTIVE TRANSFORM 107 | # -------------------- 108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 110 | # -------------------- 111 | # DATASETS 112 | # -------------------- 113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets_ = get_datasets(args.dataset_name, 114 | train_transform, 115 | test_transform, 116 | args) 117 | 118 | # -------------------- 119 | # SAMPLER 120 | # Sampler which balances labelled and unlabelled examples in each batch 121 | # -------------------- 122 | label_len = len(train_dataset.labelled_dataset) 123 | unlabelled_len = len(train_dataset.unlabelled_dataset) 124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 125 | sample_weights = torch.DoubleTensor(sample_weights) 126 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 127 | 128 | # -------------------- 129 | # DATALOADERS 130 | # -------------------- 131 | # train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 132 | # sampler=sampler, drop_last=True, pin_memory=True) 133 | # test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 134 | # batch_size=256, shuffle=False, pin_memory=False) 135 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, 136 | batch_size=256, shuffle=False, pin_memory=False) 137 | 138 | OOD_NUM_EXAMPLES = len(test_dataset) // 5 # NOTE! NOT test_loader_labelled! 139 | print(OOD_NUM_EXAMPLES) 140 | 141 | # ---------------------- 142 | # PROJECTION HEAD 143 | # ---------------------- 144 | projector = DINOHead_k(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers, 145 | init_prototypes=None, num_labeled_classes=args.num_labeled_classes) 146 | model = nn.Sequential(backbone, projector).to(device) 147 | 148 | ckpts_base_path = '/lustre/home/sjma/GCD-project/protoGCD-v7/dev_outputs_fix/' 149 | ckpts_path = os.path.join(ckpts_base_path, args.dataset_name, args.ckpts_date, 'checkpoints', 'model_best.pt') 150 | ckpts = torch.load(ckpts_path) 151 | ckpts = ckpts['model'] 152 | print('loading ckpts from %s...' % ckpts_path) 153 | model.load_state_dict(ckpts) 154 | print('successfully load ckpts') 155 | model.eval() 156 | 157 | 158 | # ---------------------- 159 | # TEST OOD 160 | # ---------------------- 161 | print('Using %s as typical data' % args.dataset_name) 162 | with open(args.ood_log_path, 'w+') as f_log: 163 | f_log.write('Using %s as typical data' % args.dataset_name) 164 | f_log.write('\n') 165 | 166 | print(test_transform) 167 | 168 | # ID score 169 | #test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 170 | in_score, right_score, wrong_score = get_ood_scores_in(test_loader_labelled, model, args) 171 | 172 | 173 | # Textures 174 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/dtd/images", transform=test_transform) 175 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 176 | print('\n\nTexture Detection') 177 | with open(args.ood_log_path, 'a+') as f_log: 178 | f_log.write('\n\nTexture Detection') 179 | f_log.write('\n') 180 | get_and_print_results(ood_loader, model, in_score, args) 181 | 182 | 183 | # SVHN 184 | ood_data = datasets.SVHN('/data4/sjma/dataset/SVHN/', split='test', download=False, transform=test_transform) 185 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 186 | print('\n\nSVHN Detection') 187 | with open(args.ood_log_path, 'a+') as f_log: 188 | f_log.write('\n\nSVHN Detection') 189 | f_log.write('\n') 190 | get_and_print_results(ood_loader, model, in_score, args) 191 | 192 | # Places 193 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/places365", transform=test_transform) 194 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 195 | print('\n\nPlaces Detection') 196 | with open(args.ood_log_path, 'a+') as f_log: 197 | f_log.write('\n\nPlaces Detection') 198 | f_log.write('\n') 199 | get_and_print_results(ood_loader, model, in_score, args) 200 | 201 | 202 | # TinyImageNet-R 203 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/Imagenet_resize", transform=test_transform) 204 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 205 | print('\n\nTinyImageNet-resize Detection') 206 | with open(args.ood_log_path, 'a+') as f_log: 207 | f_log.write('\n\nTinyImageNet-resize Detection') 208 | f_log.write('\n') 209 | get_and_print_results(ood_loader, model, in_score, args) 210 | 211 | 212 | # LSUN-R 213 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/LSUN_resize", transform=test_transform) 214 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 215 | print('\n\nLSUN-resize Detection') 216 | with open(args.ood_log_path, 'a+') as f_log: 217 | f_log.write('\n\nLSUN-resize Detection') 218 | f_log.write('\n') 219 | get_and_print_results(ood_loader, model, in_score, args) 220 | 221 | 222 | # iSUN 223 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/iSUN", transform=test_transform) 224 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 225 | print('\n\niSUN Detection') 226 | with open(args.ood_log_path, 'a+') as f_log: 227 | f_log.write('\n\niSUN Detection') 228 | f_log.write('\n') 229 | get_and_print_results(ood_loader, model, in_score, args) 230 | 231 | 232 | # CIFAR data 233 | if args.dataset_name == 'cifar10': 234 | ood_data = datasets.CIFAR100('/data4/sjma/dataset/CIFAR/', train=False, transform=test_transform) 235 | else: 236 | ood_data = datasets.CIFAR10('/data4/sjma/dataset/CIFAR/', train=False, transform=test_transform) 237 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 238 | print('\n\nCIFAR-100 Detection') if args.dataset_name == 'cifar10' else print('\n\nCIFAR-10 Detection') 239 | with open(args.ood_log_path, 'a+') as f_log: 240 | f_log.write('\n\nCIFAR-100 Detection') if args.dataset_name == 'cifar10' else f_log.write('\n\nCIFAR-10 Detection') 241 | f_log.write('\n') 242 | get_and_print_results(ood_loader, model, in_score, args) 243 | -------------------------------------------------------------------------------- /test_ood_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset 9 | from torchvision import datasets, transforms 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data.augmentations import get_transform 14 | from data.get_datasets import get_datasets, get_class_splits 15 | 16 | from config import exp_root 17 | from models.model import DINOHead_k 18 | from models.model import ContrastiveLearningViewGenerator, get_params_groups 19 | from my_utils.ood_utils import get_ood_scores_in, get_ood_scores, get_measures, print_measures, write_measures, print_measures_with_std, write_measures_with_std 20 | 21 | 22 | def get_and_print_results(ood_loader, model, in_score, args): 23 | aurocs, auprs_in, auprs_out, fprs_in, fprs_out = [], [], [], [], [] 24 | 25 | for _ in range(args.num_to_avg): 26 | out_score = get_ood_scores(ood_loader, model, OOD_NUM_EXAMPLES, args) 27 | measures_in = get_measures(-in_score, -out_score) 28 | measures_out = get_measures(out_score, in_score) # OE's defines out samples as positive 29 | 30 | auroc = measures_in[0]; aupr_in = measures_in[1]; aupr_out = measures_out[1]; fpr_in = measures_in[2]; fpr_out = measures_out[2] 31 | aurocs.append(auroc); auprs_in.append(aupr_in); auprs_out.append(aupr_out); fprs_in.append(fpr_in); fprs_out.append(fpr_out) 32 | 33 | if args.num_to_avg >= 5: 34 | print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out) 35 | write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path=args.ood_log_path) 36 | else: 37 | print_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out)) 38 | write_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out), file_path=args.ood_log_path) 39 | 40 | return (auroc, aupr_in, aupr_out, fpr_in, fpr_out) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | parser.add_argument('--batch_size', default=128, type=int) 47 | parser.add_argument('--num_workers', default=4, type=int) 48 | 49 | parser.add_argument('--warmup_model_dir', type=str, default=None) 50 | parser.add_argument('--dataset_name', type=str, default='imagenet_100', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') 51 | parser.add_argument('--ckpts_date', type=str, default=None) 52 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 53 | parser.add_argument('--use_ssb_splits', action='store_true', default=True) 54 | #parser.add_argument('--init_prototypes', action='store_true', default=False) 55 | 56 | #parser.add_argument('--grad_from_block', type=int, default=11) 57 | parser.add_argument('--exp_root', type=str, default=exp_root) 58 | parser.add_argument('--ood_log_path', type=str, default='OOD_results') 59 | parser.add_argument('--transform', type=str, default='imagenet') 60 | parser.add_argument('--n_views', default=2, type=int) 61 | 62 | parser.add_argument('--score', type=str, default='msp', help='OOD detection score function: [msp, mls, energy, xent]') 63 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature') 64 | parser.add_argument('--T', default=1., type=float, help='temperature: energy|Odin') 65 | parser.add_argument('--num_to_avg', type=int, default=10, help='Average measures across num_to_avg runs.') 66 | 67 | # ---------------------- 68 | # INIT 69 | # ---------------------- 70 | args = parser.parse_args() 71 | device = torch.device('cuda:0') 72 | args = get_class_splits(args) 73 | 74 | args.num_labeled_classes = len(args.train_classes) 75 | args.num_unlabeled_classes = len(args.unlabeled_classes) 76 | 77 | #init_experiment(args, runner_name=['ProtoGCD']) 78 | #args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') 79 | args.ood_log_path = os.path.join(args.ood_log_path, args.dataset_name) 80 | if not os.path.exists(args.ood_log_path): 81 | os.makedirs(args.ood_log_path) 82 | args.ood_log_path = os.path.join(args.ood_log_path, args.ckpts_date + '-' + args.score + '-T' + str(args.temp_logits) + '.txt') 83 | 84 | torch.backends.cudnn.benchmark = True 85 | 86 | # ---------------------- 87 | # BASE MODEL 88 | # ---------------------- 89 | args.interpolation = 3 90 | args.crop_pct = 0.875 91 | 92 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 93 | 94 | # if args.warmup_model_dir is not None: 95 | # args.logger.info(f'Loading weights from {args.warmup_model_dir}') 96 | # backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 97 | 98 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 99 | args.image_size = 224 100 | args.feat_dim = 768 101 | args.num_mlp_layers = 3 102 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes 103 | 104 | 105 | # -------------------- 106 | # CONTRASTIVE TRANSFORM 107 | # -------------------- 108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 110 | # -------------------- 111 | # DATASETS 112 | # -------------------- 113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets_ = get_datasets(args.dataset_name, 114 | train_transform, 115 | test_transform, 116 | args) 117 | 118 | # -------------------- 119 | # SAMPLER 120 | # Sampler which balances labelled and unlabelled examples in each batch 121 | # -------------------- 122 | label_len = len(train_dataset.labelled_dataset) 123 | unlabelled_len = len(train_dataset.unlabelled_dataset) 124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 125 | sample_weights = torch.DoubleTensor(sample_weights) 126 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 127 | 128 | # -------------------- 129 | # DATALOADERS 130 | # -------------------- 131 | # train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 132 | # sampler=sampler, drop_last=True, pin_memory=True) 133 | # test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 134 | # batch_size=256, shuffle=False, pin_memory=False) 135 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, 136 | batch_size=256, shuffle=False, pin_memory=False) 137 | 138 | OOD_NUM_EXAMPLES = len(test_dataset) // 5 # NOTE! NOT test_loader_labelled! 139 | print(OOD_NUM_EXAMPLES) 140 | 141 | # ---------------------- 142 | # PROJECTION HEAD 143 | # ---------------------- 144 | projector = DINOHead_k(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers, 145 | init_prototypes=None, num_labeled_classes=args.num_labeled_classes) 146 | model = nn.Sequential(backbone, projector).to(device) 147 | 148 | ckpts_base_path = '/lustre/home/sjma/GCD-project/protoGCD-v7/dev_outputs_fix/' 149 | ckpts_path = os.path.join(ckpts_base_path, args.dataset_name, args.ckpts_date, 'checkpoints', 'model_best.pt') 150 | ckpts = torch.load(ckpts_path) 151 | ckpts = ckpts['model'] 152 | print('loading ckpts from %s...' % ckpts_path) 153 | model.load_state_dict(ckpts) 154 | print('successfully load ckpts') 155 | model.eval() 156 | 157 | 158 | # ---------------------- 159 | # TEST OOD 160 | # ---------------------- 161 | print('Using %s as typical data' % args.dataset_name) 162 | with open(args.ood_log_path, 'w+') as f_log: 163 | f_log.write('Using %s as typical data' % args.dataset_name) 164 | f_log.write('\n') 165 | 166 | print(test_transform) 167 | 168 | # ID score 169 | #test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 170 | in_score, right_score, wrong_score = get_ood_scores_in(test_loader_labelled, model, args) 171 | 172 | 173 | # Textures 174 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/dtd/images", transform=test_transform) 175 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 176 | print('\n\nTexture Detection') 177 | with open(args.ood_log_path, 'a+') as f_log: 178 | f_log.write('\n\nTexture Detection') 179 | f_log.write('\n') 180 | get_and_print_results(ood_loader, model, in_score, args) 181 | 182 | 183 | # Places 184 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/places365", transform=test_transform) 185 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 186 | print('\n\nPlaces Detection') 187 | with open(args.ood_log_path, 'a+') as f_log: 188 | f_log.write('\n\nPlaces Detection') 189 | f_log.write('\n') 190 | get_and_print_results(ood_loader, model, in_score, args) 191 | 192 | 193 | # iNaturalist 194 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/iNaturalist/", transform=test_transform) 195 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 196 | print('\n\niNaturalist Detection') 197 | with open(args.ood_log_path, 'a+') as f_log: 198 | f_log.write('\n\niNaturalist Detection') 199 | f_log.write('\n') 200 | get_and_print_results(ood_loader, model, in_score, args) 201 | 202 | 203 | # ImageNet-O 204 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/ImageNet-O/", transform=test_transform) 205 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 206 | print('\n\nImageNet-O Detection') 207 | with open(args.ood_log_path, 'a+') as f_log: 208 | f_log.write('\n\nImageNet-O Detection') 209 | f_log.write('\n') 210 | get_and_print_results(ood_loader, model, in_score, args) 211 | 212 | 213 | # OpenImage-O 214 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/OpenImage-O/", transform=test_transform) 215 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 216 | print('\n\nOpenImage-O Detection') 217 | with open(args.ood_log_path, 'a+') as f_log: 218 | f_log.write('\n\nOpenImage-O Detection') 219 | f_log.write('\n') 220 | get_and_print_results(ood_loader, model, in_score, args) 221 | 222 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.optim import SGD, lr_scheduler 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from data.augmentations import get_transform 14 | from data.get_datasets import get_datasets, get_class_splits 15 | 16 | from my_utils.general_utils import AverageMeter, init_experiment 17 | from my_utils.cluster_and_log_utils import log_accs_from_preds 18 | from config import exp_root 19 | 20 | from models.model import DINOHead 21 | from models.model import ContrastiveLearningViewGenerator, get_params_groups 22 | from models.loss import info_nce_logits, SupConLoss, DistillLoss_ratio, prototype_separation_loss, entropy_regularization_loss 23 | 24 | 25 | 26 | 27 | def train(student, train_loader, test_loader, unlabelled_train_loader, args): 28 | params_groups = get_params_groups(student) 29 | optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 30 | 31 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR( 32 | optimizer, 33 | T_max=args.epochs, 34 | eta_min=args.lr * 1e-3, 35 | ) 36 | 37 | 38 | distill_criterion = DistillLoss_ratio(num_classes=args.num_labeled_classes + args.num_unlabeled_classes, 39 | wait_ratio_epochs=args.wait_ratio_epochs, 40 | ramp_ratio_teacher_epochs=args.ramp_ratio_teacher_epochs, 41 | nepochs=args.epochs, 42 | ncrops=args.n_views, 43 | init_ratio=args.init_ratio, 44 | final_ratio=args.final_ratio, 45 | temp_logits=args.temp_logits, 46 | temp_teacher_logits=args.temp_teacher_logits, 47 | device=device) 48 | 49 | # inductive 50 | #best_test_acc_ubl = 0 51 | best_test_acc_lab = 0 52 | # transductive 53 | best_train_acc_lab = 0 54 | best_train_acc_ubl = 0 55 | best_train_acc_all = 0 56 | 57 | for epoch in range(args.epochs): 58 | loss_record = AverageMeter() 59 | 60 | student.train() 61 | for batch_idx, batch in enumerate(train_loader): 62 | images, class_labels, uq_idxs, mask_lab = batch 63 | mask_lab = mask_lab[:, 0] 64 | 65 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool() 66 | images = torch.cat(images, dim=0).cuda(non_blocking=True) 67 | 68 | 69 | student_proj, student_out, prototypes = student(images) 70 | teacher_out = student_out.detach() 71 | 72 | # clustering, sup 73 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / args.temp_logits).chunk(2)], dim=0) 74 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0) 75 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels) 76 | 77 | # clustering, unsup 78 | cluster_loss = 0 79 | distill_loss = distill_criterion(student_out, teacher_out, epoch) # NOTE!!! all data 80 | cluster_loss += distill_loss 81 | 82 | entropy_reg_loss = entropy_regularization_loss(student_out, args.temp_logits) 83 | cluster_loss += args.weight_entropy_reg * entropy_reg_loss 84 | 85 | proto_sep_loss = prototype_separation_loss(prototypes=prototypes, temperature=args.temp_logits, device=device) 86 | cluster_loss += args.weight_proto_sep * proto_sep_loss 87 | 88 | # represent learning, unsup 89 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj, temperature=args.temp_unsup_con) 90 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels) 91 | 92 | # representation learning, sup 93 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1) 94 | student_proj = F.normalize(student_proj, dim=-1) 95 | sup_con_labels = class_labels[mask_lab] 96 | sup_con_loss = SupConLoss(temperature=args.temp_sup_con)(student_proj, labels=sup_con_labels) 97 | 98 | pstr = '' 99 | pstr += f'cls_loss: {cls_loss.item():.4f} ' 100 | pstr += f'cluster_loss: {cluster_loss.item():.4f} ' 101 | pstr += f'distill_loss: {distill_loss.item():.4f} ' 102 | pstr += f'entropy_reg_loss: {entropy_reg_loss.item():.4f} ' 103 | pstr += f'proto_sep_loss: {proto_sep_loss.item():.4f} ' 104 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} ' 105 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} ' 106 | 107 | loss = 0 108 | loss += (1 - args.weight_sup) * cluster_loss + args.weight_sup * cls_loss 109 | loss += (1 - args.weight_sup) * contrastive_loss + args.weight_sup * sup_con_loss 110 | 111 | # Train acc 112 | loss_record.update(loss.item(), class_labels.size(0)) 113 | optimizer.zero_grad() 114 | loss.backward() 115 | optimizer.step() 116 | 117 | if batch_idx % args.print_freq == 0: 118 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}' 119 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr)) 120 | 121 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg)) 122 | 123 | args.logger.info('Testing on unlabelled examples in the training data...') 124 | all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args) 125 | args.logger.info('Testing on disjoint test set...') 126 | all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args) 127 | 128 | 129 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 130 | args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test)) 131 | 132 | # Step schedule 133 | exp_lr_scheduler.step() 134 | 135 | save_dict = { 136 | 'model': student.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'epoch': epoch + 1, 139 | } 140 | 141 | torch.save(save_dict, args.model_path) 142 | args.logger.info("model saved to {}.".format(args.model_path)) 143 | 144 | #if new_acc_test > best_test_acc_ubl: 145 | #if old_acc_test > best_test_acc_lab and epoch > 100: 146 | if all_acc > best_train_acc_all: 147 | 148 | #args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...') 149 | args.logger.info(f'Best ACC on all Classes on train set: {all_acc:.4f}...') 150 | args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc)) 151 | 152 | torch.save(save_dict, args.model_path[:-3] + f'_best.pt') 153 | args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt')) 154 | 155 | # inductive 156 | #best_test_acc_ubl = new_acc_test 157 | best_test_acc_lab = old_acc_test 158 | # transductive 159 | best_train_acc_lab = old_acc 160 | best_train_acc_ubl = new_acc 161 | best_train_acc_all = all_acc 162 | 163 | args.logger.info(f'Exp Name: {args.exp_name}') 164 | args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}') 165 | 166 | 167 | def test(model, test_loader, epoch, save_name, args): 168 | 169 | model.eval() 170 | 171 | preds, targets = [], [] 172 | mask = np.array([]) 173 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)): 174 | images = images.cuda(non_blocking=True) 175 | with torch.no_grad(): 176 | _, logits, _ = model(images) 177 | preds.append(logits.argmax(1).cpu().numpy()) 178 | targets.append(label.cpu().numpy()) 179 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label])) 180 | 181 | preds = np.concatenate(preds) 182 | targets = np.concatenate(targets) 183 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask, 184 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name, 185 | args=args) 186 | 187 | return all_acc, old_acc, new_acc 188 | 189 | 190 | if __name__ == "__main__": 191 | 192 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 193 | parser.add_argument('--batch_size', default=128, type=int) 194 | parser.add_argument('--num_workers', default=4, type=int) 195 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p']) 196 | 197 | parser.add_argument('--warmup_model_dir', type=str, default=None) 198 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19') 199 | parser.add_argument('--prop_train_labels', type=float, default=0.5) 200 | parser.add_argument('--use_ssb_splits', action='store_true', default=True) 201 | parser.add_argument('--init_prototypes', action='store_true', default=False) 202 | 203 | parser.add_argument('--grad_from_block', type=int, default=11) 204 | parser.add_argument('--lr', type=float, default=0.1) 205 | parser.add_argument('--gamma', type=float, default=0.1) 206 | parser.add_argument('--momentum', type=float, default=0.9) 207 | parser.add_argument('--weight_decay', type=float, default=1e-4) 208 | parser.add_argument('--epochs', default=200, type=int) 209 | parser.add_argument('--exp_root', type=str, default=exp_root) 210 | parser.add_argument('--transform', type=str, default='imagenet') 211 | parser.add_argument('--n_views', default=2, type=int) 212 | 213 | parser.add_argument('--weight_sup', type=float, default=0.35) 214 | parser.add_argument('--weight_entropy_reg', type=float, default=2) 215 | parser.add_argument('--weight_proto_sep', type=float, default=1) 216 | 217 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature') 218 | parser.add_argument('--temp_teacher_logits', default=0.05, type=float, help='sharpened logits temperature of teacher') 219 | #parser.add_argument('--temp_proto_sep', default=0.1, type=float, help='prototype separation temperature') 220 | parser.add_argument('--temp_sup_con', default=0.07, type=float, help='supervised contrastive loss temperature') 221 | parser.add_argument('--temp_unsup_con', default=1.0, type=float, help='unsupervised contrastive loss temperature') 222 | 223 | parser.add_argument('--wait_ratio_epochs', default=0, type=int, help='Number of warmup epochs for the confidence filter.') 224 | parser.add_argument('--ramp_ratio_teacher_epochs', default=100, type=int, help='Number of warmup epochs for the confidence filter.') 225 | 226 | parser.add_argument('--init_ratio', default=0.2, type=float, help='initial confidence filter ratio') 227 | parser.add_argument('--final_ratio', default=1.0, type=float, help='final confidence filter ratio') 228 | 229 | parser.add_argument('--print_freq', default=10, type=int) 230 | parser.add_argument('--exp_name', default=None, type=str) 231 | 232 | # ---------------------- 233 | # INIT 234 | # ---------------------- 235 | args = parser.parse_args() 236 | device = torch.device('cuda:0') 237 | args = get_class_splits(args) 238 | 239 | args.num_labeled_classes = len(args.train_classes) 240 | args.num_unlabeled_classes = len(args.unlabeled_classes) 241 | args.exp_root = 'dev_outputs_fix' 242 | 243 | init_experiment(args, runner_name=['ProtoGCD']) 244 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results') 245 | 246 | torch.backends.cudnn.benchmark = True 247 | 248 | # ---------------------- 249 | # BASE MODEL 250 | # ---------------------- 251 | args.interpolation = 3 252 | args.crop_pct = 0.875 253 | 254 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 255 | 256 | if args.warmup_model_dir is not None: 257 | args.logger.info(f'Loading weights from {args.warmup_model_dir}') 258 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu')) 259 | 260 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model 261 | args.image_size = 224 262 | args.feat_dim = 768 263 | args.num_mlp_layers = 3 264 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes 265 | 266 | # ---------------------- 267 | # HOW MUCH OF BASE MODEL TO FINETUNE 268 | # ---------------------- 269 | for m in backbone.parameters(): 270 | m.requires_grad = False 271 | 272 | # Only finetune layers from block 'args.grad_from_block' onwards 273 | for name, m in backbone.named_parameters(): 274 | if 'block' in name: 275 | block_num = int(name.split('.')[1]) 276 | if block_num >= args.grad_from_block: 277 | m.requires_grad = True 278 | 279 | 280 | args.logger.info('model build') 281 | 282 | # -------------------- 283 | # CONTRASTIVE TRANSFORM 284 | # -------------------- 285 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args) 286 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views) 287 | # -------------------- 288 | # DATASETS 289 | # -------------------- 290 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name, 291 | train_transform, 292 | test_transform, 293 | args) 294 | 295 | # -------------------- 296 | # SAMPLER 297 | # Sampler which balances labelled and unlabelled examples in each batch 298 | # -------------------- 299 | label_len = len(train_dataset.labelled_dataset) 300 | unlabelled_len = len(train_dataset.unlabelled_dataset) 301 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))] 302 | sample_weights = torch.DoubleTensor(sample_weights) 303 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset)) 304 | 305 | # -------------------- 306 | # DATALOADERS 307 | # -------------------- 308 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False, 309 | sampler=sampler, drop_last=True, pin_memory=True) 310 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers, 311 | batch_size=256, shuffle=False, pin_memory=False) 312 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers, 313 | batch_size=256, shuffle=False, pin_memory=False) 314 | 315 | # -------------------- 316 | # Initialize prototypes 317 | # -------------------- 318 | prototypes_init = None 319 | if args.init_prototypes: 320 | prototype_init_path = './init_prototypes/%s_prototypes.pt' % args.dataset_name 321 | print('load initialized prototypes from: %s' % prototype_init_path) 322 | prototypes_init = torch.load(prototype_init_path) 323 | 324 | # ---------------------- 325 | # PROJECTION HEAD 326 | # ---------------------- 327 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers, 328 | init_prototypes=prototypes_init, num_labeled_classes=args.num_labeled_classes) 329 | model = nn.Sequential(backbone, projector).to(device) 330 | 331 | # ---------------------- 332 | # TRAIN 333 | # ---------------------- 334 | train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args) 335 | --------------------------------------------------------------------------------