├── LICENSE
├── README.md
├── assets
└── method.jpg
├── config.py
├── data
├── augmentations
│ ├── __init__.py
│ ├── cut_out.py
│ └── randaugment.py
├── cifar.py
├── corrupt_data.py
├── cub.py
├── data_utils.py
├── fgvc_aircraft.py
├── get_datasets.py
├── herbarium_19.py
├── imagenet.py
└── stanford_cars.py
├── models
├── __init__.py
├── loss.py
├── model.py
└── vision_transformer.py
├── my_utils
├── __init__.py
├── cluster_and_log_utils.py
├── general_utils.py
└── ood_utils.py
├── test_ood_cifar.py
├── test_ood_imagenet.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Shijie Ma
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery
2 |
3 |
4 |
5 | Official implementation of our TPAMI 2025 paper "ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery".
6 |
7 | 
8 |
9 | ## :running: Running
10 |
11 | ### Dependencies
12 |
13 | ```
14 | loguru
15 | numpy
16 | pandas
17 | scikit_learn
18 | scipy
19 | torch==1.10.0
20 | torchvision==0.11.1
21 | tqdm
22 | ```
23 |
24 | ### Datasets
25 |
26 | We conduct experiments on 7 datasets:
27 |
28 | * Generic datasets: CIFAR-10, CIFAR-100, ImageNet-100
29 | * Fine-grained datasets: [CUB](https://drive.google.com/drive/folders/1kFzIqZL_pEBVR7Ca_8IKibfWoeZc3GT1), [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html), [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/), [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)
30 |
31 | ### Config
32 |
33 | Set paths to datasets in `config.py`
34 |
35 | ### Training ProtoGCD
36 |
37 | CIFAR100:
38 |
39 | ```shell
40 | CUDA_VISIBLE_DEVICES=0 python train_fix.py --dataset_name 'cifar100' --batch_size 128 --epochs 200 --num_workers 4 --use_ssb_splits --weight_decay 5e-5 --lr 0.1 --eval_funcs 'v2' --weight_sup 0.35 --weight_entropy_reg 2 --weight_proto_sep 0.1 --temp_logits 0.1 --temp_teacher_logits 0.05 --wait_ratio_epochs 0 --ramp_ratio_teacher_epochs 100 --init_ratio 0.0 --final_ratio 1.0 --exp_name cifar100_protogcd
41 | ```
42 |
43 | CUB:
44 |
45 | ```shell
46 | CUDA_VISIBLE_DEVICES=0 python train_fix.py --dataset_name 'cub' --batch_size 128 --epochs 200 --num_workers 2 --use_ssb_splits --weight_decay 5e-5 --lr 0.1 --eval_funcs 'v2' --weight_sup 0.35 --weight_entropy_reg 2 --weight_proto_sep 0.05 --temp_logits 0.1 --temp_teacher_logits 0.05 --wait_ratio_epochs 0 --ramp_ratio_teacher_epochs 100 --init_ratio 0.0 --final_ratio 1.0 --exp_name cub_protogcd
47 | ```
48 |
49 | ### Evaluate OOD detection
50 |
51 | CIFAR:
52 |
53 | ```shell
54 | CUDA_VISIBLE_DEVICES=0 python test_ood_cifar.py --dataset_name 'cifar100' --batch_size 128 --num_workers 4 --use_ssb_splits --num_to_avg 10 --score msp --ckpts_date YOUR_CKPTS_NAME --temp_logits 0.1
55 | ```
56 |
57 | ImageNet:
58 |
59 | ```shell
60 | CUDA_VISIBLE_DEVICES=0 python test_ood_imagenet.py --dataset_name 'imagenet_100' --batch_size 128 --num_workers 4 --use_ssb_splits --num_to_avg 10 --score msp --ckpts_date YOUR_CKPTS_NAME --temp_logits 0.1
61 | ```
62 |
63 |
64 |
65 | ## :clipboard: Citing this work
66 |
67 | ```bibtex
68 | @ARTICLE{10948388,
69 | author={Ma, Shijie and Zhu, Fei and Zhang, Xu-Yao and Liu, Cheng-Lin},
70 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
71 | title={ProtoGCD: Unified and Unbiased Prototype Learning for Generalized Category Discovery},
72 | year={2025},
73 | volume={},
74 | number={},
75 | pages={1-17},
76 | keywords={Prototypes;Adaptation models;Contrastive learning;Training;Magnetic heads;Feature extraction;Estimation;Automobiles;Accuracy;Pragmatics;Generalized category discovery;open-world learning;prototype learning;semi-supervised learning},
77 | doi={10.1109/TPAMI.2025.3557502}
78 | }
79 | ```
80 |
81 |
82 |
83 | ## :gift: Acknowledgements
84 |
85 | In building the ProtoGCD codebase, we reference [SimGCD](https://github.com/CVMI-Lab/SimGCD).
86 |
87 |
88 |
89 | ## :white_check_mark: License
90 |
91 | This project is licensed under the MIT License - see the [LICENSE](https://github.com/mashijie1028/ProtoGCD/blob/main/LICENSE) file for details.
92 |
93 |
94 |
95 | ## :email: Contact
96 |
97 | If you have further questions or discussions, feel free to contact me:
98 |
99 | Shijie Ma (mashijie2021@ia.ac.cn)
--------------------------------------------------------------------------------
/assets/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/assets/method.jpg
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -----------------
2 | # DATASET ROOTS
3 | # -----------------
4 | cifar_10_root = '/data4/sjma/dataset/CIFAR/'
5 | cifar_100_root = '/data4/sjma/dataset/CIFAR/'
6 | cub_root = '/data4/sjma/dataset/CUB/'
7 | aircraft_root = '/data4/sjma/dataset/FGVC-Aircraft/fgvc-aircraft-2013b/'
8 | car_root = "/data4/sjma/dataset/Stanford-Cars/"
9 | herbarium_dataroot = '/data4/sjma/dataset/Herbarium19-Small/'
10 | #imagenet_root = '/lustre/datasharing/sjma/ImageNet/ILSVRC2012/imagenet/'
11 | imagenet_root = '/data4/sjma/dataset/ImageNet/ILSVRC2012/imagenet/'
12 |
13 | # OSR Split dir
14 | osr_split_dir = '/data4/sjma/dataset/ssb_splits/'
15 |
16 | # -----------------
17 | # OTHER PATHS
18 | # -----------------
19 | exp_root = 'dev_outputs' # All logs and checkpoints will be saved here
20 |
--------------------------------------------------------------------------------
/data/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from data.augmentations.cut_out import *
3 | from data.augmentations.randaugment import RandAugment
4 |
5 | def get_transform(transform_type='imagenet', image_size=32, args=None):
6 |
7 | if transform_type == 'imagenet':
8 |
9 | mean = (0.485, 0.456, 0.406)
10 | std = (0.229, 0.224, 0.225)
11 | interpolation = args.interpolation
12 | crop_pct = args.crop_pct
13 |
14 | train_transform = transforms.Compose([
15 | transforms.Resize(int(image_size / crop_pct), interpolation),
16 | transforms.RandomCrop(image_size),
17 | transforms.RandomHorizontalFlip(p=0.5),
18 | transforms.ColorJitter(),
19 | transforms.ToTensor(),
20 | transforms.Normalize(
21 | mean=torch.tensor(mean),
22 | std=torch.tensor(std))
23 | ])
24 |
25 | test_transform = transforms.Compose([
26 | transforms.Resize(int(image_size / crop_pct), interpolation),
27 | transforms.CenterCrop(image_size),
28 | transforms.ToTensor(),
29 | transforms.Normalize(
30 | mean=torch.tensor(mean),
31 | std=torch.tensor(std))
32 | ])
33 |
34 | elif transform_type == 'pytorch-cifar':
35 |
36 | mean = (0.4914, 0.4822, 0.4465)
37 | std = (0.2023, 0.1994, 0.2010)
38 |
39 | train_transform = transforms.Compose([
40 | transforms.RandomCrop(image_size, padding=4),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=mean, std=std),
44 | ])
45 |
46 | test_transform = transforms.Compose([
47 | transforms.Resize((image_size, image_size)),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=mean, std=std),
50 | ])
51 |
52 | elif transform_type == 'herbarium_default':
53 |
54 | train_transform = transforms.Compose([
55 | transforms.Resize((image_size, image_size)),
56 | transforms.RandomResizedCrop(image_size, scale=(args.resize_lower_bound, 1)),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.ToTensor(),
59 | ])
60 |
61 | test_transform = transforms.Compose([
62 | transforms.Resize((image_size, image_size)),
63 | transforms.ToTensor(),
64 | ])
65 |
66 | elif transform_type == 'cutout':
67 |
68 | mean = np.array([0.4914, 0.4822, 0.4465])
69 | std = np.array([0.2470, 0.2435, 0.2616])
70 |
71 | train_transform = transforms.Compose([
72 | transforms.RandomCrop(image_size, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | normalize(mean, std),
75 | cutout(mask_size=int(image_size / 2),
76 | p=1,
77 | cutout_inside=False),
78 | to_tensor(),
79 | ])
80 | test_transform = transforms.Compose([
81 | transforms.Resize((image_size, image_size)),
82 | transforms.ToTensor(),
83 | transforms.Normalize(mean, std),
84 | ])
85 |
86 | elif transform_type == 'rand-augment':
87 |
88 | mean = (0.485, 0.456, 0.406)
89 | std = (0.229, 0.224, 0.225)
90 |
91 | train_transform = transforms.Compose([
92 | transforms.Resize((image_size, image_size)),
93 | transforms.RandomCrop(image_size, padding=4),
94 | transforms.RandomHorizontalFlip(),
95 | transforms.ToTensor(),
96 | transforms.Normalize(mean=mean, std=std),
97 | ])
98 |
99 | train_transform.transforms.insert(0, RandAugment(args.rand_aug_n, args.rand_aug_m, args=None))
100 |
101 | test_transform = transforms.Compose([
102 | transforms.Resize((image_size, image_size)),
103 | transforms.ToTensor(),
104 | transforms.Normalize(mean=mean, std=std),
105 | ])
106 |
107 | elif transform_type == 'random_affine':
108 |
109 | mean = (0.485, 0.456, 0.406)
110 | std = (0.229, 0.224, 0.225)
111 | interpolation = args.interpolation
112 | crop_pct = args.crop_pct
113 |
114 | train_transform = transforms.Compose([
115 | transforms.Resize((image_size, image_size), interpolation),
116 | transforms.RandomAffine(degrees=(-45, 45),
117 | translate=(0.1, 0.1), shear=(-15, 15), scale=(0.7, args.crop_pct)),
118 | transforms.ColorJitter(),
119 | transforms.ToTensor(),
120 | transforms.Normalize(
121 | mean=torch.tensor(mean),
122 | std=torch.tensor(std))
123 | ])
124 |
125 | test_transform = transforms.Compose([
126 | transforms.Resize(int(image_size / crop_pct), interpolation),
127 | transforms.CenterCrop(image_size),
128 | transforms.ToTensor(),
129 | transforms.Normalize(
130 | mean=torch.tensor(mean),
131 | std=torch.tensor(std))
132 | ])
133 |
134 | else:
135 |
136 | raise NotImplementedError
137 |
138 | return (train_transform, test_transform)
--------------------------------------------------------------------------------
/data/augmentations/cut_out.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/hysts/pytorch_cutout
3 | """
4 |
5 | import torch
6 | import numpy as np
7 |
8 | def cutout(mask_size, p, cutout_inside, mask_color=(0, 0, 0)):
9 | mask_size_half = mask_size // 2
10 | offset = 1 if mask_size % 2 == 0 else 0
11 |
12 | def _cutout(image):
13 | image = np.asarray(image).copy()
14 |
15 | if np.random.random() > p:
16 | return image
17 |
18 | h, w = image.shape[:2]
19 |
20 | if cutout_inside:
21 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half
22 | cymin, cymax = mask_size_half, h + offset - mask_size_half
23 | else:
24 | cxmin, cxmax = 0, w + offset
25 | cymin, cymax = 0, h + offset
26 |
27 | cx = np.random.randint(cxmin, cxmax)
28 | cy = np.random.randint(cymin, cymax)
29 | xmin = cx - mask_size_half
30 | ymin = cy - mask_size_half
31 | xmax = xmin + mask_size
32 | ymax = ymin + mask_size
33 | xmin = max(0, xmin)
34 | ymin = max(0, ymin)
35 | xmax = min(w, xmax)
36 | ymax = min(h, ymax)
37 | image[ymin:ymax, xmin:xmax] = mask_color
38 | return image
39 |
40 | return _cutout
41 |
42 | def to_tensor():
43 | def _to_tensor(image):
44 | if len(image.shape) == 3:
45 | return torch.from_numpy(
46 | image.transpose(2, 0, 1).astype(float))
47 | else:
48 | return torch.from_numpy(image[None, :, :].astype(float))
49 |
50 | return _to_tensor
51 |
52 | def normalize(mean, std):
53 |
54 | mean = np.array(mean)
55 | std = np.array(std)
56 |
57 | def _normalize(image):
58 | image = np.asarray(image).astype(float) / 255.
59 | image = (image - mean) / std
60 | return image
61 |
62 | return _normalize
63 |
--------------------------------------------------------------------------------
/data/augmentations/randaugment.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | """
4 |
5 | import random
6 |
7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
8 | import numpy as np
9 | import torch
10 | from PIL import Image
11 |
12 |
13 | def ShearX(img, v): # [-0.3, 0.3]
14 | assert -0.3 <= v <= 0.3
15 | if random.random() > 0.5:
16 | v = -v
17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 |
19 |
20 | def ShearY(img, v): # [-0.3, 0.3]
21 | assert -0.3 <= v <= 0.3
22 | if random.random() > 0.5:
23 | v = -v
24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 |
26 |
27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28 | assert -0.45 <= v <= 0.45
29 | if random.random() > 0.5:
30 | v = -v
31 | v = v * img.size[0]
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 |
35 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36 | assert 0 <= v
37 | if random.random() > 0.5:
38 | v = -v
39 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
40 |
41 |
42 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
43 | assert -0.45 <= v <= 0.45
44 | if random.random() > 0.5:
45 | v = -v
46 | v = v * img.size[1]
47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
48 |
49 |
50 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 | assert 0 <= v
52 | if random.random() > 0.5:
53 | v = -v
54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 |
56 |
57 | def Rotate(img, v): # [-30, 30]
58 | assert -30 <= v <= 30
59 | if random.random() > 0.5:
60 | v = -v
61 | return img.rotate(v)
62 |
63 |
64 | def AutoContrast(img, _):
65 | return PIL.ImageOps.autocontrast(img)
66 |
67 |
68 | def Invert(img, _):
69 | return PIL.ImageOps.invert(img)
70 |
71 |
72 | def Equalize(img, _):
73 | return PIL.ImageOps.equalize(img)
74 |
75 |
76 | def Flip(img, _): # not from the paper
77 | return PIL.ImageOps.mirror(img)
78 |
79 |
80 | def Solarize(img, v): # [0, 256]
81 | assert 0 <= v <= 256
82 | return PIL.ImageOps.solarize(img, v)
83 |
84 |
85 | def SolarizeAdd(img, addition=0, threshold=128):
86 | img_np = np.array(img).astype(np.int)
87 | img_np = img_np + addition
88 | img_np = np.clip(img_np, 0, 255)
89 | img_np = img_np.astype(np.uint8)
90 | img = Image.fromarray(img_np)
91 | return PIL.ImageOps.solarize(img, threshold)
92 |
93 |
94 | def Posterize(img, v): # [4, 8]
95 | v = int(v)
96 | v = max(1, v)
97 | return PIL.ImageOps.posterize(img, v)
98 |
99 |
100 | def Contrast(img, v): # [0.1,1.9]
101 | assert 0.1 <= v <= 1.9
102 | return PIL.ImageEnhance.Contrast(img).enhance(v)
103 |
104 |
105 | def Color(img, v): # [0.1,1.9]
106 | assert 0.1 <= v <= 1.9
107 | return PIL.ImageEnhance.Color(img).enhance(v)
108 |
109 |
110 | def Brightness(img, v): # [0.1,1.9]
111 | assert 0.1 <= v <= 1.9
112 | return PIL.ImageEnhance.Brightness(img).enhance(v)
113 |
114 |
115 | def Sharpness(img, v): # [0.1,1.9]
116 | assert 0.1 <= v <= 1.9
117 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
118 |
119 |
120 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
121 | assert 0.0 <= v <= 0.2
122 | if v <= 0.:
123 | return img
124 |
125 | v = v * img.size[0]
126 | return CutoutAbs(img, v)
127 |
128 |
129 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
130 | # assert 0 <= v <= 20
131 | if v < 0:
132 | return img
133 | w, h = img.size
134 | x0 = np.random.uniform(w)
135 | y0 = np.random.uniform(h)
136 |
137 | x0 = int(max(0, x0 - v / 2.))
138 | y0 = int(max(0, y0 - v / 2.))
139 | x1 = min(w, x0 + v)
140 | y1 = min(h, y0 + v)
141 |
142 | xy = (x0, y0, x1, y1)
143 | color = (125, 123, 114)
144 | # color = (0, 0, 0)
145 | img = img.copy()
146 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
147 | return img
148 |
149 |
150 | def SamplePairing(imgs): # [0, 0.4]
151 | def f(img1, v):
152 | i = np.random.choice(len(imgs))
153 | img2 = PIL.Image.fromarray(imgs[i])
154 | return PIL.Image.blend(img1, img2, v)
155 |
156 | return f
157 |
158 |
159 | def Identity(img, v):
160 | return img
161 |
162 |
163 | def augment_list(): # 16 oeprations and their ranges
164 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
165 | # l = [
166 | # (Identity, 0., 1.0),
167 | # (ShearX, 0., 0.3), # 0
168 | # (ShearY, 0., 0.3), # 1
169 | # (TranslateX, 0., 0.33), # 2
170 | # (TranslateY, 0., 0.33), # 3
171 | # (Rotate, 0, 30), # 4
172 | # (AutoContrast, 0, 1), # 5
173 | # (Invert, 0, 1), # 6
174 | # (Equalize, 0, 1), # 7
175 | # (Solarize, 0, 110), # 8
176 | # (Posterize, 4, 8), # 9
177 | # # (Contrast, 0.1, 1.9), # 10
178 | # (Color, 0.1, 1.9), # 11
179 | # (Brightness, 0.1, 1.9), # 12
180 | # (Sharpness, 0.1, 1.9), # 13
181 | # # (Cutout, 0, 0.2), # 14
182 | # # (SamplePairing(imgs), 0, 0.4), # 15
183 | # ]
184 |
185 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
186 | l = [
187 | (AutoContrast, 0, 1),
188 | (Equalize, 0, 1),
189 | (Invert, 0, 1),
190 | (Rotate, 0, 30),
191 | (Posterize, 0, 4),
192 | (Solarize, 0, 256),
193 | (SolarizeAdd, 0, 110),
194 | (Color, 0.1, 1.9),
195 | (Contrast, 0.1, 1.9),
196 | (Brightness, 0.1, 1.9),
197 | (Sharpness, 0.1, 1.9),
198 | (ShearX, 0., 0.3),
199 | (ShearY, 0., 0.3),
200 | (CutoutAbs, 0, 40),
201 | (TranslateXabs, 0., 100),
202 | (TranslateYabs, 0., 100),
203 | ]
204 |
205 | return l
206 |
207 | def augment_list_svhn(): # 16 oeprations and their ranges
208 |
209 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
210 | l = [
211 | (AutoContrast, 0, 1),
212 | (Equalize, 0, 1),
213 | (Invert, 0, 1),
214 | (Posterize, 0, 4),
215 | (Solarize, 0, 256),
216 | (SolarizeAdd, 0, 110),
217 | (Color, 0.1, 1.9),
218 | (Contrast, 0.1, 1.9),
219 | (Brightness, 0.1, 1.9),
220 | (Sharpness, 0.1, 1.9),
221 | (ShearX, 0., 0.3),
222 | (ShearY, 0., 0.3),
223 | (CutoutAbs, 0, 40),
224 | ]
225 |
226 | return l
227 |
228 |
229 | class Lighting(object):
230 | """Lighting noise(AlexNet - style PCA - based noise)"""
231 |
232 | def __init__(self, alphastd, eigval, eigvec):
233 | self.alphastd = alphastd
234 | self.eigval = torch.Tensor(eigval)
235 | self.eigvec = torch.Tensor(eigvec)
236 |
237 | def __call__(self, img):
238 | if self.alphastd == 0:
239 | return img
240 |
241 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
242 | rgb = self.eigvec.type_as(img).clone() \
243 | .mul(alpha.view(1, 3).expand(3, 3)) \
244 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
245 | .sum(1).squeeze()
246 |
247 | return img.add(rgb.view(3, 1, 1).expand_as(img))
248 |
249 |
250 | class CutoutDefault(object):
251 | """
252 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
253 | """
254 | def __init__(self, length):
255 | self.length = length
256 |
257 | def __call__(self, img):
258 | h, w = img.size(1), img.size(2)
259 | mask = np.ones((h, w), float)
260 | y = np.random.randint(h)
261 | x = np.random.randint(w)
262 |
263 | y1 = np.clip(y - self.length // 2, 0, h)
264 | y2 = np.clip(y + self.length // 2, 0, h)
265 | x1 = np.clip(x - self.length // 2, 0, w)
266 | x2 = np.clip(x + self.length // 2, 0, w)
267 |
268 | mask[y1: y2, x1: x2] = 0.
269 | mask = torch.from_numpy(mask)
270 | mask = mask.expand_as(img)
271 | img *= mask
272 | return img
273 |
274 |
275 | class RandAugment:
276 | def __init__(self, n, m, args=None):
277 | self.n = n # [1, 2]
278 | self.m = m # [0...30]
279 |
280 | if args is None:
281 | self.augment_list = augment_list()
282 |
283 | elif args.dataset == 'svhn' or args.dataset == 'mnist':
284 | self.augment_list = augment_list_svhn()
285 |
286 | else:
287 | self.augment_list = augment_list()
288 |
289 | def __call__(self, img):
290 | ops = random.choices(self.augment_list, k=self.n)
291 | for op, minval, maxval in ops:
292 | val = (float(self.m) / 30) * float(maxval - minval) + minval
293 | img = op(img, val)
294 |
295 | return img
296 |
--------------------------------------------------------------------------------
/data/cifar.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10, CIFAR100
2 | from copy import deepcopy
3 | import numpy as np
4 |
5 | from data.data_utils import subsample_instances
6 | from config import cifar_10_root, cifar_100_root
7 |
8 |
9 | class CustomCIFAR10(CIFAR10):
10 |
11 | def __init__(self, *args, **kwargs):
12 |
13 | super(CustomCIFAR10, self).__init__(*args, **kwargs)
14 |
15 | self.uq_idxs = np.array(range(len(self)))
16 |
17 | def __getitem__(self, item):
18 |
19 | img, label = super().__getitem__(item)
20 | uq_idx = self.uq_idxs[item]
21 |
22 | return img, label, uq_idx
23 |
24 | def __len__(self):
25 | return len(self.targets)
26 |
27 |
28 | class CustomCIFAR100(CIFAR100):
29 |
30 | def __init__(self, *args, **kwargs):
31 | super(CustomCIFAR100, self).__init__(*args, **kwargs)
32 |
33 | self.uq_idxs = np.array(range(len(self)))
34 |
35 | def __getitem__(self, item):
36 | img, label = super().__getitem__(item)
37 | uq_idx = self.uq_idxs[item]
38 |
39 | return img, label, uq_idx
40 |
41 | def __len__(self):
42 | return len(self.targets)
43 |
44 |
45 | def subsample_dataset(dataset, idxs):
46 |
47 | # Allow for setting in which all empty set of indices is passed
48 |
49 | if len(idxs) > 0:
50 |
51 | dataset.data = dataset.data[idxs]
52 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
53 | dataset.uq_idxs = dataset.uq_idxs[idxs]
54 |
55 | return dataset
56 |
57 | else:
58 |
59 | return None
60 |
61 |
62 | def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):
63 |
64 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
65 |
66 | target_xform_dict = {}
67 | for i, k in enumerate(include_classes):
68 | target_xform_dict[k] = i
69 |
70 | dataset = subsample_dataset(dataset, cls_idxs)
71 |
72 | # dataset.target_transform = lambda x: target_xform_dict[x]
73 |
74 | return dataset
75 |
76 |
77 | def get_train_val_indices(train_dataset, val_split=0.2):
78 |
79 | train_classes = np.unique(train_dataset.targets)
80 |
81 | # Get train/test indices
82 | train_idxs = []
83 | val_idxs = []
84 | for cls in train_classes:
85 |
86 | cls_idxs = np.where(train_dataset.targets == cls)[0]
87 |
88 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
89 | t_ = [x for x in cls_idxs if x not in v_]
90 |
91 | train_idxs.extend(t_)
92 | val_idxs.extend(v_)
93 |
94 | return train_idxs, val_idxs
95 |
96 |
97 | def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
98 | prop_train_labels=0.8, split_train_val=False, seed=0):
99 |
100 | np.random.seed(seed)
101 |
102 | # Init entire training set
103 | whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)
104 |
105 | # Get labelled training set which has subsampled classes, then subsample some indices from that
106 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
107 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
108 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
109 |
110 | # Split into training and validation sets
111 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
112 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
113 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
114 | val_dataset_labelled_split.transform = test_transform
115 |
116 | # Get unlabelled data
117 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
118 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
119 |
120 | # Get test set for all classes
121 | test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)
122 |
123 | # Either split train into train and val or use test set as val
124 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
125 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
126 |
127 | all_datasets = {
128 | 'train_labelled': train_dataset_labelled,
129 | 'train_unlabelled': train_dataset_unlabelled,
130 | 'val': val_dataset_labelled,
131 | 'test': test_dataset,
132 | }
133 |
134 | return all_datasets
135 |
136 |
137 | def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
138 | prop_train_labels=0.8, split_train_val=False, seed=0):
139 |
140 | np.random.seed(seed)
141 |
142 | # Init entire training set
143 | whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)
144 |
145 | # Get labelled training set which has subsampled classes, then subsample some indices from that
146 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
147 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
148 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
149 |
150 | # Split into training and validation sets
151 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
152 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
153 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
154 | val_dataset_labelled_split.transform = test_transform
155 |
156 | # Get unlabelled data
157 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
158 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
159 |
160 | # Get test set for all classes
161 | test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)
162 |
163 | # Either split train into train and val or use test set as val
164 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
165 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
166 |
167 | all_datasets = {
168 | 'train_labelled': train_dataset_labelled,
169 | 'train_unlabelled': train_dataset_unlabelled,
170 | 'val': val_dataset_labelled,
171 | 'test': test_dataset,
172 | }
173 |
174 | return all_datasets
175 |
176 |
177 | if __name__ == '__main__':
178 |
179 | x = get_cifar_100_datasets(None, None, split_train_val=False,
180 | train_classes=range(80), prop_train_labels=0.5)
181 |
182 | print('Printing lens...')
183 | for k, v in x.items():
184 | if v is not None:
185 | print(f'{k}: {len(v)}')
186 |
187 | print('Printing labelled and unlabelled overlap...')
188 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
189 | print('Printing total instances in train...')
190 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
191 |
192 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
193 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
194 | print(f'Len labelled set: {len(x["train_labelled"])}')
195 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/corrupt_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | import torchvision.transforms.functional as tfunc
8 |
9 | ROOT_DIR_10 = "/data4/sjma/dataset/Corruptions/CIFAR-10-C/"
10 | ROOT_DIR_100 = "/data4/sjma/dataset/Corruptions/CIFAR-100-C/"
11 |
12 |
13 | class DatasetFromTorchTensor(Dataset):
14 | def __init__(self, data, target, transform=None):
15 | # Data type handling must be done beforehand. It is too difficult at this point.
16 | self.data = data
17 | self.target = target
18 | if len(self.target.shape)==1:
19 | self.target = target.long()
20 | self.transform = transform
21 |
22 | def __getitem__(self, index):
23 | x = self.data[index]
24 | y = self.target[index]
25 | if self.transform:
26 | x = tfunc.to_pil_image(x)
27 | x = self.transform(x)
28 | return x, y
29 |
30 | def __len__(self):
31 | return len(self.data)
32 |
33 |
34 | def get_data(data_name, dataset, test_transform=None, severity=1):
35 | if data_name == 'cifar10':
36 | ROOT_DIR = ROOT_DIR_10
37 | if data_name == 'cifar100':
38 | ROOT_DIR = ROOT_DIR_100
39 | data_path = os.path.join(ROOT_DIR, dataset+'.npy')
40 | label_path = os.path.join(ROOT_DIR, 'labels.npy')
41 | data = torch.tensor(np.transpose(np.load(data_path), (0,3,1,2)))
42 | labels = torch.tensor(np.load(label_path))
43 | start = 10000 * (severity - 1)
44 |
45 | data = data[start:start+10000]
46 | labels = labels[start:start+10000]
47 | test_data = DatasetFromTorchTensor(data, labels, transform=test_transform)
48 |
49 | return test_data
50 |
51 |
52 |
53 | if __name__ =='__main__':
54 | train, test = get_data('snow')
55 | print(len(train), len(test))
56 |
--------------------------------------------------------------------------------
/data/cub.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 | from torchvision.datasets.folder import default_loader
7 | from torchvision.datasets.utils import download_url
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import cub_root
12 |
13 |
14 | class CustomCub2011(Dataset):
15 | base_folder = 'CUB_200_2011/images'
16 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
17 | filename = 'CUB_200_2011.tgz'
18 | tgz_md5 = '97eceeb196236b17998738112f37df78'
19 |
20 | def __init__(self, root, train=True, transform=None, target_transform=None, loader=default_loader, download=False):
21 |
22 | self.root = os.path.expanduser(root)
23 | self.transform = transform
24 | self.target_transform = target_transform
25 |
26 | self.loader = loader
27 | self.train = train
28 |
29 |
30 | if download:
31 | self._download()
32 |
33 | if not self._check_integrity():
34 | raise RuntimeError('Dataset not found or corrupted.' +
35 | ' You can use download=True to download it')
36 |
37 | self.uq_idxs = np.array(range(len(self)))
38 |
39 | def _load_metadata(self):
40 | images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
41 | names=['img_id', 'filepath'])
42 | image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
43 | sep=' ', names=['img_id', 'target'])
44 | train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
45 | sep=' ', names=['img_id', 'is_training_img'])
46 |
47 | data = images.merge(image_class_labels, on='img_id')
48 | self.data = data.merge(train_test_split, on='img_id')
49 |
50 | if self.train:
51 | self.data = self.data[self.data.is_training_img == 1]
52 | else:
53 | self.data = self.data[self.data.is_training_img == 0]
54 |
55 | def _check_integrity(self):
56 | try:
57 | self._load_metadata()
58 | except Exception:
59 | return False
60 |
61 | for index, row in self.data.iterrows():
62 | filepath = os.path.join(self.root, self.base_folder, row.filepath)
63 | if not os.path.isfile(filepath):
64 | print(filepath)
65 | return False
66 | return True
67 |
68 | def _download(self):
69 | import tarfile
70 |
71 | if self._check_integrity():
72 | print('Files already downloaded and verified')
73 | return
74 |
75 | download_url(self.url, self.root, self.filename, self.tgz_md5)
76 |
77 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
78 | tar.extractall(path=self.root)
79 |
80 | def __len__(self):
81 | return len(self.data)
82 |
83 | def __getitem__(self, idx):
84 | sample = self.data.iloc[idx]
85 | path = os.path.join(self.root, self.base_folder, sample.filepath)
86 | target = sample.target - 1 # Targets start at 1 by default, so shift to 0
87 | img = self.loader(path)
88 |
89 | if self.transform is not None:
90 | img = self.transform(img)
91 |
92 | if self.target_transform is not None:
93 | target = self.target_transform(target)
94 |
95 | return img, target, self.uq_idxs[idx]
96 |
97 |
98 | def subsample_dataset(dataset, idxs):
99 |
100 | mask = np.zeros(len(dataset)).astype('bool')
101 | mask[idxs] = True
102 |
103 | dataset.data = dataset.data[mask]
104 | dataset.uq_idxs = dataset.uq_idxs[mask]
105 |
106 | return dataset
107 |
108 |
109 | def subsample_classes(dataset, include_classes=range(160)):
110 |
111 | include_classes_cub = np.array(include_classes) + 1 # CUB classes are indexed 1 --> 200 instead of 0 --> 199
112 | cls_idxs = [x for x, (_, r) in enumerate(dataset.data.iterrows()) if int(r['target']) in include_classes_cub]
113 |
114 | # TODO: For now have no target transform
115 | target_xform_dict = {}
116 | for i, k in enumerate(include_classes):
117 | target_xform_dict[k] = i
118 |
119 | dataset = subsample_dataset(dataset, cls_idxs)
120 |
121 | dataset.target_transform = lambda x: target_xform_dict[x]
122 |
123 | return dataset
124 |
125 |
126 | def get_train_val_indices(train_dataset, val_split=0.2):
127 |
128 | train_classes = np.unique(train_dataset.data['target'])
129 |
130 | # Get train/test indices
131 | train_idxs = []
132 | val_idxs = []
133 | for cls in train_classes:
134 |
135 | cls_idxs = np.where(train_dataset.data['target'] == cls)[0]
136 |
137 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
138 | t_ = [x for x in cls_idxs if x not in v_]
139 |
140 | train_idxs.extend(t_)
141 | val_idxs.extend(v_)
142 |
143 | return train_idxs, val_idxs
144 |
145 |
146 | def get_cub_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
147 | split_train_val=False, seed=0, download=False):
148 |
149 | np.random.seed(seed)
150 |
151 | # Init entire training set
152 | whole_training_set = CustomCub2011(root=cub_root, transform=train_transform, train=True, download=download)
153 |
154 | # Get labelled training set which has subsampled classes, then subsample some indices from that
155 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
156 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
157 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
158 |
159 | # Split into training and validation sets
160 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
161 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
162 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
163 | val_dataset_labelled_split.transform = test_transform
164 |
165 | # Get unlabelled data
166 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
167 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
168 |
169 | # Get test set for all classes
170 | test_dataset = CustomCub2011(root=cub_root, transform=test_transform, train=False)
171 |
172 | # Either split train into train and val or use test set as val
173 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
174 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
175 |
176 | all_datasets = {
177 | 'train_labelled': train_dataset_labelled,
178 | 'train_unlabelled': train_dataset_unlabelled,
179 | 'val': val_dataset_labelled,
180 | 'test': test_dataset,
181 | }
182 |
183 | return all_datasets
184 |
185 | if __name__ == '__main__':
186 |
187 | x = get_cub_datasets(None, None, split_train_val=False,
188 | train_classes=range(100), prop_train_labels=0.5)
189 |
190 | print('Printing lens...')
191 | for k, v in x.items():
192 | if v is not None:
193 | print(f'{k}: {len(v)}')
194 |
195 | print('Printing labelled and unlabelled overlap...')
196 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
197 | print('Printing total instances in train...')
198 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
199 |
200 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].data["target"].values))}')
201 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].data["target"].values))}')
202 | print(f'Len labelled set: {len(x["train_labelled"])}')
203 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset
3 |
4 | def subsample_instances(dataset, prop_indices_to_subsample=0.8):
5 |
6 | np.random.seed(0)
7 | subsample_indices = np.random.choice(range(len(dataset)), replace=False,
8 | size=(int(prop_indices_to_subsample * len(dataset)),))
9 |
10 | return subsample_indices
11 |
12 |
13 | class MergedDataset(Dataset):
14 |
15 | """
16 | Takes two datasets (labelled_dataset, unlabelled_dataset) and merges them
17 | Allows you to iterate over them in parallel
18 | """
19 |
20 | def __init__(self, labelled_dataset, unlabelled_dataset):
21 |
22 | self.labelled_dataset = labelled_dataset
23 | self.unlabelled_dataset = unlabelled_dataset
24 | self.target_transform = None
25 |
26 | def __getitem__(self, item):
27 |
28 | if item < len(self.labelled_dataset):
29 | img, label, uq_idx = self.labelled_dataset[item]
30 | labeled_or_not = 1
31 |
32 | else:
33 |
34 | img, label, uq_idx = self.unlabelled_dataset[item - len(self.labelled_dataset)]
35 | labeled_or_not = 0
36 |
37 |
38 | return img, label, uq_idx, np.array([labeled_or_not])
39 |
40 | def __len__(self):
41 | return len(self.unlabelled_dataset) + len(self.labelled_dataset)
42 |
--------------------------------------------------------------------------------
/data/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from copy import deepcopy
4 |
5 | from torchvision.datasets.folder import default_loader
6 | from torch.utils.data import Dataset
7 |
8 | from data.data_utils import subsample_instances
9 | from config import aircraft_root
10 |
11 | def make_dataset(dir, image_ids, targets):
12 | assert(len(image_ids) == len(targets))
13 | images = []
14 | dir = os.path.expanduser(dir)
15 | for i in range(len(image_ids)):
16 | item = (os.path.join(dir, 'data', 'images',
17 | '%s.jpg' % image_ids[i]), targets[i])
18 | images.append(item)
19 | return images
20 |
21 |
22 | def find_classes(classes_file):
23 |
24 | # read classes file, separating out image IDs and class names
25 | image_ids = []
26 | targets = []
27 | f = open(classes_file, 'r')
28 | for line in f:
29 | split_line = line.split(' ')
30 | image_ids.append(split_line[0])
31 | targets.append(' '.join(split_line[1:]))
32 | f.close()
33 |
34 | # index class names
35 | classes = np.unique(targets)
36 | class_to_idx = {classes[i]: i for i in range(len(classes))}
37 | targets = [class_to_idx[c] for c in targets]
38 |
39 | return (image_ids, targets, classes, class_to_idx)
40 |
41 |
42 | class FGVCAircraft(Dataset):
43 |
44 | """`FGVC-Aircraft `_ Dataset.
45 |
46 | Args:
47 | root (string): Root directory path to dataset.
48 | class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
49 | to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
50 | transform (callable, optional): A function/transform that takes in a PIL image
51 | and returns a transformed version. E.g. ``transforms.RandomCrop``
52 | target_transform (callable, optional): A function/transform that takes in the
53 | target and transforms it.
54 | loader (callable, optional): A function to load an image given its path.
55 | download (bool, optional): If true, downloads the dataset from the internet and
56 | puts it in the root directory. If dataset is already downloaded, it is not
57 | downloaded again.
58 | """
59 | url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
60 | class_types = ('variant', 'family', 'manufacturer')
61 | splits = ('train', 'val', 'trainval', 'test')
62 |
63 | def __init__(self, root, class_type='variant', split='train', transform=None,
64 | target_transform=None, loader=default_loader, download=False):
65 | if split not in self.splits:
66 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
67 | split, ', '.join(self.splits),
68 | ))
69 | if class_type not in self.class_types:
70 | raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
71 | class_type, ', '.join(self.class_types),
72 | ))
73 | self.root = os.path.expanduser(root)
74 | self.class_type = class_type
75 | self.split = split
76 | self.classes_file = os.path.join(self.root, 'data',
77 | 'images_%s_%s.txt' % (self.class_type, self.split))
78 |
79 | if download:
80 | self.download()
81 |
82 | (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
83 | samples = make_dataset(self.root, image_ids, targets)
84 |
85 | self.transform = transform
86 | self.target_transform = target_transform
87 | self.loader = loader
88 |
89 | self.samples = samples
90 | self.classes = classes
91 | self.class_to_idx = class_to_idx
92 | self.train = True if split == 'train' else False
93 |
94 | self.uq_idxs = np.array(range(len(self)))
95 |
96 | def __getitem__(self, index):
97 | """
98 | Args:
99 | index (int): Index
100 |
101 | Returns:
102 | tuple: (sample, target) where target is class_index of the target class.
103 | """
104 |
105 | path, target = self.samples[index]
106 | sample = self.loader(path)
107 | if self.transform is not None:
108 | sample = self.transform(sample)
109 | if self.target_transform is not None:
110 | target = self.target_transform(target)
111 |
112 | return sample, target, self.uq_idxs[index]
113 |
114 | def __len__(self):
115 | return len(self.samples)
116 |
117 | def __repr__(self):
118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
119 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
120 | fmt_str += ' Root Location: {}\n'.format(self.root)
121 | tmp = ' Transforms (if any): '
122 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
123 | tmp = ' Target Transforms (if any): '
124 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
125 | return fmt_str
126 |
127 | def _check_exists(self):
128 | return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
129 | os.path.exists(self.classes_file)
130 |
131 | def download(self):
132 | """Download the FGVC-Aircraft data if it doesn't exist already."""
133 | from six.moves import urllib
134 | import tarfile
135 |
136 | if self._check_exists():
137 | return
138 |
139 | # prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
140 | print('Downloading %s ... (may take a few minutes)' % self.url)
141 | parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
142 | tar_name = self.url.rpartition('/')[-1]
143 | tar_path = os.path.join(parent_dir, tar_name)
144 | data = urllib.request.urlopen(self.url)
145 |
146 | # download .tar.gz file
147 | with open(tar_path, 'wb') as f:
148 | f.write(data.read())
149 |
150 | # extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
151 | data_folder = tar_path.strip('.tar.gz')
152 | print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
153 | tar = tarfile.open(tar_path)
154 | tar.extractall(parent_dir)
155 |
156 | # if necessary, rename data folder to self.root
157 | if not os.path.samefile(data_folder, self.root):
158 | print('Renaming %s to %s ...' % (data_folder, self.root))
159 | os.rename(data_folder, self.root)
160 |
161 | # delete .tar.gz file
162 | print('Deleting %s ...' % tar_path)
163 | os.remove(tar_path)
164 |
165 | print('Done!')
166 |
167 |
168 | def subsample_dataset(dataset, idxs):
169 |
170 | mask = np.zeros(len(dataset)).astype('bool')
171 | mask[idxs] = True
172 |
173 | dataset.samples = [(p, t) for i, (p, t) in enumerate(dataset.samples) if i in idxs]
174 | dataset.uq_idxs = dataset.uq_idxs[mask]
175 |
176 | return dataset
177 |
178 |
179 | def subsample_classes(dataset, include_classes=range(60)):
180 |
181 | cls_idxs = [i for i, (p, t) in enumerate(dataset.samples) if t in include_classes]
182 |
183 | # TODO: Don't transform targets for now
184 | target_xform_dict = {}
185 | for i, k in enumerate(include_classes):
186 | target_xform_dict[k] = i
187 |
188 | dataset = subsample_dataset(dataset, cls_idxs)
189 |
190 | dataset.target_transform = lambda x: target_xform_dict[x]
191 |
192 | return dataset
193 |
194 |
195 | def get_train_val_indices(train_dataset, val_split=0.2):
196 |
197 | all_targets = [t for i, (p, t) in enumerate(train_dataset.samples)]
198 | train_classes = np.unique(all_targets)
199 |
200 | # Get train/test indices
201 | train_idxs = []
202 | val_idxs = []
203 | for cls in train_classes:
204 | cls_idxs = np.where(all_targets == cls)[0]
205 |
206 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
207 | t_ = [x for x in cls_idxs if x not in v_]
208 |
209 | train_idxs.extend(t_)
210 | val_idxs.extend(v_)
211 |
212 | return train_idxs, val_idxs
213 |
214 |
215 | def get_aircraft_datasets(train_transform, test_transform, train_classes=range(50), prop_train_labels=0.8,
216 | split_train_val=False, seed=0):
217 |
218 | np.random.seed(seed)
219 |
220 | # Init entire training set
221 | whole_training_set = FGVCAircraft(root=aircraft_root, transform=train_transform, split='trainval')
222 |
223 | # Get labelled training set which has subsampled classes, then subsample some indices from that
224 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
225 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
226 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
227 |
228 | # Split into training and validation sets
229 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
230 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
231 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
232 | val_dataset_labelled_split.transform = test_transform
233 |
234 | # Get unlabelled data
235 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
236 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
237 |
238 | # Get test set for all classes
239 | test_dataset = FGVCAircraft(root=aircraft_root, transform=test_transform, split='test')
240 |
241 | # Either split train into train and val or use test set as val
242 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
243 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
244 |
245 | all_datasets = {
246 | 'train_labelled': train_dataset_labelled,
247 | 'train_unlabelled': train_dataset_unlabelled,
248 | 'val': val_dataset_labelled,
249 | 'test': test_dataset,
250 | }
251 |
252 | return all_datasets
253 |
254 | if __name__ == '__main__':
255 |
256 | x = get_aircraft_datasets(None, None, split_train_val=False)
257 |
258 | print('Printing lens...')
259 | for k, v in x.items():
260 | if v is not None:
261 | print(f'{k}: {len(v)}')
262 |
263 | print('Printing labelled and unlabelled overlap...')
264 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
265 | print('Printing total instances in train...')
266 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
267 | print('Printing number of labelled classes...')
268 | print(len(set([i[1] for i in x['train_labelled'].samples])))
269 | print('Printing total number of classes...')
270 | print(len(set([i[1] for i in x['train_unlabelled'].samples])))
271 |
--------------------------------------------------------------------------------
/data/get_datasets.py:
--------------------------------------------------------------------------------
1 | from data.data_utils import MergedDataset
2 |
3 | from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets
4 | from data.herbarium_19 import get_herbarium_datasets
5 | from data.stanford_cars import get_scars_datasets
6 | from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets
7 | from data.cub import get_cub_datasets
8 | from data.fgvc_aircraft import get_aircraft_datasets
9 |
10 | from copy import deepcopy
11 | import pickle
12 | import os
13 |
14 | from config import osr_split_dir
15 |
16 |
17 | get_dataset_funcs = {
18 | 'cifar10': get_cifar_10_datasets,
19 | 'cifar100': get_cifar_100_datasets,
20 | 'imagenet_100': get_imagenet_100_datasets,
21 | 'imagenet_1k': get_imagenet_1k_datasets,
22 | 'herbarium_19': get_herbarium_datasets,
23 | 'cub': get_cub_datasets,
24 | 'aircraft': get_aircraft_datasets,
25 | 'scars': get_scars_datasets
26 | }
27 |
28 |
29 | def get_datasets(dataset_name, train_transform, test_transform, args):
30 |
31 | """
32 | :return: train_dataset: MergedDataset which concatenates labelled and unlabelled
33 | test_dataset,
34 | unlabelled_train_examples_test,
35 | datasets
36 | """
37 |
38 | #
39 | if dataset_name not in get_dataset_funcs.keys():
40 | raise ValueError
41 |
42 | # Get datasets
43 | get_dataset_f = get_dataset_funcs[dataset_name]
44 | datasets = get_dataset_f(train_transform=train_transform, test_transform=test_transform,
45 | train_classes=args.train_classes,
46 | prop_train_labels=args.prop_train_labels,
47 | split_train_val=False)
48 | # Set target transforms:
49 | target_transform_dict = {}
50 | for i, cls in enumerate(list(args.train_classes) + list(args.unlabeled_classes)):
51 | target_transform_dict[cls] = i
52 | target_transform = lambda x: target_transform_dict[x]
53 |
54 | for dataset_name, dataset in datasets.items():
55 | if dataset is not None:
56 | dataset.target_transform = target_transform
57 |
58 | # Train split (labelled and unlabelled classes) for training
59 | train_dataset = MergedDataset(labelled_dataset=deepcopy(datasets['train_labelled']),
60 | unlabelled_dataset=deepcopy(datasets['train_unlabelled']))
61 |
62 | test_dataset = datasets['test']
63 | unlabelled_train_examples_test = deepcopy(datasets['train_unlabelled'])
64 | unlabelled_train_examples_test.transform = test_transform
65 |
66 | return train_dataset, test_dataset, unlabelled_train_examples_test, datasets
67 |
68 |
69 | def get_class_splits(args):
70 |
71 | # For FGVC datasets, optionally return bespoke splits
72 | if args.dataset_name in ('scars', 'cub', 'aircraft'):
73 | if hasattr(args, 'use_ssb_splits'):
74 | use_ssb_splits = args.use_ssb_splits
75 | else:
76 | use_ssb_splits = False
77 |
78 | # -------------
79 | # GET CLASS SPLITS
80 | # -------------
81 | if args.dataset_name == 'cifar10':
82 |
83 | args.image_size = 32
84 | args.train_classes = range(5)
85 | args.unlabeled_classes = range(5, 10)
86 |
87 | elif args.dataset_name == 'cifar100':
88 |
89 | args.image_size = 32
90 | args.train_classes = range(80)
91 | args.unlabeled_classes = range(80, 100)
92 |
93 | elif args.dataset_name == 'herbarium_19':
94 |
95 | args.image_size = 224
96 | herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl')
97 |
98 | with open(herb_path_splits, 'rb') as handle:
99 | class_splits = pickle.load(handle)
100 |
101 | args.train_classes = class_splits['Old']
102 | args.unlabeled_classes = class_splits['New']
103 |
104 | elif args.dataset_name == 'imagenet_100':
105 |
106 | args.image_size = 224
107 | args.train_classes = range(50)
108 | args.unlabeled_classes = range(50, 100)
109 |
110 | elif args.dataset_name == 'imagenet_1k':
111 |
112 | args.image_size = 224
113 | args.train_classes = range(500)
114 | args.unlabeled_classes = range(500, 1000)
115 |
116 | elif args.dataset_name == 'scars':
117 |
118 | args.image_size = 224
119 |
120 | if use_ssb_splits:
121 |
122 | split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
123 | with open(split_path, 'rb') as handle:
124 | class_info = pickle.load(handle)
125 |
126 | args.train_classes = class_info['known_classes']
127 | open_set_classes = class_info['unknown_classes']
128 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
129 |
130 | else:
131 |
132 | args.train_classes = range(98)
133 | args.unlabeled_classes = range(98, 196)
134 |
135 | elif args.dataset_name == 'aircraft':
136 |
137 | args.image_size = 224
138 | if use_ssb_splits:
139 |
140 | split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
141 | with open(split_path, 'rb') as handle:
142 | class_info = pickle.load(handle)
143 |
144 | args.train_classes = class_info['known_classes']
145 | open_set_classes = class_info['unknown_classes']
146 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
147 |
148 | else:
149 |
150 | args.train_classes = range(50)
151 | args.unlabeled_classes = range(50, 100)
152 |
153 | elif args.dataset_name == 'cub':
154 |
155 | args.image_size = 224
156 |
157 | if use_ssb_splits:
158 |
159 | split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
160 | with open(split_path, 'rb') as handle:
161 | class_info = pickle.load(handle)
162 |
163 | args.train_classes = class_info['known_classes']
164 | open_set_classes = class_info['unknown_classes']
165 | args.unlabeled_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']
166 |
167 | else:
168 |
169 | args.train_classes = range(100)
170 | args.unlabeled_classes = range(100, 200)
171 |
172 | else:
173 |
174 | raise NotImplementedError
175 |
176 | return args
177 |
--------------------------------------------------------------------------------
/data/herbarium_19.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torchvision
4 | import numpy as np
5 | from copy import deepcopy
6 |
7 | from data.data_utils import subsample_instances
8 | from config import herbarium_dataroot
9 |
10 | class HerbariumDataset19(torchvision.datasets.ImageFolder):
11 |
12 | def __init__(self, *args, **kwargs):
13 |
14 | # Process metadata json for training images into a DataFrame
15 | super().__init__(*args, **kwargs)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, idx):
20 |
21 | img, label = super().__getitem__(idx)
22 | uq_idx = self.uq_idxs[idx]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs):
28 |
29 | mask = np.zeros(len(dataset)).astype('bool')
30 | mask[idxs] = True
31 |
32 | dataset.samples = np.array(dataset.samples)[mask].tolist()
33 | dataset.targets = np.array(dataset.targets)[mask].tolist()
34 |
35 | dataset.uq_idxs = dataset.uq_idxs[mask]
36 |
37 | dataset.samples = [[x[0], int(x[1])] for x in dataset.samples]
38 | dataset.targets = [int(x) for x in dataset.targets]
39 |
40 | return dataset
41 |
42 |
43 | def subsample_classes(dataset, include_classes=range(250)):
44 |
45 | cls_idxs = [x for x, l in enumerate(dataset.targets) if l in include_classes]
46 |
47 | target_xform_dict = {}
48 | for i, k in enumerate(include_classes):
49 | target_xform_dict[k] = i
50 |
51 | dataset = subsample_dataset(dataset, cls_idxs)
52 |
53 | dataset.target_transform = lambda x: target_xform_dict[x]
54 |
55 | return dataset
56 |
57 |
58 | def get_train_val_indices(train_dataset, val_instances_per_class=5):
59 |
60 | train_classes = list(set(train_dataset.targets))
61 |
62 | # Get train/test indices
63 | train_idxs = []
64 | val_idxs = []
65 | for cls in train_classes:
66 |
67 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
68 |
69 | # Have a balanced test set
70 | v_ = np.random.choice(cls_idxs, replace=False, size=(val_instances_per_class,))
71 | t_ = [x for x in cls_idxs if x not in v_]
72 |
73 | train_idxs.extend(t_)
74 | val_idxs.extend(v_)
75 |
76 | return train_idxs, val_idxs
77 |
78 |
79 | def get_herbarium_datasets(train_transform, test_transform, train_classes=range(500), prop_train_labels=0.8,
80 | seed=0, split_train_val=False):
81 |
82 | np.random.seed(seed)
83 |
84 | # Init entire training set
85 | train_dataset = HerbariumDataset19(transform=train_transform,
86 | root=os.path.join(herbarium_dataroot, 'small-train'))
87 |
88 | # Get labelled training set which has subsampled classes, then subsample some indices from that
89 | # TODO: Subsampling unlabelled set in uniform random fashion from training data, will contain many instances of dominant class
90 | train_dataset_labelled = subsample_classes(deepcopy(train_dataset), include_classes=train_classes)
91 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
92 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
93 |
94 | # Split into training and validation sets
95 | if split_train_val:
96 |
97 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled,
98 | val_instances_per_class=5)
99 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
100 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
101 | val_dataset_labelled_split.transform = test_transform
102 |
103 | else:
104 |
105 | train_dataset_labelled_split, val_dataset_labelled_split = None, None
106 |
107 | # Get unlabelled data
108 | unlabelled_indices = set(train_dataset.uq_idxs) - set(train_dataset_labelled.uq_idxs)
109 | train_dataset_unlabelled = subsample_dataset(deepcopy(train_dataset), np.array(list(unlabelled_indices)))
110 |
111 | # Get test dataset
112 | test_dataset = HerbariumDataset19(transform=test_transform,
113 | root=os.path.join(herbarium_dataroot, 'small-validation'))
114 |
115 | # Transform dict
116 | unlabelled_classes = list(set(train_dataset.targets) - set(train_classes))
117 | target_xform_dict = {}
118 | for i, k in enumerate(list(train_classes) + unlabelled_classes):
119 | target_xform_dict[k] = i
120 |
121 | test_dataset.target_transform = lambda x: target_xform_dict[x]
122 | train_dataset_unlabelled.target_transform = lambda x: target_xform_dict[x]
123 |
124 | # Either split train into train and val or use test set as val
125 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
126 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
127 |
128 | all_datasets = {
129 | 'train_labelled': train_dataset_labelled,
130 | 'train_unlabelled': train_dataset_unlabelled,
131 | 'val': val_dataset_labelled,
132 | 'test': test_dataset,
133 | }
134 |
135 | return all_datasets
136 |
137 | if __name__ == '__main__':
138 |
139 | np.random.seed(0)
140 | train_classes = np.random.choice(range(683,), size=(int(683 / 2)), replace=False)
141 |
142 | x = get_herbarium_datasets(None, None, train_classes=train_classes,
143 | prop_train_labels=0.5)
144 |
145 | assert set(x['train_unlabelled'].targets) == set(range(683))
146 |
147 | print('Printing lens...')
148 | for k, v in x.items():
149 | if v is not None:
150 | print(f'{k}: {len(v)}')
151 |
152 | print('Printing labelled and unlabelled overlap...')
153 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
154 | print('Printing total instances in train...')
155 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
156 | print('Printing number of labelled classes...')
157 | print(len(set(x['train_labelled'].targets)))
158 | print('Printing total number of classes...')
159 | print(len(set(x['train_unlabelled'].targets)))
160 |
161 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
162 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
163 | print(f'Len labelled set: {len(x["train_labelled"])}')
164 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/imagenet.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import numpy as np
3 |
4 | import os
5 |
6 | from copy import deepcopy
7 | from data.data_utils import subsample_instances
8 | from config import imagenet_root
9 |
10 |
11 | class ImageNetBase(torchvision.datasets.ImageFolder):
12 |
13 | def __init__(self, root, transform):
14 |
15 | super(ImageNetBase, self).__init__(root, transform)
16 |
17 | self.uq_idxs = np.array(range(len(self)))
18 |
19 | def __getitem__(self, item):
20 |
21 | img, label = super().__getitem__(item)
22 | uq_idx = self.uq_idxs[item]
23 |
24 | return img, label, uq_idx
25 |
26 |
27 | def subsample_dataset(dataset, idxs):
28 |
29 | imgs_ = []
30 | for i in idxs:
31 | imgs_.append(dataset.imgs[i])
32 | dataset.imgs = imgs_
33 |
34 | samples_ = []
35 | for i in idxs:
36 | samples_.append(dataset.samples[i])
37 | dataset.samples = samples_
38 |
39 | # dataset.imgs = [x for i, x in enumerate(dataset.imgs) if i in idxs]
40 | # dataset.samples = [x for i, x in enumerate(dataset.samples) if i in idxs]
41 |
42 | dataset.targets = np.array(dataset.targets)[idxs].tolist()
43 | dataset.uq_idxs = dataset.uq_idxs[idxs]
44 |
45 | return dataset
46 |
47 |
48 | def subsample_classes(dataset, include_classes=list(range(1000))):
49 |
50 | cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]
51 |
52 | target_xform_dict = {}
53 | for i, k in enumerate(include_classes):
54 | target_xform_dict[k] = i
55 |
56 | dataset = subsample_dataset(dataset, cls_idxs)
57 | dataset.target_transform = lambda x: target_xform_dict[x]
58 |
59 | return dataset
60 |
61 |
62 | def get_train_val_indices(train_dataset, val_split=0.2):
63 |
64 | train_classes = list(set(train_dataset.targets))
65 |
66 | # Get train/test indices
67 | train_idxs = []
68 | val_idxs = []
69 | for cls in train_classes:
70 |
71 | cls_idxs = np.where(np.array(train_dataset.targets) == cls)[0]
72 |
73 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
74 | t_ = [x for x in cls_idxs if x not in v_]
75 |
76 | train_idxs.extend(t_)
77 | val_idxs.extend(v_)
78 |
79 | return train_idxs, val_idxs
80 |
81 |
82 | def get_imagenet_100_datasets(train_transform, test_transform, train_classes=range(80),
83 | prop_train_labels=0.8, split_train_val=False, seed=0):
84 |
85 | np.random.seed(seed)
86 |
87 | # Subsample imagenet dataset initially to include 100 classes
88 | subsampled_100_classes = np.random.choice(range(1000), size=(100,), replace=False)
89 | subsampled_100_classes = np.sort(subsampled_100_classes)
90 | print(f'Constructing ImageNet-100 dataset from the following classes: {subsampled_100_classes.tolist()}')
91 | cls_map = {i: j for i, j in zip(subsampled_100_classes, range(100))}
92 |
93 | # Init entire training set
94 | imagenet_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
95 | whole_training_set = subsample_classes(imagenet_training_set, include_classes=subsampled_100_classes)
96 |
97 | # Reset dataset
98 | whole_training_set.samples = [(s[0], cls_map[s[1]]) for s in whole_training_set.samples]
99 | whole_training_set.targets = [s[1] for s in whole_training_set.samples]
100 | whole_training_set.uq_idxs = np.array(range(len(whole_training_set)))
101 | whole_training_set.target_transform = None
102 |
103 | # Get labelled training set which has subsampled classes, then subsample some indices from that
104 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
105 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
106 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
107 |
108 | # Split into training and validation sets
109 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
110 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
111 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
112 | val_dataset_labelled_split.transform = test_transform
113 |
114 | # Get unlabelled data
115 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
116 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
117 |
118 | # Get test set for all classes
119 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
120 | test_dataset = subsample_classes(test_dataset, include_classes=subsampled_100_classes)
121 |
122 | # Reset test set
123 | test_dataset.samples = [(s[0], cls_map[s[1]]) for s in test_dataset.samples]
124 | test_dataset.targets = [s[1] for s in test_dataset.samples]
125 | test_dataset.uq_idxs = np.array(range(len(test_dataset)))
126 | test_dataset.target_transform = None
127 |
128 | # Either split train into train and val or use test set as val
129 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
130 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
131 |
132 | all_datasets = {
133 | 'train_labelled': train_dataset_labelled,
134 | 'train_unlabelled': train_dataset_unlabelled,
135 | 'val': val_dataset_labelled,
136 | 'test': test_dataset,
137 | }
138 |
139 | return all_datasets
140 |
141 |
142 | def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500),
143 | prop_train_labels=0.5, split_train_val=False, seed=0):
144 |
145 | np.random.seed(seed)
146 |
147 | # Init entire training set
148 | whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)
149 |
150 | # Get labelled training set which has subsampled classes, then subsample some indices from that
151 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
152 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
153 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
154 |
155 | # Split into training and validation sets
156 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
157 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
158 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
159 | val_dataset_labelled_split.transform = test_transform
160 |
161 | # Get unlabelled data
162 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
163 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
164 |
165 | # Get test set for all classes
166 | test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)
167 |
168 | # Either split train into train and val or use test set as val
169 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
170 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
171 |
172 | all_datasets = {
173 | 'train_labelled': train_dataset_labelled,
174 | 'train_unlabelled': train_dataset_unlabelled,
175 | 'val': val_dataset_labelled,
176 | 'test': test_dataset,
177 | }
178 |
179 | return all_datasets
180 |
181 |
182 |
183 | if __name__ == '__main__':
184 |
185 | x = get_imagenet_100_datasets(None, None, split_train_val=False,
186 | train_classes=range(50), prop_train_labels=0.5)
187 |
188 | print('Printing lens...')
189 | for k, v in x.items():
190 | if v is not None:
191 | print(f'{k}: {len(v)}')
192 |
193 | print('Printing labelled and unlabelled overlap...')
194 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
195 | print('Printing total instances in train...')
196 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
197 |
198 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
199 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
200 | print(f'Len labelled set: {len(x["train_labelled"])}')
201 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/data/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | from copy import deepcopy
5 | from scipy import io as mat_io
6 |
7 | from torchvision.datasets.folder import default_loader
8 | from torch.utils.data import Dataset
9 |
10 | from data.data_utils import subsample_instances
11 | from config import car_root
12 |
13 | class CarsDataset(Dataset):
14 | """
15 | Cars Dataset
16 | """
17 | def __init__(self, train=True, limit=0, data_dir=car_root, transform=None):
18 |
19 | metas = os.path.join(data_dir, 'devkit/cars_train_annos.mat') if train else os.path.join(data_dir, 'devkit/cars_test_annos_withlabels.mat')
20 | data_dir = os.path.join(data_dir, 'cars_train/') if train else os.path.join(data_dir, 'cars_test/')
21 |
22 | self.loader = default_loader
23 | self.data_dir = data_dir
24 | self.data = []
25 | self.target = []
26 | self.train = train
27 |
28 | self.transform = transform
29 |
30 | if not isinstance(metas, str):
31 | raise Exception("Train metas must be string location !")
32 | labels_meta = mat_io.loadmat(metas)
33 |
34 | for idx, img_ in enumerate(labels_meta['annotations'][0]):
35 | if limit:
36 | if idx > limit:
37 | break
38 |
39 | # self.data.append(img_resized)
40 | self.data.append(data_dir + img_[5][0])
41 | # if self.mode == 'train':
42 | self.target.append(img_[4][0][0])
43 |
44 | self.uq_idxs = np.array(range(len(self)))
45 | self.target_transform = None
46 |
47 | def __getitem__(self, idx):
48 |
49 | image = self.loader(self.data[idx])
50 | target = self.target[idx] - 1
51 |
52 | if self.transform is not None:
53 | image = self.transform(image)
54 |
55 | if self.target_transform is not None:
56 | target = self.target_transform(target)
57 |
58 | idx = self.uq_idxs[idx]
59 |
60 | return image, target, idx
61 |
62 | def __len__(self):
63 | return len(self.data)
64 |
65 |
66 | def subsample_dataset(dataset, idxs):
67 |
68 | dataset.data = np.array(dataset.data)[idxs].tolist()
69 | dataset.target = np.array(dataset.target)[idxs].tolist()
70 | dataset.uq_idxs = dataset.uq_idxs[idxs]
71 |
72 | return dataset
73 |
74 |
75 | def subsample_classes(dataset, include_classes=range(160)):
76 |
77 | include_classes_cars = np.array(include_classes) + 1 # SCars classes are indexed 1 --> 196 instead of 0 --> 195
78 | cls_idxs = [x for x, t in enumerate(dataset.target) if t in include_classes_cars]
79 |
80 | target_xform_dict = {}
81 | for i, k in enumerate(include_classes):
82 | target_xform_dict[k] = i
83 |
84 | dataset = subsample_dataset(dataset, cls_idxs)
85 |
86 | # dataset.target_transform = lambda x: target_xform_dict[x]
87 |
88 | return dataset
89 |
90 | def get_train_val_indices(train_dataset, val_split=0.2):
91 |
92 | train_classes = np.unique(train_dataset.target)
93 |
94 | # Get train/test indices
95 | train_idxs = []
96 | val_idxs = []
97 | for cls in train_classes:
98 |
99 | cls_idxs = np.where(train_dataset.target == cls)[0]
100 |
101 | v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
102 | t_ = [x for x in cls_idxs if x not in v_]
103 |
104 | train_idxs.extend(t_)
105 | val_idxs.extend(v_)
106 |
107 | return train_idxs, val_idxs
108 |
109 |
110 | def get_scars_datasets(train_transform, test_transform, train_classes=range(160), prop_train_labels=0.8,
111 | split_train_val=False, seed=0):
112 |
113 | np.random.seed(seed)
114 |
115 | # Init entire training set
116 | whole_training_set = CarsDataset(data_dir=car_root, transform=train_transform, train=True)
117 |
118 | # Get labelled training set which has subsampled classes, then subsample some indices from that
119 | train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
120 | subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
121 | train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)
122 |
123 | # Split into training and validation sets
124 | train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
125 | train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
126 | val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
127 | val_dataset_labelled_split.transform = test_transform
128 |
129 | # Get unlabelled data
130 | unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
131 | train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))
132 |
133 | # Get test set for all classes
134 | test_dataset = CarsDataset(data_dir=car_root, transform=test_transform, train=False)
135 |
136 | # Either split train into train and val or use test set as val
137 | train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
138 | val_dataset_labelled = val_dataset_labelled_split if split_train_val else None
139 |
140 | all_datasets = {
141 | 'train_labelled': train_dataset_labelled,
142 | 'train_unlabelled': train_dataset_unlabelled,
143 | 'val': val_dataset_labelled,
144 | 'test': test_dataset,
145 | }
146 |
147 | return all_datasets
148 |
149 | if __name__ == '__main__':
150 |
151 | x = get_scars_datasets(None, None, train_classes=range(98), prop_train_labels=0.5, split_train_val=False)
152 |
153 | print('Printing lens...')
154 | for k, v in x.items():
155 | if v is not None:
156 | print(f'{k}: {len(v)}')
157 |
158 | print('Printing labelled and unlabelled overlap...')
159 | print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
160 | print('Printing total instances in train...')
161 | print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))
162 |
163 | print(f'Num Labelled Classes: {len(set(x["train_labelled"].target))}')
164 | print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].target))}')
165 | print(f'Len labelled set: {len(x["train_labelled"])}')
166 | print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/models/__init__.py
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 |
7 |
8 | class SupConLoss(torch.nn.Module):
9 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
10 | It also supports the unsupervised contrastive loss in SimCLR
11 | From: https://github.com/HobbitLong/SupContrast"""
12 | def __init__(self, temperature=0.07, contrast_mode='all',
13 | base_temperature=0.07):
14 | super(SupConLoss, self).__init__()
15 | self.temperature = temperature
16 | self.contrast_mode = contrast_mode
17 | self.base_temperature = base_temperature
18 |
19 | def forward(self, features, labels=None, mask=None):
20 | """Compute loss for model. If both `labels` and `mask` are None,
21 | it degenerates to SimCLR unsupervised loss:
22 | https://arxiv.org/pdf/2002.05709.pdf
23 | Args:
24 | features: hidden vector of shape [bsz, n_views, ...].
25 | labels: ground truth of shape [bsz].
26 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
27 | has the same class as sample i. Can be asymmetric.
28 | Returns:
29 | A loss scalar.
30 | """
31 |
32 | device = (torch.device('cuda')
33 | if features.is_cuda
34 | else torch.device('cpu'))
35 |
36 | if len(features.shape) < 3:
37 | raise ValueError('`features` needs to be [bsz, n_views, ...],'
38 | 'at least 3 dimensions are required')
39 | if len(features.shape) > 3:
40 | features = features.view(features.shape[0], features.shape[1], -1)
41 |
42 | batch_size = features.shape[0]
43 | if labels is not None and mask is not None:
44 | raise ValueError('Cannot define both `labels` and `mask`')
45 | elif labels is None and mask is None:
46 | mask = torch.eye(batch_size, dtype=torch.float32).to(device)
47 | elif labels is not None:
48 | labels = labels.contiguous().view(-1, 1)
49 | if labels.shape[0] != batch_size:
50 | raise ValueError('Num of labels does not match num of features')
51 | mask = torch.eq(labels, labels.T).float().to(device)
52 | else:
53 | mask = mask.float().to(device)
54 |
55 | contrast_count = features.shape[1]
56 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
57 | if self.contrast_mode == 'one':
58 | anchor_feature = features[:, 0]
59 | anchor_count = 1
60 | elif self.contrast_mode == 'all':
61 | anchor_feature = contrast_feature
62 | anchor_count = contrast_count
63 | else:
64 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
65 |
66 | # compute logits
67 | anchor_dot_contrast = torch.div(
68 | torch.matmul(anchor_feature, contrast_feature.T),
69 | self.temperature)
70 |
71 | # for numerical stability
72 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
73 | logits = anchor_dot_contrast - logits_max.detach()
74 |
75 | # tile mask
76 | mask = mask.repeat(anchor_count, contrast_count)
77 | # mask-out self-contrast cases
78 | logits_mask = torch.scatter(
79 | torch.ones_like(mask),
80 | 1,
81 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
82 | 0
83 | )
84 | mask = mask * logits_mask
85 |
86 | # compute log_prob
87 | exp_logits = torch.exp(logits) * logits_mask
88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
89 |
90 | # compute mean of log-likelihood over positive
91 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
92 |
93 | # loss
94 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
95 | loss = loss.view(anchor_count, batch_size).mean()
96 |
97 | return loss
98 |
99 |
100 |
101 | def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'):
102 |
103 | b_ = 0.5 * int(features.size(0))
104 |
105 | labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0)
106 | labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
107 | labels = labels.to(device)
108 |
109 | features = F.normalize(features, dim=1)
110 |
111 | similarity_matrix = torch.matmul(features, features.T)
112 |
113 | # discard the main diagonal from both: labels and similarities matrix
114 | mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
115 | labels = labels[~mask].view(labels.shape[0], -1)
116 | similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
117 |
118 | # select and combine multiple positives
119 | positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
120 |
121 | # select only the negatives the negatives
122 | negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
123 |
124 | logits = torch.cat([positives, negatives], dim=1)
125 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
126 |
127 | logits = logits / temperature
128 | return logits, labels
129 |
130 |
131 | def entropy_regularization_loss(logits, temperature):
132 | avg_probs = (logits / temperature).softmax(dim=1).mean(dim=0)
133 | entropy_reg_loss = - torch.sum(torch.log(avg_probs**(-avg_probs))) + math.log(float(len(avg_probs)))
134 | return entropy_reg_loss
135 |
136 |
137 | def prototype_separation_loss(prototypes, temperature=0.1, base_temperature=0.1, device='cuda'):
138 | num_classes = prototypes.size(0)
139 | labels = torch.arange(0, num_classes).to(device)
140 | labels = labels.contiguous().view(-1, 1)
141 |
142 | mask = (1- torch.eq(labels, labels.T).float()).cuda()
143 |
144 | logits = torch.div(torch.matmul(prototypes, prototypes.T), temperature)
145 |
146 | mean_prob_neg = torch.log((mask * torch.exp(logits)).sum(1) / mask.sum(1))
147 | mean_prob_neg = mean_prob_neg[~torch.isnan(mean_prob_neg)]
148 |
149 | # loss
150 | loss = temperature / base_temperature * mean_prob_neg.mean()
151 |
152 | return loss
153 |
154 |
155 |
156 | class DistillLoss_ratio(nn.Module):
157 | def __init__(self, num_classes=100, wait_ratio_epochs=0, ramp_ratio_teacher_epochs=100,
158 | nepochs=200, ncrops=2, init_ratio=0.0, final_ratio=1.0,
159 | temp_logits=0.1, temp_teacher_logits=0.05, device='cuda'):
160 | super().__init__()
161 | self.device = device
162 | self.num_classes = num_classes
163 | self.temp_logits = temp_logits
164 | self.temp_teacher_logits = temp_teacher_logits
165 | self.ncrops = ncrops
166 | self.ratio_schedule = np.concatenate((
167 | np.zeros(wait_ratio_epochs),
168 | np.linspace(init_ratio,
169 | final_ratio, ramp_ratio_teacher_epochs),
170 | np.ones(nepochs - wait_ratio_epochs - ramp_ratio_teacher_epochs) * final_ratio
171 | ))
172 |
173 | def forward(self, student_output, teacher_output, epoch):
174 | """
175 | Cross-entropy between softmax outputs of the teacher and student networks.
176 | """
177 | student_out = student_output / self.temp_logits
178 | student_out = student_out.chunk(self.ncrops)
179 |
180 | # confidence filtering
181 | ratio_epoch = self.ratio_schedule[epoch]
182 | teacher_out = F.softmax(teacher_output / self.temp_teacher_logits, dim=-1)
183 | teacher_out = teacher_out.detach().chunk(self.ncrops)
184 |
185 | teacher_label = []
186 | for i in range(self.ncrops):
187 | top2 = torch.topk(teacher_out[i], k=2, dim=-1, largest=True)[0]
188 | top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
189 | filter_number = int(len(teacher_out[i]) * ratio_epoch)
190 | topk_filter = torch.topk(top2_div, k=filter_number, largest=True)[1]
191 | pseudo_label = F.one_hot(teacher_out[i].argmax(dim=-1), num_classes=self.num_classes)
192 | pseudo_label = pseudo_label.float()
193 | teacher_out[i][topk_filter] = pseudo_label[topk_filter]
194 | teacher_label.append(teacher_out[i])
195 |
196 | total_loss = 0
197 | n_loss_terms = 0
198 | for iq, q in enumerate(teacher_label):
199 | #for v in range(len(student_out)):
200 | for iv, v in enumerate(student_out):
201 | if iv == iq:
202 | # we skip cases where student and teacher operate on the same view
203 | continue
204 | #loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
205 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1)
206 | total_loss += loss.mean()
207 | n_loss_terms += 1
208 | total_loss /= n_loss_terms
209 | return total_loss
210 |
211 |
212 |
213 | class DistillLoss_ratio_ramp(nn.Module):
214 | def __init__(self, num_classes=100, wait_ratio_epochs=0, ramp_ratio_teacher_epochs=100,
215 | nepochs=200, ncrops=2, init_ratio=0.2, final_ratio=1.0,
216 | temp_logits=0.1, temp_teacher_logits_init=0.07, temp_teacher_logits_final=0.04, ramp_temp_teacher_epochs=30,
217 | device='cuda'):
218 | super().__init__()
219 | self.device = device
220 | self.num_classes = num_classes
221 | self.temp_logits = temp_logits
222 | # self.temp_teacher_logits_init = temp_teacher_logits_init
223 | # self.temp_teacher_logits_final = temp_teacher_logits_final
224 | self.ncrops = ncrops
225 | self.teacher_temp_schedule = np.concatenate((
226 | np.linspace(temp_teacher_logits_init,
227 | temp_teacher_logits_final, ramp_temp_teacher_epochs),
228 | np.ones(nepochs - ramp_temp_teacher_epochs) * temp_teacher_logits_final
229 | ))
230 | self.ratio_schedule = np.concatenate((
231 | np.zeros(wait_ratio_epochs),
232 | np.linspace(init_ratio,
233 | final_ratio, ramp_ratio_teacher_epochs),
234 | np.ones(nepochs - wait_ratio_epochs - ramp_ratio_teacher_epochs) * final_ratio
235 | ))
236 |
237 | def forward(self, student_output, teacher_output, epoch):
238 | """
239 | Cross-entropy between softmax outputs of the teacher and student networks.
240 | """
241 | student_out = student_output / self.temp_logits
242 | student_out = student_out.chunk(self.ncrops)
243 |
244 | # confidence filtering
245 | temp_teacher_epoch = self.teacher_temp_schedule[epoch]
246 | ratio_epoch = self.ratio_schedule[epoch]
247 | teacher_out = F.softmax(teacher_output / temp_teacher_epoch, dim=-1)
248 | teacher_out = teacher_out.detach().chunk(self.ncrops)
249 |
250 | teacher_label = []
251 | for i in range(self.ncrops):
252 | top2 = torch.topk(teacher_out[i], k=2, dim=-1, largest=True)[0]
253 | top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
254 | filter_number = int(len(teacher_out[i]) * ratio_epoch)
255 | topk_filter = torch.topk(top2_div, k=filter_number, largest=True)[1]
256 | pseudo_label = F.one_hot(teacher_out[i].argmax(dim=-1), num_classes=self.num_classes)
257 | pseudo_label = pseudo_label.float()
258 | teacher_out[i][topk_filter] = pseudo_label[topk_filter]
259 | teacher_label.append(teacher_out[i])
260 |
261 | total_loss = 0
262 | n_loss_terms = 0
263 | for iq, q in enumerate(teacher_label):
264 | #for v in range(len(student_out)):
265 | for iv, v in enumerate(student_out):
266 | if iv == iq:
267 | # we skip cases where student and teacher operate on the same view
268 | continue
269 | #loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
270 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1)
271 | total_loss += loss.mean()
272 | n_loss_terms += 1
273 | total_loss /= n_loss_terms
274 | return total_loss
275 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 |
7 | class DINOHead(nn.Module):
8 | def __init__(self, in_dim, out_dim, use_bn=False, init_prototypes=None,
9 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, num_labeled_classes=50):
10 | super().__init__()
11 | nlayers = max(nlayers, 1)
12 | if nlayers == 1:
13 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
14 | elif nlayers != 0:
15 | layers = [nn.Linear(in_dim, hidden_dim)]
16 | if use_bn:
17 | layers.append(nn.BatchNorm1d(hidden_dim))
18 | layers.append(nn.GELU())
19 | for _ in range(nlayers - 2):
20 | layers.append(nn.Linear(hidden_dim, hidden_dim))
21 | if use_bn:
22 | layers.append(nn.BatchNorm1d(hidden_dim))
23 | layers.append(nn.GELU())
24 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
25 | self.mlp = nn.Sequential(*layers)
26 | self.apply(self._init_weights)
27 |
28 | # prototypes
29 | self.prototype_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
30 | self.prototype_layer.weight_g.data.fill_(1)
31 | self.prototype_layer.weight_g.requires_grad = False
32 | print('prototype size: ', self.prototype_layer.weight_v.size())
33 |
34 | if init_prototypes is not None:
35 | print('initialize templates with labeled means and k-means centroids...')
36 | print(init_prototypes.size())
37 | print(init_prototypes)
38 | #self.prototype_layer.weight_v.data.copy_(init_prototypes)
39 | self.prototype_layer.weight_v.data[:num_labeled_classes].copy_(init_prototypes[:num_labeled_classes])
40 | print(self.prototype_layer.weight_v)
41 | else:
42 | print('randomly initialize prototypes...')
43 | print(self.prototype_layer.weight_v)
44 |
45 |
46 | def _init_weights(self, m):
47 | if isinstance(m, nn.Linear):
48 | torch.nn.init.trunc_normal_(m.weight, std=.02)
49 | if isinstance(m, nn.Linear) and m.bias is not None:
50 | nn.init.constant_(m.bias, 0)
51 |
52 | def forward(self, x):
53 | x_proj = self.mlp(x)
54 | x = F.normalize(x, dim=-1, p=2)
55 | # x = x.detach()
56 | logits = self.prototype_layer(x)
57 |
58 | prototypes = self.prototype_layer.weight_v.clone()
59 | normed_prototypes = F.normalize(prototypes, dim=-1, p=2)
60 |
61 | return x_proj, logits, normed_prototypes
62 |
63 |
64 |
65 | class DINOHead_k(nn.Module):
66 | '''
67 | DINOHead for estimating k.
68 | difference with DINOHead: `forward()`, return one more `x`
69 | date: 20230515
70 | '''
71 | def __init__(self, in_dim, out_dim, use_bn=False, init_prototypes=None,
72 | nlayers=3, hidden_dim=2048, bottleneck_dim=256, num_labeled_classes=50):
73 | super().__init__()
74 | nlayers = max(nlayers, 1)
75 | if nlayers == 1:
76 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
77 | elif nlayers != 0:
78 | layers = [nn.Linear(in_dim, hidden_dim)]
79 | if use_bn:
80 | layers.append(nn.BatchNorm1d(hidden_dim))
81 | layers.append(nn.GELU())
82 | for _ in range(nlayers - 2):
83 | layers.append(nn.Linear(hidden_dim, hidden_dim))
84 | if use_bn:
85 | layers.append(nn.BatchNorm1d(hidden_dim))
86 | layers.append(nn.GELU())
87 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
88 | self.mlp = nn.Sequential(*layers)
89 | self.apply(self._init_weights)
90 |
91 | # prototypes
92 | self.prototype_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
93 | self.prototype_layer.weight_g.data.fill_(1)
94 | self.prototype_layer.weight_g.requires_grad = False
95 | print('prototype size: ', self.prototype_layer.weight_v.size())
96 |
97 | if init_prototypes is not None:
98 | print('initialize templates with labeled means and k-means centroids...')
99 | print(init_prototypes.size())
100 | print(init_prototypes)
101 | #self.prototype_layer.weight_v.data.copy_(init_prototypes)
102 | self.prototype_layer.weight_v.data[:num_labeled_classes].copy_(init_prototypes[:num_labeled_classes])
103 | print(self.prototype_layer.weight_v)
104 | else:
105 | print('randomly initialize prototypes...')
106 | print(self.prototype_layer.weight_v)
107 |
108 |
109 | def _init_weights(self, m):
110 | if isinstance(m, nn.Linear):
111 | torch.nn.init.trunc_normal_(m.weight, std=.02)
112 | if isinstance(m, nn.Linear) and m.bias is not None:
113 | nn.init.constant_(m.bias, 0)
114 |
115 | def forward(self, x):
116 | x_proj = self.mlp(x)
117 | x = F.normalize(x, dim=-1, p=2)
118 | # x = x.detach()
119 | logits = self.prototype_layer(x)
120 |
121 | prototypes = self.prototype_layer.weight_v.clone()
122 | normed_prototypes = F.normalize(prototypes, dim=-1, p=2)
123 |
124 | return x, x_proj, logits, normed_prototypes
125 |
126 |
127 |
128 | class ContrastiveLearningViewGenerator(object):
129 | """Take two random crops of one image as the query and key."""
130 |
131 | def __init__(self, base_transform, n_views=2):
132 | self.base_transform = base_transform
133 | self.n_views = n_views
134 |
135 | def __call__(self, x):
136 | if not isinstance(self.base_transform, list):
137 | return [self.base_transform(x) for i in range(self.n_views)]
138 | else:
139 | return [self.base_transform[i](x) for i in range(self.n_views)]
140 |
141 |
142 | def get_params_groups(model):
143 | regularized = []
144 | not_regularized = []
145 | for name, param in model.named_parameters():
146 | if not param.requires_grad:
147 | continue
148 | # we do not regularize biases nor Norm parameters
149 | if name.endswith(".bias") or len(param.shape) == 1:
150 | not_regularized.append(param)
151 | else:
152 | regularized.append(param)
153 | return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
154 |
--------------------------------------------------------------------------------
/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Mostly copy-paste from timm library.
16 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17 | """
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.nn as nn
23 |
24 | import warnings
25 |
26 |
27 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
28 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
29 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
30 | def norm_cdf(x):
31 | # Computes standard normal cumulative distribution function
32 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
33 |
34 | if (mean < a - 2 * std) or (mean > b + 2 * std):
35 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
36 | "The distribution of values may be incorrect.",
37 | stacklevel=2)
38 |
39 | with torch.no_grad():
40 | # Values are generated by using a truncated uniform distribution and
41 | # then using the inverse CDF for the normal distribution.
42 | # Get upper and lower cdf values
43 | l = norm_cdf((a - mean) / std)
44 | u = norm_cdf((b - mean) / std)
45 |
46 | # Uniformly fill tensor with values from [l, u], then translate to
47 | # [2l-1, 2u-1].
48 | tensor.uniform_(2 * l - 1, 2 * u - 1)
49 |
50 | # Use inverse cdf transform for normal distribution to get truncated
51 | # standard normal
52 | tensor.erfinv_()
53 |
54 | # Transform to proper mean, std
55 | tensor.mul_(std * math.sqrt(2.))
56 | tensor.add_(mean)
57 |
58 | # Clamp to ensure it's in the proper range
59 | tensor.clamp_(min=a, max=b)
60 | return tensor
61 |
62 |
63 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
64 | # type: (Tensor, float, float, float, float) -> Tensor
65 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
66 |
67 |
68 | def drop_path(x, drop_prob: float = 0., training: bool = False):
69 | if drop_prob == 0. or not training:
70 | return x
71 | keep_prob = 1 - drop_prob
72 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
73 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
74 | random_tensor.floor_() # binarize
75 | output = x.div(keep_prob) * random_tensor
76 | return output
77 |
78 |
79 | class DropPath(nn.Module):
80 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
81 | """
82 | def __init__(self, drop_prob=None):
83 | super(DropPath, self).__init__()
84 | self.drop_prob = drop_prob
85 |
86 | def forward(self, x):
87 | return drop_path(x, self.drop_prob, self.training)
88 |
89 |
90 | class Mlp(nn.Module):
91 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
92 | super().__init__()
93 | out_features = out_features or in_features
94 | hidden_features = hidden_features or in_features
95 | self.fc1 = nn.Linear(in_features, hidden_features)
96 | self.act = act_layer()
97 | self.fc2 = nn.Linear(hidden_features, out_features)
98 | self.drop = nn.Dropout(drop)
99 |
100 | def forward(self, x):
101 | x = self.fc1(x)
102 | x = self.act(x)
103 | x = self.drop(x)
104 | x = self.fc2(x)
105 | x = self.drop(x)
106 | return x
107 |
108 |
109 | class Attention(nn.Module):
110 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
111 | super().__init__()
112 | self.num_heads = num_heads
113 | head_dim = dim // num_heads
114 | self.scale = qk_scale or head_dim ** -0.5
115 |
116 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
117 | self.attn_drop = nn.Dropout(attn_drop)
118 | self.proj = nn.Linear(dim, dim)
119 | self.proj_drop = nn.Dropout(proj_drop)
120 |
121 | def forward(self, x):
122 | B, N, C = x.shape
123 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124 | q, k, v = qkv[0], qkv[1], qkv[2]
125 |
126 | attn = (q @ k.transpose(-2, -1)) * self.scale
127 | attn = attn.softmax(dim=-1)
128 | attn = self.attn_drop(attn)
129 |
130 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
131 | x = self.proj(x)
132 | x = self.proj_drop(x)
133 | return x, attn
134 |
135 |
136 | class Block(nn.Module):
137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
139 | super().__init__()
140 | self.norm1 = norm_layer(dim)
141 | self.attn = Attention(
142 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
143 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144 | self.norm2 = norm_layer(dim)
145 | mlp_hidden_dim = int(dim * mlp_ratio)
146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147 |
148 | def forward(self, x, return_attention=False):
149 | y, attn = self.attn(self.norm1(x))
150 | if return_attention:
151 | return attn
152 | x = x + self.drop_path(y)
153 | x = x + self.drop_path(self.mlp(self.norm2(x)))
154 | return x
155 |
156 |
157 | class PatchEmbed(nn.Module):
158 | """ Image to Patch Embedding
159 | """
160 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
161 | super().__init__()
162 | num_patches = (img_size // patch_size) * (img_size // patch_size)
163 | self.img_size = img_size
164 | self.patch_size = patch_size
165 | self.num_patches = num_patches
166 |
167 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
168 |
169 | def forward(self, x):
170 | B, C, H, W = x.shape
171 | x = self.proj(x).flatten(2).transpose(1, 2)
172 | return x
173 |
174 |
175 | class VisionTransformer(nn.Module):
176 | """ Vision Transformer """
177 | def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
178 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
179 | drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
180 | super().__init__()
181 | self.num_features = self.embed_dim = embed_dim
182 |
183 | self.patch_embed = PatchEmbed(
184 | img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
185 | num_patches = self.patch_embed.num_patches
186 |
187 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
188 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
189 | self.pos_drop = nn.Dropout(p=drop_rate)
190 |
191 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
192 | self.blocks = nn.ModuleList([
193 | Block(
194 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
195 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
196 | for i in range(depth)])
197 | self.norm = norm_layer(embed_dim)
198 |
199 | # Classifier head
200 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
201 |
202 | trunc_normal_(self.pos_embed, std=.02)
203 | trunc_normal_(self.cls_token, std=.02)
204 | self.apply(self._init_weights)
205 |
206 | def _init_weights(self, m):
207 | if isinstance(m, nn.Linear):
208 | trunc_normal_(m.weight, std=.02)
209 | if isinstance(m, nn.Linear) and m.bias is not None:
210 | nn.init.constant_(m.bias, 0)
211 | elif isinstance(m, nn.LayerNorm):
212 | nn.init.constant_(m.bias, 0)
213 | nn.init.constant_(m.weight, 1.0)
214 |
215 | def interpolate_pos_encoding(self, x, w, h):
216 | npatch = x.shape[1] - 1
217 | N = self.pos_embed.shape[1] - 1
218 | if npatch == N and w == h:
219 | return self.pos_embed
220 | class_pos_embed = self.pos_embed[:, 0]
221 | patch_pos_embed = self.pos_embed[:, 1:]
222 | dim = x.shape[-1]
223 | w0 = w // self.patch_embed.patch_size
224 | h0 = h // self.patch_embed.patch_size
225 | # we add a small number to avoid floating point error in the interpolation
226 | # see discussion at https://github.com/facebookresearch/dino/issues/8
227 | w0, h0 = w0 + 0.1, h0 + 0.1
228 | patch_pos_embed = nn.functional.interpolate(
229 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
230 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
231 | mode='bicubic',
232 | )
233 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
234 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
235 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
236 |
237 | def prepare_tokens(self, x):
238 | B, nc, w, h = x.shape
239 | x = self.patch_embed(x) # patch linear embedding
240 |
241 | # add the [CLS] token to the embed patch tokens
242 | cls_tokens = self.cls_token.expand(B, -1, -1)
243 | x = torch.cat((cls_tokens, x), dim=1)
244 |
245 | # add positional encoding to each token
246 | x = x + self.interpolate_pos_encoding(x, w, h)
247 |
248 | return self.pos_drop(x)
249 |
250 | def forward(self, x):
251 | x = self.prepare_tokens(x)
252 | for blk in self.blocks:
253 | x = blk(x)
254 | x = self.norm(x)
255 | return x[:, 0]
256 |
257 | def get_last_selfattention(self, x):
258 | x = self.prepare_tokens(x)
259 | for i, blk in enumerate(self.blocks):
260 | if i < len(self.blocks) - 1:
261 | x = blk(x)
262 | else:
263 | # return attention of the last block
264 | return blk(x, return_attention=True)
265 |
266 | def get_intermediate_layers(self, x, n=1):
267 | x = self.prepare_tokens(x)
268 | # we return the output tokens from the `n` last blocks
269 | output = []
270 | for i, blk in enumerate(self.blocks):
271 | x = blk(x)
272 | if len(self.blocks) - i <= n:
273 | output.append(self.norm(x))
274 | return output
275 |
276 |
277 | def vit_tiny(patch_size=16, **kwargs):
278 | model = VisionTransformer(
279 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
280 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
281 | return model
282 |
283 |
284 | def vit_small(patch_size=16, **kwargs):
285 | model = VisionTransformer(
286 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
287 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
288 | return model
289 |
290 |
291 | def vit_base(patch_size=16, **kwargs):
292 | model = VisionTransformer(
293 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
294 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
295 | return model
296 |
297 |
298 | class DINOHead(nn.Module):
299 | def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
300 | super().__init__()
301 | nlayers = max(nlayers, 1)
302 | if nlayers == 1:
303 | self.mlp = nn.Linear(in_dim, bottleneck_dim)
304 | else:
305 | layers = [nn.Linear(in_dim, hidden_dim)]
306 | if use_bn:
307 | layers.append(nn.BatchNorm1d(hidden_dim))
308 | layers.append(nn.GELU())
309 | for _ in range(nlayers - 2):
310 | layers.append(nn.Linear(hidden_dim, hidden_dim))
311 | if use_bn:
312 | layers.append(nn.BatchNorm1d(hidden_dim))
313 | layers.append(nn.GELU())
314 | layers.append(nn.Linear(hidden_dim, bottleneck_dim))
315 | self.mlp = nn.Sequential(*layers)
316 | self.apply(self._init_weights)
317 | self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
318 | self.last_layer.weight_g.data.fill_(1)
319 | if norm_last_layer:
320 | self.last_layer.weight_g.requires_grad = False
321 |
322 | def _init_weights(self, m):
323 | if isinstance(m, nn.Linear):
324 | trunc_normal_(m.weight, std=.02)
325 | if isinstance(m, nn.Linear) and m.bias is not None:
326 | nn.init.constant_(m.bias, 0)
327 |
328 | def forward(self, x):
329 | x = self.mlp(x)
330 | x = nn.functional.normalize(x, dim=-1, p=2)
331 | x = self.last_layer(x)
332 | return x
--------------------------------------------------------------------------------
/my_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mashijie1028/ProtoGCD/8835a4d24662c65be125a42d815b52f62ae1482e/my_utils/__init__.py
--------------------------------------------------------------------------------
/my_utils/cluster_and_log_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | #import torch.distributed as dist
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment as linear_assignment
5 |
6 |
7 | # def all_sum_item(item):
8 | # item = torch.tensor(item).cuda()
9 | # dist.all_reduce(item)
10 | # return item.item()
11 |
12 |
13 | def old_cluster_acc(y_true, y_pred, return_ind=False):
14 | """
15 | https://github.com/sgvaze/generalized-category-discovery/blob/main/project_utils/cluster_utils.py#L39
16 | used ONLY for estimating # of novel categories in `estimate_k.py`
17 |
18 | Calculate clustering accuracy. Require scikit-learn installed
19 |
20 | # Arguments
21 | y: true labels, numpy.array with shape `(n_samples,)`
22 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
23 |
24 | # Return
25 | accuracy, in [0,1]
26 | """
27 | y_true = y_true.astype(int)
28 | assert y_pred.size == y_true.size
29 | D = max(y_pred.max(), y_true.max()) + 1
30 | w = np.zeros((D, D), dtype=int)
31 | for i in range(y_pred.size):
32 | w[y_pred[i], y_true[i]] += 1
33 |
34 | ind = linear_assignment(w.max() - w)
35 | ind = np.vstack(ind).T
36 |
37 | if return_ind:
38 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size, ind, w
39 | else:
40 | return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size
41 |
42 |
43 |
44 | def split_cluster_acc_v2(y_true, y_pred, mask):
45 | """
46 | Calculate clustering accuracy. Require scikit-learn installed
47 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
48 |
49 | # Arguments
50 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
51 | y: true labels, numpy.array with shape `(n_samples,)`
52 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
53 |
54 | # Return
55 | accuracy, in [0,1]
56 | """
57 | y_true = y_true.astype(int)
58 |
59 | old_classes_gt = set(y_true[mask])
60 | new_classes_gt = set(y_true[~mask])
61 |
62 | assert y_pred.size == y_true.size
63 | D = max(y_pred.max(), y_true.max()) + 1
64 | w = np.zeros((D, D), dtype=int)
65 | for i in range(y_pred.size):
66 | w[y_pred[i], y_true[i]] += 1
67 |
68 | ind = linear_assignment(w.max() - w)
69 | ind = np.vstack(ind).T
70 |
71 | ind_map = {j: i for i, j in ind}
72 | total_acc = sum([w[i, j] for i, j in ind])
73 | total_instances = y_pred.size
74 | # try:
75 | # if dist.get_world_size() > 0:
76 | # total_acc = all_sum_item(total_acc)
77 | # total_instances = all_sum_item(total_instances)
78 | # except:
79 | # pass
80 | total_acc /= total_instances
81 |
82 | old_acc = 0
83 | total_old_instances = 0
84 | for i in old_classes_gt:
85 | old_acc += w[ind_map[i], i]
86 | total_old_instances += sum(w[:, i])
87 |
88 | # try:
89 | # if dist.get_world_size() > 0:
90 | # old_acc = all_sum_item(old_acc)
91 | # total_old_instances = all_sum_item(total_old_instances)
92 | # except:
93 | # pass
94 | old_acc /= total_old_instances
95 |
96 | new_acc = 0
97 | total_new_instances = 0
98 | for i in new_classes_gt:
99 | new_acc += w[ind_map[i], i]
100 | total_new_instances += sum(w[:, i])
101 |
102 | # try:
103 | # if dist.get_world_size() > 0:
104 | # new_acc = all_sum_item(new_acc)
105 | # total_new_instances = all_sum_item(total_new_instances)
106 | # except:
107 | # pass
108 | new_acc /= total_new_instances
109 |
110 | return total_acc, old_acc, new_acc
111 |
112 |
113 | def split_cluster_acc_v2_balanced(y_true, y_pred, mask):
114 | """
115 | Calculate clustering accuracy. Require scikit-learn installed
116 | First compute linear assignment on all data, then look at how good the accuracy is on subsets
117 |
118 | # Arguments
119 | mask: Which instances come from old classes (True) and which ones come from new classes (False)
120 | y: true labels, numpy.array with shape `(n_samples,)`
121 | y_pred: predicted labels, numpy.array with shape `(n_samples,)`
122 |
123 | # Return
124 | accuracy, in [0,1]
125 | """
126 | y_true = y_true.astype(int)
127 |
128 | old_classes_gt = set(y_true[mask])
129 | new_classes_gt = set(y_true[~mask])
130 |
131 | assert y_pred.size == y_true.size
132 | D = max(y_pred.max(), y_true.max()) + 1
133 | w = np.zeros((D, D), dtype=int)
134 | for i in range(y_pred.size):
135 | w[y_pred[i], y_true[i]] += 1
136 |
137 | ind = linear_assignment(w.max() - w)
138 | ind = np.vstack(ind).T
139 |
140 | ind_map = {j: i for i, j in ind}
141 |
142 | old_acc = np.zeros(len(old_classes_gt))
143 | total_old_instances = np.zeros(len(old_classes_gt))
144 | for idx, i in enumerate(old_classes_gt):
145 | old_acc[idx] += w[ind_map[i], i]
146 | total_old_instances[idx] += sum(w[:, i])
147 |
148 | new_acc = np.zeros(len(new_classes_gt))
149 | total_new_instances = np.zeros(len(new_classes_gt))
150 | for idx, i in enumerate(new_classes_gt):
151 | new_acc[idx] += w[ind_map[i], i]
152 | total_new_instances[idx] += sum(w[:, i])
153 |
154 | # try:
155 | # if dist.get_world_size() > 0:
156 | # old_acc, new_acc = torch.from_numpy(old_acc).cuda(), torch.from_numpy(new_acc).cuda()
157 | # dist.all_reduce(old_acc), dist.all_reduce(new_acc)
158 | # dist.all_reduce(total_old_instances), dist.all_reduce(total_new_instances)
159 | # old_acc, new_acc = old_acc.cpu().numpy(), new_acc.cpu().numpy()
160 | # total_old_instances, total_new_instances = total_old_instances.cpu().numpy(), total_new_instances.cpu().numpy()
161 | # except:
162 | # pass
163 |
164 | total_acc = np.concatenate([old_acc, new_acc]) / np.concatenate([total_old_instances, total_new_instances])
165 | old_acc /= total_old_instances
166 | new_acc /= total_new_instances
167 | total_acc, old_acc, new_acc = total_acc.mean(), old_acc.mean(), new_acc.mean()
168 | return total_acc, old_acc, new_acc
169 |
170 |
171 | EVAL_FUNCS = {
172 | 'v2': split_cluster_acc_v2,
173 | 'v2b': split_cluster_acc_v2_balanced
174 | }
175 |
176 | def log_accs_from_preds(y_true, y_pred, mask, eval_funcs, save_name, T=None,
177 | print_output=True, args=None):
178 |
179 | """
180 | Given a list of evaluation functions to use (e.g ['v1', 'v2']) evaluate and log ACC results
181 |
182 | :param y_true: GT labels
183 | :param y_pred: Predicted indices
184 | :param mask: Which instances belong to Old and New classes
185 | :param T: Epoch
186 | :param eval_funcs: Which evaluation functions to use
187 | :param save_name: What are we evaluating ACC on
188 | :param writer: Tensorboard logger
189 | :return:
190 | """
191 |
192 | mask = mask.astype(bool)
193 | y_true = y_true.astype(int)
194 | y_pred = y_pred.astype(int)
195 |
196 | for i, f_name in enumerate(eval_funcs):
197 |
198 | acc_f = EVAL_FUNCS[f_name]
199 | all_acc, old_acc, new_acc = acc_f(y_true, y_pred, mask)
200 | log_name = f'{save_name}_{f_name}'
201 |
202 | if i == 0:
203 | to_return = (all_acc, old_acc, new_acc)
204 |
205 | if print_output:
206 | print_str = f'Epoch {T}, {log_name}: All {all_acc:.4f} | Old {old_acc:.4f} | New {new_acc:.4f}'
207 | # try:
208 | # if dist.get_rank() == 0:
209 | # try:
210 | # args.logger.info(print_str)
211 | # except:
212 | # print(print_str)
213 | # except:
214 | # pass
215 |
216 | return to_return
217 |
--------------------------------------------------------------------------------
/my_utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import inspect
4 |
5 | from datetime import datetime
6 | from loguru import logger
7 | import time
8 |
9 |
10 | class AverageMeter(object):
11 | """Computes and stores the average and current value"""
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 |
17 | self.val = 0
18 | self.avg = 0
19 | self.sum = 0
20 | self.count = 0
21 |
22 | def update(self, val, n=1):
23 |
24 | self.val = val
25 | self.sum += val * n
26 | self.count += n
27 | self.avg = self.sum / self.count
28 |
29 |
30 | def init_experiment(args, runner_name=None, exp_id=None):
31 | # Get filepath of calling script
32 | if runner_name is None:
33 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
34 |
35 | #root_dir = os.path.join(args.exp_root, *runner_name)
36 | root_dir = os.path.join(args.exp_root, args.dataset_name)
37 |
38 | if not os.path.exists(root_dir):
39 | os.makedirs(root_dir)
40 |
41 | # Either generate a unique experiment ID, or use one which is passed
42 | if exp_id is None:
43 |
44 | if args.exp_name is None:
45 | raise ValueError("Need to specify the experiment name")
46 | # Unique identifier for experiment
47 | # now = '{}_({:02d}.{:02d}.{}_|_'.format(args.exp_name, datetime.now().day, datetime.now().month, datetime.now().year) + \
48 | # datetime.now().strftime("%S.%f")[:-3] + ')'
49 | #now = args.exp_name + '_' + str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
50 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
51 |
52 | #log_dir = os.path.join(root_dir, 'log', now)
53 | log_dir = os.path.join(root_dir, now)
54 | while os.path.exists(log_dir):
55 | # now = '({:02d}.{:02d}.{}_|_'.format(datetime.now().day, datetime.now().month, datetime.now().year) + \
56 | # datetime.now().strftime("%S.%f")[:-3] + ')'
57 | #now = args.exp_name + '_' + str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
58 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
59 |
60 | #log_dir = os.path.join(root_dir, 'log', now)
61 | log_dir = os.path.join(root_dir, now)
62 |
63 | else:
64 |
65 | #log_dir = os.path.join(root_dir, 'log', f'{exp_id}')
66 | log_dir = os.path.join(root_dir, f'{exp_id}')
67 |
68 | if not os.path.exists(log_dir):
69 | os.makedirs(log_dir)
70 |
71 |
72 | #logger.add(os.path.join(log_dir, 'log.txt'))
73 | logger.add(os.path.join(log_dir, 'log.txt'), enqueue=True)
74 | args.logger = logger
75 | args.log_dir = log_dir
76 |
77 | # Instantiate directory to save models to
78 | model_root_dir = os.path.join(args.log_dir, 'checkpoints')
79 | if not os.path.exists(model_root_dir):
80 | os.mkdir(model_root_dir)
81 |
82 | args.model_dir = model_root_dir
83 | args.model_path = os.path.join(args.model_dir, 'model.pt')
84 |
85 | print(f'Experiment saved to: {args.log_dir}')
86 |
87 | hparam_dict = {}
88 |
89 | for k, v in vars(args).items():
90 | if isinstance(v, (int, float, str, bool, torch.Tensor)):
91 | hparam_dict[k] = v
92 |
93 | print(runner_name)
94 |
95 | # print and save args
96 | print(args)
97 | save_args_path = os.path.join(log_dir, 'args.txt')
98 | f_args = open(save_args_path, 'w')
99 | f_args.write('args: \n')
100 | f_args.write(str(vars(args)))
101 | f_args.close()
102 |
103 | return args
104 |
105 |
106 | # estimate # of novel class K
107 | def init_experiment_estimate_k(args, runner_name=None, exp_id=None):
108 | # Get filepath of calling script
109 | if runner_name is None:
110 | runner_name = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))).split(".")[-2:]
111 |
112 | #root_dir = os.path.join(args.exp_root, *runner_name)
113 | root_dir = os.path.join(args.exp_root, args.dataset_name)
114 |
115 | if not os.path.exists(root_dir):
116 | os.makedirs(root_dir)
117 |
118 | # Either generate a unique experiment ID, or use one which is passed
119 | if exp_id is None:
120 |
121 | if args.exp_name is None:
122 | raise ValueError("Need to specify the experiment name")
123 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
124 |
125 | log_dir = os.path.join(root_dir, now)
126 | now_k = now + '-k' + str(args.estimate_novel_k)
127 | log_dir = os.path.join(root_dir, now_k)
128 | while os.path.exists(log_dir):
129 | now = str(time.strftime("%Y%m%d-%H%M%S", time.localtime()))
130 |
131 | now_k = now + '-k' + str(args.estimate_novel_k)
132 | log_dir = os.path.join(root_dir, now_k)
133 |
134 | else:
135 |
136 | log_dir = os.path.join(root_dir, f'{exp_id}')
137 |
138 | if not os.path.exists(log_dir):
139 | os.makedirs(log_dir)
140 |
141 |
142 | #logger.add(os.path.join(log_dir, 'log.txt'))
143 | logger.add(os.path.join(log_dir, 'log.txt'), enqueue=True)
144 | args.logger = logger
145 | args.log_dir = log_dir
146 |
147 | print(f'Experiment saved to: {args.log_dir}')
148 |
149 | hparam_dict = {}
150 |
151 | for k, v in vars(args).items():
152 | if isinstance(v, (int, float, str, bool, torch.Tensor)):
153 | hparam_dict[k] = v
154 |
155 | print(runner_name)
156 |
157 | # print and save args
158 | print(args)
159 | save_args_path = os.path.join(log_dir, 'args.txt')
160 | f_args = open(save_args_path, 'w')
161 | f_args.write('args: \n')
162 | f_args.write(str(vars(args)))
163 | f_args.close()
164 |
165 | return args
166 |
167 |
168 |
169 | class DistributedWeightedSampler(torch.utils.data.distributed.DistributedSampler):
170 |
171 | def __init__(self, dataset, weights, num_samples, num_replicas=None, rank=None,
172 | replacement=True, generator=None):
173 | super(DistributedWeightedSampler, self).__init__(dataset, num_replicas, rank)
174 | if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
175 | num_samples <= 0:
176 | raise ValueError("num_samples should be a positive integer "
177 | "value, but got num_samples={}".format(num_samples))
178 | if not isinstance(replacement, bool):
179 | raise ValueError("replacement should be a boolean value, but got "
180 | "replacement={}".format(replacement))
181 | self.weights = torch.as_tensor(weights, dtype=torch.double)
182 | self.num_samples = num_samples
183 | self.replacement = replacement
184 | self.generator = generator
185 | self.weights = self.weights[self.rank::self.num_replicas]
186 | self.num_samples = self.num_samples // self.num_replicas
187 |
188 | def __iter__(self):
189 | rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
190 | rand_tensor = self.rank + rand_tensor * self.num_replicas
191 | yield from iter(rand_tensor.tolist())
192 |
193 | def __len__(self):
194 | return self.num_samples
195 |
--------------------------------------------------------------------------------
/my_utils/ood_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import sklearn.metrics as sk
5 | from tqdm import tqdm
6 |
7 | recall_level_default = 0.95
8 |
9 |
10 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
11 | """Use high precision for cumsum and check that final value matches sum
12 | Parameters
13 | ----------
14 | arr : array-like
15 | To be cumulatively summed as flat
16 | rtol : float
17 | Relative tolerance, see ``np.allclose``
18 | atol : float
19 | Absolute tolerance, see ``np.allclose``
20 | """
21 | out = np.cumsum(arr, dtype=np.float64)
22 | expected = np.sum(arr, dtype=np.float64)
23 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
24 | raise RuntimeError('cumsum was found to be unstable: '
25 | 'its last element does not correspond to sum')
26 | return out
27 |
28 |
29 | concat = lambda x: np.concatenate(x, axis=0)
30 | to_np = lambda x: x.data.cpu().numpy()
31 |
32 |
33 | def get_ood_scores_in(loader, model, args):
34 | _score = []
35 | _right_score = []
36 | _wrong_score = []
37 |
38 | with torch.no_grad():
39 | for batch_idx, (images, labels, _) in enumerate(tqdm(loader)):
40 |
41 | images = images.cuda(non_blocking=True)
42 | feats, _, logits, prototypes = model(images)
43 |
44 | #output = model(data)
45 | #smax = to_np(F.softmax(logits, dim=1))
46 | smax = to_np(F.softmax(logits / args.temp_logits, dim=1)) # NOTE!!!
47 | output_ = to_np(logits)
48 |
49 | # if args.use_xent:
50 | # _score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
51 | # else:
52 | # _score.append(-np.max(smax, axis=1))
53 |
54 | if args.score == 'energy':
55 | _score.append(-to_np((args.T * torch.logsumexp(logits / args.T, dim=1))))
56 | elif args.score == 'mls':
57 | _score.append(-np.max(output_, axis=1))
58 | elif args.score == 'xent':
59 | #_score.append(to_np((logits.mean(1) - torch.logsumexp(logits, dim=1))))
60 | _score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))) # NOTE!!!
61 | elif args.score == 'proto':
62 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0]
63 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
64 | smax_sort = np.sort(smax, axis=1)
65 | top2_div = smax_sort[:, -1] / (smax_sort[:, -2] + 1e-6)
66 | _score.append(-top2_div) # NOTE!!!
67 | elif args.score == 'margin':
68 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0]
69 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
70 | smax_sort = np.sort(smax, axis=1)
71 | top2_margin = smax_sort[:, -1] - smax_sort[:, -2]
72 | _score.append(-top2_margin) # NOTE!!!
73 | else:
74 | _score.append(-np.max(smax, axis=1))
75 |
76 | preds = np.argmax(smax, axis=1)
77 | targets = labels.numpy().squeeze()
78 | right_indices = preds == targets
79 | wrong_indices = np.invert(right_indices)
80 |
81 | if args.score == 'xent':
82 | _right_score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))[right_indices])
83 | _wrong_score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))[wrong_indices])
84 | else:
85 | _right_score.append(-np.max(smax[right_indices], axis=1))
86 | _wrong_score.append(-np.max(smax[wrong_indices], axis=1))
87 |
88 | return concat(_score).copy(), concat(_right_score).copy(), concat(_wrong_score).copy()
89 |
90 |
91 | def get_ood_scores(loader, model, ood_num_examples, args):
92 | _score = []
93 |
94 | with torch.no_grad():
95 | for batch_idx, (images, labels) in enumerate(loader):
96 | if batch_idx >= ood_num_examples // args.batch_size:
97 | break
98 |
99 | images = images.cuda(non_blocking=True)
100 | feats, _, logits, prototypes = model(images)
101 |
102 | #output = model(data)
103 | #smax = to_np(F.softmax(logits, dim=1))
104 | smax = to_np(F.softmax(logits / args.temp_logits, dim=1)) # NOTE!!!
105 | output_ = to_np(logits)
106 |
107 | # if args.use_xent:
108 | # _score.append(to_np((output.mean(1) - torch.logsumexp(output, dim=1))))
109 | # else:
110 | # _score.append(-np.max(smax, axis=1))
111 |
112 | if args.score == 'energy':
113 | _score.append(-to_np((args.T * torch.logsumexp(logits / args.T, dim=1))))
114 | elif args.score == 'mls':
115 | _score.append(-np.max(output_, axis=1))
116 | elif args.score == 'xent':
117 | #_score.append(to_np((logits.mean(1) - torch.logsumexp(logits, dim=1))))
118 | _score.append(to_np((logits.mean(1) / args.temp_logits - torch.logsumexp(logits / args.temp_logits, dim=1)))) # NOTE!!!
119 | elif args.score == 'proto':
120 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0]
121 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
122 | smax_sort = np.sort(smax, axis=1)
123 | top2_div = smax_sort[:, -1] / (smax_sort[:, -2] + 1e-6)
124 | _score.append(-top2_div) # NOTE!!!
125 | elif args.score == 'margin':
126 | #top2 = np.topk(smax, k=2, dim=-1, largest=True)[0]
127 | #top2_div = top2[:, 0] / (top2[:, 1] + 1e-6)
128 | smax_sort = np.sort(smax, axis=1)
129 | top2_margin = smax_sort[:, -1] - smax_sort[:, -2]
130 | _score.append(-top2_margin) # NOTE!!!
131 | else:
132 | _score.append(-np.max(smax, axis=1))
133 |
134 | return concat(_score)[:ood_num_examples].copy()
135 |
136 |
137 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=recall_level_default, pos_label=None):
138 | classes = np.unique(y_true)
139 | if (pos_label is None and
140 | not (np.array_equal(classes, [0, 1]) or
141 | np.array_equal(classes, [-1, 1]) or
142 | np.array_equal(classes, [0]) or
143 | np.array_equal(classes, [-1]) or
144 | np.array_equal(classes, [1]))):
145 | raise ValueError("Data is not binary and pos_label is not specified")
146 | elif pos_label is None:
147 | pos_label = 1.
148 |
149 | # make y_true a boolean vector
150 | y_true = (y_true == pos_label)
151 |
152 | # sort scores and corresponding truth values
153 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
154 | y_score = y_score[desc_score_indices]
155 | y_true = y_true[desc_score_indices]
156 |
157 | # y_score typically has many tied values. Here we extract
158 | # the indices associated with the distinct values. We also
159 | # concatenate a value for the end of the curve.
160 | distinct_value_indices = np.where(np.diff(y_score))[0]
161 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
162 |
163 | # accumulate the true positives with decreasing threshold
164 | tps = stable_cumsum(y_true)[threshold_idxs]
165 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing
166 |
167 | thresholds = y_score[threshold_idxs]
168 |
169 | recall = tps / tps[-1]
170 |
171 | last_ind = tps.searchsorted(tps[-1])
172 | sl = slice(last_ind, None, -1) # [last_ind::-1]
173 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]
174 |
175 | cutoff = np.argmin(np.abs(recall - recall_level))
176 |
177 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) # , fps[cutoff]/(fps[cutoff] + tps[cutoff])
178 |
179 |
180 | def get_measures(_pos, _neg, recall_level=recall_level_default):
181 | pos = np.array(_pos[:]).reshape((-1, 1))
182 | neg = np.array(_neg[:]).reshape((-1, 1))
183 | examples = np.squeeze(np.vstack((pos, neg)))
184 | labels = np.zeros(len(examples), dtype=np.int32)
185 | labels[:len(pos)] += 1
186 |
187 | auroc = sk.roc_auc_score(labels, examples)
188 | aupr = sk.average_precision_score(labels, examples)
189 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level)
190 |
191 | return auroc, aupr, fpr
192 |
193 |
194 | def print_measures(auroc, aupr_in, aupr_out, fpr_in, fpr_out, recall_level=recall_level_default):
195 | print('FPR(IN){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_in))
196 | print('FPR(OUT){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_out))
197 | print('AUROC: {:.2f}'.format(100 * auroc))
198 | print('AUPR(IN): {:.2f}'.format(100 * aupr_in))
199 | print('AUPR(OUT): {:.2f}'.format(100 * aupr_out))
200 |
201 |
202 | def write_measures(auroc, aupr_in, aupr_out, fpr_in, fpr_out, file_path, recall_level=recall_level_default):
203 | with open(file_path, 'a+') as f_log:
204 | f_log.write('FPR(IN){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_in))
205 | f_log.write('\n')
206 | f_log.write('FPR(OUT){:d}: {:.2f}'.format(int(100 * recall_level), 100 * fpr_out))
207 | f_log.write('\n')
208 | f_log.write('AUROC: {:.2f}'.format(100 * auroc))
209 | f_log.write('\n')
210 | f_log.write('AUPR(IN): {:.2f}'.format(100 * aupr_in))
211 | f_log.write('\n')
212 | f_log.write('AUPR(OUT): {:.2f}'.format(100 * aupr_out))
213 | f_log.write('\n')
214 |
215 |
216 | def print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, recall_level=recall_level_default):
217 | print('FPR(IN){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_in), 100*np.std(fprs_in)))
218 | print('FPR(OUT){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_out), 100*np.std(fprs_out)))
219 | print('AUROC: {:.2f} +/- {:.2f}'.format(100*np.mean(aurocs), 100*np.std(aurocs)))
220 | print('AUPR(IN): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_in), 100*np.std(auprs_in)))
221 | print('AUPR(OUT): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_out), 100*np.std(auprs_out)))
222 |
223 |
224 | def write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path, recall_level=recall_level_default):
225 | with open(file_path, 'a+') as f_log:
226 | f_log.write('FPR(IN){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_in), 100*np.std(fprs_in)))
227 | f_log.write('\n')
228 | f_log.write('FPR(OUT){:d}: {:.2f} +/- {:.2f}'.format(int(100*recall_level), 100*np.mean(fprs_out), 100*np.std(fprs_out)))
229 | f_log.write('\n')
230 | f_log.write('AUROC: {:.2f} +/- {:.2f}'.format(100*np.mean(aurocs), 100*np.std(aurocs)))
231 | f_log.write('\n')
232 | f_log.write('AUPR(IN): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_in), 100*np.std(auprs_in)))
233 | f_log.write('\n')
234 | f_log.write('AUPR(OUT): {:.2f} +/- {:.2f}'.format(100*np.mean(auprs_out), 100*np.std(auprs_out)))
235 | f_log.write('\n')
236 |
--------------------------------------------------------------------------------
/test_ood_cifar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.utils.data import Dataset
9 | from torchvision import datasets, transforms
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from data.augmentations import get_transform
14 | from data.get_datasets import get_datasets, get_class_splits
15 |
16 | from config import exp_root
17 | from models.model import DINOHead_k
18 | from models.model import ContrastiveLearningViewGenerator, get_params_groups
19 | from my_utils.ood_utils import get_ood_scores_in, get_ood_scores, get_measures, print_measures, write_measures, print_measures_with_std, write_measures_with_std
20 |
21 |
22 | def get_and_print_results(ood_loader, model, in_score, args):
23 | aurocs, auprs_in, auprs_out, fprs_in, fprs_out = [], [], [], [], []
24 |
25 | for _ in range(args.num_to_avg):
26 | out_score = get_ood_scores(ood_loader, model, OOD_NUM_EXAMPLES, args)
27 | measures_in = get_measures(-in_score, -out_score)
28 | measures_out = get_measures(out_score, in_score) # OE's defines out samples as positive
29 |
30 | auroc = measures_in[0]; aupr_in = measures_in[1]; aupr_out = measures_out[1]; fpr_in = measures_in[2]; fpr_out = measures_out[2]
31 | aurocs.append(auroc); auprs_in.append(aupr_in); auprs_out.append(aupr_out); fprs_in.append(fpr_in); fprs_out.append(fpr_out)
32 |
33 | if args.num_to_avg >= 5:
34 | print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out)
35 | write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path=args.ood_log_path)
36 | else:
37 | print_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out))
38 | write_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out), file_path=args.ood_log_path)
39 |
40 | return (auroc, aupr_in, aupr_out, fpr_in, fpr_out)
41 |
42 |
43 | if __name__ == "__main__":
44 |
45 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
46 | parser.add_argument('--batch_size', default=128, type=int)
47 | parser.add_argument('--num_workers', default=4, type=int)
48 |
49 | parser.add_argument('--warmup_model_dir', type=str, default=None)
50 | parser.add_argument('--dataset_name', type=str, default='cifar10', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
51 | parser.add_argument('--ckpts_date', type=str, default=None)
52 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
53 | parser.add_argument('--use_ssb_splits', action='store_true', default=True)
54 | #parser.add_argument('--init_prototypes', action='store_true', default=False)
55 |
56 | #parser.add_argument('--grad_from_block', type=int, default=11)
57 | parser.add_argument('--exp_root', type=str, default=exp_root)
58 | parser.add_argument('--ood_log_path', type=str, default='OOD_results')
59 | parser.add_argument('--transform', type=str, default='imagenet')
60 | parser.add_argument('--n_views', default=2, type=int)
61 |
62 | parser.add_argument('--score', type=str, default='msp', help='OOD detection score function: [msp, mls, energy, xent]')
63 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature')
64 | parser.add_argument('--T', default=1., type=float, help='temperature: energy|Odin')
65 | parser.add_argument('--num_to_avg', type=int, default=10, help='Average measures across num_to_avg runs.')
66 |
67 | # ----------------------
68 | # INIT
69 | # ----------------------
70 | args = parser.parse_args()
71 | device = torch.device('cuda:0')
72 | args = get_class_splits(args)
73 |
74 | args.num_labeled_classes = len(args.train_classes)
75 | args.num_unlabeled_classes = len(args.unlabeled_classes)
76 |
77 | #init_experiment(args, runner_name=['ProtoGCD'])
78 | #args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
79 | args.ood_log_path = os.path.join(args.ood_log_path, args.dataset_name)
80 | if not os.path.exists(args.ood_log_path):
81 | os.makedirs(args.ood_log_path)
82 | args.ood_log_path = os.path.join(args.ood_log_path, args.ckpts_date + '-' + args.score + '-T' + str(args.temp_logits) + '.txt')
83 |
84 | torch.backends.cudnn.benchmark = True
85 |
86 | # ----------------------
87 | # BASE MODEL
88 | # ----------------------
89 | args.interpolation = 3
90 | args.crop_pct = 0.875
91 |
92 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
93 |
94 | # if args.warmup_model_dir is not None:
95 | # args.logger.info(f'Loading weights from {args.warmup_model_dir}')
96 | # backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
97 |
98 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
99 | args.image_size = 224
100 | args.feat_dim = 768
101 | args.num_mlp_layers = 3
102 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
103 |
104 |
105 | # --------------------
106 | # CONTRASTIVE TRANSFORM
107 | # --------------------
108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
110 | # --------------------
111 | # DATASETS
112 | # --------------------
113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets_ = get_datasets(args.dataset_name,
114 | train_transform,
115 | test_transform,
116 | args)
117 |
118 | # --------------------
119 | # SAMPLER
120 | # Sampler which balances labelled and unlabelled examples in each batch
121 | # --------------------
122 | label_len = len(train_dataset.labelled_dataset)
123 | unlabelled_len = len(train_dataset.unlabelled_dataset)
124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
125 | sample_weights = torch.DoubleTensor(sample_weights)
126 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
127 |
128 | # --------------------
129 | # DATALOADERS
130 | # --------------------
131 | # train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
132 | # sampler=sampler, drop_last=True, pin_memory=True)
133 | # test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
134 | # batch_size=256, shuffle=False, pin_memory=False)
135 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
136 | batch_size=256, shuffle=False, pin_memory=False)
137 |
138 | OOD_NUM_EXAMPLES = len(test_dataset) // 5 # NOTE! NOT test_loader_labelled!
139 | print(OOD_NUM_EXAMPLES)
140 |
141 | # ----------------------
142 | # PROJECTION HEAD
143 | # ----------------------
144 | projector = DINOHead_k(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers,
145 | init_prototypes=None, num_labeled_classes=args.num_labeled_classes)
146 | model = nn.Sequential(backbone, projector).to(device)
147 |
148 | ckpts_base_path = '/lustre/home/sjma/GCD-project/protoGCD-v7/dev_outputs_fix/'
149 | ckpts_path = os.path.join(ckpts_base_path, args.dataset_name, args.ckpts_date, 'checkpoints', 'model_best.pt')
150 | ckpts = torch.load(ckpts_path)
151 | ckpts = ckpts['model']
152 | print('loading ckpts from %s...' % ckpts_path)
153 | model.load_state_dict(ckpts)
154 | print('successfully load ckpts')
155 | model.eval()
156 |
157 |
158 | # ----------------------
159 | # TEST OOD
160 | # ----------------------
161 | print('Using %s as typical data' % args.dataset_name)
162 | with open(args.ood_log_path, 'w+') as f_log:
163 | f_log.write('Using %s as typical data' % args.dataset_name)
164 | f_log.write('\n')
165 |
166 | print(test_transform)
167 |
168 | # ID score
169 | #test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
170 | in_score, right_score, wrong_score = get_ood_scores_in(test_loader_labelled, model, args)
171 |
172 |
173 | # Textures
174 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/dtd/images", transform=test_transform)
175 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
176 | print('\n\nTexture Detection')
177 | with open(args.ood_log_path, 'a+') as f_log:
178 | f_log.write('\n\nTexture Detection')
179 | f_log.write('\n')
180 | get_and_print_results(ood_loader, model, in_score, args)
181 |
182 |
183 | # SVHN
184 | ood_data = datasets.SVHN('/data4/sjma/dataset/SVHN/', split='test', download=False, transform=test_transform)
185 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
186 | print('\n\nSVHN Detection')
187 | with open(args.ood_log_path, 'a+') as f_log:
188 | f_log.write('\n\nSVHN Detection')
189 | f_log.write('\n')
190 | get_and_print_results(ood_loader, model, in_score, args)
191 |
192 | # Places
193 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/places365", transform=test_transform)
194 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
195 | print('\n\nPlaces Detection')
196 | with open(args.ood_log_path, 'a+') as f_log:
197 | f_log.write('\n\nPlaces Detection')
198 | f_log.write('\n')
199 | get_and_print_results(ood_loader, model, in_score, args)
200 |
201 |
202 | # TinyImageNet-R
203 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/Imagenet_resize", transform=test_transform)
204 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
205 | print('\n\nTinyImageNet-resize Detection')
206 | with open(args.ood_log_path, 'a+') as f_log:
207 | f_log.write('\n\nTinyImageNet-resize Detection')
208 | f_log.write('\n')
209 | get_and_print_results(ood_loader, model, in_score, args)
210 |
211 |
212 | # LSUN-R
213 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/LSUN_resize", transform=test_transform)
214 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
215 | print('\n\nLSUN-resize Detection')
216 | with open(args.ood_log_path, 'a+') as f_log:
217 | f_log.write('\n\nLSUN-resize Detection')
218 | f_log.write('\n')
219 | get_and_print_results(ood_loader, model, in_score, args)
220 |
221 |
222 | # iSUN
223 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/iSUN", transform=test_transform)
224 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
225 | print('\n\niSUN Detection')
226 | with open(args.ood_log_path, 'a+') as f_log:
227 | f_log.write('\n\niSUN Detection')
228 | f_log.write('\n')
229 | get_and_print_results(ood_loader, model, in_score, args)
230 |
231 |
232 | # CIFAR data
233 | if args.dataset_name == 'cifar10':
234 | ood_data = datasets.CIFAR100('/data4/sjma/dataset/CIFAR/', train=False, transform=test_transform)
235 | else:
236 | ood_data = datasets.CIFAR10('/data4/sjma/dataset/CIFAR/', train=False, transform=test_transform)
237 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
238 | print('\n\nCIFAR-100 Detection') if args.dataset_name == 'cifar10' else print('\n\nCIFAR-10 Detection')
239 | with open(args.ood_log_path, 'a+') as f_log:
240 | f_log.write('\n\nCIFAR-100 Detection') if args.dataset_name == 'cifar10' else f_log.write('\n\nCIFAR-10 Detection')
241 | f_log.write('\n')
242 | get_and_print_results(ood_loader, model, in_score, args)
243 |
--------------------------------------------------------------------------------
/test_ood_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.utils.data import Dataset
9 | from torchvision import datasets, transforms
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from data.augmentations import get_transform
14 | from data.get_datasets import get_datasets, get_class_splits
15 |
16 | from config import exp_root
17 | from models.model import DINOHead_k
18 | from models.model import ContrastiveLearningViewGenerator, get_params_groups
19 | from my_utils.ood_utils import get_ood_scores_in, get_ood_scores, get_measures, print_measures, write_measures, print_measures_with_std, write_measures_with_std
20 |
21 |
22 | def get_and_print_results(ood_loader, model, in_score, args):
23 | aurocs, auprs_in, auprs_out, fprs_in, fprs_out = [], [], [], [], []
24 |
25 | for _ in range(args.num_to_avg):
26 | out_score = get_ood_scores(ood_loader, model, OOD_NUM_EXAMPLES, args)
27 | measures_in = get_measures(-in_score, -out_score)
28 | measures_out = get_measures(out_score, in_score) # OE's defines out samples as positive
29 |
30 | auroc = measures_in[0]; aupr_in = measures_in[1]; aupr_out = measures_out[1]; fpr_in = measures_in[2]; fpr_out = measures_out[2]
31 | aurocs.append(auroc); auprs_in.append(aupr_in); auprs_out.append(aupr_out); fprs_in.append(fpr_in); fprs_out.append(fpr_out)
32 |
33 | if args.num_to_avg >= 5:
34 | print_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out)
35 | write_measures_with_std(aurocs, auprs_in, auprs_out, fprs_in, fprs_out, file_path=args.ood_log_path)
36 | else:
37 | print_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out))
38 | write_measures(np.mean(aurocs), np.mean(auprs_in), np.mean(auprs_out), np.mean(fprs_in), np.mean(fprs_out), file_path=args.ood_log_path)
39 |
40 | return (auroc, aupr_in, aupr_out, fpr_in, fpr_out)
41 |
42 |
43 | if __name__ == "__main__":
44 |
45 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
46 | parser.add_argument('--batch_size', default=128, type=int)
47 | parser.add_argument('--num_workers', default=4, type=int)
48 |
49 | parser.add_argument('--warmup_model_dir', type=str, default=None)
50 | parser.add_argument('--dataset_name', type=str, default='imagenet_100', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
51 | parser.add_argument('--ckpts_date', type=str, default=None)
52 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
53 | parser.add_argument('--use_ssb_splits', action='store_true', default=True)
54 | #parser.add_argument('--init_prototypes', action='store_true', default=False)
55 |
56 | #parser.add_argument('--grad_from_block', type=int, default=11)
57 | parser.add_argument('--exp_root', type=str, default=exp_root)
58 | parser.add_argument('--ood_log_path', type=str, default='OOD_results')
59 | parser.add_argument('--transform', type=str, default='imagenet')
60 | parser.add_argument('--n_views', default=2, type=int)
61 |
62 | parser.add_argument('--score', type=str, default='msp', help='OOD detection score function: [msp, mls, energy, xent]')
63 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature')
64 | parser.add_argument('--T', default=1., type=float, help='temperature: energy|Odin')
65 | parser.add_argument('--num_to_avg', type=int, default=10, help='Average measures across num_to_avg runs.')
66 |
67 | # ----------------------
68 | # INIT
69 | # ----------------------
70 | args = parser.parse_args()
71 | device = torch.device('cuda:0')
72 | args = get_class_splits(args)
73 |
74 | args.num_labeled_classes = len(args.train_classes)
75 | args.num_unlabeled_classes = len(args.unlabeled_classes)
76 |
77 | #init_experiment(args, runner_name=['ProtoGCD'])
78 | #args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
79 | args.ood_log_path = os.path.join(args.ood_log_path, args.dataset_name)
80 | if not os.path.exists(args.ood_log_path):
81 | os.makedirs(args.ood_log_path)
82 | args.ood_log_path = os.path.join(args.ood_log_path, args.ckpts_date + '-' + args.score + '-T' + str(args.temp_logits) + '.txt')
83 |
84 | torch.backends.cudnn.benchmark = True
85 |
86 | # ----------------------
87 | # BASE MODEL
88 | # ----------------------
89 | args.interpolation = 3
90 | args.crop_pct = 0.875
91 |
92 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
93 |
94 | # if args.warmup_model_dir is not None:
95 | # args.logger.info(f'Loading weights from {args.warmup_model_dir}')
96 | # backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
97 |
98 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
99 | args.image_size = 224
100 | args.feat_dim = 768
101 | args.num_mlp_layers = 3
102 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
103 |
104 |
105 | # --------------------
106 | # CONTRASTIVE TRANSFORM
107 | # --------------------
108 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
109 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
110 | # --------------------
111 | # DATASETS
112 | # --------------------
113 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets_ = get_datasets(args.dataset_name,
114 | train_transform,
115 | test_transform,
116 | args)
117 |
118 | # --------------------
119 | # SAMPLER
120 | # Sampler which balances labelled and unlabelled examples in each batch
121 | # --------------------
122 | label_len = len(train_dataset.labelled_dataset)
123 | unlabelled_len = len(train_dataset.unlabelled_dataset)
124 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
125 | sample_weights = torch.DoubleTensor(sample_weights)
126 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
127 |
128 | # --------------------
129 | # DATALOADERS
130 | # --------------------
131 | # train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
132 | # sampler=sampler, drop_last=True, pin_memory=True)
133 | # test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
134 | # batch_size=256, shuffle=False, pin_memory=False)
135 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
136 | batch_size=256, shuffle=False, pin_memory=False)
137 |
138 | OOD_NUM_EXAMPLES = len(test_dataset) // 5 # NOTE! NOT test_loader_labelled!
139 | print(OOD_NUM_EXAMPLES)
140 |
141 | # ----------------------
142 | # PROJECTION HEAD
143 | # ----------------------
144 | projector = DINOHead_k(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers,
145 | init_prototypes=None, num_labeled_classes=args.num_labeled_classes)
146 | model = nn.Sequential(backbone, projector).to(device)
147 |
148 | ckpts_base_path = '/lustre/home/sjma/GCD-project/protoGCD-v7/dev_outputs_fix/'
149 | ckpts_path = os.path.join(ckpts_base_path, args.dataset_name, args.ckpts_date, 'checkpoints', 'model_best.pt')
150 | ckpts = torch.load(ckpts_path)
151 | ckpts = ckpts['model']
152 | print('loading ckpts from %s...' % ckpts_path)
153 | model.load_state_dict(ckpts)
154 | print('successfully load ckpts')
155 | model.eval()
156 |
157 |
158 | # ----------------------
159 | # TEST OOD
160 | # ----------------------
161 | print('Using %s as typical data' % args.dataset_name)
162 | with open(args.ood_log_path, 'w+') as f_log:
163 | f_log.write('Using %s as typical data' % args.dataset_name)
164 | f_log.write('\n')
165 |
166 | print(test_transform)
167 |
168 | # ID score
169 | #test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
170 | in_score, right_score, wrong_score = get_ood_scores_in(test_loader_labelled, model, args)
171 |
172 |
173 | # Textures
174 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/dtd/images", transform=test_transform)
175 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
176 | print('\n\nTexture Detection')
177 | with open(args.ood_log_path, 'a+') as f_log:
178 | f_log.write('\n\nTexture Detection')
179 | f_log.write('\n')
180 | get_and_print_results(ood_loader, model, in_score, args)
181 |
182 |
183 | # Places
184 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/places365", transform=test_transform)
185 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
186 | print('\n\nPlaces Detection')
187 | with open(args.ood_log_path, 'a+') as f_log:
188 | f_log.write('\n\nPlaces Detection')
189 | f_log.write('\n')
190 | get_and_print_results(ood_loader, model, in_score, args)
191 |
192 |
193 | # iNaturalist
194 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/iNaturalist/", transform=test_transform)
195 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
196 | print('\n\niNaturalist Detection')
197 | with open(args.ood_log_path, 'a+') as f_log:
198 | f_log.write('\n\niNaturalist Detection')
199 | f_log.write('\n')
200 | get_and_print_results(ood_loader, model, in_score, args)
201 |
202 |
203 | # ImageNet-O
204 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/ImageNet-O/", transform=test_transform)
205 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
206 | print('\n\nImageNet-O Detection')
207 | with open(args.ood_log_path, 'a+') as f_log:
208 | f_log.write('\n\nImageNet-O Detection')
209 | f_log.write('\n')
210 | get_and_print_results(ood_loader, model, in_score, args)
211 |
212 |
213 | # OpenImage-O
214 | ood_data = datasets.ImageFolder(root="/data4/sjma/dataset/OOD/OpenImage-O/", transform=test_transform)
215 | ood_loader = torch.utils.data.DataLoader(ood_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
216 | print('\n\nOpenImage-O Detection')
217 | with open(args.ood_log_path, 'a+') as f_log:
218 | f_log.write('\n\nOpenImage-O Detection')
219 | f_log.write('\n')
220 | get_and_print_results(ood_loader, model, in_score, args)
221 |
222 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import math
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.optim import SGD, lr_scheduler
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from data.augmentations import get_transform
14 | from data.get_datasets import get_datasets, get_class_splits
15 |
16 | from my_utils.general_utils import AverageMeter, init_experiment
17 | from my_utils.cluster_and_log_utils import log_accs_from_preds
18 | from config import exp_root
19 |
20 | from models.model import DINOHead
21 | from models.model import ContrastiveLearningViewGenerator, get_params_groups
22 | from models.loss import info_nce_logits, SupConLoss, DistillLoss_ratio, prototype_separation_loss, entropy_regularization_loss
23 |
24 |
25 |
26 |
27 | def train(student, train_loader, test_loader, unlabelled_train_loader, args):
28 | params_groups = get_params_groups(student)
29 | optimizer = SGD(params_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
30 |
31 | exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
32 | optimizer,
33 | T_max=args.epochs,
34 | eta_min=args.lr * 1e-3,
35 | )
36 |
37 |
38 | distill_criterion = DistillLoss_ratio(num_classes=args.num_labeled_classes + args.num_unlabeled_classes,
39 | wait_ratio_epochs=args.wait_ratio_epochs,
40 | ramp_ratio_teacher_epochs=args.ramp_ratio_teacher_epochs,
41 | nepochs=args.epochs,
42 | ncrops=args.n_views,
43 | init_ratio=args.init_ratio,
44 | final_ratio=args.final_ratio,
45 | temp_logits=args.temp_logits,
46 | temp_teacher_logits=args.temp_teacher_logits,
47 | device=device)
48 |
49 | # inductive
50 | #best_test_acc_ubl = 0
51 | best_test_acc_lab = 0
52 | # transductive
53 | best_train_acc_lab = 0
54 | best_train_acc_ubl = 0
55 | best_train_acc_all = 0
56 |
57 | for epoch in range(args.epochs):
58 | loss_record = AverageMeter()
59 |
60 | student.train()
61 | for batch_idx, batch in enumerate(train_loader):
62 | images, class_labels, uq_idxs, mask_lab = batch
63 | mask_lab = mask_lab[:, 0]
64 |
65 | class_labels, mask_lab = class_labels.cuda(non_blocking=True), mask_lab.cuda(non_blocking=True).bool()
66 | images = torch.cat(images, dim=0).cuda(non_blocking=True)
67 |
68 |
69 | student_proj, student_out, prototypes = student(images)
70 | teacher_out = student_out.detach()
71 |
72 | # clustering, sup
73 | sup_logits = torch.cat([f[mask_lab] for f in (student_out / args.temp_logits).chunk(2)], dim=0)
74 | sup_labels = torch.cat([class_labels[mask_lab] for _ in range(2)], dim=0)
75 | cls_loss = nn.CrossEntropyLoss()(sup_logits, sup_labels)
76 |
77 | # clustering, unsup
78 | cluster_loss = 0
79 | distill_loss = distill_criterion(student_out, teacher_out, epoch) # NOTE!!! all data
80 | cluster_loss += distill_loss
81 |
82 | entropy_reg_loss = entropy_regularization_loss(student_out, args.temp_logits)
83 | cluster_loss += args.weight_entropy_reg * entropy_reg_loss
84 |
85 | proto_sep_loss = prototype_separation_loss(prototypes=prototypes, temperature=args.temp_logits, device=device)
86 | cluster_loss += args.weight_proto_sep * proto_sep_loss
87 |
88 | # represent learning, unsup
89 | contrastive_logits, contrastive_labels = info_nce_logits(features=student_proj, temperature=args.temp_unsup_con)
90 | contrastive_loss = torch.nn.CrossEntropyLoss()(contrastive_logits, contrastive_labels)
91 |
92 | # representation learning, sup
93 | student_proj = torch.cat([f[mask_lab].unsqueeze(1) for f in student_proj.chunk(2)], dim=1)
94 | student_proj = F.normalize(student_proj, dim=-1)
95 | sup_con_labels = class_labels[mask_lab]
96 | sup_con_loss = SupConLoss(temperature=args.temp_sup_con)(student_proj, labels=sup_con_labels)
97 |
98 | pstr = ''
99 | pstr += f'cls_loss: {cls_loss.item():.4f} '
100 | pstr += f'cluster_loss: {cluster_loss.item():.4f} '
101 | pstr += f'distill_loss: {distill_loss.item():.4f} '
102 | pstr += f'entropy_reg_loss: {entropy_reg_loss.item():.4f} '
103 | pstr += f'proto_sep_loss: {proto_sep_loss.item():.4f} '
104 | pstr += f'sup_con_loss: {sup_con_loss.item():.4f} '
105 | pstr += f'contrastive_loss: {contrastive_loss.item():.4f} '
106 |
107 | loss = 0
108 | loss += (1 - args.weight_sup) * cluster_loss + args.weight_sup * cls_loss
109 | loss += (1 - args.weight_sup) * contrastive_loss + args.weight_sup * sup_con_loss
110 |
111 | # Train acc
112 | loss_record.update(loss.item(), class_labels.size(0))
113 | optimizer.zero_grad()
114 | loss.backward()
115 | optimizer.step()
116 |
117 | if batch_idx % args.print_freq == 0:
118 | args.logger.info('Epoch: [{}][{}/{}]\t loss {:.5f}\t {}'
119 | .format(epoch, batch_idx, len(train_loader), loss.item(), pstr))
120 |
121 | args.logger.info('Train Epoch: {} Avg Loss: {:.4f} '.format(epoch, loss_record.avg))
122 |
123 | args.logger.info('Testing on unlabelled examples in the training data...')
124 | all_acc, old_acc, new_acc = test(student, unlabelled_train_loader, epoch=epoch, save_name='Train ACC Unlabelled', args=args)
125 | args.logger.info('Testing on disjoint test set...')
126 | all_acc_test, old_acc_test, new_acc_test = test(student, test_loader, epoch=epoch, save_name='Test ACC', args=args)
127 |
128 |
129 | args.logger.info('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
130 | args.logger.info('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))
131 |
132 | # Step schedule
133 | exp_lr_scheduler.step()
134 |
135 | save_dict = {
136 | 'model': student.state_dict(),
137 | 'optimizer': optimizer.state_dict(),
138 | 'epoch': epoch + 1,
139 | }
140 |
141 | torch.save(save_dict, args.model_path)
142 | args.logger.info("model saved to {}.".format(args.model_path))
143 |
144 | #if new_acc_test > best_test_acc_ubl:
145 | #if old_acc_test > best_test_acc_lab and epoch > 100:
146 | if all_acc > best_train_acc_all:
147 |
148 | #args.logger.info(f'Best ACC on old Classes on disjoint test set: {old_acc_test:.4f}...')
149 | args.logger.info(f'Best ACC on all Classes on train set: {all_acc:.4f}...')
150 | args.logger.info('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
151 |
152 | torch.save(save_dict, args.model_path[:-3] + f'_best.pt')
153 | args.logger.info("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
154 |
155 | # inductive
156 | #best_test_acc_ubl = new_acc_test
157 | best_test_acc_lab = old_acc_test
158 | # transductive
159 | best_train_acc_lab = old_acc
160 | best_train_acc_ubl = new_acc
161 | best_train_acc_all = all_acc
162 |
163 | args.logger.info(f'Exp Name: {args.exp_name}')
164 | args.logger.info(f'Metrics with best model on test set: All: {best_train_acc_all:.4f} Old: {best_train_acc_lab:.4f} New: {best_train_acc_ubl:.4f}')
165 |
166 |
167 | def test(model, test_loader, epoch, save_name, args):
168 |
169 | model.eval()
170 |
171 | preds, targets = [], []
172 | mask = np.array([])
173 | for batch_idx, (images, label, _) in enumerate(tqdm(test_loader)):
174 | images = images.cuda(non_blocking=True)
175 | with torch.no_grad():
176 | _, logits, _ = model(images)
177 | preds.append(logits.argmax(1).cpu().numpy())
178 | targets.append(label.cpu().numpy())
179 | mask = np.append(mask, np.array([True if x.item() in range(len(args.train_classes)) else False for x in label]))
180 |
181 | preds = np.concatenate(preds)
182 | targets = np.concatenate(targets)
183 | all_acc, old_acc, new_acc = log_accs_from_preds(y_true=targets, y_pred=preds, mask=mask,
184 | T=epoch, eval_funcs=args.eval_funcs, save_name=save_name,
185 | args=args)
186 |
187 | return all_acc, old_acc, new_acc
188 |
189 |
190 | if __name__ == "__main__":
191 |
192 | parser = argparse.ArgumentParser(description='cluster', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
193 | parser.add_argument('--batch_size', default=128, type=int)
194 | parser.add_argument('--num_workers', default=4, type=int)
195 | parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v2', 'v2p'])
196 |
197 | parser.add_argument('--warmup_model_dir', type=str, default=None)
198 | parser.add_argument('--dataset_name', type=str, default='scars', help='options: cifar10, cifar100, imagenet_100, cub, scars, fgvc_aricraft, herbarium_19')
199 | parser.add_argument('--prop_train_labels', type=float, default=0.5)
200 | parser.add_argument('--use_ssb_splits', action='store_true', default=True)
201 | parser.add_argument('--init_prototypes', action='store_true', default=False)
202 |
203 | parser.add_argument('--grad_from_block', type=int, default=11)
204 | parser.add_argument('--lr', type=float, default=0.1)
205 | parser.add_argument('--gamma', type=float, default=0.1)
206 | parser.add_argument('--momentum', type=float, default=0.9)
207 | parser.add_argument('--weight_decay', type=float, default=1e-4)
208 | parser.add_argument('--epochs', default=200, type=int)
209 | parser.add_argument('--exp_root', type=str, default=exp_root)
210 | parser.add_argument('--transform', type=str, default='imagenet')
211 | parser.add_argument('--n_views', default=2, type=int)
212 |
213 | parser.add_argument('--weight_sup', type=float, default=0.35)
214 | parser.add_argument('--weight_entropy_reg', type=float, default=2)
215 | parser.add_argument('--weight_proto_sep', type=float, default=1)
216 |
217 | parser.add_argument('--temp_logits', default=0.1, type=float, help='cosine similarity of prototypes to classification logits temperature')
218 | parser.add_argument('--temp_teacher_logits', default=0.05, type=float, help='sharpened logits temperature of teacher')
219 | #parser.add_argument('--temp_proto_sep', default=0.1, type=float, help='prototype separation temperature')
220 | parser.add_argument('--temp_sup_con', default=0.07, type=float, help='supervised contrastive loss temperature')
221 | parser.add_argument('--temp_unsup_con', default=1.0, type=float, help='unsupervised contrastive loss temperature')
222 |
223 | parser.add_argument('--wait_ratio_epochs', default=0, type=int, help='Number of warmup epochs for the confidence filter.')
224 | parser.add_argument('--ramp_ratio_teacher_epochs', default=100, type=int, help='Number of warmup epochs for the confidence filter.')
225 |
226 | parser.add_argument('--init_ratio', default=0.2, type=float, help='initial confidence filter ratio')
227 | parser.add_argument('--final_ratio', default=1.0, type=float, help='final confidence filter ratio')
228 |
229 | parser.add_argument('--print_freq', default=10, type=int)
230 | parser.add_argument('--exp_name', default=None, type=str)
231 |
232 | # ----------------------
233 | # INIT
234 | # ----------------------
235 | args = parser.parse_args()
236 | device = torch.device('cuda:0')
237 | args = get_class_splits(args)
238 |
239 | args.num_labeled_classes = len(args.train_classes)
240 | args.num_unlabeled_classes = len(args.unlabeled_classes)
241 | args.exp_root = 'dev_outputs_fix'
242 |
243 | init_experiment(args, runner_name=['ProtoGCD'])
244 | args.logger.info(f'Using evaluation function {args.eval_funcs[0]} to print results')
245 |
246 | torch.backends.cudnn.benchmark = True
247 |
248 | # ----------------------
249 | # BASE MODEL
250 | # ----------------------
251 | args.interpolation = 3
252 | args.crop_pct = 0.875
253 |
254 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
255 |
256 | if args.warmup_model_dir is not None:
257 | args.logger.info(f'Loading weights from {args.warmup_model_dir}')
258 | backbone.load_state_dict(torch.load(args.warmup_model_dir, map_location='cpu'))
259 |
260 | # NOTE: Hardcoded image size as we do not finetune the entire ViT model
261 | args.image_size = 224
262 | args.feat_dim = 768
263 | args.num_mlp_layers = 3
264 | args.mlp_out_dim = args.num_labeled_classes + args.num_unlabeled_classes
265 |
266 | # ----------------------
267 | # HOW MUCH OF BASE MODEL TO FINETUNE
268 | # ----------------------
269 | for m in backbone.parameters():
270 | m.requires_grad = False
271 |
272 | # Only finetune layers from block 'args.grad_from_block' onwards
273 | for name, m in backbone.named_parameters():
274 | if 'block' in name:
275 | block_num = int(name.split('.')[1])
276 | if block_num >= args.grad_from_block:
277 | m.requires_grad = True
278 |
279 |
280 | args.logger.info('model build')
281 |
282 | # --------------------
283 | # CONTRASTIVE TRANSFORM
284 | # --------------------
285 | train_transform, test_transform = get_transform(args.transform, image_size=args.image_size, args=args)
286 | train_transform = ContrastiveLearningViewGenerator(base_transform=train_transform, n_views=args.n_views)
287 | # --------------------
288 | # DATASETS
289 | # --------------------
290 | train_dataset, test_dataset, unlabelled_train_examples_test, datasets = get_datasets(args.dataset_name,
291 | train_transform,
292 | test_transform,
293 | args)
294 |
295 | # --------------------
296 | # SAMPLER
297 | # Sampler which balances labelled and unlabelled examples in each batch
298 | # --------------------
299 | label_len = len(train_dataset.labelled_dataset)
300 | unlabelled_len = len(train_dataset.unlabelled_dataset)
301 | sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(train_dataset))]
302 | sample_weights = torch.DoubleTensor(sample_weights)
303 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(train_dataset))
304 |
305 | # --------------------
306 | # DATALOADERS
307 | # --------------------
308 | train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False,
309 | sampler=sampler, drop_last=True, pin_memory=True)
310 | test_loader_unlabelled = DataLoader(unlabelled_train_examples_test, num_workers=args.num_workers,
311 | batch_size=256, shuffle=False, pin_memory=False)
312 | test_loader_labelled = DataLoader(test_dataset, num_workers=args.num_workers,
313 | batch_size=256, shuffle=False, pin_memory=False)
314 |
315 | # --------------------
316 | # Initialize prototypes
317 | # --------------------
318 | prototypes_init = None
319 | if args.init_prototypes:
320 | prototype_init_path = './init_prototypes/%s_prototypes.pt' % args.dataset_name
321 | print('load initialized prototypes from: %s' % prototype_init_path)
322 | prototypes_init = torch.load(prototype_init_path)
323 |
324 | # ----------------------
325 | # PROJECTION HEAD
326 | # ----------------------
327 | projector = DINOHead(in_dim=args.feat_dim, out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers,
328 | init_prototypes=prototypes_init, num_labeled_classes=args.num_labeled_classes)
329 | model = nn.Sequential(backbone, projector).to(device)
330 |
331 | # ----------------------
332 | # TRAIN
333 | # ----------------------
334 | train(model, train_loader, test_loader_labelled, test_loader_unlabelled, args)
335 |
--------------------------------------------------------------------------------